-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathserver.py
More file actions
137 lines (113 loc) · 4.67 KB
/
server.py
File metadata and controls
137 lines (113 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Placeholder docerting"""
from __future__ import absolute_import
import uuid
import logging
import importlib
import platform
from sagemaker.core import fw_utils
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.common_utils import _is_s3_uri
from sagemaker.serve.utils.uploader import upload
from sagemaker.core.s3.utils import determine_bucket_and_prefix, parse_s3_url
from sagemaker.core.local.local_session import get_docker_host
import docker
from docker.types import DeviceRequest
logger = logging.getLogger(__name__)
# TODO: automatically update memory size
_SHM_SIZE = "2G"
class LocalTritonServer:
"""Placeholder docstring"""
def __init__(self) -> None:
self.triton_client = None
def _start_triton_server(
self,
docker_client: docker.DockerClient,
model_path: str,
secret_key: str,
image_uri: str,
env_vars: dict,
):
"""Placeholder docstring"""
self.container_name = "triton" + uuid.uuid1().hex
model_repository = model_path + "/model_repository"
env_vars.update(
{
"TRITON_MODEL_DIR": "/models/model",
"LOCAL_PYTHON": platform.python_version(),
}
)
if "cpu" not in image_uri:
self.container = docker_client.containers.run(
image=image_uri,
command=["tritonserver", "--model-repository=/models"],
shm_size=_SHM_SIZE,
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
network_mode="host",
detach=True,
auto_remove=True,
volumes={model_repository: {"bind": "/models", "mode": "rw"}},
environment=env_vars,
)
else:
self.container = docker_client.containers.run(
image=image_uri,
command=["tritonserver", "--model-repository=/models"],
shm_size=_SHM_SIZE,
network_mode="host",
detach=True,
auto_remove=True,
volumes={model_repository: {"bind": "/models", "mode": "rw"}},
environment=env_vars,
)
def _invoke_triton_server(self, payload, *args, **kwargs):
"""Placeholder docstring"""
httpClient = importlib.import_module("tritonclient.http")
if not self.triton_client:
self.triton_client = httpClient.InferenceServerClient(url=f"{get_docker_host()}:8000")
payload = self.schema_builder.input_serializer.serialize(payload)
dtype = self.schema_builder._input_triton_dtype.split("_")[-1]
input_request = httpClient.InferInput("input_1", payload.shape, datatype=dtype)
input_request.set_data_from_numpy(payload, binary_data=True)
response = self.triton_client.infer(model_name="model", inputs=[input_request])
response_name = response.get_response().get("outputs")[0].get("name")
return self.schema_builder.output_deserializer.deserialize(response.as_numpy(response_name))
class SageMakerTritonServer:
"""Placeholder docstring"""
def __init__(self) -> None:
pass
def _upload_triton_artifacts(
self,
model_path: str,
sagemaker_session: Session,
secret_key: str,
s3_model_data_url: str = None,
image: str = None,
should_upload_artifacts: bool = False,
):
"""Tar triton artifacts and upload to s3"""
s3_upload_path = None
if _is_s3_uri(model_path):
s3_upload_path = model_path
elif should_upload_artifacts:
if s3_model_data_url:
bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
else:
bucket, key_prefix = None, None
code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
bucket, code_key_prefix = determine_bucket_and_prefix(
bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
)
logger.debug(
"Uploading the model resources to bucket=%s, key_prefix=%s.",
bucket,
code_key_prefix,
)
model_repository = model_path + "/model_repository"
s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix)
logger.debug("Model resources uploaded to: %s", s3_upload_path)
env_vars = {
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
"TRITON_MODEL_DIR": "/opt/ml/model/model",
"LOCAL_PYTHON": platform.python_version(),
}
return s3_upload_path, env_vars