Skip to content

Commit 588329f

Browse files
committed
Fix HyperparameterTuner to include ModelTrainer internal channels
HyperparameterTuner._build_training_job_definition() was only copying user-provided input channels, missing ModelTrainer's internal channels (code, sm_drivers). This caused tuning jobs to fail when using custom training scripts. - Add logic to include ModelTrainer's input_data_config channels - Refactor test_tuner.py with factory functions to reduce duplication - Add test for internal channels inclusion Fixes #5508
1 parent 7a61bb6 commit 588329f

2 files changed

Lines changed: 131 additions & 52 deletions

File tree

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,14 @@ def _build_training_job_definition(self, inputs):
13881388
)
13891389
))
13901390

1391+
# Include ModelTrainer's internal channels (code, sm_drivers, etc.)
1392+
# These are created by ModelTrainer and are required for custom training logic
1393+
if hasattr(model_trainer, 'input_data_config') and model_trainer.input_data_config:
1394+
for channel in model_trainer.input_data_config:
1395+
# Add internal channels that aren't already in input_data_config
1396+
if not any(c.channel_name == channel.channel_name for c in input_data_config):
1397+
input_data_config.append(channel)
1398+
13911399
# Build output data config
13921400
output_config = OutputDataConfig(
13931401
s3_output_path=model_trainer.output_data_config.s3_output_path if model_trainer.output_data_config else None
@@ -1412,7 +1420,7 @@ def _build_training_job_definition(self, inputs):
14121420
output_data_config=output_config,
14131421
resource_config=resource_config,
14141422
stopping_condition=stopping_condition,
1415-
static_hyper_parameters=self.static_hyperparameters or {}
1423+
static_hyper_parameters=self.static_hyperparameters_dict or {}
14161424
)
14171425

14181426
return definition

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 122 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import pytest
17-
from unittest.mock import MagicMock, patch, PropertyMock
17+
from unittest.mock import MagicMock, patch
1818

1919
from sagemaker.train.tuner import (
2020
HyperparameterTuner,
@@ -26,9 +26,81 @@
2626
ContinuousParameter,
2727
IntegerParameter,
2828
)
29-
from sagemaker.core.shapes import HyperParameterTuningJobWarmStartConfig
29+
from sagemaker.core.shapes import (
30+
HyperParameterTuningJobWarmStartConfig,
31+
Channel,
32+
DataSource,
33+
S3DataSource,
34+
)
35+
36+
37+
# ---------------------------------------------------------------------------
38+
# Factory functions for creating test objects (reduces fixture duplication)
39+
# ---------------------------------------------------------------------------
40+
41+
def _create_mock_model_trainer(with_internal_channels=False):
42+
"""Create a mock ModelTrainer with common attributes.
43+
44+
Args:
45+
with_internal_channels: If True, adds internal channels (code, sm_drivers)
46+
to input_data_config for testing channel inclusion in tuning jobs.
47+
"""
48+
trainer = MagicMock()
49+
trainer.sagemaker_session = MagicMock()
50+
trainer.hyperparameters = {"learning_rate": 0.1, "batch_size": 32, "optimizer": "adam"}
51+
trainer.training_image = "test-image:latest"
52+
trainer.training_input_mode = "File"
53+
trainer.role = "arn:aws:iam::123456789012:role/SageMakerRole"
54+
trainer.output_data_config = MagicMock()
55+
trainer.output_data_config.s3_output_path = "s3://bucket/output"
56+
trainer.compute = MagicMock()
57+
trainer.compute.instance_type = "ml.m5.xlarge"
58+
trainer.compute.instance_count = 1
59+
trainer.compute.volume_size_in_gb = 30
60+
trainer.stopping_condition = MagicMock()
61+
trainer.stopping_condition.max_runtime_in_seconds = 3600
62+
trainer.input_data_config = None
63+
64+
if with_internal_channels:
65+
trainer.input_data_config = [
66+
_create_channel("code", "s3://bucket/code"),
67+
_create_channel("sm_drivers", "s3://bucket/drivers"),
68+
]
69+
return trainer
70+
71+
72+
def _create_hyperparameter_ranges():
73+
"""Create sample hyperparameter ranges."""
74+
return {
75+
"learning_rate": ContinuousParameter(0.001, 0.1),
76+
"batch_size": IntegerParameter(32, 256),
77+
"optimizer": CategoricalParameter(["sgd", "adam"]),
78+
}
79+
80+
81+
def _create_single_hp_range():
82+
"""Create a single hyperparameter range for simple tests."""
83+
return {"learning_rate": ContinuousParameter(0.001, 0.1)}
84+
85+
86+
def _create_channel(name: str, uri: str) -> Channel:
87+
"""Create a Channel with S3 data source."""
88+
return Channel(
89+
channel_name=name,
90+
data_source=DataSource(
91+
s3_data_source=S3DataSource(
92+
s3_data_type="S3Prefix",
93+
s3_uri=uri,
94+
s3_data_distribution_type="FullyReplicated"
95+
)
96+
)
97+
)
3098

3199

100+
# ---------------------------------------------------------------------------
101+
# Test Classes
102+
# ---------------------------------------------------------------------------
103+
32104
class TestWarmStartTypes:
33105
"""Test WarmStartTypes enum."""
34106

@@ -47,19 +119,12 @@ class TestHyperparameterTunerInit:
47119
@pytest.fixture
48120
def mock_model_trainer(self):
49121
"""Create a mock ModelTrainer."""
50-
trainer = MagicMock()
51-
trainer.sagemaker_session = MagicMock()
52-
trainer.hyperparameters = {"learning_rate": 0.1}
53-
return trainer
122+
return _create_mock_model_trainer()
54123

55124
@pytest.fixture
56125
def hyperparameter_ranges(self):
57126
"""Create sample hyperparameter ranges."""
58-
return {
59-
"learning_rate": ContinuousParameter(0.001, 0.1),
60-
"batch_size": IntegerParameter(32, 256),
61-
"optimizer": CategoricalParameter(["sgd", "adam"]),
62-
}
127+
return _create_hyperparameter_ranges()
63128

64129
def test_init_with_basic_params(self, mock_model_trainer, hyperparameter_ranges):
65130
"""Test initialization with basic parameters."""
@@ -266,14 +331,10 @@ class TestHyperparameterTunerProperties:
266331
@pytest.fixture
267332
def tuner(self):
268333
"""Create a basic tuner instance."""
269-
mock_trainer = MagicMock()
270-
mock_trainer.sagemaker_session = MagicMock()
271334
return HyperparameterTuner(
272-
model_trainer=mock_trainer,
335+
model_trainer=_create_mock_model_trainer(),
273336
objective_metric_name="accuracy",
274-
hyperparameter_ranges={
275-
"learning_rate": ContinuousParameter(0.001, 0.1),
276-
},
337+
hyperparameter_ranges=_create_single_hp_range(),
277338
)
278339

279340
def test_sagemaker_session_property(self, tuner):
@@ -293,15 +354,10 @@ def test_hyperparameter_ranges_dict_property_returns_none(self, tuner):
293354

294355
def test_hyperparameter_ranges_dict_property_with_dict(self):
295356
"""Test hyperparameter_ranges_dict property with model_trainer_dict."""
296-
mock_trainer = MagicMock()
297-
mock_trainer.sagemaker_session = MagicMock()
298-
299357
tuner = HyperparameterTuner(
300-
model_trainer=mock_trainer,
358+
model_trainer=_create_mock_model_trainer(),
301359
objective_metric_name="accuracy",
302-
hyperparameter_ranges={
303-
"learning_rate": ContinuousParameter(0.001, 0.1),
304-
},
360+
hyperparameter_ranges=_create_single_hp_range(),
305361
model_trainer_name="trainer1",
306362
)
307363

@@ -316,28 +372,21 @@ class TestHyperparameterTunerMethods:
316372
@pytest.fixture
317373
def tuner_with_job(self):
318374
"""Create a tuner with a latest_tuning_job."""
319-
mock_trainer = MagicMock()
320-
mock_trainer.sagemaker_session = MagicMock()
321375
tuner = HyperparameterTuner(
322-
model_trainer=mock_trainer,
376+
model_trainer=_create_mock_model_trainer(),
323377
objective_metric_name="accuracy",
324-
hyperparameter_ranges={
325-
"learning_rate": ContinuousParameter(0.001, 0.1),
326-
},
378+
hyperparameter_ranges=_create_single_hp_range(),
327379
)
328380
tuner.latest_tuning_job = MagicMock()
329381
tuner._current_job_name = "test-tuning-job"
330382
return tuner
331383

332384
def test_ensure_last_tuning_job_raises_error_when_none(self):
333385
"""Test _ensure_last_tuning_job raises error when no job exists."""
334-
mock_trainer = MagicMock()
335386
tuner = HyperparameterTuner(
336-
model_trainer=mock_trainer,
387+
model_trainer=_create_mock_model_trainer(),
337388
objective_metric_name="accuracy",
338-
hyperparameter_ranges={
339-
"learning_rate": ContinuousParameter(0.001, 0.1),
340-
},
389+
hyperparameter_ranges=_create_single_hp_range(),
341390
)
342391

343392
with pytest.raises(ValueError):
@@ -363,7 +412,7 @@ def test_best_training_job(self, tuner_with_job):
363412
mock_best_job = MagicMock()
364413
mock_best_job.training_job_name = "best-job-123"
365414
mock_best_job.training_job_definition_name = "training-def"
366-
415+
367416
mock_tuning_job = MagicMock()
368417
mock_tuning_job.best_training_job = mock_best_job
369418
tuner_with_job.latest_tuning_job.refresh.return_value = mock_tuning_job
@@ -426,16 +475,8 @@ class TestHyperparameterTunerStaticMethods:
426475

427476
def test_prepare_static_hyperparameters(self):
428477
"""Test _prepare_static_hyperparameters method."""
429-
mock_trainer = MagicMock()
430-
mock_trainer.hyperparameters = {
431-
"learning_rate": 0.1,
432-
"batch_size": 32,
433-
"optimizer": "adam",
434-
}
435-
436-
hyperparameter_ranges = {
437-
"learning_rate": ContinuousParameter(0.001, 0.1),
438-
}
478+
mock_trainer = _create_mock_model_trainer()
479+
hyperparameter_ranges = _create_single_hp_range()
439480

440481
static_hps = HyperparameterTuner._prepare_static_hyperparameters(
441482
mock_trainer, hyperparameter_ranges
@@ -491,11 +532,7 @@ def test_extract_hyperparameters_from_parameter_ranges(self):
491532

492533
def test_prepare_parameter_ranges_for_tuning(self):
493534
"""Test _prepare_parameter_ranges_for_tuning method."""
494-
parameter_ranges = {
495-
"learning_rate": ContinuousParameter(0.001, 0.1),
496-
"batch_size": IntegerParameter(32, 256),
497-
"optimizer": CategoricalParameter(["sgd", "adam"]),
498-
}
535+
parameter_ranges = _create_hyperparameter_ranges()
499536

500537
processed_ranges = HyperparameterTuner._prepare_parameter_ranges_for_tuning(
501538
parameter_ranges
@@ -507,3 +544,37 @@ def test_prepare_parameter_ranges_for_tuning(self):
507544
assert len(processed_ranges["ContinuousParameterRanges"]) == 1
508545
assert len(processed_ranges["IntegerParameterRanges"]) == 1
509546
assert len(processed_ranges["CategoricalParameterRanges"]) == 1
547+
548+
def test_build_training_job_definition_includes_internal_channels(self):
549+
"""Test that _build_training_job_definition includes ModelTrainer's internal channels.
550+
551+
This test verifies the fix for GitHub issue #5508 where tuning jobs were missing
552+
internal channels (code, sm_drivers) that ModelTrainer creates for custom training.
553+
"""
554+
from sagemaker.core.training.configs import InputData
555+
556+
# Create mock ModelTrainer with internal channels (code, sm_drivers)
557+
mock_trainer = _create_mock_model_trainer(with_internal_channels=True)
558+
559+
# User-provided inputs
560+
user_inputs = [
561+
InputData(channel_name="train", data_source="s3://bucket/train"),
562+
InputData(channel_name="validation", data_source="s3://bucket/val")
563+
]
564+
565+
tuner = HyperparameterTuner(
566+
model_trainer=mock_trainer,
567+
objective_metric_name="accuracy",
568+
hyperparameter_ranges=_create_single_hp_range(),
569+
)
570+
571+
# Build training job definition
572+
definition = tuner._build_training_job_definition(user_inputs)
573+
574+
# Verify all channels are included
575+
channel_names = [ch.channel_name for ch in definition.input_data_config]
576+
assert "code" in channel_names, "Internal 'code' channel should be included"
577+
assert "sm_drivers" in channel_names, "Internal 'sm_drivers' channel should be included"
578+
assert "train" in channel_names, "User 'train' channel should be included"
579+
assert "validation" in channel_names, "User 'validation' channel should be included"
580+
assert len(channel_names) == 4, "Should have exactly 4 channels"

0 commit comments

Comments
 (0)