@@ -222,7 +222,7 @@ def test_mlflow_arn_validation(self, mock_resolve, mlflow_arn, should_pass, mock
222222 )
223223
224224 @patch ("sagemaker.train.common_utils.model_resolution._resolve_base_model" )
225- @patch ("sagemaker.train.common_utils.finetune_utils ._resolve_mlflow_resource_arn" )
225+ @patch ("sagemaker.train.evaluate.base_evaluator ._resolve_mlflow_resource_arn" )
226226 def test_mlflow_arn_optional_with_resolution (self , mock_resolve_mlflow , mock_resolve , mock_session , mock_model_info ):
227227 """Test that MLflow ARN is optional and gets resolved automatically."""
228228 mock_resolve .return_value = mock_model_info
@@ -240,7 +240,7 @@ def test_mlflow_arn_optional_with_resolution(self, mock_resolve_mlflow, mock_res
240240 mock_resolve_mlflow .assert_called_once_with (mock_session , None )
241241
242242 @patch ("sagemaker.train.common_utils.model_resolution._resolve_base_model" )
243- @patch ("sagemaker.train.common_utils.finetune_utils ._resolve_mlflow_resource_arn" )
243+ @patch ("sagemaker.train.evaluate.base_evaluator ._resolve_mlflow_resource_arn" )
244244 def test_mlflow_arn_provided_skips_resolution (self , mock_resolve_mlflow , mock_resolve , mock_session , mock_model_info ):
245245 """Test that provided MLflow ARN is used instead of resolution."""
246246 mock_resolve .return_value = mock_model_info
@@ -261,7 +261,7 @@ def test_mlflow_arn_provided_skips_resolution(self, mock_resolve_mlflow, mock_re
261261 mock_resolve_mlflow .assert_called_once_with (mock_session , provided_arn )
262262
263263 @patch ("sagemaker.train.common_utils.model_resolution._resolve_base_model" )
264- @patch ("sagemaker.train.common_utils.finetune_utils ._resolve_mlflow_resource_arn" )
264+ @patch ("sagemaker.train.evaluate.base_evaluator ._resolve_mlflow_resource_arn" )
265265 def test_mlflow_arn_resolution_returns_none (self , mock_resolve_mlflow , mock_resolve , mock_session , mock_model_info ):
266266 """Test that MLflow resolution can return None (disabled tracking)."""
267267 mock_resolve .return_value = mock_model_info
@@ -278,7 +278,7 @@ def test_mlflow_arn_resolution_returns_none(self, mock_resolve_mlflow, mock_reso
278278 mock_resolve_mlflow .assert_called_once_with (mock_session , None )
279279
280280 @patch ("sagemaker.train.common_utils.model_resolution._resolve_base_model" )
281- @patch ("sagemaker.train.common_utils.finetune_utils ._resolve_mlflow_resource_arn" )
281+ @patch ("sagemaker.train.evaluate.base_evaluator ._resolve_mlflow_resource_arn" )
282282 def test_mlflow_arn_resolution_with_exception (self , mock_resolve_mlflow , mock_resolve , mock_session , mock_model_info ):
283283 """Test that MLflow resolution exceptions are handled gracefully by returning None."""
284284 mock_resolve .return_value = mock_model_info
@@ -925,6 +925,37 @@ def test_get_base_template_context(self, mock_resolve, mock_session, mock_model_
925925 assert context ['dataset_artifact_arn' ] == DEFAULT_ARTIFACT_ARN
926926 assert 'action_arn_prefix' in context
927927
928+ @patch ("sagemaker.train.common_utils.model_resolution._resolve_base_model" )
929+ @patch ("sagemaker.train.evaluate.base_evaluator._resolve_mlflow_resource_arn" )
930+ def test_get_base_template_context_deferred_mlflow_resolution (self , mock_resolve_mlflow , mock_resolve , mock_session , mock_model_info ):
931+ """Test that mlflow_resource_arn is resolved in _get_base_template_context when session was None at construction."""
932+ mock_resolve .return_value = mock_model_info
933+ # Validator returns None because session was None at construction time
934+ mock_resolve_mlflow .return_value = None
935+
936+ evaluator = BaseEvaluator (
937+ model = DEFAULT_MODEL ,
938+ s3_output_path = DEFAULT_S3_OUTPUT ,
939+ model_package_group = DEFAULT_MODEL_PACKAGE_GROUP_ARN ,
940+ sagemaker_session = mock_session ,
941+ )
942+ # Simulate the case where ARN was not resolved at construction (session was None)
943+ evaluator .mlflow_resource_arn = None
944+
945+ resolved_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/deferred"
946+ mock_resolve_mlflow .return_value = resolved_arn
947+
948+ context = evaluator ._get_base_template_context (
949+ role_arn = DEFAULT_ROLE_ARN ,
950+ region = DEFAULT_REGION ,
951+ account_id = "123456789012" ,
952+ model_package_group_arn = DEFAULT_MODEL_PACKAGE_GROUP_ARN ,
953+ resolved_model_artifact_arn = DEFAULT_ARTIFACT_ARN ,
954+ )
955+
956+ assert context ['mlflow_resource_arn' ] == resolved_arn
957+ mock_resolve_mlflow .assert_called_with (mock_session )
958+
928959
929960class TestResolveModelArtifacts :
930961 """Tests for model artifacts resolution."""
0 commit comments