Skip to content

Commit ab7b3c8

Browse files
committed
fix: address review comments (iteration #1)
1 parent 22e4363 commit ab7b3c8

File tree

1 file changed

+80
-41
lines changed

1 file changed

+80
-41
lines changed

sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py

Lines changed: 80 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import pytest
88

9-
from sagemaker.serve.model_builder_servers import _ModelBuilderServers
9+
from sagemaker.serve.model_builder_servers import (
10+
_ModelBuilderServers,
11+
)
1012
from sagemaker.serve.utils.types import ModelServer
1113
from sagemaker.serve.mode.function_pointers import Mode
1214

@@ -22,7 +24,9 @@ def _create_mock_builder(
2224
"""Create a mock builder with common attributes set."""
2325
builder = MagicMock(spec=_ModelBuilderServers)
2426
builder.model = model
25-
builder.env_vars = env_vars if env_vars is not None else {}
27+
builder.env_vars = (
28+
env_vars if env_vars is not None else {}
29+
)
2630
builder.model_path = "/tmp/test_model_path"
2731
builder.mode = Mode.SAGEMAKER_ENDPOINT
2832
builder.model_server = ModelServer.DJL_SERVING
@@ -43,7 +47,9 @@ def _create_mock_builder(
4347
builder.hf_model_config = {}
4448
builder.model_data_download_timeout = None
4549
builder._user_provided_instance_type = True
46-
builder._is_jumpstart_model_id = Mock(return_value=False)
50+
builder._is_jumpstart_model_id = Mock(
51+
return_value=False
52+
)
4753
builder._auto_detect_image_uri = Mock()
4854
builder._prepare_for_mode = Mock(
4955
return_value=("s3://model-data", None)
@@ -68,25 +74,27 @@ def _create_mock_builder(
6874

6975
@pytest.fixture
7076
def mock_builder() -> MagicMock:
71-
"""Create a mock builder with default (empty) env_vars."""
77+
"""Create a mock builder with default env_vars."""
7278
return _create_mock_builder(env_vars={})
7379

7480

7581
@pytest.fixture
7682
def mock_builder_with_s3() -> MagicMock:
77-
"""Create a mock builder with user-provided S3 HF_MODEL_ID."""
83+
"""Mock builder with user-provided S3 HF_MODEL_ID."""
7884
return _create_mock_builder(
7985
env_vars={"HF_MODEL_ID": S3_PATH}
8086
)
8187

8288

83-
# -- Patch sets for each server type ----------------------------------
89+
# -- Patch targets for each server type --------------------------
8490

8591
_DJL_PATCHES: List[str] = [
8692
"sagemaker.serve.model_builder_servers"
8793
"._get_default_tensor_parallel_degree",
88-
"sagemaker.serve.model_builder_servers._get_gpu_info",
89-
"sagemaker.serve.model_builder_servers._get_nb_instance",
94+
"sagemaker.serve.model_builder_servers"
95+
"._get_gpu_info",
96+
"sagemaker.serve.model_builder_servers"
97+
"._get_nb_instance",
9098
"sagemaker.serve.model_builder_servers"
9199
"._get_default_djl_configurations",
92100
"sagemaker.serve.model_builder_servers"
@@ -96,19 +104,21 @@ def mock_builder_with_s3() -> MagicMock:
96104
]
97105

98106
_DJL_RETURN_VALUES = [
99-
1, # tensor_parallel_degree
100-
1, # gpu_info
101-
None, # nb_instance
107+
1, # tensor_parallel_degree
108+
1, # gpu_info
109+
None, # nb_instance
102110
({}, 256), # djl_configurations
103-
{}, # hf_model_config
104-
None, # _create_dir_structure
111+
{}, # hf_model_config
112+
None, # _create_dir_structure
105113
]
106114

107115
_TGI_PATCHES: List[str] = [
108116
"sagemaker.serve.model_builder_servers"
109117
"._get_default_tensor_parallel_degree",
110-
"sagemaker.serve.model_builder_servers._get_gpu_info",
111-
"sagemaker.serve.model_builder_servers._get_nb_instance",
118+
"sagemaker.serve.model_builder_servers"
119+
"._get_gpu_info",
120+
"sagemaker.serve.model_builder_servers"
121+
"._get_nb_instance",
112122
"sagemaker.serve.model_builder_servers"
113123
"._get_default_tgi_configurations",
114124
"sagemaker.serve.model_builder_servers"
@@ -118,16 +128,17 @@ def mock_builder_with_s3() -> MagicMock:
118128
]
119129

120130
_TGI_RETURN_VALUES = [
121-
1, # tensor_parallel_degree
122-
1, # gpu_info
123-
None, # nb_instance
131+
1, # tensor_parallel_degree
132+
1, # gpu_info
133+
None, # nb_instance
124134
({}, 256), # tgi_configurations
125-
{}, # hf_model_config
126-
None, # _create_dir_structure
135+
{}, # hf_model_config
136+
None, # _create_dir_structure
127137
]
128138

129139
_TEI_PATCHES: List[str] = [
130-
"sagemaker.serve.model_builder_servers._get_nb_instance",
140+
"sagemaker.serve.model_builder_servers"
141+
"._get_nb_instance",
131142
"sagemaker.serve.model_builder_servers"
132143
"._get_model_config_properties_from_hf",
133144
"sagemaker.serve.model_server.tgi"
@@ -140,8 +151,21 @@ def mock_builder_with_s3() -> MagicMock:
140151
None, # _create_dir_structure
141152
]
142153

154+
_TORCHSERVE_PATCHES: List[str] = [
155+
"sagemaker.serve.model_builder_servers"
156+
".prepare_for_torchserve",
157+
]
158+
159+
_TORCHSERVE_RETURN_VALUES = [
160+
"mock-secret-key", # prepare_for_torchserve
161+
]
162+
163+
_TRITON_PATCHES: List[str] = []
164+
_TRITON_RETURN_VALUES: list = []
165+
143166
_MMS_PATCHES: List[str] = [
144-
"sagemaker.serve.model_builder_servers._get_nb_instance",
167+
"sagemaker.serve.model_builder_servers"
168+
"._get_nb_instance",
145169
"sagemaker.serve.model_builder_servers"
146170
"._get_model_config_properties_from_hf",
147171
"sagemaker.serve.model_server.multi_model_server"
@@ -201,14 +225,14 @@ def _stop_patches(patchers: List) -> None:
201225
(
202226
"_build_for_torchserve",
203227
ModelServer.TORCHSERVE,
204-
[],
205-
[],
228+
_TORCHSERVE_PATCHES,
229+
_TORCHSERVE_RETURN_VALUES,
206230
),
207231
(
208232
"_build_for_triton",
209233
ModelServer.TRITON,
210-
[],
211-
[],
234+
_TRITON_PATCHES,
235+
_TRITON_RETURN_VALUES,
212236
),
213237
],
214238
ids=[
@@ -231,7 +255,9 @@ def test_preserves_user_provided_hf_model_id(
231255
builder.model_server = server_type
232256
patchers = _apply_patches(patch_targets, patch_rvs)
233257
try:
234-
getattr(_ModelBuilderServers, build_method)(builder)
258+
getattr(
259+
_ModelBuilderServers, build_method
260+
)(builder)
235261
finally:
236262
_stop_patches(patchers)
237263
assert builder.env_vars["HF_MODEL_ID"] == S3_PATH
@@ -261,14 +287,14 @@ def test_preserves_user_provided_hf_model_id(
261287
(
262288
"_build_for_torchserve",
263289
ModelServer.TORCHSERVE,
264-
[],
265-
[],
290+
_TORCHSERVE_PATCHES,
291+
_TORCHSERVE_RETURN_VALUES,
266292
),
267293
(
268294
"_build_for_triton",
269295
ModelServer.TRITON,
270-
[],
271-
[],
296+
_TRITON_PATCHES,
297+
_TRITON_RETURN_VALUES,
272298
),
273299
],
274300
ids=[
@@ -291,21 +317,25 @@ def test_sets_default_hf_model_id_when_not_provided(
291317
builder.model_server = server_type
292318
patchers = _apply_patches(patch_targets, patch_rvs)
293319
try:
294-
getattr(_ModelBuilderServers, build_method)(builder)
320+
getattr(
321+
_ModelBuilderServers, build_method
322+
)(builder)
295323
finally:
296324
_stop_patches(patchers)
297325
assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL
298326

299327

300328
# ---------------------------------------------------------------
301-
# Transformers (MMS) — needs extra patches for _create_dir_structure
329+
# Transformers (MMS) — needs extra patches
302330
# ---------------------------------------------------------------
303331
class TestBuildForTransformersHfModelId:
304-
"""_build_for_transformers preserves user-provided HF_MODEL_ID."""
332+
"""_build_for_transformers preserves HF_MODEL_ID."""
305333

306334
def test_preserves_user_provided_s3_uri(
307-
self, mock_builder_with_s3: MagicMock
335+
self,
336+
mock_builder_with_s3: MagicMock,
308337
) -> None:
338+
"""User S3 URI is preserved."""
309339
builder = mock_builder_with_s3
310340
builder.model_server = ModelServer.MMS
311341
patchers = _apply_patches(
@@ -320,8 +350,10 @@ def test_preserves_user_provided_s3_uri(
320350
assert builder.env_vars["HF_MODEL_ID"] == S3_PATH
321351

322352
def test_sets_default_when_not_provided(
323-
self, mock_builder: MagicMock
353+
self,
354+
mock_builder: MagicMock,
324355
) -> None:
356+
"""HF_MODEL_ID defaults to self.model."""
325357
builder = mock_builder
326358
builder.model_server = ModelServer.MMS
327359
patchers = _apply_patches(
@@ -333,9 +365,13 @@ def test_sets_default_when_not_provided(
333365
)
334366
finally:
335367
_stop_patches(patchers)
336-
assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL
368+
assert (
369+
builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL
370+
)
337371

338-
@patch("sagemaker.serve.model_builder_servers.save_pkl")
372+
@patch(
373+
"sagemaker.serve.model_builder_servers.save_pkl"
374+
)
339375
@patch(
340376
"sagemaker.serve.model_builder_servers"
341377
"._get_model_config_properties_from_hf",
@@ -347,7 +383,8 @@ def test_sets_default_when_not_provided(
347383
return_value=None,
348384
)
349385
@patch(
350-
"sagemaker.serve.model_server.multi_model_server"
386+
"sagemaker.serve.model_server"
387+
".multi_model_server"
351388
".prepare._create_dir_structure",
352389
)
353390
@patch("os.makedirs")
@@ -359,7 +396,7 @@ def test_preserves_with_inference_spec(
359396
_mock_hf: Mock,
360397
_mock_pkl: Mock,
361398
) -> None:
362-
"""User-provided HF_MODEL_ID preserved with inference_spec."""
399+
"""User HF_MODEL_ID preserved with inference_spec."""
363400
builder = _create_mock_builder(
364401
env_vars={"HF_MODEL_ID": S3_PATH}
365402
)
@@ -373,5 +410,7 @@ def test_preserves_with_inference_spec(
373410
builder._is_jumpstart_model_id = Mock(
374411
return_value=False
375412
)
376-
_ModelBuilderServers._build_for_transformers(builder)
413+
_ModelBuilderServers._build_for_transformers(
414+
builder
415+
)
377416
assert builder.env_vars["HF_MODEL_ID"] == S3_PATH

0 commit comments

Comments
 (0)