Skip to content

Commit 228240a

Browse files
committed
fix: bug: ModelBuilder overwrites user-provided HF_MODEL_ID for DJL Serving, preventi (5529)
1 parent 272fdbf commit 228240a

File tree

2 files changed

+282
-7
lines changed

2 files changed

+282
-7
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _build_for_torchserve(self) -> Model:
136136
if isinstance(self.model, str):
137137
# Configure HuggingFace model support
138138
if not self._is_jumpstart_model_id():
139-
self.env_vars.update({"HF_MODEL_ID": self.model})
139+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
140140

141141
# Add HuggingFace token if available
142142
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
@@ -212,7 +212,7 @@ def _build_for_tgi(self) -> Model:
212212

213213
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
214214
# Configure HuggingFace model for TGI
215-
self.env_vars.update({"HF_MODEL_ID": self.model})
215+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
216216

217217
self.hf_model_config = _get_model_config_properties_from_hf(
218218
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
@@ -320,7 +320,7 @@ def _build_for_djl(self) -> Model:
320320

321321
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
322322
# Configure HuggingFace model for DJL
323-
self.env_vars.update({"HF_MODEL_ID": self.model})
323+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
324324

325325
# Get model configuration for DJL optimization
326326
self.hf_model_config = _get_model_config_properties_from_hf(
@@ -426,7 +426,7 @@ def _build_for_triton(self) -> Model:
426426
self.env_vars.update({"HF_TASK": model_task})
427427

428428
# Configure HuggingFace authentication
429-
self.env_vars.update({"HF_MODEL_ID": self.model})
429+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
430430
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
431431
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
432432

@@ -532,7 +532,7 @@ def _build_for_tei(self) -> Model:
532532

533533
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
534534
# Configure HuggingFace model for TEI
535-
self.env_vars.update({"HF_MODEL_ID": self.model})
535+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
536536

537537
self.hf_model_config = _get_model_config_properties_from_hf(
538538
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
@@ -676,7 +676,7 @@ def _build_for_transformers(self) -> Model:
676676
if self.inference_spec is not None:
677677
hf_model_id = self.inference_spec.get_model()
678678
if isinstance(hf_model_id, str): # Only if it's a valid HF model ID
679-
self.env_vars.update({"HF_MODEL_ID": hf_model_id})
679+
self.env_vars.setdefault("HF_MODEL_ID", hf_model_id)
680680
# Get HF config only for string model IDs
681681
if hasattr(self.env_vars, "HF_API_TOKEN"):
682682
self.hf_model_config = _get_model_config_properties_from_hf(
@@ -687,7 +687,7 @@ def _build_for_transformers(self) -> Model:
687687
hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
688688
)
689689
elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string
690-
self.env_vars.update({"HF_MODEL_ID": self.model})
690+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
691691
# Get HF config for string model IDs
692692
if hasattr(self.env_vars, "HF_API_TOKEN"):
693693
self.hf_model_config = _get_model_config_properties_from_hf(
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
"""Unit tests to verify HF_MODEL_ID is not overwritten when user provides it."""
2+
import unittest
3+
from unittest.mock import Mock, patch, MagicMock, PropertyMock
4+
5+
from sagemaker.serve.model_builder_servers import _ModelBuilderServers
6+
from sagemaker.serve.utils.types import ModelServer
7+
from sagemaker.serve.mode.function_pointers import Mode
8+
9+
10+
def _create_mock_builder(env_vars=None, model="Qwen/Qwen3-VL-4B-Instruct"):
11+
"""Create a mock builder with common attributes set."""
12+
builder = MagicMock(spec=_ModelBuilderServers)
13+
builder.model = model
14+
builder.env_vars = env_vars if env_vars is not None else {}
15+
builder.model_path = "/tmp/test_model_path"
16+
builder.mode = Mode.SAGEMAKER_ENDPOINT
17+
builder.model_server = ModelServer.DJL_SERVING
18+
builder.secret_key = ""
19+
builder.s3_upload_path = None
20+
builder.s3_model_data_url = None
21+
builder.shared_libs = []
22+
builder.dependencies = {}
23+
builder.image_uri = "test-image-uri"
24+
builder.instance_type = "ml.g5.2xlarge"
25+
builder.sagemaker_session = Mock()
26+
builder.schema_builder = MagicMock()
27+
builder.schema_builder.sample_input = {"inputs": "Hello", "parameters": {}}
28+
builder.inference_spec = None
29+
builder.hf_model_config = {}
30+
builder.model_data_download_timeout = None
31+
builder._user_provided_instance_type = True
32+
builder._is_jumpstart_model_id = Mock(return_value=False)
33+
builder._auto_detect_image_uri = Mock()
34+
builder._prepare_for_mode = Mock(return_value=("s3://model-data", None))
35+
builder._create_model = Mock(return_value=Mock())
36+
builder._optimizing = False
37+
builder._validate_djl_serving_sample_data = Mock()
38+
builder._validate_tgi_serving_sample_data = Mock()
39+
builder._validate_for_triton = Mock()
40+
builder.get_huggingface_model_metadata = Mock(return_value={"pipeline_tag": "text-generation"})
41+
builder.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole"
42+
return builder
43+
44+
45+
class TestDjlPreservesHfModelId(unittest.TestCase):
46+
"""Test that _build_for_djl preserves user-provided HF_MODEL_ID."""
47+
48+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
49+
@patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations")
50+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
51+
@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
52+
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
53+
def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config):
54+
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
55+
mock_hf_config.return_value = {}
56+
mock_djl_config.return_value = ({}, 256)
57+
58+
s3_path = "s3://my-bucket/models/Qwen/"
59+
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
60+
61+
with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"):
62+
_ModelBuilderServers._build_for_djl(builder)
63+
64+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)
65+
66+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
67+
@patch("sagemaker.serve.model_builder_servers._get_default_djl_configurations")
68+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
69+
@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
70+
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
71+
def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_djl_config, mock_hf_config):
72+
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
73+
mock_hf_config.return_value = {}
74+
mock_djl_config.return_value = ({}, 256)
75+
76+
builder = _create_mock_builder(env_vars={})
77+
78+
with patch("sagemaker.serve.model_server.djl_serving.prepare._create_dir_structure"):
79+
_ModelBuilderServers._build_for_djl(builder)
80+
81+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")
82+
83+
84+
class TestTgiPreservesHfModelId(unittest.TestCase):
85+
"""Test that _build_for_tgi preserves user-provided HF_MODEL_ID."""
86+
87+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
88+
@patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations")
89+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
90+
@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
91+
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
92+
def test_preserves_user_provided_s3_uri(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config):
93+
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
94+
mock_hf_config.return_value = {}
95+
mock_tgi_config.return_value = ({}, 256)
96+
97+
s3_path = "s3://my-bucket/models/Qwen/"
98+
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
99+
builder.model_server = ModelServer.TGI
100+
101+
with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
102+
_ModelBuilderServers._build_for_tgi(builder)
103+
104+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)
105+
106+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
107+
@patch("sagemaker.serve.model_builder_servers._get_default_tgi_configurations")
108+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
109+
@patch("sagemaker.serve.model_builder_servers._get_gpu_info", return_value=1)
110+
@patch("sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree", return_value=1)
111+
def test_sets_hf_model_id_when_not_provided(self, mock_tp, mock_gpu, mock_nb, mock_tgi_config, mock_hf_config):
112+
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
113+
mock_hf_config.return_value = {}
114+
mock_tgi_config.return_value = ({}, 256)
115+
116+
builder = _create_mock_builder(env_vars={})
117+
builder.model_server = ModelServer.TGI
118+
119+
with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
120+
_ModelBuilderServers._build_for_tgi(builder)
121+
122+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")
123+
124+
125+
class TestTeiPreservesHfModelId(unittest.TestCase):
126+
"""Test that _build_for_tei preserves user-provided HF_MODEL_ID."""
127+
128+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
129+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
130+
def test_preserves_user_provided_s3_uri(self, mock_nb, mock_hf_config):
131+
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
132+
mock_hf_config.return_value = {}
133+
134+
s3_path = "s3://my-bucket/models/embedding-model/"
135+
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
136+
builder.model_server = ModelServer.TEI
137+
138+
with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
139+
_ModelBuilderServers._build_for_tei(builder)
140+
141+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)
142+
143+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
144+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
145+
def test_sets_hf_model_id_when_not_provided(self, mock_nb, mock_hf_config):
146+
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
147+
mock_hf_config.return_value = {}
148+
149+
builder = _create_mock_builder(env_vars={})
150+
builder.model_server = ModelServer.TEI
151+
152+
with patch("sagemaker.serve.model_server.tgi.prepare._create_dir_structure"):
153+
_ModelBuilderServers._build_for_tei(builder)
154+
155+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")
156+
157+
158+
class TestTorchservePreservesHfModelId(unittest.TestCase):
159+
"""Test that _build_for_torchserve preserves user-provided HF_MODEL_ID."""
160+
161+
def test_preserves_user_provided_s3_uri(self):
162+
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
163+
s3_path = "s3://my-bucket/models/my-model/"
164+
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
165+
builder.model_server = ModelServer.TORCHSERVE
166+
builder.mode = Mode.SAGEMAKER_ENDPOINT
167+
builder._save_model_inference_spec = Mock()
168+
169+
_ModelBuilderServers._build_for_torchserve(builder)
170+
171+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)
172+
173+
def test_sets_hf_model_id_when_not_provided(self):
174+
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
175+
builder = _create_mock_builder(env_vars={})
176+
builder.model_server = ModelServer.TORCHSERVE
177+
builder.mode = Mode.SAGEMAKER_ENDPOINT
178+
builder._save_model_inference_spec = Mock()
179+
180+
_ModelBuilderServers._build_for_torchserve(builder)
181+
182+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")
183+
184+
185+
class TestTritonPreservesHfModelId(unittest.TestCase):
186+
"""Test that _build_for_triton preserves user-provided HF_MODEL_ID."""
187+
188+
def test_preserves_user_provided_s3_uri(self):
189+
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten."""
190+
s3_path = "s3://my-bucket/models/my-model/"
191+
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
192+
builder.model_server = ModelServer.TRITON
193+
builder._save_inference_spec = Mock()
194+
builder._prepare_for_triton = Mock()
195+
builder._auto_detect_image_for_triton = Mock()
196+
197+
_ModelBuilderServers._build_for_triton(builder)
198+
199+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)
200+
201+
def test_sets_hf_model_id_when_not_provided(self):
202+
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
203+
builder = _create_mock_builder(env_vars={})
204+
builder.model_server = ModelServer.TRITON
205+
builder._save_inference_spec = Mock()
206+
builder._prepare_for_triton = Mock()
207+
builder._auto_detect_image_for_triton = Mock()
208+
209+
_ModelBuilderServers._build_for_triton(builder)
210+
211+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")
212+
213+
214+
class TestTransformersPreservesHfModelId(unittest.TestCase):
215+
"""Test that _build_for_transformers preserves user-provided HF_MODEL_ID."""
216+
217+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
218+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
219+
def test_preserves_user_provided_s3_uri_with_model_string(self, mock_nb, mock_hf_config):
220+
"""User-provided S3 URI for HF_MODEL_ID should not be overwritten when model is a string."""
221+
mock_hf_config.return_value = {}
222+
223+
s3_path = "s3://my-bucket/models/my-model/"
224+
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
225+
builder.model_server = ModelServer.MMS
226+
builder.mode = Mode.SAGEMAKER_ENDPOINT
227+
builder.model_data_download_timeout = None
228+
229+
with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"):
230+
_ModelBuilderServers._build_for_transformers(builder)
231+
232+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)
233+
234+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
235+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
236+
def test_sets_hf_model_id_when_not_provided_with_model_string(self, mock_nb, mock_hf_config):
237+
"""HF_MODEL_ID should be set from self.model when user doesn't provide it."""
238+
mock_hf_config.return_value = {}
239+
240+
builder = _create_mock_builder(env_vars={})
241+
builder.model_server = ModelServer.MMS
242+
builder.mode = Mode.SAGEMAKER_ENDPOINT
243+
builder.model_data_download_timeout = None
244+
245+
with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"):
246+
_ModelBuilderServers._build_for_transformers(builder)
247+
248+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "Qwen/Qwen3-VL-4B-Instruct")
249+
250+
@patch("sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf")
251+
@patch("sagemaker.serve.model_builder_servers._get_nb_instance", return_value=None)
252+
@patch("sagemaker.serve.model_builder_servers.save_pkl")
253+
def test_preserves_user_provided_hf_model_id_with_inference_spec(self, mock_pkl, mock_nb, mock_hf_config):
254+
"""User-provided HF_MODEL_ID should not be overwritten when inference_spec provides a model ID."""
255+
mock_hf_config.return_value = {}
256+
257+
s3_path = "s3://my-bucket/models/my-model/"
258+
builder = _create_mock_builder(env_vars={"HF_MODEL_ID": s3_path})
259+
builder.model_server = ModelServer.MMS
260+
builder.mode = Mode.SAGEMAKER_ENDPOINT
261+
builder.model_data_download_timeout = None
262+
builder.model = None # No model string, using inference_spec
263+
builder.inference_spec = Mock()
264+
builder.inference_spec.get_model.return_value = "some-hf-model-id"
265+
builder._is_jumpstart_model_id = Mock(return_value=False)
266+
267+
with patch("sagemaker.serve.model_server.multi_model_server.prepare._create_dir_structure"):
268+
with patch("os.makedirs"):
269+
_ModelBuilderServers._build_for_transformers(builder)
270+
271+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], s3_path)
272+
273+
274+
if __name__ == "__main__":
275+
unittest.main()

0 commit comments

Comments
 (0)