Skip to content

Commit 3050387

Browse files
committed
fix: address review comments (iteration #2)
1 parent ab7b3c8 commit 3050387

File tree

1 file changed

+105
-139
lines changed

1 file changed

+105
-139
lines changed

sagemaker-serve/tests/unit/test_model_builder_servers_hf_model_id.py

Lines changed: 105 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616
S3_PATH = "s3://my-bucket/models/Qwen/"
1717
DEFAULT_MODEL = "Qwen/Qwen3-VL-4B-Instruct"
1818

19+
_MOD = "sagemaker.serve.model_builder_servers"
20+
_DJL_PREP = (
21+
"sagemaker.serve.model_server"
22+
".djl_serving.prepare._create_dir_structure"
23+
)
24+
_TGI_PREP = (
25+
"sagemaker.serve.model_server"
26+
".tgi.prepare._create_dir_structure"
27+
)
28+
_MMS_PREP = (
29+
"sagemaker.serve.model_server"
30+
".multi_model_server.prepare._create_dir_structure"
31+
)
32+
1933

2034
def _create_mock_builder(
2135
env_vars: Optional[Dict[str, str]] = None,
@@ -54,7 +68,9 @@ def _create_mock_builder(
5468
builder._prepare_for_mode = Mock(
5569
return_value=("s3://model-data", None)
5670
)
57-
builder._create_model = Mock(return_value=Mock())
71+
builder._create_model = Mock(
72+
return_value=Mock()
73+
)
5874
builder._optimizing = False
5975
builder._validate_djl_serving_sample_data = Mock()
6076
builder._validate_tgi_serving_sample_data = Mock()
@@ -86,21 +102,15 @@ def mock_builder_with_s3() -> MagicMock:
86102
)
87103

88104

89-
# -- Patch targets for each server type --------------------------
105+
# -- Patch targets for each server type ----------------------
90106

91107
_DJL_PATCHES: List[str] = [
92-
"sagemaker.serve.model_builder_servers"
93-
"._get_default_tensor_parallel_degree",
94-
"sagemaker.serve.model_builder_servers"
95-
"._get_gpu_info",
96-
"sagemaker.serve.model_builder_servers"
97-
"._get_nb_instance",
98-
"sagemaker.serve.model_builder_servers"
99-
"._get_default_djl_configurations",
100-
"sagemaker.serve.model_builder_servers"
101-
"._get_model_config_properties_from_hf",
102-
"sagemaker.serve.model_server.djl_serving"
103-
".prepare._create_dir_structure",
108+
f"{_MOD}._get_default_tensor_parallel_degree",
109+
f"{_MOD}._get_gpu_info",
110+
f"{_MOD}._get_nb_instance",
111+
f"{_MOD}._get_default_djl_configurations",
112+
f"{_MOD}._get_model_config_properties_from_hf",
113+
_DJL_PREP,
104114
]
105115

106116
_DJL_RETURN_VALUES = [
@@ -113,18 +123,12 @@ def mock_builder_with_s3() -> MagicMock:
113123
]
114124

115125
_TGI_PATCHES: List[str] = [
116-
"sagemaker.serve.model_builder_servers"
117-
"._get_default_tensor_parallel_degree",
118-
"sagemaker.serve.model_builder_servers"
119-
"._get_gpu_info",
120-
"sagemaker.serve.model_builder_servers"
121-
"._get_nb_instance",
122-
"sagemaker.serve.model_builder_servers"
123-
"._get_default_tgi_configurations",
124-
"sagemaker.serve.model_builder_servers"
125-
"._get_model_config_properties_from_hf",
126-
"sagemaker.serve.model_server.tgi"
127-
".prepare._create_dir_structure",
126+
f"{_MOD}._get_default_tensor_parallel_degree",
127+
f"{_MOD}._get_gpu_info",
128+
f"{_MOD}._get_nb_instance",
129+
f"{_MOD}._get_default_tgi_configurations",
130+
f"{_MOD}._get_model_config_properties_from_hf",
131+
_TGI_PREP,
128132
]
129133

130134
_TGI_RETURN_VALUES = [
@@ -137,12 +141,9 @@ def mock_builder_with_s3() -> MagicMock:
137141
]
138142

139143
_TEI_PATCHES: List[str] = [
140-
"sagemaker.serve.model_builder_servers"
141-
"._get_nb_instance",
142-
"sagemaker.serve.model_builder_servers"
143-
"._get_model_config_properties_from_hf",
144-
"sagemaker.serve.model_server.tgi"
145-
".prepare._create_dir_structure",
144+
f"{_MOD}._get_nb_instance",
145+
f"{_MOD}._get_model_config_properties_from_hf",
146+
_TGI_PREP,
146147
]
147148

148149
_TEI_RETURN_VALUES = [
@@ -152,24 +153,20 @@ def mock_builder_with_s3() -> MagicMock:
152153
]
153154

154155
_TORCHSERVE_PATCHES: List[str] = [
155-
"sagemaker.serve.model_builder_servers"
156-
".prepare_for_torchserve",
156+
f"{_MOD}.prepare_for_torchserve",
157157
]
158158

159159
_TORCHSERVE_RETURN_VALUES = [
160-
"mock-secret-key", # prepare_for_torchserve
160+
"mock-secret-key",
161161
]
162162

163163
_TRITON_PATCHES: List[str] = []
164164
_TRITON_RETURN_VALUES: list = []
165165

166166
_MMS_PATCHES: List[str] = [
167-
"sagemaker.serve.model_builder_servers"
168-
"._get_nb_instance",
169-
"sagemaker.serve.model_builder_servers"
170-
"._get_model_config_properties_from_hf",
171-
"sagemaker.serve.model_server.multi_model_server"
172-
".prepare._create_dir_structure",
167+
f"{_MOD}._get_nb_instance",
168+
f"{_MOD}._get_model_config_properties_from_hf",
169+
_MMS_PREP,
173170
]
174171

175172
_MMS_RETURN_VALUES = [
@@ -198,50 +195,55 @@ def _stop_patches(patchers: List) -> None:
198195
p.stop()
199196

200197

201-
# ---------------------------------------------------------------
198+
# -----------------------------------------------------------
202199
# Parametrised tests: preserve user-provided HF_MODEL_ID
203-
# ---------------------------------------------------------------
200+
# -----------------------------------------------------------
201+
_SERVER_PARAMS = [
202+
(
203+
"_build_for_djl",
204+
ModelServer.DJL_SERVING,
205+
_DJL_PATCHES,
206+
_DJL_RETURN_VALUES,
207+
),
208+
(
209+
"_build_for_tgi",
210+
ModelServer.TGI,
211+
_TGI_PATCHES,
212+
_TGI_RETURN_VALUES,
213+
),
214+
(
215+
"_build_for_tei",
216+
ModelServer.TEI,
217+
_TEI_PATCHES,
218+
_TEI_RETURN_VALUES,
219+
),
220+
(
221+
"_build_for_torchserve",
222+
ModelServer.TORCHSERVE,
223+
_TORCHSERVE_PATCHES,
224+
_TORCHSERVE_RETURN_VALUES,
225+
),
226+
(
227+
"_build_for_triton",
228+
ModelServer.TRITON,
229+
_TRITON_PATCHES,
230+
_TRITON_RETURN_VALUES,
231+
),
232+
]
233+
234+
_SERVER_IDS = [
235+
"djl",
236+
"tgi",
237+
"tei",
238+
"torchserve",
239+
"triton",
240+
]
241+
242+
204243
@pytest.mark.parametrize(
205244
"build_method, server_type, patch_targets, patch_rvs",
206-
[
207-
(
208-
"_build_for_djl",
209-
ModelServer.DJL_SERVING,
210-
_DJL_PATCHES,
211-
_DJL_RETURN_VALUES,
212-
),
213-
(
214-
"_build_for_tgi",
215-
ModelServer.TGI,
216-
_TGI_PATCHES,
217-
_TGI_RETURN_VALUES,
218-
),
219-
(
220-
"_build_for_tei",
221-
ModelServer.TEI,
222-
_TEI_PATCHES,
223-
_TEI_RETURN_VALUES,
224-
),
225-
(
226-
"_build_for_torchserve",
227-
ModelServer.TORCHSERVE,
228-
_TORCHSERVE_PATCHES,
229-
_TORCHSERVE_RETURN_VALUES,
230-
),
231-
(
232-
"_build_for_triton",
233-
ModelServer.TRITON,
234-
_TRITON_PATCHES,
235-
_TRITON_RETURN_VALUES,
236-
),
237-
],
238-
ids=[
239-
"djl",
240-
"tgi",
241-
"tei",
242-
"torchserve",
243-
"triton",
244-
],
245+
_SERVER_PARAMS,
246+
ids=_SERVER_IDS,
245247
)
246248
def test_preserves_user_provided_hf_model_id(
247249
build_method: str,
@@ -265,45 +267,8 @@ def test_preserves_user_provided_hf_model_id(
265267

266268
@pytest.mark.parametrize(
267269
"build_method, server_type, patch_targets, patch_rvs",
268-
[
269-
(
270-
"_build_for_djl",
271-
ModelServer.DJL_SERVING,
272-
_DJL_PATCHES,
273-
_DJL_RETURN_VALUES,
274-
),
275-
(
276-
"_build_for_tgi",
277-
ModelServer.TGI,
278-
_TGI_PATCHES,
279-
_TGI_RETURN_VALUES,
280-
),
281-
(
282-
"_build_for_tei",
283-
ModelServer.TEI,
284-
_TEI_PATCHES,
285-
_TEI_RETURN_VALUES,
286-
),
287-
(
288-
"_build_for_torchserve",
289-
ModelServer.TORCHSERVE,
290-
_TORCHSERVE_PATCHES,
291-
_TORCHSERVE_RETURN_VALUES,
292-
),
293-
(
294-
"_build_for_triton",
295-
ModelServer.TRITON,
296-
_TRITON_PATCHES,
297-
_TRITON_RETURN_VALUES,
298-
),
299-
],
300-
ids=[
301-
"djl",
302-
"tgi",
303-
"tei",
304-
"torchserve",
305-
"triton",
306-
],
270+
_SERVER_PARAMS,
271+
ids=_SERVER_IDS,
307272
)
308273
def test_sets_default_hf_model_id_when_not_provided(
309274
build_method: str,
@@ -322,12 +287,14 @@ def test_sets_default_hf_model_id_when_not_provided(
322287
)(builder)
323288
finally:
324289
_stop_patches(patchers)
325-
assert builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL
290+
assert (
291+
builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL
292+
)
326293

327294

328-
# ---------------------------------------------------------------
295+
# -----------------------------------------------------------
329296
# Transformers (MMS) — needs extra patches
330-
# ---------------------------------------------------------------
297+
# -----------------------------------------------------------
331298
class TestBuildForTransformersHfModelId:
332299
"""_build_for_transformers preserves HF_MODEL_ID."""
333300

@@ -347,7 +314,9 @@ def test_preserves_user_provided_s3_uri(
347314
)
348315
finally:
349316
_stop_patches(patchers)
350-
assert builder.env_vars["HF_MODEL_ID"] == S3_PATH
317+
assert (
318+
builder.env_vars["HF_MODEL_ID"] == S3_PATH
319+
)
351320

352321
def test_sets_default_when_not_provided(
353322
self,
@@ -366,27 +335,21 @@ def test_sets_default_when_not_provided(
366335
finally:
367336
_stop_patches(patchers)
368337
assert (
369-
builder.env_vars["HF_MODEL_ID"] == DEFAULT_MODEL
338+
builder.env_vars["HF_MODEL_ID"]
339+
== DEFAULT_MODEL
370340
)
371341

342+
@patch(f"{_MOD}.prepare_for_mms")
343+
@patch(f"{_MOD}.save_pkl")
372344
@patch(
373-
"sagemaker.serve.model_builder_servers.save_pkl"
374-
)
375-
@patch(
376-
"sagemaker.serve.model_builder_servers"
377-
"._get_model_config_properties_from_hf",
345+
f"{_MOD}._get_model_config_properties_from_hf",
378346
return_value={},
379347
)
380348
@patch(
381-
"sagemaker.serve.model_builder_servers"
382-
"._get_nb_instance",
349+
f"{_MOD}._get_nb_instance",
383350
return_value=None,
384351
)
385-
@patch(
386-
"sagemaker.serve.model_server"
387-
".multi_model_server"
388-
".prepare._create_dir_structure",
389-
)
352+
@patch(_MMS_PREP)
390353
@patch("os.makedirs")
391354
def test_preserves_with_inference_spec(
392355
self,
@@ -395,6 +358,7 @@ def test_preserves_with_inference_spec(
395358
_mock_nb: Mock,
396359
_mock_hf: Mock,
397360
_mock_pkl: Mock,
361+
_mock_mms: Mock,
398362
) -> None:
399363
"""User HF_MODEL_ID preserved with inference_spec."""
400364
builder = _create_mock_builder(
@@ -413,4 +377,6 @@ def test_preserves_with_inference_spec(
413377
_ModelBuilderServers._build_for_transformers(
414378
builder
415379
)
416-
assert builder.env_vars["HF_MODEL_ID"] == S3_PATH
380+
assert (
381+
builder.env_vars["HF_MODEL_ID"] == S3_PATH
382+
)

0 commit comments

Comments
 (0)