1313"""Unit tests for workflow pipeline_experiment_config."""
1414from __future__ import absolute_import
1515
16+ from unittest .mock import Mock
17+
1618from sagemaker .mlops .workflow .pipeline_experiment_config import (
1719 PipelineExperimentConfig , PipelineExperimentConfigProperties
1820)
21+ from sagemaker .mlops .workflow .pipeline import Pipeline , _DEFAULT_EXPERIMENT_CFG
1922from sagemaker .core .workflow .execution_variables import ExecutionVariables
2023
2124
@@ -41,3 +44,42 @@ def test_pipeline_experiment_config_with_execution_variables():
4144def test_pipeline_experiment_config_properties ():
4245 assert PipelineExperimentConfigProperties .EXPERIMENT_NAME .name == "ExperimentName"
4346 assert PipelineExperimentConfigProperties .TRIAL_NAME .name == "TrialName"
47+
48+
49+ def _create_mock_session (region : str ) -> Mock :
50+ """Helper to create a mock SageMaker session with specified region."""
51+ mock_session = Mock ()
52+ mock_session .boto_region_name = region
53+ mock_session .boto_session = Mock ()
54+ mock_session .boto_session .client = Mock (return_value = Mock ())
55+ mock_session .local_mode = False
56+ return mock_session
57+
58+
59+ def test_default_config_applied_in_ga_region ():
60+ """Default config applied when nothing provided in GA region."""
61+ mock_session = _create_mock_session ("us-east-1" )
62+ pipeline = Pipeline (name = "test-pipeline" , sagemaker_session = mock_session )
63+ assert pipeline .pipeline_experiment_config == _DEFAULT_EXPERIMENT_CFG
64+
65+
66+ def test_no_default_config_in_non_ga_region ():
67+ """No default config when nothing provided in non-GA region (THE FIX)."""
68+ mock_session = _create_mock_session ("us-gov-west-1" )
69+ pipeline = Pipeline (name = "test-pipeline" , sagemaker_session = mock_session )
70+ assert pipeline .pipeline_experiment_config is None
71+
72+
73+ def test_explicit_none_respected_in_ga_region ():
74+ """None gets default config in GA region."""
75+ mock_session = _create_mock_session ("us-east-1" )
76+ pipeline = Pipeline (name = "test-pipeline" , sagemaker_session = mock_session , pipeline_experiment_config = None )
77+ assert pipeline .pipeline_experiment_config is None
78+
79+
80+ def test_custom_config_respected ():
81+ """Custom config respected regardless of region."""
82+ mock_session = _create_mock_session ("us-east-1" )
83+ custom_config = PipelineExperimentConfig ("my-experiment" , "my-trial" )
84+ pipeline = Pipeline (name = "test-pipeline" , sagemaker_session = mock_session , pipeline_experiment_config = custom_config )
85+ assert pipeline .pipeline_experiment_config == custom_config
0 commit comments