1616S3_PATH = "s3://my-bucket/models/Qwen/"
1717DEFAULT_MODEL = "Qwen/Qwen3-VL-4B-Instruct"
1818
19+ _MOD = "sagemaker.serve.model_builder_servers"
20+ _DJL_PREP = (
21+ "sagemaker.serve.model_server"
22+ ".djl_serving.prepare._create_dir_structure"
23+ )
24+ _TGI_PREP = (
25+ "sagemaker.serve.model_server"
26+ ".tgi.prepare._create_dir_structure"
27+ )
28+ _MMS_PREP = (
29+ "sagemaker.serve.model_server"
30+ ".multi_model_server.prepare._create_dir_structure"
31+ )
32+
1933
2034def _create_mock_builder (
2135 env_vars : Optional [Dict [str , str ]] = None ,
@@ -54,7 +68,9 @@ def _create_mock_builder(
5468 builder ._prepare_for_mode = Mock (
5569 return_value = ("s3://model-data" , None )
5670 )
57- builder ._create_model = Mock (return_value = Mock ())
71+ builder ._create_model = Mock (
72+ return_value = Mock ()
73+ )
5874 builder ._optimizing = False
5975 builder ._validate_djl_serving_sample_data = Mock ()
6076 builder ._validate_tgi_serving_sample_data = Mock ()
@@ -86,21 +102,15 @@ def mock_builder_with_s3() -> MagicMock:
86102 )
87103
88104
89- # -- Patch targets for each server type --------------------------
105+ # -- Patch targets for each server type ----------------------
90106
91107_DJL_PATCHES : List [str ] = [
92- "sagemaker.serve.model_builder_servers"
93- "._get_default_tensor_parallel_degree" ,
94- "sagemaker.serve.model_builder_servers"
95- "._get_gpu_info" ,
96- "sagemaker.serve.model_builder_servers"
97- "._get_nb_instance" ,
98- "sagemaker.serve.model_builder_servers"
99- "._get_default_djl_configurations" ,
100- "sagemaker.serve.model_builder_servers"
101- "._get_model_config_properties_from_hf" ,
102- "sagemaker.serve.model_server.djl_serving"
103- ".prepare._create_dir_structure" ,
108+ f"{ _MOD } ._get_default_tensor_parallel_degree" ,
109+ f"{ _MOD } ._get_gpu_info" ,
110+ f"{ _MOD } ._get_nb_instance" ,
111+ f"{ _MOD } ._get_default_djl_configurations" ,
112+ f"{ _MOD } ._get_model_config_properties_from_hf" ,
113+ _DJL_PREP ,
104114]
105115
106116_DJL_RETURN_VALUES = [
@@ -113,18 +123,12 @@ def mock_builder_with_s3() -> MagicMock:
113123]
114124
115125_TGI_PATCHES : List [str ] = [
116- "sagemaker.serve.model_builder_servers"
117- "._get_default_tensor_parallel_degree" ,
118- "sagemaker.serve.model_builder_servers"
119- "._get_gpu_info" ,
120- "sagemaker.serve.model_builder_servers"
121- "._get_nb_instance" ,
122- "sagemaker.serve.model_builder_servers"
123- "._get_default_tgi_configurations" ,
124- "sagemaker.serve.model_builder_servers"
125- "._get_model_config_properties_from_hf" ,
126- "sagemaker.serve.model_server.tgi"
127- ".prepare._create_dir_structure" ,
126+ f"{ _MOD } ._get_default_tensor_parallel_degree" ,
127+ f"{ _MOD } ._get_gpu_info" ,
128+ f"{ _MOD } ._get_nb_instance" ,
129+ f"{ _MOD } ._get_default_tgi_configurations" ,
130+ f"{ _MOD } ._get_model_config_properties_from_hf" ,
131+ _TGI_PREP ,
128132]
129133
130134_TGI_RETURN_VALUES = [
@@ -137,12 +141,9 @@ def mock_builder_with_s3() -> MagicMock:
137141]
138142
139143_TEI_PATCHES : List [str ] = [
140- "sagemaker.serve.model_builder_servers"
141- "._get_nb_instance" ,
142- "sagemaker.serve.model_builder_servers"
143- "._get_model_config_properties_from_hf" ,
144- "sagemaker.serve.model_server.tgi"
145- ".prepare._create_dir_structure" ,
144+ f"{ _MOD } ._get_nb_instance" ,
145+ f"{ _MOD } ._get_model_config_properties_from_hf" ,
146+ _TGI_PREP ,
146147]
147148
148149_TEI_RETURN_VALUES = [
@@ -152,24 +153,20 @@ def mock_builder_with_s3() -> MagicMock:
152153]
153154
154155_TORCHSERVE_PATCHES : List [str ] = [
155- "sagemaker.serve.model_builder_servers"
156- ".prepare_for_torchserve" ,
156+ f"{ _MOD } .prepare_for_torchserve" ,
157157]
158158
159159_TORCHSERVE_RETURN_VALUES = [
160- "mock-secret-key" , # prepare_for_torchserve
160+ "mock-secret-key" ,
161161]
162162
163163_TRITON_PATCHES : List [str ] = []
164164_TRITON_RETURN_VALUES : list = []
165165
166166_MMS_PATCHES : List [str ] = [
167- "sagemaker.serve.model_builder_servers"
168- "._get_nb_instance" ,
169- "sagemaker.serve.model_builder_servers"
170- "._get_model_config_properties_from_hf" ,
171- "sagemaker.serve.model_server.multi_model_server"
172- ".prepare._create_dir_structure" ,
167+ f"{ _MOD } ._get_nb_instance" ,
168+ f"{ _MOD } ._get_model_config_properties_from_hf" ,
169+ _MMS_PREP ,
173170]
174171
175172_MMS_RETURN_VALUES = [
@@ -198,50 +195,55 @@ def _stop_patches(patchers: List) -> None:
198195 p .stop ()
199196
200197
201- # ---------------------------------------------------------------
198+ # -----------------------------------------------------------
202199# Parametrised tests: preserve user-provided HF_MODEL_ID
203- # ---------------------------------------------------------------
200+ # -----------------------------------------------------------
201+ _SERVER_PARAMS = [
202+ (
203+ "_build_for_djl" ,
204+ ModelServer .DJL_SERVING ,
205+ _DJL_PATCHES ,
206+ _DJL_RETURN_VALUES ,
207+ ),
208+ (
209+ "_build_for_tgi" ,
210+ ModelServer .TGI ,
211+ _TGI_PATCHES ,
212+ _TGI_RETURN_VALUES ,
213+ ),
214+ (
215+ "_build_for_tei" ,
216+ ModelServer .TEI ,
217+ _TEI_PATCHES ,
218+ _TEI_RETURN_VALUES ,
219+ ),
220+ (
221+ "_build_for_torchserve" ,
222+ ModelServer .TORCHSERVE ,
223+ _TORCHSERVE_PATCHES ,
224+ _TORCHSERVE_RETURN_VALUES ,
225+ ),
226+ (
227+ "_build_for_triton" ,
228+ ModelServer .TRITON ,
229+ _TRITON_PATCHES ,
230+ _TRITON_RETURN_VALUES ,
231+ ),
232+ ]
233+
234+ _SERVER_IDS = [
235+ "djl" ,
236+ "tgi" ,
237+ "tei" ,
238+ "torchserve" ,
239+ "triton" ,
240+ ]
241+
242+
204243@pytest .mark .parametrize (
205244 "build_method, server_type, patch_targets, patch_rvs" ,
206- [
207- (
208- "_build_for_djl" ,
209- ModelServer .DJL_SERVING ,
210- _DJL_PATCHES ,
211- _DJL_RETURN_VALUES ,
212- ),
213- (
214- "_build_for_tgi" ,
215- ModelServer .TGI ,
216- _TGI_PATCHES ,
217- _TGI_RETURN_VALUES ,
218- ),
219- (
220- "_build_for_tei" ,
221- ModelServer .TEI ,
222- _TEI_PATCHES ,
223- _TEI_RETURN_VALUES ,
224- ),
225- (
226- "_build_for_torchserve" ,
227- ModelServer .TORCHSERVE ,
228- _TORCHSERVE_PATCHES ,
229- _TORCHSERVE_RETURN_VALUES ,
230- ),
231- (
232- "_build_for_triton" ,
233- ModelServer .TRITON ,
234- _TRITON_PATCHES ,
235- _TRITON_RETURN_VALUES ,
236- ),
237- ],
238- ids = [
239- "djl" ,
240- "tgi" ,
241- "tei" ,
242- "torchserve" ,
243- "triton" ,
244- ],
245+ _SERVER_PARAMS ,
246+ ids = _SERVER_IDS ,
245247)
246248def test_preserves_user_provided_hf_model_id (
247249 build_method : str ,
@@ -265,45 +267,8 @@ def test_preserves_user_provided_hf_model_id(
265267
266268@pytest .mark .parametrize (
267269 "build_method, server_type, patch_targets, patch_rvs" ,
268- [
269- (
270- "_build_for_djl" ,
271- ModelServer .DJL_SERVING ,
272- _DJL_PATCHES ,
273- _DJL_RETURN_VALUES ,
274- ),
275- (
276- "_build_for_tgi" ,
277- ModelServer .TGI ,
278- _TGI_PATCHES ,
279- _TGI_RETURN_VALUES ,
280- ),
281- (
282- "_build_for_tei" ,
283- ModelServer .TEI ,
284- _TEI_PATCHES ,
285- _TEI_RETURN_VALUES ,
286- ),
287- (
288- "_build_for_torchserve" ,
289- ModelServer .TORCHSERVE ,
290- _TORCHSERVE_PATCHES ,
291- _TORCHSERVE_RETURN_VALUES ,
292- ),
293- (
294- "_build_for_triton" ,
295- ModelServer .TRITON ,
296- _TRITON_PATCHES ,
297- _TRITON_RETURN_VALUES ,
298- ),
299- ],
300- ids = [
301- "djl" ,
302- "tgi" ,
303- "tei" ,
304- "torchserve" ,
305- "triton" ,
306- ],
270+ _SERVER_PARAMS ,
271+ ids = _SERVER_IDS ,
307272)
308273def test_sets_default_hf_model_id_when_not_provided (
309274 build_method : str ,
@@ -322,12 +287,14 @@ def test_sets_default_hf_model_id_when_not_provided(
322287 )(builder )
323288 finally :
324289 _stop_patches (patchers )
325- assert builder .env_vars ["HF_MODEL_ID" ] == DEFAULT_MODEL
290+ assert (
291+ builder .env_vars ["HF_MODEL_ID" ] == DEFAULT_MODEL
292+ )
326293
327294
328- # ---------------------------------------------------------------
295+ # -----------------------------------------------------------
329296# Transformers (MMS) — needs extra patches
330- # ---------------------------------------------------------------
297+ # -----------------------------------------------------------
331298class TestBuildForTransformersHfModelId :
332299 """_build_for_transformers preserves HF_MODEL_ID."""
333300
@@ -347,7 +314,9 @@ def test_preserves_user_provided_s3_uri(
347314 )
348315 finally :
349316 _stop_patches (patchers )
350- assert builder .env_vars ["HF_MODEL_ID" ] == S3_PATH
317+ assert (
318+ builder .env_vars ["HF_MODEL_ID" ] == S3_PATH
319+ )
351320
352321 def test_sets_default_when_not_provided (
353322 self ,
@@ -366,27 +335,21 @@ def test_sets_default_when_not_provided(
366335 finally :
367336 _stop_patches (patchers )
368337 assert (
369- builder .env_vars ["HF_MODEL_ID" ] == DEFAULT_MODEL
338+ builder .env_vars ["HF_MODEL_ID" ]
339+ == DEFAULT_MODEL
370340 )
371341
342+ @patch (f"{ _MOD } .prepare_for_mms" )
343+ @patch (f"{ _MOD } .save_pkl" )
372344 @patch (
373- "sagemaker.serve.model_builder_servers.save_pkl"
374- )
375- @patch (
376- "sagemaker.serve.model_builder_servers"
377- "._get_model_config_properties_from_hf" ,
345+ f"{ _MOD } ._get_model_config_properties_from_hf" ,
378346 return_value = {},
379347 )
380348 @patch (
381- "sagemaker.serve.model_builder_servers"
382- "._get_nb_instance" ,
349+ f"{ _MOD } ._get_nb_instance" ,
383350 return_value = None ,
384351 )
385- @patch (
386- "sagemaker.serve.model_server"
387- ".multi_model_server"
388- ".prepare._create_dir_structure" ,
389- )
352+ @patch (_MMS_PREP )
390353 @patch ("os.makedirs" )
391354 def test_preserves_with_inference_spec (
392355 self ,
@@ -395,6 +358,7 @@ def test_preserves_with_inference_spec(
395358 _mock_nb : Mock ,
396359 _mock_hf : Mock ,
397360 _mock_pkl : Mock ,
361+ _mock_mms : Mock ,
398362 ) -> None :
399363 """User HF_MODEL_ID preserved with inference_spec."""
400364 builder = _create_mock_builder (
@@ -413,4 +377,6 @@ def test_preserves_with_inference_spec(
413377 _ModelBuilderServers ._build_for_transformers (
414378 builder
415379 )
416- assert builder .env_vars ["HF_MODEL_ID" ] == S3_PATH
380+ assert (
381+ builder .env_vars ["HF_MODEL_ID" ] == S3_PATH
382+ )
0 commit comments