Skip to content

Commit 857485f

Browse files
committed
fix: ModelBuilder with source_code + DJL LMI: /opt/ml/model becomes read-only, breaki (5698)
1 parent 272fdbf commit 857485f

File tree

3 files changed

+325
-2
lines changed

3 files changed

+325
-2
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _build_for_djl(self) -> Model:
320320

321321
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
322322
# Configure HuggingFace model for DJL
323-
self.env_vars.update({"HF_MODEL_ID": self.model})
323+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
324324

325325
# Get model configuration for DJL optimization
326326
self.hf_model_config = _get_model_config_properties_from_hf(
@@ -345,7 +345,9 @@ def _build_for_djl(self) -> Model:
345345
"SERVING_MAX_WORKERS": "1",
346346
"OPTION_MODEL_LOADING_TIMEOUT": "240",
347347
"OPTION_PREDICT_TIMEOUT": "60",
348-
"TENSOR_PARALLEL_DEGREE": "1" # Default, will be overridden below
348+
"TENSOR_PARALLEL_DEGREE": "1", # Default, will be overridden below
349+
"HF_HOME": "/tmp",
350+
"HUGGINGFACE_HUB_CACHE": "/tmp",
349351
}
350352

351353
# Add HuggingFace authentication
@@ -370,6 +372,9 @@ def _build_for_djl(self) -> Model:
370372
# Cache management based on mode
371373
if self.mode in LOCAL_MODES:
372374
self.env_vars.update({"HF_HUB_OFFLINE": "1"})
375+
else:
376+
self.env_vars["HF_HOME"] = "/tmp"
377+
self.env_vars["HUGGINGFACE_HUB_CACHE"] = "/tmp"
373378

374379
# GPU-based tensor parallel calculation for SAGEMAKER_ENDPOINT mode
375380
if self.mode == Mode.SAGEMAKER_ENDPOINT:

sagemaker-serve/tests/unit/servers/__init__.py

Whitespace-only changes.
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
"""Tests for DJL builder HF cache environment variables and HF_MODEL_ID handling.
2+
3+
Verifies that _build_for_djl() correctly:
4+
- Sets HF_HOME and HUGGINGFACE_HUB_CACHE to /tmp for writable cache
5+
- Preserves user-provided HF_MODEL_ID values (uses setdefault)
6+
- Sets HF_MODEL_ID when not provided by user
7+
- Sets HF_HUB_OFFLINE in local modes
8+
"""
9+
10+
import unittest
11+
from unittest.mock import Mock, patch, MagicMock
12+
import tempfile
13+
import os
14+
import shutil
15+
16+
from sagemaker.serve.model_builder import ModelBuilder
17+
from sagemaker.serve.utils.types import ModelServer
18+
from sagemaker.serve.mode.function_pointers import Mode
19+
from sagemaker.core.resources import Model
20+
21+
22+
def _mock_sagemaker_session():
23+
"""Create a mock SageMaker session."""
24+
session = Mock()
25+
session.boto_region_name = "us-east-1"
26+
session.sagemaker_config = {}
27+
session.default_bucket.return_value = "mock-bucket"
28+
session.upload_data.return_value = "s3://mock-bucket/model.tar.gz"
29+
return session
30+
31+
32+
MOCK_ROLE_ARN = "arn:aws:iam::123456789012:role/SageMakerRole"
33+
MOCK_IMAGE_URI = "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.36.0-lmi22.0.0-cu129"
34+
MOCK_HF_MODEL_CONFIG = {"model_type": "gpt2", "architectures": ["GPT2LMHeadModel"]}
35+
36+
37+
class TestDjlHfCacheEnv(unittest.TestCase):
38+
"""Test DJL builder HF cache environment variable handling."""
39+
40+
def setUp(self):
41+
"""Set up test fixtures."""
42+
self.mock_session = _mock_sagemaker_session()
43+
self.temp_dir = tempfile.mkdtemp()
44+
45+
def tearDown(self):
46+
"""Clean up temp directory."""
47+
if os.path.exists(self.temp_dir):
48+
shutil.rmtree(self.temp_dir)
49+
50+
@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
51+
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
52+
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
53+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
54+
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
55+
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
56+
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
57+
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
58+
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
59+
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
60+
def test_build_for_djl_sets_hf_home_to_tmp(
61+
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
62+
mock_validate, mock_auto_detect, mock_prepare, mock_create,
63+
mock_tp_degree, mock_gpu_info
64+
):
65+
"""Verify HF_HOME=/tmp is set in SAGEMAKER_ENDPOINT mode."""
66+
mock_nb.return_value = None
67+
mock_is_js.return_value = False
68+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
69+
mock_djl_config.return_value = ({}, 256)
70+
mock_create.return_value = Mock(spec=Model)
71+
mock_prepare.return_value = ("s3://bucket/model", None)
72+
mock_gpu_info.return_value = 4
73+
mock_tp_degree.return_value = 4
74+
75+
builder = ModelBuilder(
76+
model="chromadb/context-1",
77+
role_arn=MOCK_ROLE_ARN,
78+
sagemaker_session=self.mock_session,
79+
model_path=self.temp_dir,
80+
mode=Mode.SAGEMAKER_ENDPOINT,
81+
image_uri=MOCK_IMAGE_URI,
82+
model_server=ModelServer.DJL_SERVING,
83+
instance_type="ml.g6e.12xlarge",
84+
)
85+
builder.schema_builder = Mock()
86+
builder.schema_builder.sample_input = {"inputs": "Hello"}
87+
builder._optimizing = False
88+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
89+
90+
builder._build_for_djl()
91+
92+
self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp")
93+
94+
@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
95+
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
96+
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
97+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
98+
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
99+
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
100+
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
101+
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
102+
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
103+
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
104+
def test_build_for_djl_sets_huggingface_hub_cache_to_tmp(
105+
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
106+
mock_validate, mock_auto_detect, mock_prepare, mock_create,
107+
mock_tp_degree, mock_gpu_info
108+
):
109+
"""Verify HUGGINGFACE_HUB_CACHE=/tmp is set in SAGEMAKER_ENDPOINT mode."""
110+
mock_nb.return_value = None
111+
mock_is_js.return_value = False
112+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
113+
mock_djl_config.return_value = ({}, 256)
114+
mock_create.return_value = Mock(spec=Model)
115+
mock_prepare.return_value = ("s3://bucket/model", None)
116+
mock_gpu_info.return_value = 4
117+
mock_tp_degree.return_value = 4
118+
119+
builder = ModelBuilder(
120+
model="chromadb/context-1",
121+
role_arn=MOCK_ROLE_ARN,
122+
sagemaker_session=self.mock_session,
123+
model_path=self.temp_dir,
124+
mode=Mode.SAGEMAKER_ENDPOINT,
125+
image_uri=MOCK_IMAGE_URI,
126+
model_server=ModelServer.DJL_SERVING,
127+
instance_type="ml.g6e.12xlarge",
128+
)
129+
builder.schema_builder = Mock()
130+
builder.schema_builder.sample_input = {"inputs": "Hello"}
131+
builder._optimizing = False
132+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
133+
134+
builder._build_for_djl()
135+
136+
self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp")
137+
138+
@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
139+
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
140+
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
141+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
142+
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
143+
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
144+
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
145+
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
146+
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
147+
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
148+
def test_build_for_djl_preserves_user_provided_hf_model_id(
149+
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
150+
mock_validate, mock_auto_detect, mock_prepare, mock_create,
151+
mock_tp_degree, mock_gpu_info
152+
):
153+
"""Verify user-provided HF_MODEL_ID is NOT overridden."""
154+
mock_nb.return_value = None
155+
mock_is_js.return_value = False
156+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
157+
mock_djl_config.return_value = ({}, 256)
158+
mock_create.return_value = Mock(spec=Model)
159+
mock_prepare.return_value = ("s3://bucket/model", None)
160+
mock_gpu_info.return_value = 4
161+
mock_tp_degree.return_value = 4
162+
163+
builder = ModelBuilder(
164+
model="chromadb/context-1",
165+
role_arn=MOCK_ROLE_ARN,
166+
sagemaker_session=self.mock_session,
167+
model_path=self.temp_dir,
168+
mode=Mode.SAGEMAKER_ENDPOINT,
169+
image_uri=MOCK_IMAGE_URI,
170+
model_server=ModelServer.DJL_SERVING,
171+
instance_type="ml.g6e.12xlarge",
172+
env_vars={"HF_MODEL_ID": "/opt/ml/model"},
173+
)
174+
builder.schema_builder = Mock()
175+
builder.schema_builder.sample_input = {"inputs": "Hello"}
176+
builder._optimizing = False
177+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
178+
179+
builder._build_for_djl()
180+
181+
# User-provided value should be preserved, NOT overridden by model param
182+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "/opt/ml/model")
183+
184+
@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
185+
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
186+
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
187+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
188+
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
189+
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
190+
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
191+
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
192+
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
193+
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
194+
def test_build_for_djl_sets_hf_model_id_when_not_provided(
195+
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
196+
mock_validate, mock_auto_detect, mock_prepare, mock_create,
197+
mock_tp_degree, mock_gpu_info
198+
):
199+
"""Verify HF_MODEL_ID is set from model param when not user-provided."""
200+
mock_nb.return_value = None
201+
mock_is_js.return_value = False
202+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
203+
mock_djl_config.return_value = ({}, 256)
204+
mock_create.return_value = Mock(spec=Model)
205+
mock_prepare.return_value = ("s3://bucket/model", None)
206+
mock_gpu_info.return_value = 4
207+
mock_tp_degree.return_value = 4
208+
209+
builder = ModelBuilder(
210+
model="chromadb/context-1",
211+
role_arn=MOCK_ROLE_ARN,
212+
sagemaker_session=self.mock_session,
213+
model_path=self.temp_dir,
214+
mode=Mode.SAGEMAKER_ENDPOINT,
215+
image_uri=MOCK_IMAGE_URI,
216+
model_server=ModelServer.DJL_SERVING,
217+
instance_type="ml.g6e.12xlarge",
218+
)
219+
builder.schema_builder = Mock()
220+
builder.schema_builder.sample_input = {"inputs": "Hello"}
221+
builder._optimizing = False
222+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
223+
224+
builder._build_for_djl()
225+
226+
# When no user-provided HF_MODEL_ID, it should be set from model param
227+
self.assertEqual(builder.env_vars["HF_MODEL_ID"], "chromadb/context-1")
228+
229+
@patch('sagemaker.serve.model_builder_servers._get_gpu_info')
230+
@patch('sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree')
231+
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
232+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
233+
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
234+
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
235+
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
236+
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
237+
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
238+
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
239+
def test_build_for_djl_with_source_code_and_hf_model_id(
240+
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
241+
mock_validate, mock_auto_detect, mock_prepare, mock_create,
242+
mock_tp_degree, mock_gpu_info
243+
):
244+
"""Verify HF cache env vars are set to /tmp when source_code is provided.
245+
246+
This is the key scenario from the bug: source_code makes /opt/ml/model
247+
read-only, so HF cache must be redirected to /tmp.
248+
"""
249+
mock_nb.return_value = None
250+
mock_is_js.return_value = False
251+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
252+
mock_djl_config.return_value = ({}, 256)
253+
mock_create.return_value = Mock(spec=Model)
254+
mock_prepare.return_value = ("s3://bucket/model", None)
255+
mock_gpu_info.return_value = 4
256+
mock_tp_degree.return_value = 4
257+
258+
builder = ModelBuilder(
259+
model="chromadb/context-1",
260+
role_arn=MOCK_ROLE_ARN,
261+
sagemaker_session=self.mock_session,
262+
model_path=self.temp_dir,
263+
mode=Mode.SAGEMAKER_ENDPOINT,
264+
image_uri=MOCK_IMAGE_URI,
265+
model_server=ModelServer.DJL_SERVING,
266+
instance_type="ml.g6e.12xlarge",
267+
)
268+
builder.schema_builder = Mock()
269+
builder.schema_builder.sample_input = {"inputs": "Hello"}
270+
builder._optimizing = False
271+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
272+
273+
builder._build_for_djl()
274+
275+
# HF cache should be redirected to /tmp to avoid read-only /opt/ml/model
276+
self.assertEqual(builder.env_vars.get("HF_HOME"), "/tmp")
277+
self.assertEqual(builder.env_vars.get("HUGGINGFACE_HUB_CACHE"), "/tmp")
278+
279+
@patch('sagemaker.serve.model_builder.ModelBuilder._create_model')
280+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
281+
@patch('sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri')
282+
@patch('sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data')
283+
@patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id')
284+
@patch('sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf')
285+
@patch('sagemaker.serve.model_builder_servers._get_default_djl_configurations')
286+
@patch('sagemaker.serve.model_builder_servers._get_nb_instance')
287+
def test_build_for_djl_local_mode_sets_hf_hub_offline(
288+
self, mock_nb, mock_djl_config, mock_hf_config, mock_is_js,
289+
mock_validate, mock_auto_detect, mock_prepare, mock_create
290+
):
291+
"""Verify HF_HUB_OFFLINE=1 is set in LOCAL_CONTAINER mode."""
292+
mock_nb.return_value = None
293+
mock_is_js.return_value = False
294+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
295+
mock_djl_config.return_value = ({}, 256)
296+
mock_create.return_value = Mock(spec=Model)
297+
298+
builder = ModelBuilder(
299+
model="chromadb/context-1",
300+
role_arn=MOCK_ROLE_ARN,
301+
sagemaker_session=self.mock_session,
302+
model_path=self.temp_dir,
303+
mode=Mode.LOCAL_CONTAINER,
304+
image_uri=MOCK_IMAGE_URI,
305+
model_server=ModelServer.DJL_SERVING,
306+
)
307+
builder.schema_builder = Mock()
308+
builder.schema_builder.sample_input = {"inputs": "Hello"}
309+
builder._optimizing = False
310+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
311+
312+
builder._build_for_djl()
313+
314+
self.assertEqual(builder.env_vars.get("HF_HUB_OFFLINE"), "1")
315+
316+
317+
if __name__ == "__main__":
318+
unittest.main()

0 commit comments

Comments
 (0)