Skip to content

Commit 6a0ef97

Browse files
Support IAM role for BaseEvaluator (#5671)
1 parent d9b47ef commit 6a0ef97

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class BaseEvaluator(BaseModel):
4949
Attributes:
5050
region (Optional[str]): AWS region for evaluation jobs. If not provided, will use
5151
SAGEMAKER_REGION env var or default region.
52+
role (Optional[str]): IAM execution role ARN for SageMaker pipeline and training jobs.
53+
If not provided, will be derived from the session's caller identity. Use this when
54+
running outside SageMaker-managed environments (e.g., local notebooks, CI/CD) where
55+
the caller identity is not a SageMaker-assumable role.
5256
sagemaker_session (Optional[Any]): SageMaker session object. If not provided, a default
5357
session will be created automatically.
5458
model (Union[str, Any]): Model for evaluation. Can be:
@@ -88,6 +92,7 @@ class BaseEvaluator(BaseModel):
8892
"""
8993

9094
region: Optional[str] = None
95+
role: Optional[str] = None
9196
sagemaker_session: Optional[Any] = None
9297
model: Union[str, BaseTrainer, ModelPackage]
9398
base_eval_name: Optional[str] = None
@@ -631,9 +636,12 @@ def _get_aws_execution_context(self) -> Dict[str, str]:
631636
- account_id (str): AWS account ID
632637
"""
633638
# Get role ARN
634-
role_arn = (self.sagemaker_session.get_caller_identity_arn()
635-
if hasattr(self.sagemaker_session, 'get_caller_identity_arn')
636-
else self.sagemaker_session.expand_role())
639+
if self.role:
640+
role_arn = self.role
641+
else:
642+
role_arn = (self.sagemaker_session.get_caller_identity_arn()
643+
if hasattr(self.sagemaker_session, 'get_caller_identity_arn')
644+
else self.sagemaker_session.expand_role())
637645

638646
# Get region - prefer self.region if set, otherwise extract from session
639647
region = self.region or (self.sagemaker_session.boto_region_name

sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,27 @@ def test_get_aws_execution_context(self, mock_resolve, mock_session, mock_model_
738738
assert context['region'] == DEFAULT_REGION
739739
assert context['account_id'] == '123456789012'
740740

741+
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
742+
def test_get_aws_execution_context_with_explicit_role(self, mock_resolve, mock_session, mock_model_info):
743+
"""Test that an explicit role overrides the session-derived role."""
744+
mock_resolve.return_value = mock_model_info
745+
explicit_role = "arn:aws:iam::123456789012:role/service-role/AmazonSageMaker-ExecutionRole"
746+
747+
evaluator = BaseEvaluator(
748+
model=DEFAULT_MODEL,
749+
s3_output_path=DEFAULT_S3_OUTPUT,
750+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
751+
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
752+
sagemaker_session=mock_session,
753+
region=DEFAULT_REGION,
754+
role=explicit_role,
755+
)
756+
757+
context = evaluator._get_aws_execution_context()
758+
759+
assert context['role_arn'] == explicit_role
760+
mock_session.get_caller_identity_arn.assert_not_called()
761+
741762
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
742763
def test_get_aws_execution_context_without_region(self, mock_resolve, mock_session, mock_model_info):
743764
"""Test getting AWS execution context without explicit region."""

0 commit comments

Comments
 (0)