Skip to content

Commit e5b65df

Browse files
MarvinPittMarvinPitt
authored andcommitted
update output_data_config reference in HyperparameterTuner and add test for its inclusion
1 parent 196562c commit e5b65df

2 files changed

Lines changed: 16 additions & 6 deletions

File tree

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ def __init__(
253253
self.instance_configs_dict = None
254254
self.instance_configs = None
255255
self.autotune = autotune
256-
self.output_data_config = model_trainer.output_data_config
257256

258257
def override_resource_config(
259258
self,
@@ -1422,7 +1421,6 @@ def _build_training_job_definition(self, inputs):
14221421
if not any(c.channel_name == channel.channel_name for c in input_data_config):
14231422
input_data_config.append(channel)
14241423

1425-
14261424
# Build resource config
14271425
resource_config = ResourceConfig(
14281426
instance_type=(
@@ -1448,7 +1446,7 @@ def _build_training_job_definition(self, inputs):
14481446
algorithm_specification=algorithm_spec,
14491447
role_arn=model_trainer.role,
14501448
input_data_config=input_data_config if input_data_config else None,
1451-
output_data_config=self.output_data_config,
1449+
output_data_config=model_trainer.output_data_config,
14521450
resource_config=resource_config,
14531451
stopping_condition=stopping_condition,
14541452
static_hyper_parameters=self.static_hyperparameters or {},

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Channel,
3232
DataSource,
3333
S3DataSource,
34+
OutputDataConfig
3435
)
3536

3637

@@ -52,9 +53,7 @@ def _create_mock_model_trainer(with_internal_channels=False):
5253
trainer.training_image = "test-image:latest"
5354
trainer.training_input_mode = "File"
5455
trainer.role = "arn:aws:iam::123456789012:role/SageMakerRole"
55-
trainer.output_data_config = MagicMock()
56-
trainer.output_data_config.kms_key_id = None
57-
trainer.output_data_config.s3_output_path = "s3://bucket/output"
56+
trainer.output_data_config = OutputDataConfig(kms_key_id="arn:aws:kms:us-west-2:123456789012:key/abc123", s3_output_path="s3://bucket/output")
5857
trainer.compute = MagicMock()
5958
trainer.compute.instance_type = "ml.m5.xlarge"
6059
trainer.compute.instance_count = 1
@@ -575,3 +574,16 @@ def test_build_training_job_definition_includes_internal_channels(self):
575574
assert "train" in channel_names, "User 'train' channel should be included"
576575
assert "validation" in channel_names, "User 'validation' channel should be included"
577576
assert len(channel_names) == 4, "Should have exactly 4 channels"
577+
578+
def test_build_training_job_definition_includes_output_data_config(self):
579+
"""Test that _build_training_job_definition includes ModelTrainer's output data config."""
580+
mock_trainer = _create_mock_model_trainer()
581+
tuner = HyperparameterTuner(
582+
model_trainer=mock_trainer,
583+
objective_metric_name="accuracy",
584+
hyperparameter_ranges=_create_single_hp_range(),
585+
)
586+
587+
definition = tuner._build_training_job_definition([])
588+
589+
assert definition.output_data_config == mock_trainer.output_data_config

0 commit comments

Comments
 (0)