66
77import pytest
88
9- from sagemaker .serve .model_builder_servers import _ModelBuilderServers
9+ from sagemaker .serve .model_builder_servers import (
10+ _ModelBuilderServers ,
11+ )
1012from sagemaker .serve .utils .types import ModelServer
1113from sagemaker .serve .mode .function_pointers import Mode
1214
@@ -22,7 +24,9 @@ def _create_mock_builder(
2224 """Create a mock builder with common attributes set."""
2325 builder = MagicMock (spec = _ModelBuilderServers )
2426 builder .model = model
25- builder .env_vars = env_vars if env_vars is not None else {}
27+ builder .env_vars = (
28+ env_vars if env_vars is not None else {}
29+ )
2630 builder .model_path = "/tmp/test_model_path"
2731 builder .mode = Mode .SAGEMAKER_ENDPOINT
2832 builder .model_server = ModelServer .DJL_SERVING
@@ -43,7 +47,9 @@ def _create_mock_builder(
4347 builder .hf_model_config = {}
4448 builder .model_data_download_timeout = None
4549 builder ._user_provided_instance_type = True
46- builder ._is_jumpstart_model_id = Mock (return_value = False )
50+ builder ._is_jumpstart_model_id = Mock (
51+ return_value = False
52+ )
4753 builder ._auto_detect_image_uri = Mock ()
4854 builder ._prepare_for_mode = Mock (
4955 return_value = ("s3://model-data" , None )
@@ -68,25 +74,27 @@ def _create_mock_builder(
6874
6975@pytest .fixture
7076def mock_builder () -> MagicMock :
71- """Create a mock builder with default (empty) env_vars."""
77+ """Create a mock builder with default env_vars."""
7278 return _create_mock_builder (env_vars = {})
7379
7480
7581@pytest .fixture
7682def mock_builder_with_s3 () -> MagicMock :
77- """Create a mock builder with user-provided S3 HF_MODEL_ID."""
83+ """Mock builder with user-provided S3 HF_MODEL_ID."""
7884 return _create_mock_builder (
7985 env_vars = {"HF_MODEL_ID" : S3_PATH }
8086 )
8187
8288
83- # -- Patch sets for each server type -------- --------------------------
89+ # -- Patch targets for each server type --------------------------
8490
8591_DJL_PATCHES : List [str ] = [
8692 "sagemaker.serve.model_builder_servers"
8793 "._get_default_tensor_parallel_degree" ,
88- "sagemaker.serve.model_builder_servers._get_gpu_info" ,
89- "sagemaker.serve.model_builder_servers._get_nb_instance" ,
94+ "sagemaker.serve.model_builder_servers"
95+ "._get_gpu_info" ,
96+ "sagemaker.serve.model_builder_servers"
97+ "._get_nb_instance" ,
9098 "sagemaker.serve.model_builder_servers"
9199 "._get_default_djl_configurations" ,
92100 "sagemaker.serve.model_builder_servers"
@@ -96,19 +104,21 @@ def mock_builder_with_s3() -> MagicMock:
96104]
97105
98106_DJL_RETURN_VALUES = [
99- 1 , # tensor_parallel_degree
100- 1 , # gpu_info
101- None , # nb_instance
107+ 1 , # tensor_parallel_degree
108+ 1 , # gpu_info
109+ None , # nb_instance
102110 ({}, 256 ), # djl_configurations
103- {}, # hf_model_config
104- None , # _create_dir_structure
111+ {}, # hf_model_config
112+ None , # _create_dir_structure
105113]
106114
107115_TGI_PATCHES : List [str ] = [
108116 "sagemaker.serve.model_builder_servers"
109117 "._get_default_tensor_parallel_degree" ,
110- "sagemaker.serve.model_builder_servers._get_gpu_info" ,
111- "sagemaker.serve.model_builder_servers._get_nb_instance" ,
118+ "sagemaker.serve.model_builder_servers"
119+ "._get_gpu_info" ,
120+ "sagemaker.serve.model_builder_servers"
121+ "._get_nb_instance" ,
112122 "sagemaker.serve.model_builder_servers"
113123 "._get_default_tgi_configurations" ,
114124 "sagemaker.serve.model_builder_servers"
@@ -118,16 +128,17 @@ def mock_builder_with_s3() -> MagicMock:
118128]
119129
120130_TGI_RETURN_VALUES = [
121- 1 , # tensor_parallel_degree
122- 1 , # gpu_info
123- None , # nb_instance
131+ 1 , # tensor_parallel_degree
132+ 1 , # gpu_info
133+ None , # nb_instance
124134 ({}, 256 ), # tgi_configurations
125- {}, # hf_model_config
126- None , # _create_dir_structure
135+ {}, # hf_model_config
136+ None , # _create_dir_structure
127137]
128138
129139_TEI_PATCHES : List [str ] = [
130- "sagemaker.serve.model_builder_servers._get_nb_instance" ,
140+ "sagemaker.serve.model_builder_servers"
141+ "._get_nb_instance" ,
131142 "sagemaker.serve.model_builder_servers"
132143 "._get_model_config_properties_from_hf" ,
133144 "sagemaker.serve.model_server.tgi"
@@ -140,8 +151,21 @@ def mock_builder_with_s3() -> MagicMock:
140151 None , # _create_dir_structure
141152]
142153
154+ _TORCHSERVE_PATCHES : List [str ] = [
155+ "sagemaker.serve.model_builder_servers"
156+ ".prepare_for_torchserve" ,
157+ ]
158+
159+ _TORCHSERVE_RETURN_VALUES = [
160+ "mock-secret-key" , # prepare_for_torchserve
161+ ]
162+
163+ _TRITON_PATCHES : List [str ] = []
164+ _TRITON_RETURN_VALUES : list = []
165+
143166_MMS_PATCHES : List [str ] = [
144- "sagemaker.serve.model_builder_servers._get_nb_instance" ,
167+ "sagemaker.serve.model_builder_servers"
168+ "._get_nb_instance" ,
145169 "sagemaker.serve.model_builder_servers"
146170 "._get_model_config_properties_from_hf" ,
147171 "sagemaker.serve.model_server.multi_model_server"
@@ -201,14 +225,14 @@ def _stop_patches(patchers: List) -> None:
201225 (
202226 "_build_for_torchserve" ,
203227 ModelServer .TORCHSERVE ,
204- [] ,
205- [] ,
228+ _TORCHSERVE_PATCHES ,
229+ _TORCHSERVE_RETURN_VALUES ,
206230 ),
207231 (
208232 "_build_for_triton" ,
209233 ModelServer .TRITON ,
210- [] ,
211- [] ,
234+ _TRITON_PATCHES ,
235+ _TRITON_RETURN_VALUES ,
212236 ),
213237 ],
214238 ids = [
@@ -231,7 +255,9 @@ def test_preserves_user_provided_hf_model_id(
231255 builder .model_server = server_type
232256 patchers = _apply_patches (patch_targets , patch_rvs )
233257 try :
234- getattr (_ModelBuilderServers , build_method )(builder )
258+ getattr (
259+ _ModelBuilderServers , build_method
260+ )(builder )
235261 finally :
236262 _stop_patches (patchers )
237263 assert builder .env_vars ["HF_MODEL_ID" ] == S3_PATH
@@ -261,14 +287,14 @@ def test_preserves_user_provided_hf_model_id(
261287 (
262288 "_build_for_torchserve" ,
263289 ModelServer .TORCHSERVE ,
264- [] ,
265- [] ,
290+ _TORCHSERVE_PATCHES ,
291+ _TORCHSERVE_RETURN_VALUES ,
266292 ),
267293 (
268294 "_build_for_triton" ,
269295 ModelServer .TRITON ,
270- [] ,
271- [] ,
296+ _TRITON_PATCHES ,
297+ _TRITON_RETURN_VALUES ,
272298 ),
273299 ],
274300 ids = [
@@ -291,21 +317,25 @@ def test_sets_default_hf_model_id_when_not_provided(
291317 builder .model_server = server_type
292318 patchers = _apply_patches (patch_targets , patch_rvs )
293319 try :
294- getattr (_ModelBuilderServers , build_method )(builder )
320+ getattr (
321+ _ModelBuilderServers , build_method
322+ )(builder )
295323 finally :
296324 _stop_patches (patchers )
297325 assert builder .env_vars ["HF_MODEL_ID" ] == DEFAULT_MODEL
298326
299327
300328# ---------------------------------------------------------------
301- # Transformers (MMS) — needs extra patches for _create_dir_structure
329+ # Transformers (MMS) — needs extra patches
302330# ---------------------------------------------------------------
303331class TestBuildForTransformersHfModelId :
304- """_build_for_transformers preserves user-provided HF_MODEL_ID."""
332+ """_build_for_transformers preserves HF_MODEL_ID."""
305333
306334 def test_preserves_user_provided_s3_uri (
307- self , mock_builder_with_s3 : MagicMock
335+ self ,
336+ mock_builder_with_s3 : MagicMock ,
308337 ) -> None :
338+ """User S3 URI is preserved."""
309339 builder = mock_builder_with_s3
310340 builder .model_server = ModelServer .MMS
311341 patchers = _apply_patches (
@@ -320,8 +350,10 @@ def test_preserves_user_provided_s3_uri(
320350 assert builder .env_vars ["HF_MODEL_ID" ] == S3_PATH
321351
322352 def test_sets_default_when_not_provided (
323- self , mock_builder : MagicMock
353+ self ,
354+ mock_builder : MagicMock ,
324355 ) -> None :
356+ """HF_MODEL_ID defaults to self.model."""
325357 builder = mock_builder
326358 builder .model_server = ModelServer .MMS
327359 patchers = _apply_patches (
@@ -333,9 +365,13 @@ def test_sets_default_when_not_provided(
333365 )
334366 finally :
335367 _stop_patches (patchers )
336- assert builder .env_vars ["HF_MODEL_ID" ] == DEFAULT_MODEL
368+ assert (
369+ builder .env_vars ["HF_MODEL_ID" ] == DEFAULT_MODEL
370+ )
337371
338- @patch ("sagemaker.serve.model_builder_servers.save_pkl" )
372+ @patch (
373+ "sagemaker.serve.model_builder_servers.save_pkl"
374+ )
339375 @patch (
340376 "sagemaker.serve.model_builder_servers"
341377 "._get_model_config_properties_from_hf" ,
@@ -347,7 +383,8 @@ def test_sets_default_when_not_provided(
347383 return_value = None ,
348384 )
349385 @patch (
350- "sagemaker.serve.model_server.multi_model_server"
386+ "sagemaker.serve.model_server"
387+ ".multi_model_server"
351388 ".prepare._create_dir_structure" ,
352389 )
353390 @patch ("os.makedirs" )
@@ -359,7 +396,7 @@ def test_preserves_with_inference_spec(
359396 _mock_hf : Mock ,
360397 _mock_pkl : Mock ,
361398 ) -> None :
362- """User-provided HF_MODEL_ID preserved with inference_spec."""
399+ """User HF_MODEL_ID preserved with inference_spec."""
363400 builder = _create_mock_builder (
364401 env_vars = {"HF_MODEL_ID" : S3_PATH }
365402 )
@@ -373,5 +410,7 @@ def test_preserves_with_inference_spec(
373410 builder ._is_jumpstart_model_id = Mock (
374411 return_value = False
375412 )
376- _ModelBuilderServers ._build_for_transformers (builder )
413+ _ModelBuilderServers ._build_for_transformers (
414+ builder
415+ )
377416 assert builder .env_vars ["HF_MODEL_ID" ] == S3_PATH
0 commit comments