Skip to content

Commit eca36f2

Browse files
committed
fix(eval): resolve mlflow_resource_arn in _get_base_template_context when session was absent at construction
1 parent daf19b0 commit eca36f2

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.core.helper.session_helper import Session
2222

2323
from sagemaker.train.base_trainer import BaseTrainer
24+
from sagemaker.train.common_utils.finetune_utils import _resolve_mlflow_resource_arn
2425
# Module-level logger
2526
_logger = logging.getLogger(__name__)
2627

@@ -163,8 +164,6 @@ def _resolve_dataset(cls, v):
163164
@validator('mlflow_resource_arn', pre=True, always=True)
164165
def _resolve_mlflow_arn(cls, v, values):
165166
"""Resolve MLflow resource ARN using default experience logic if not provided."""
166-
from ..common_utils.finetune_utils import _resolve_mlflow_resource_arn
167-
168167
# Get sagemaker_session from values
169168
sagemaker_session = values.get('sagemaker_session')
170169
if sagemaker_session is None:
@@ -709,6 +708,10 @@ def _get_base_template_context(
709708
Returns:
710709
dict: Base template context dictionary
711710
"""
711+
# Resolve MLflow ARN if not already resolved (e.g. session was None at construction time)
712+
if not self.mlflow_resource_arn and self.sagemaker_session:
713+
self.mlflow_resource_arn = _resolve_mlflow_resource_arn(self.sagemaker_session)
714+
712715
# Generate default mlflow_experiment_name if not provided
713716
# This is required by AWS when ModelPackageGroupArn is not provided in training jobs
714717
mlflow_experiment_name = self.mlflow_experiment_name

sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,37 @@ def test_get_base_template_context(self, mock_resolve, mock_session, mock_model_
925925
assert context['dataset_artifact_arn'] == DEFAULT_ARTIFACT_ARN
926926
assert 'action_arn_prefix' in context
927927

928+
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
929+
@patch("sagemaker.train.evaluate.base_evaluator._resolve_mlflow_resource_arn")
930+
def test_get_base_template_context_deferred_mlflow_resolution(self, mock_resolve_mlflow, mock_resolve, mock_session, mock_model_info):
931+
"""Test that mlflow_resource_arn is resolved in _get_base_template_context when session was None at construction."""
932+
mock_resolve.return_value = mock_model_info
933+
# Validator returns None because session was None at construction time
934+
mock_resolve_mlflow.return_value = None
935+
936+
evaluator = BaseEvaluator(
937+
model=DEFAULT_MODEL,
938+
s3_output_path=DEFAULT_S3_OUTPUT,
939+
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
940+
sagemaker_session=mock_session,
941+
)
942+
# Simulate the case where ARN was not resolved at construction (session was None)
943+
evaluator.mlflow_resource_arn = None
944+
945+
resolved_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/deferred"
946+
mock_resolve_mlflow.return_value = resolved_arn
947+
948+
context = evaluator._get_base_template_context(
949+
role_arn=DEFAULT_ROLE_ARN,
950+
region=DEFAULT_REGION,
951+
account_id="123456789012",
952+
model_package_group_arn=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
953+
resolved_model_artifact_arn=DEFAULT_ARTIFACT_ARN,
954+
)
955+
956+
assert context['mlflow_resource_arn'] == resolved_arn
957+
mock_resolve_mlflow.assert_called_with(mock_session)
958+
928959

929960
class TestResolveModelArtifacts:
930961
"""Tests for model artifacts resolution."""

0 commit comments

Comments
 (0)