Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class BaseEvaluator(BaseModel):
Attributes:
region (Optional[str]): AWS region for evaluation jobs. If not provided, will use
SAGEMAKER_REGION env var or default region.
role (Optional[str]): IAM execution role ARN for SageMaker pipeline and training jobs.
If not provided, will be derived from the session's caller identity. Use this when
running outside SageMaker-managed environments (e.g., local notebooks, CI/CD) where
the caller identity is not a SageMaker-assumable role.
sagemaker_session (Optional[Any]): SageMaker session object. If not provided, a default
session will be created automatically.
model (Union[str, Any]): Model for evaluation. Can be:
Expand Down Expand Up @@ -88,6 +92,7 @@ class BaseEvaluator(BaseModel):
"""

region: Optional[str] = None
role: Optional[str] = None
sagemaker_session: Optional[Any] = None
model: Union[str, BaseTrainer, ModelPackage]
base_eval_name: Optional[str] = None
Expand Down Expand Up @@ -631,9 +636,12 @@ def _get_aws_execution_context(self) -> Dict[str, str]:
- account_id (str): AWS account ID
"""
# Get role ARN
role_arn = (self.sagemaker_session.get_caller_identity_arn()
if hasattr(self.sagemaker_session, 'get_caller_identity_arn')
else self.sagemaker_session.expand_role())
if self.role:
role_arn = self.role
else:
role_arn = (self.sagemaker_session.get_caller_identity_arn()
if hasattr(self.sagemaker_session, 'get_caller_identity_arn')
else self.sagemaker_session.expand_role())

# Get region - prefer self.region if set, otherwise extract from session
region = self.region or (self.sagemaker_session.boto_region_name
Expand Down
21 changes: 21 additions & 0 deletions sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,27 @@ def test_get_aws_execution_context(self, mock_resolve, mock_session, mock_model_
assert context['region'] == DEFAULT_REGION
assert context['account_id'] == '123456789012'

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_get_aws_execution_context_with_explicit_role(self, mock_resolve, mock_session, mock_model_info):
"""Test that an explicit role overrides the session-derived role."""
mock_resolve.return_value = mock_model_info
explicit_role = "arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole"

evaluator = BaseEvaluator(
model=DEFAULT_MODEL,
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
region=DEFAULT_REGION,
role=explicit_role,
)

context = evaluator._get_aws_execution_context()

assert context['role_arn'] == explicit_role
mock_session.get_caller_identity_arn.assert_not_called()

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_get_aws_execution_context_without_region(self, mock_resolve, mock_session, mock_model_info):
"""Test getting AWS execution context without explicit region."""
Expand Down
Loading