Skip to content

Commit f20a7e2

Browse files
authored
fix: ModelBuilder with source_code + DJL LMI: /opt/ml/model becomes read-only, breaki (5698) (aws#5733)
* fix: ModelBuilder with source_code + DJL LMI: /opt/ml/model becomes read-only, breaki (5698) * fix: address review comments (iteration #1)
1 parent a65cf35 commit f20a7e2

File tree

3 files changed

+169
-12
lines changed

3 files changed

+169
-12
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,43 +319,43 @@ def _build_for_djl(self) -> Model:
319319
logger.debug(f"Using detected notebook instance type: {nb_instance}")
320320

321321
if isinstance(self.model, str) and not self._is_jumpstart_model_id():
322-
# Configure HuggingFace model for DJL
323-
self.env_vars.update({"HF_MODEL_ID": self.model})
324-
322+
# Configure HuggingFace model for DJL (preserve user-provided HF_MODEL_ID)
323+
self.env_vars.setdefault("HF_MODEL_ID", self.model)
324+
325325
# Get model configuration for DJL optimization
326326
self.hf_model_config = _get_model_config_properties_from_hf(
327327
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
328328
)
329-
329+
330330
# Apply DJL-specific configurations
331331
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
332332
self.model, self.hf_model_config, self.schema_builder
333333
)
334334
self.env_vars.update(default_djl_configurations)
335-
335+
336336
# Configure schema builder for text generation
337337
if "parameters" not in self.schema_builder.sample_input:
338338
self.schema_builder.sample_input["parameters"] = {}
339339
self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens
340-
341-
# Set DJL serving defaults
340+
341+
# Set DJL serving defaults (only if not already set by user)
342342
djl_env_vars = {
343343
"OPTION_ENGINE": "Python",
344344
"SERVING_MIN_WORKERS": "1",
345-
"SERVING_MAX_WORKERS": "1",
345+
"SERVING_MAX_WORKERS": "1",
346346
"OPTION_MODEL_LOADING_TIMEOUT": "240",
347347
"OPTION_PREDICT_TIMEOUT": "60",
348-
"TENSOR_PARALLEL_DEGREE": "1" # Default, will be overridden below
348+
"TENSOR_PARALLEL_DEGREE": "1",
349349
}
350-
350+
351351
# Add HuggingFace authentication
352352
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
353353
djl_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
354-
354+
355355
# Update with defaults only if not already set
356356
for key, value in djl_env_vars.items():
357357
self.env_vars.setdefault(key, value)
358-
358+
359359
# DJL downloads models directly from HuggingFace Hub
360360
self.s3_upload_path = None
361361

@@ -367,6 +367,12 @@ def _build_for_djl(self) -> Model:
367367
else:
368368
self.s3_model_data_url, _ = self._prepare_for_mode()
369369

370+
# Set HF cache env vars to writable location (unconditionally, using setdefault
371+
# to preserve user-provided values). This is needed because /opt/ml/model/ may be
372+
# read-only when source_code artifacts are mounted there.
373+
self.env_vars.setdefault("HF_HOME", "/tmp")
374+
self.env_vars.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp")
375+
370376
# Cache management based on mode
371377
if self.mode in LOCAL_MODES:
372378
self.env_vars.update({"HF_HUB_OFFLINE": "1"})

sagemaker-serve/tests/unit/servers/__init__.py

Whitespace-only changes.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Tests for DJL builder HF cache environment variables and HF_MODEL_ID handling.
2+
3+
Verifies that _build_for_djl() correctly:
4+
- Sets HF_HOME and HUGGINGFACE_HUB_CACHE to /tmp for writable cache
5+
- Preserves user-provided HF_MODEL_ID values (uses setdefault)
6+
- Sets HF_MODEL_ID from model param when not provided by user
7+
- Preserves user-provided HF_HOME and HUGGINGFACE_HUB_CACHE values
8+
"""
9+
10+
import pytest
11+
from unittest.mock import Mock, patch
12+
13+
from sagemaker.serve.model_builder import ModelBuilder
14+
from sagemaker.serve.utils.types import ModelServer
15+
from sagemaker.serve.mode.function_pointers import Mode
16+
from sagemaker.core.resources import Model
17+
18+
19+
MOCK_ROLE_ARN = "arn:aws:iam::000000000000:role/SageMakerRole"
20+
MOCK_IMAGE_URI = "000000000000.dkr.ecr.us-east-1.amazonaws.com/djl-inference:latest"
21+
MOCK_HF_MODEL_CONFIG = {"model_type": "gpt2", "architectures": ["GPT2LMHeadModel"]}
22+
23+
24+
# Common patches needed for _build_for_djl
25+
_DJL_PATCHES = [
26+
"sagemaker.serve.model_builder_servers._get_nb_instance",
27+
"sagemaker.serve.model_builder_servers._get_default_djl_configurations",
28+
"sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf",
29+
"sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id",
30+
"sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data",
31+
"sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri",
32+
"sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode",
33+
"sagemaker.serve.model_builder.ModelBuilder._create_model",
34+
"sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree",
35+
"sagemaker.serve.model_builder_servers._get_gpu_info",
36+
]
37+
38+
39+
def _mock_sagemaker_session():
40+
"""Create a mock SageMaker session."""
41+
session = Mock()
42+
session.boto_region_name = "us-east-1"
43+
session.sagemaker_config = {}
44+
session.default_bucket.return_value = "mock-bucket"
45+
session.upload_data.return_value = "s3://mock-bucket/model.tar.gz"
46+
return session
47+
48+
49+
def _create_djl_builder(tmp_path, env_vars=None, mode=Mode.SAGEMAKER_ENDPOINT):
50+
"""Create a ModelBuilder configured for DJL serving tests."""
51+
builder = ModelBuilder(
52+
model="test-org/test-model",
53+
role_arn=MOCK_ROLE_ARN,
54+
sagemaker_session=_mock_sagemaker_session(),
55+
model_path=str(tmp_path),
56+
mode=mode,
57+
image_uri=MOCK_IMAGE_URI,
58+
model_server=ModelServer.DJL_SERVING,
59+
instance_type="ml.g6e.12xlarge",
60+
env_vars=env_vars or {},
61+
)
62+
builder.schema_builder = Mock()
63+
builder.schema_builder.sample_input = {"inputs": "Hello"}
64+
builder._optimizing = False
65+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
66+
return builder
67+
68+
69+
def _setup_mocks(mocks):
70+
"""Configure common mock return values for DJL build."""
71+
# mocks are in reverse order of _DJL_PATCHES
72+
mock_gpu_info = mocks[-1]
73+
mock_tp_degree = mocks[-2]
74+
mock_create = mocks[-3]
75+
mock_prepare = mocks[-4]
76+
# mock_auto_detect = mocks[-5] # no setup needed
77+
# mock_validate = mocks[-6] # no setup needed
78+
mock_is_js = mocks[-7]
79+
mock_hf_config = mocks[-8]
80+
mock_djl_config = mocks[-9]
81+
mock_nb = mocks[-10]
82+
83+
mock_nb.return_value = None
84+
mock_djl_config.return_value = ({}, 256)
85+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
86+
mock_is_js.return_value = False
87+
mock_prepare.return_value = ("s3://bucket/model", None)
88+
mock_create.return_value = Mock(spec=Model)
89+
mock_tp_degree.return_value = 4
90+
mock_gpu_info.return_value = 4
91+
92+
93+
class TestDjlHfCacheAndModelId:
94+
"""Tests for DJL builder HF cache env vars and HF_MODEL_ID handling."""
95+
96+
@pytest.fixture(autouse=True)
97+
def _patch_djl(self):
98+
"""Apply all DJL-related patches for each test."""
99+
patchers = [patch(p) for p in _DJL_PATCHES]
100+
self._mocks = [p.start() for p in patchers]
101+
_setup_mocks(self._mocks)
102+
yield
103+
for p in patchers:
104+
p.stop()
105+
106+
def test_sets_hf_cache_env_vars_to_tmp(self, tmp_path):
107+
"""HF_HOME and HUGGINGFACE_HUB_CACHE should be /tmp in endpoint mode."""
108+
builder = _create_djl_builder(tmp_path)
109+
builder._build_for_djl()
110+
111+
assert builder.env_vars["HF_HOME"] == "/tmp"
112+
assert builder.env_vars["HUGGINGFACE_HUB_CACHE"] == "/tmp"
113+
114+
def test_preserves_user_provided_hf_model_id(self, tmp_path):
115+
"""User-provided HF_MODEL_ID must NOT be overridden by model param."""
116+
builder = _create_djl_builder(
117+
tmp_path, env_vars={"HF_MODEL_ID": "/opt/ml/model"}
118+
)
119+
builder._build_for_djl()
120+
121+
assert builder.env_vars["HF_MODEL_ID"] == "/opt/ml/model"
122+
123+
def test_sets_hf_model_id_from_model_param_when_not_provided(self, tmp_path):
124+
"""When no user-provided HF_MODEL_ID, it should come from model param."""
125+
builder = _create_djl_builder(tmp_path)
126+
builder._build_for_djl()
127+
128+
assert builder.env_vars["HF_MODEL_ID"] == "test-org/test-model"
129+
130+
def test_preserves_user_provided_hf_cache_dirs(self, tmp_path):
131+
"""User-provided HF_HOME and HUGGINGFACE_HUB_CACHE should be preserved."""
132+
builder = _create_djl_builder(
133+
tmp_path,
134+
env_vars={
135+
"HF_HOME": "/my/custom/cache",
136+
"HUGGINGFACE_HUB_CACHE": "/my/custom/hub",
137+
},
138+
)
139+
builder._build_for_djl()
140+
141+
assert builder.env_vars["HF_HOME"] == "/my/custom/cache"
142+
assert builder.env_vars["HUGGINGFACE_HUB_CACHE"] == "/my/custom/hub"
143+
144+
def test_local_mode_sets_hf_hub_offline(self, tmp_path):
145+
"""HF_HUB_OFFLINE=1 should be set in LOCAL_CONTAINER mode."""
146+
builder = _create_djl_builder(tmp_path, mode=Mode.LOCAL_CONTAINER)
147+
# Local mode doesn't need GPU info mocks for instance_type validation
148+
builder.instance_type = None
149+
builder._build_for_djl()
150+
151+
assert builder.env_vars["HF_HUB_OFFLINE"] == "1"

0 commit comments

Comments
 (0)