1414from __future__ import absolute_import
1515
1616import pytest
17- from unittest .mock import MagicMock , patch , PropertyMock
17+ from unittest .mock import MagicMock , patch
1818
1919from sagemaker .train .tuner import (
2020 HyperparameterTuner ,
2626 ContinuousParameter ,
2727 IntegerParameter ,
2828)
29- from sagemaker .core .shapes import HyperParameterTuningJobWarmStartConfig
29+ from sagemaker .core .shapes import (
30+ HyperParameterTuningJobWarmStartConfig ,
31+ Channel ,
32+ DataSource ,
33+ S3DataSource ,
34+ )
35+
36+
37+ # ---------------------------------------------------------------------------
38+ # Factory functions for creating test objects (reduces fixture duplication)
39+ # ---------------------------------------------------------------------------
40+
41+ def _create_mock_model_trainer (with_internal_channels = False ):
42+ """Create a mock ModelTrainer with common attributes.
43+
44+ Args:
45+ with_internal_channels: If True, adds internal channels (code, sm_drivers)
46+ to input_data_config for testing channel inclusion in tuning jobs.
47+ """
48+ trainer = MagicMock ()
49+ trainer .sagemaker_session = MagicMock ()
50+ trainer .hyperparameters = {"learning_rate" : 0.1 , "batch_size" : 32 , "optimizer" : "adam" }
51+ trainer .training_image = "test-image:latest"
52+ trainer .training_input_mode = "File"
53+ trainer .role = "arn:aws:iam::123456789012:role/SageMakerRole"
54+ trainer .output_data_config = MagicMock ()
55+ trainer .output_data_config .s3_output_path = "s3://bucket/output"
56+ trainer .compute = MagicMock ()
57+ trainer .compute .instance_type = "ml.m5.xlarge"
58+ trainer .compute .instance_count = 1
59+ trainer .compute .volume_size_in_gb = 30
60+ trainer .stopping_condition = MagicMock ()
61+ trainer .stopping_condition .max_runtime_in_seconds = 3600
62+ trainer .input_data_config = None
63+
64+ if with_internal_channels :
65+ trainer .input_data_config = [
66+ _create_channel ("code" , "s3://bucket/code" ),
67+ _create_channel ("sm_drivers" , "s3://bucket/drivers" ),
68+ ]
69+ return trainer
70+
71+
72+ def _create_hyperparameter_ranges ():
73+ """Create sample hyperparameter ranges."""
74+ return {
75+ "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
76+ "batch_size" : IntegerParameter (32 , 256 ),
77+ "optimizer" : CategoricalParameter (["sgd" , "adam" ]),
78+ }
79+
80+
81+ def _create_single_hp_range ():
82+ """Create a single hyperparameter range for simple tests."""
83+ return {"learning_rate" : ContinuousParameter (0.001 , 0.1 )}
84+
85+
86+ def _create_channel (name : str , uri : str ) -> Channel :
87+ """Create a Channel with S3 data source."""
88+ return Channel (
89+ channel_name = name ,
90+ data_source = DataSource (
91+ s3_data_source = S3DataSource (
92+ s3_data_type = "S3Prefix" ,
93+ s3_uri = uri ,
94+ s3_data_distribution_type = "FullyReplicated"
95+ )
96+ )
97+ )
3098
3199
100+ # ---------------------------------------------------------------------------
101+ # Test Classes
102+ # ---------------------------------------------------------------------------
103+
32104class TestWarmStartTypes :
33105 """Test WarmStartTypes enum."""
34106
@@ -47,19 +119,12 @@ class TestHyperparameterTunerInit:
47119 @pytest .fixture
48120 def mock_model_trainer (self ):
49121 """Create a mock ModelTrainer."""
50- trainer = MagicMock ()
51- trainer .sagemaker_session = MagicMock ()
52- trainer .hyperparameters = {"learning_rate" : 0.1 }
53- return trainer
122+ return _create_mock_model_trainer ()
54123
55124 @pytest .fixture
56125 def hyperparameter_ranges (self ):
57126 """Create sample hyperparameter ranges."""
58- return {
59- "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
60- "batch_size" : IntegerParameter (32 , 256 ),
61- "optimizer" : CategoricalParameter (["sgd" , "adam" ]),
62- }
127+ return _create_hyperparameter_ranges ()
63128
64129 def test_init_with_basic_params (self , mock_model_trainer , hyperparameter_ranges ):
65130 """Test initialization with basic parameters."""
@@ -266,14 +331,10 @@ class TestHyperparameterTunerProperties:
266331 @pytest .fixture
267332 def tuner (self ):
268333 """Create a basic tuner instance."""
269- mock_trainer = MagicMock ()
270- mock_trainer .sagemaker_session = MagicMock ()
271334 return HyperparameterTuner (
272- model_trainer = mock_trainer ,
335+ model_trainer = _create_mock_model_trainer () ,
273336 objective_metric_name = "accuracy" ,
274- hyperparameter_ranges = {
275- "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
276- },
337+ hyperparameter_ranges = _create_single_hp_range (),
277338 )
278339
279340 def test_sagemaker_session_property (self , tuner ):
@@ -293,15 +354,10 @@ def test_hyperparameter_ranges_dict_property_returns_none(self, tuner):
293354
294355 def test_hyperparameter_ranges_dict_property_with_dict (self ):
295356 """Test hyperparameter_ranges_dict property with model_trainer_dict."""
296- mock_trainer = MagicMock ()
297- mock_trainer .sagemaker_session = MagicMock ()
298-
299357 tuner = HyperparameterTuner (
300- model_trainer = mock_trainer ,
358+ model_trainer = _create_mock_model_trainer () ,
301359 objective_metric_name = "accuracy" ,
302- hyperparameter_ranges = {
303- "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
304- },
360+ hyperparameter_ranges = _create_single_hp_range (),
305361 model_trainer_name = "trainer1" ,
306362 )
307363
@@ -316,28 +372,21 @@ class TestHyperparameterTunerMethods:
316372 @pytest .fixture
317373 def tuner_with_job (self ):
318374 """Create a tuner with a latest_tuning_job."""
319- mock_trainer = MagicMock ()
320- mock_trainer .sagemaker_session = MagicMock ()
321375 tuner = HyperparameterTuner (
322- model_trainer = mock_trainer ,
376+ model_trainer = _create_mock_model_trainer () ,
323377 objective_metric_name = "accuracy" ,
324- hyperparameter_ranges = {
325- "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
326- },
378+ hyperparameter_ranges = _create_single_hp_range (),
327379 )
328380 tuner .latest_tuning_job = MagicMock ()
329381 tuner ._current_job_name = "test-tuning-job"
330382 return tuner
331383
332384 def test_ensure_last_tuning_job_raises_error_when_none (self ):
333385 """Test _ensure_last_tuning_job raises error when no job exists."""
334- mock_trainer = MagicMock ()
335386 tuner = HyperparameterTuner (
336- model_trainer = mock_trainer ,
387+ model_trainer = _create_mock_model_trainer () ,
337388 objective_metric_name = "accuracy" ,
338- hyperparameter_ranges = {
339- "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
340- },
389+ hyperparameter_ranges = _create_single_hp_range (),
341390 )
342391
343392 with pytest .raises (ValueError ):
@@ -363,7 +412,7 @@ def test_best_training_job(self, tuner_with_job):
363412 mock_best_job = MagicMock ()
364413 mock_best_job .training_job_name = "best-job-123"
365414 mock_best_job .training_job_definition_name = "training-def"
366-
415+
367416 mock_tuning_job = MagicMock ()
368417 mock_tuning_job .best_training_job = mock_best_job
369418 tuner_with_job .latest_tuning_job .refresh .return_value = mock_tuning_job
@@ -426,16 +475,8 @@ class TestHyperparameterTunerStaticMethods:
426475
427476 def test_prepare_static_hyperparameters (self ):
428477 """Test _prepare_static_hyperparameters method."""
429- mock_trainer = MagicMock ()
430- mock_trainer .hyperparameters = {
431- "learning_rate" : 0.1 ,
432- "batch_size" : 32 ,
433- "optimizer" : "adam" ,
434- }
435-
436- hyperparameter_ranges = {
437- "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
438- }
478+ mock_trainer = _create_mock_model_trainer ()
479+ hyperparameter_ranges = _create_single_hp_range ()
439480
440481 static_hps = HyperparameterTuner ._prepare_static_hyperparameters (
441482 mock_trainer , hyperparameter_ranges
@@ -491,11 +532,7 @@ def test_extract_hyperparameters_from_parameter_ranges(self):
491532
492533 def test_prepare_parameter_ranges_for_tuning (self ):
493534 """Test _prepare_parameter_ranges_for_tuning method."""
494- parameter_ranges = {
495- "learning_rate" : ContinuousParameter (0.001 , 0.1 ),
496- "batch_size" : IntegerParameter (32 , 256 ),
497- "optimizer" : CategoricalParameter (["sgd" , "adam" ]),
498- }
535+ parameter_ranges = _create_hyperparameter_ranges ()
499536
500537 processed_ranges = HyperparameterTuner ._prepare_parameter_ranges_for_tuning (
501538 parameter_ranges
@@ -507,3 +544,37 @@ def test_prepare_parameter_ranges_for_tuning(self):
507544 assert len (processed_ranges ["ContinuousParameterRanges" ]) == 1
508545 assert len (processed_ranges ["IntegerParameterRanges" ]) == 1
509546 assert len (processed_ranges ["CategoricalParameterRanges" ]) == 1
547+
548+ def test_build_training_job_definition_includes_internal_channels (self ):
549+ """Test that _build_training_job_definition includes ModelTrainer's internal channels.
550+
551+ This test verifies the fix for GitHub issue #5508 where tuning jobs were missing
552+ internal channels (code, sm_drivers) that ModelTrainer creates for custom training.
553+ """
554+ from sagemaker .core .training .configs import InputData
555+
556+ # Create mock ModelTrainer with internal channels (code, sm_drivers)
557+ mock_trainer = _create_mock_model_trainer (with_internal_channels = True )
558+
559+ # User-provided inputs
560+ user_inputs = [
561+ InputData (channel_name = "train" , data_source = "s3://bucket/train" ),
562+ InputData (channel_name = "validation" , data_source = "s3://bucket/val" )
563+ ]
564+
565+ tuner = HyperparameterTuner (
566+ model_trainer = mock_trainer ,
567+ objective_metric_name = "accuracy" ,
568+ hyperparameter_ranges = _create_single_hp_range (),
569+ )
570+
571+ # Build training job definition
572+ definition = tuner ._build_training_job_definition (user_inputs )
573+
574+ # Verify all channels are included
575+ channel_names = [ch .channel_name for ch in definition .input_data_config ]
576+ assert "code" in channel_names , "Internal 'code' channel should be included"
577+ assert "sm_drivers" in channel_names , "Internal 'sm_drivers' channel should be included"
578+ assert "train" in channel_names , "User 'train' channel should be included"
579+ assert "validation" in channel_names , "User 'validation' channel should be included"
580+ assert len (channel_names ) == 4 , "Should have exactly 4 channels"
0 commit comments