3131 Channel ,
3232 DataSource ,
3333 S3DataSource ,
34+ OutputDataConfig
3435)
3536
3637
@@ -52,9 +53,7 @@ def _create_mock_model_trainer(with_internal_channels=False):
5253 trainer .training_image = "test-image:latest"
5354 trainer .training_input_mode = "File"
5455 trainer .role = "arn:aws:iam::123456789012:role/SageMakerRole"
55- trainer .output_data_config = MagicMock ()
56- trainer .output_data_config .kms_key_id = None
57- trainer .output_data_config .s3_output_path = "s3://bucket/output"
56+ trainer .output_data_config = OutputDataConfig (kms_key_id = "arn:aws:kms:us-west-2:123456789012:key/abc123" , s3_output_path = "s3://bucket/output" )
5857 trainer .compute = MagicMock ()
5958 trainer .compute .instance_type = "ml.m5.xlarge"
6059 trainer .compute .instance_count = 1
@@ -575,3 +574,16 @@ def test_build_training_job_definition_includes_internal_channels(self):
575574 assert "train" in channel_names , "User 'train' channel should be included"
576575 assert "validation" in channel_names , "User 'validation' channel should be included"
577576 assert len (channel_names ) == 4 , "Should have exactly 4 channels"
577+
578+ def test_build_training_job_definition_includes_output_data_config (self ):
579+ """Test that _build_training_job_definition includes ModelTrainer's output data config."""
580+ mock_trainer = _create_mock_model_trainer ()
581+ tuner = HyperparameterTuner (
582+ model_trainer = mock_trainer ,
583+ objective_metric_name = "accuracy" ,
584+ hyperparameter_ranges = _create_single_hp_range (),
585+ )
586+
587+ definition = tuner ._build_training_job_definition ([])
588+
589+ assert definition .output_data_config == mock_trainer .output_data_config
0 commit comments