Skip to content

Commit ad0776f

Browse files
authored
Merge branch 'master' into fix/bug-modelbuilder-overwrites-user-provided-hf-5529
2 parents 603d9f3 + f20a7e2 commit ad0776f

File tree

6 files changed

+283
-18
lines changed

6 files changed

+283
-18
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,36 +326,36 @@ def _build_for_djl(self) -> Model:
326326
self.hf_model_config = _get_model_config_properties_from_hf(
327327
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
328328
)
329-
329+
330330
# Apply DJL-specific configurations
331331
default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations(
332332
self.model, self.hf_model_config, self.schema_builder
333333
)
334334
self.env_vars.update(default_djl_configurations)
335-
335+
336336
# Configure schema builder for text generation
337337
if "parameters" not in self.schema_builder.sample_input:
338338
self.schema_builder.sample_input["parameters"] = {}
339339
self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens
340-
341-
# Set DJL serving defaults
340+
341+
# Set DJL serving defaults (only if not already set by user)
342342
djl_env_vars = {
343343
"OPTION_ENGINE": "Python",
344344
"SERVING_MIN_WORKERS": "1",
345-
"SERVING_MAX_WORKERS": "1",
345+
"SERVING_MAX_WORKERS": "1",
346346
"OPTION_MODEL_LOADING_TIMEOUT": "240",
347347
"OPTION_PREDICT_TIMEOUT": "60",
348-
"TENSOR_PARALLEL_DEGREE": "1" # Default, will be overridden below
348+
"TENSOR_PARALLEL_DEGREE": "1",
349349
}
350-
350+
351351
# Add HuggingFace authentication
352352
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
353353
djl_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
354-
354+
355355
# Update with defaults only if not already set
356356
for key, value in djl_env_vars.items():
357357
self.env_vars.setdefault(key, value)
358-
358+
359359
# DJL downloads models directly from HuggingFace Hub
360360
self.s3_upload_path = None
361361

@@ -367,6 +367,12 @@ def _build_for_djl(self) -> Model:
367367
else:
368368
self.s3_model_data_url, _ = self._prepare_for_mode()
369369

370+
# Set HF cache env vars to writable location (unconditionally, using setdefault
371+
# to preserve user-provided values). This is needed because /opt/ml/model/ may be
372+
# read-only when source_code artifacts are mounted there.
373+
self.env_vars.setdefault("HF_HOME", "/tmp")
374+
self.env_vars.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp")
375+
370376
# Cache management based on mode
371377
if self.mode in LOCAL_MODES:
372378
self.env_vars.update({"HF_HUB_OFFLINE": "1"})

sagemaker-serve/tests/unit/servers/__init__.py

Whitespace-only changes.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Tests for DJL builder HF cache environment variables and HF_MODEL_ID handling.
2+
3+
Verifies that _build_for_djl() correctly:
4+
- Sets HF_HOME and HUGGINGFACE_HUB_CACHE to /tmp for writable cache
5+
- Preserves user-provided HF_MODEL_ID values (uses setdefault)
6+
- Sets HF_MODEL_ID from model param when not provided by user
7+
- Preserves user-provided HF_HOME and HUGGINGFACE_HUB_CACHE values
8+
"""
9+
10+
import pytest
11+
from unittest.mock import Mock, patch
12+
13+
from sagemaker.serve.model_builder import ModelBuilder
14+
from sagemaker.serve.utils.types import ModelServer
15+
from sagemaker.serve.mode.function_pointers import Mode
16+
from sagemaker.core.resources import Model
17+
18+
19+
MOCK_ROLE_ARN = "arn:aws:iam::000000000000:role/SageMakerRole"
20+
MOCK_IMAGE_URI = "000000000000.dkr.ecr.us-east-1.amazonaws.com/djl-inference:latest"
21+
MOCK_HF_MODEL_CONFIG = {"model_type": "gpt2", "architectures": ["GPT2LMHeadModel"]}
22+
23+
24+
# Common patches needed for _build_for_djl
25+
_DJL_PATCHES = [
26+
"sagemaker.serve.model_builder_servers._get_nb_instance",
27+
"sagemaker.serve.model_builder_servers._get_default_djl_configurations",
28+
"sagemaker.serve.model_builder_servers._get_model_config_properties_from_hf",
29+
"sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id",
30+
"sagemaker.serve.model_builder.ModelBuilder._validate_djl_serving_sample_data",
31+
"sagemaker.serve.model_builder.ModelBuilder._auto_detect_image_uri",
32+
"sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode",
33+
"sagemaker.serve.model_builder.ModelBuilder._create_model",
34+
"sagemaker.serve.model_builder_servers._get_default_tensor_parallel_degree",
35+
"sagemaker.serve.model_builder_servers._get_gpu_info",
36+
]
37+
38+
39+
def _mock_sagemaker_session():
40+
"""Create a mock SageMaker session."""
41+
session = Mock()
42+
session.boto_region_name = "us-east-1"
43+
session.sagemaker_config = {}
44+
session.default_bucket.return_value = "mock-bucket"
45+
session.upload_data.return_value = "s3://mock-bucket/model.tar.gz"
46+
return session
47+
48+
49+
def _create_djl_builder(tmp_path, env_vars=None, mode=Mode.SAGEMAKER_ENDPOINT):
50+
"""Create a ModelBuilder configured for DJL serving tests."""
51+
builder = ModelBuilder(
52+
model="test-org/test-model",
53+
role_arn=MOCK_ROLE_ARN,
54+
sagemaker_session=_mock_sagemaker_session(),
55+
model_path=str(tmp_path),
56+
mode=mode,
57+
image_uri=MOCK_IMAGE_URI,
58+
model_server=ModelServer.DJL_SERVING,
59+
instance_type="ml.g6e.12xlarge",
60+
env_vars=env_vars or {},
61+
)
62+
builder.schema_builder = Mock()
63+
builder.schema_builder.sample_input = {"inputs": "Hello"}
64+
builder._optimizing = False
65+
builder.hf_model_config = MOCK_HF_MODEL_CONFIG
66+
return builder
67+
68+
69+
def _setup_mocks(mocks):
70+
"""Configure common mock return values for DJL build."""
71+
# mocks are in reverse order of _DJL_PATCHES
72+
mock_gpu_info = mocks[-1]
73+
mock_tp_degree = mocks[-2]
74+
mock_create = mocks[-3]
75+
mock_prepare = mocks[-4]
76+
# mock_auto_detect = mocks[-5] # no setup needed
77+
# mock_validate = mocks[-6] # no setup needed
78+
mock_is_js = mocks[-7]
79+
mock_hf_config = mocks[-8]
80+
mock_djl_config = mocks[-9]
81+
mock_nb = mocks[-10]
82+
83+
mock_nb.return_value = None
84+
mock_djl_config.return_value = ({}, 256)
85+
mock_hf_config.return_value = MOCK_HF_MODEL_CONFIG
86+
mock_is_js.return_value = False
87+
mock_prepare.return_value = ("s3://bucket/model", None)
88+
mock_create.return_value = Mock(spec=Model)
89+
mock_tp_degree.return_value = 4
90+
mock_gpu_info.return_value = 4
91+
92+
93+
class TestDjlHfCacheAndModelId:
94+
"""Tests for DJL builder HF cache env vars and HF_MODEL_ID handling."""
95+
96+
@pytest.fixture(autouse=True)
97+
def _patch_djl(self):
98+
"""Apply all DJL-related patches for each test."""
99+
patchers = [patch(p) for p in _DJL_PATCHES]
100+
self._mocks = [p.start() for p in patchers]
101+
_setup_mocks(self._mocks)
102+
yield
103+
for p in patchers:
104+
p.stop()
105+
106+
def test_sets_hf_cache_env_vars_to_tmp(self, tmp_path):
107+
"""HF_HOME and HUGGINGFACE_HUB_CACHE should be /tmp in endpoint mode."""
108+
builder = _create_djl_builder(tmp_path)
109+
builder._build_for_djl()
110+
111+
assert builder.env_vars["HF_HOME"] == "/tmp"
112+
assert builder.env_vars["HUGGINGFACE_HUB_CACHE"] == "/tmp"
113+
114+
def test_preserves_user_provided_hf_model_id(self, tmp_path):
115+
"""User-provided HF_MODEL_ID must NOT be overridden by model param."""
116+
builder = _create_djl_builder(
117+
tmp_path, env_vars={"HF_MODEL_ID": "/opt/ml/model"}
118+
)
119+
builder._build_for_djl()
120+
121+
assert builder.env_vars["HF_MODEL_ID"] == "/opt/ml/model"
122+
123+
def test_sets_hf_model_id_from_model_param_when_not_provided(self, tmp_path):
124+
"""When no user-provided HF_MODEL_ID, it should come from model param."""
125+
builder = _create_djl_builder(tmp_path)
126+
builder._build_for_djl()
127+
128+
assert builder.env_vars["HF_MODEL_ID"] == "test-org/test-model"
129+
130+
def test_preserves_user_provided_hf_cache_dirs(self, tmp_path):
131+
"""User-provided HF_HOME and HUGGINGFACE_HUB_CACHE should be preserved."""
132+
builder = _create_djl_builder(
133+
tmp_path,
134+
env_vars={
135+
"HF_HOME": "/my/custom/cache",
136+
"HUGGINGFACE_HUB_CACHE": "/my/custom/hub",
137+
},
138+
)
139+
builder._build_for_djl()
140+
141+
assert builder.env_vars["HF_HOME"] == "/my/custom/cache"
142+
assert builder.env_vars["HUGGINGFACE_HUB_CACHE"] == "/my/custom/hub"
143+
144+
def test_local_mode_sets_hf_hub_offline(self, tmp_path):
145+
"""HF_HUB_OFFLINE=1 should be set in LOCAL_CONTAINER mode."""
146+
builder = _create_djl_builder(tmp_path, mode=Mode.LOCAL_CONTAINER)
147+
# Local mode doesn't need GPU info mocks for instance_type validation
148+
builder.instance_type = None
149+
builder._build_for_djl()
150+
151+
assert builder.env_vars["HF_HUB_OFFLINE"] == "1"

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,13 @@ def _build_training_job_definition(self, inputs):
15041504
model_trainer.stopping_condition.max_wait_time_in_seconds
15051505
)
15061506

1507-
definition = HyperParameterTrainingJobDefinition(
1507+
# Propagate environment variables from ModelTrainer.
1508+
# Only include when it's a dict (even empty); omit otherwise so the
1509+
# Pydantic field stays Unassigned and is excluded during serialization.
1510+
env = model_trainer.environment
1511+
1512+
# Build base kwargs for the definition
1513+
definition_kwargs = dict(
15081514
algorithm_specification=algorithm_spec,
15091515
role_arn=model_trainer.role,
15101516
input_data_config=input_data_config if input_data_config else None,
@@ -1515,10 +1521,11 @@ def _build_training_job_definition(self, inputs):
15151521
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
15161522
)
15171523

1518-
# Pass through environment variables from model_trainer
1519-
env = getattr(model_trainer, "environment", None)
1520-
if env and isinstance(env, dict):
1521-
definition.environment = env
1524+
# Include environment only when it's a dict (including empty).
1525+
if isinstance(env, dict):
1526+
definition_kwargs["environment"] = env
1527+
1528+
definition = HyperParameterTrainingJobDefinition(**definition_kwargs)
15221529

15231530
# Pass through VPC config from model_trainer
15241531
networking = getattr(model_trainer, "networking", None)

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,3 +596,73 @@ def test_build_training_job_definition_includes_spot_params(self):
596596
assert isinstance(
597597
definition.stopping_condition.max_wait_time_in_seconds, int
598598
), "Max wait time should be set"
599+
600+
def test_build_training_job_definition_includes_environment_variables(self):
601+
"""Test that _build_training_job_definition includes environment variables.
602+
603+
This test verifies the fix for GitHub issue #5613 where tuning jobs were
604+
missing environment variables that were set on the ModelTrainer.
605+
"""
606+
mock_trainer = _create_mock_model_trainer()
607+
mock_trainer.environment = {
608+
"FOO": "bar",
609+
"RANDOM_STATE": "42",
610+
}
611+
612+
tuner = HyperparameterTuner(
613+
model_trainer=mock_trainer,
614+
objective_metric_name="accuracy",
615+
hyperparameter_ranges=_create_single_hp_range(),
616+
)
617+
618+
definition = tuner._build_training_job_definition(None)
619+
620+
assert definition.environment is not None, "Environment should not be None"
621+
assert definition.environment == {
622+
"FOO": "bar",
623+
"RANDOM_STATE": "42",
624+
}, "Environment variables should match those set on ModelTrainer"
625+
626+
def test_build_training_job_definition_with_none_environment(self):
627+
"""Test that _build_training_job_definition handles None environment gracefully.
628+
629+
When environment is None, it should not be passed to the Pydantic constructor,
630+
so the field stays as Unassigned (excluded from serialization).
631+
"""
632+
from sagemaker.core.utils.utils import Unassigned
633+
634+
mock_trainer = _create_mock_model_trainer()
635+
mock_trainer.environment = None
636+
637+
tuner = HyperparameterTuner(
638+
model_trainer=mock_trainer,
639+
objective_metric_name="accuracy",
640+
hyperparameter_ranges=_create_single_hp_range(),
641+
)
642+
643+
definition = tuner._build_training_job_definition(None)
644+
645+
assert isinstance(definition.environment, Unassigned), (
646+
"Environment should be Unassigned when model_trainer.environment is None"
647+
)
648+
649+
def test_build_training_job_definition_with_empty_environment(self):
650+
"""Test that _build_training_job_definition passes through empty environment.
651+
652+
An empty dict is valid for the SageMaker API, so we pass it through as-is
653+
rather than silently converting it to None.
654+
"""
655+
mock_trainer = _create_mock_model_trainer()
656+
mock_trainer.environment = {}
657+
658+
tuner = HyperparameterTuner(
659+
model_trainer=mock_trainer,
660+
objective_metric_name="accuracy",
661+
hyperparameter_ranges=_create_single_hp_range(),
662+
)
663+
664+
definition = tuner._build_training_job_definition(None)
665+
666+
assert definition.environment == {}, (
667+
"Empty dict environment should be passed through as-is"
668+
)

sagemaker-train/tests/unit/train/test_tuner_driver_channels.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,31 @@ def test_passes_environment_variables(self):
405405
definition = tuner._build_training_job_definition(inputs=None)
406406
assert definition.environment == {"MY_VAR": "value", "OTHER": "123"}
407407

408+
def test_passes_empty_environment(self):
409+
"""Should pass through empty dict environment as-is.
410+
411+
An empty dict is valid for the SageMaker API, so we pass it through
412+
rather than silently converting it to None/Unassigned.
413+
"""
414+
trainer = _mock_model_trainer(environment={})
415+
416+
tuner = HyperparameterTuner(
417+
model_trainer=trainer,
418+
objective_metric_name="accuracy",
419+
hyperparameter_ranges=_hp_ranges(),
420+
)
421+
422+
definition = tuner._build_training_job_definition(inputs=None)
423+
assert definition.environment == {}, (
424+
"Empty dict environment should be passed through as-is"
425+
)
426+
408427
def test_skips_environment_when_none(self):
409-
"""Should not set environment when model_trainer.environment is None."""
428+
"""Should not set environment when model_trainer.environment is None.
429+
430+
When environment is None, it is not passed to the Pydantic constructor,
431+
so the field stays as Unassigned (excluded from serialization).
432+
"""
410433
trainer = _mock_model_trainer(environment=None)
411434

412435
tuner = HyperparameterTuner(
@@ -416,10 +439,16 @@ def test_skips_environment_when_none(self):
416439
)
417440

418441
definition = tuner._build_training_job_definition(inputs=None)
419-
assert _is_unassigned(definition.environment)
442+
assert _is_unassigned(definition.environment), (
443+
"Environment should be Unassigned when model_trainer.environment is None"
444+
)
420445

421446
def test_skips_environment_when_not_dict(self):
422-
"""Should not set environment when it's not a dict (e.g. MagicMock)."""
447+
"""Should not set environment when it's not a dict (e.g. MagicMock).
448+
449+
Non-dict values are not passed to the Pydantic constructor to avoid
450+
validation errors. The field stays as Unassigned.
451+
"""
423452
trainer = _mock_model_trainer(environment=MagicMock())
424453

425454
tuner = HyperparameterTuner(
@@ -429,7 +458,9 @@ def test_skips_environment_when_not_dict(self):
429458
)
430459

431460
definition = tuner._build_training_job_definition(inputs=None)
432-
assert _is_unassigned(definition.environment)
461+
assert _is_unassigned(definition.environment), (
462+
"Environment should be Unassigned when model_trainer.environment is not a dict"
463+
)
433464

434465
def test_passes_vpc_config(self):
435466
"""Should set definition.vpc_config from model_trainer.networking._to_vpc_config()."""

0 commit comments

Comments
 (0)