Skip to content

Commit 8e060f6

Browse files
committed
fix(train): always use SageMakerPublicHub for base model ARN in evaluations
The backend now enforces that serverless training jobs only accept BaseModelArn values pointing to SageMakerPublicHub. The SAGEMAKER_HUB_NAME env var is for training recipe lookups only and should not affect the base model ARN passed to evaluation pipelines. Also remove ModelApprovalStatus filter from test helper since training tests never set approval status on output packages.
1 parent 32fab32 commit 8e060f6

2 files changed

Lines changed: 5 additions & 10 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
- ModelPackage objects or ARNs (fine-tuned models)
77
"""
88

9-
import os
109
import json
1110
import boto3
1211
from typing import Union, Optional, Dict, Any
@@ -239,11 +238,8 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
239238
arn_parts = model_pkg_arn.split(':')
240239
if len(arn_parts) >= 4:
241240
region = arn_parts[3]
242-
# Use SAGEMAKER_HUB_NAME if set (private hub), otherwise fall back to public hub
243-
hub_name = os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub")
244-
# Private hubs are account-scoped; public hub uses 'aws' as account
245-
hub_account = "aws" if hub_name == "SageMakerPublicHub" else arn_parts[4]
246-
base_model_arn = f"arn:aws:sagemaker:{region}:{hub_account}:hub-content/{hub_name}/Model/{hub_content_name}/{hub_content_version}"
241+
# Base model always lives in SageMakerPublicHub (SAGEMAKER_HUB_NAME is for training recipes only)
242+
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
247243

248244
# If we couldn't extract or construct base model ARN, this is not a supported model package
249245
if not base_model_arn:

sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,10 @@
8282

8383

8484
def _get_latest_model_package_arn():
85-
"""Return the ARN of the latest approved model package, or None."""
85+
"""Return the ARN of the latest model package, or None."""
8686
sm_client = boto3.client("sagemaker", region_name=REGION)
8787
packages = sm_client.list_model_packages(
8888
ModelPackageGroupName=MODEL_PACKAGE_GROUP,
89-
ModelApprovalStatus="Approved",
9089
SortBy="CreationTime",
9190
SortOrder="Descending",
9291
MaxResults=1,
@@ -123,7 +122,7 @@ def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn):
123122
model_package_arn = _get_latest_model_package_arn()
124123
if not model_package_arn:
125124
pytest.skip(
126-
f"No approved model packages in group '{MODEL_PACKAGE_GROUP}'. "
125+
f"No model packages in group '{MODEL_PACKAGE_GROUP}'. "
127126
"Run SFT/RLVR training first."
128127
)
129128

@@ -289,7 +288,7 @@ def test_base_model_false_still_works(self, mlflow_resource_arn):
289288
model_package_arn = _get_latest_model_package_arn()
290289
if not model_package_arn:
291290
pytest.skip(
292-
f"No approved model packages in group '{MODEL_PACKAGE_GROUP}'. "
291+
f"No model packages in group '{MODEL_PACKAGE_GROUP}'. "
293292
"Run SFT/RLVR training first."
294293
)
295294

0 commit comments

Comments
 (0)