Skip to content

Commit c1084f7

Browse files
fix(train): use latest model package in LLM-as-judge base model integ tests (#5959)
* fix(train): use latest model package in LLM-as-judge base model integ tests The hardcoded model package version 1 predates the backend's SageMakerPublicHub requirement for serverless training jobs, causing consistent failures across all PRs. * fix(train): always use SageMakerPublicHub for base model ARN in evaluations The backend now enforces that serverless training jobs only accept BaseModelArn values pointing to SageMakerPublicHub. The SAGEMAKER_HUB_NAME env var is for training recipe lookups only and should not affect the base model ARN passed to evaluation pipelines. Also remove ModelApprovalStatus filter from test helper since training tests never set approval status on output packages. * fix(train): fall back to SageMakerPublicHub when model not found in private hub When _get_hub_content_metadata fails to find a model in the configured private hub (e.g. sdktest), retry with SageMakerPublicHub. This handles models like meta-textgeneration-llama-3-2-1b-instruct that only exist in the public hub.
1 parent b84338b commit c1084f7

3 files changed

Lines changed: 77 additions & 30 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
- ModelPackage objects or ARNs (fine-tuned models)
77
"""
88

9-
import os
109
import json
1110
import boto3
1211
from typing import Union, Optional, Dict, Any
@@ -239,11 +238,8 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
239238
arn_parts = model_pkg_arn.split(':')
240239
if len(arn_parts) >= 4:
241240
region = arn_parts[3]
242-
# Use SAGEMAKER_HUB_NAME if set (private hub), otherwise fall back to public hub
243-
hub_name = os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub")
244-
# Private hubs are account-scoped; public hub uses 'aws' as account
245-
hub_account = "aws" if hub_name == "SageMakerPublicHub" else arn_parts[4]
246-
base_model_arn = f"arn:aws:sagemaker:{region}:{hub_account}:hub-content/{hub_name}/Model/{hub_content_name}/{hub_content_version}"
241+
# Base model always lives in SageMakerPublicHub (SAGEMAKER_HUB_NAME is for training recipes only)
242+
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
247243

248244
# If we couldn't extract or construct base model ARN, this is not a supported model package
249245
if not base_model_arn:

sagemaker-train/src/sagemaker/train/common_utils/recipe_utils.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,30 @@ def _get_hub_content_metadata(
6363
... )
6464
>>> print(metadata['HubContentName'])
6565
"""
66-
hub_content = HubContent.get(
67-
hub_name=hub_name,
68-
hub_content_type=hub_content_type,
69-
hub_content_name=hub_content_name,
70-
region=region,
71-
session=session
72-
)
73-
66+
try:
67+
hub_content = HubContent.get(
68+
hub_name=hub_name,
69+
hub_content_type=hub_content_type,
70+
hub_content_name=hub_content_name,
71+
region=region,
72+
session=session
73+
)
74+
except Exception:
75+
if hub_name != "SageMakerPublicHub":
76+
logger.info(
77+
f"Hub content '{hub_content_name}' not found in '{hub_name}', "
78+
f"falling back to SageMakerPublicHub"
79+
)
80+
hub_content = HubContent.get(
81+
hub_name="SageMakerPublicHub",
82+
hub_content_type=hub_content_type,
83+
hub_content_name=hub_content_name,
84+
region=region,
85+
session=session
86+
)
87+
else:
88+
raise
89+
7490
# Convert to dict for easier access
7591
hub_content_dict = hub_content.__dict__
7692

sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919
from __future__ import absolute_import
2020

21+
import boto3
2122
import json
2223
import time
2324
import pytest
@@ -64,51 +65,77 @@
6465
}
6566

6667
# Test configuration
68+
MODEL_PACKAGE_GROUP = "sdk-test-finetuned-models"
69+
REGION = "us-west-2"
70+
ACCOUNT_ID = "729646638167"
71+
6772
TEST_CONFIG = {
68-
"model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1",
6973
"evaluator_model": "anthropic.claude-3-5-haiku-20241022-v1:0",
70-
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/gen_qa.jsonl",
74+
"dataset_s3_uri": f"s3://sagemaker-{REGION}-{ACCOUNT_ID}/model-customization/eval/gen_qa.jsonl",
7175
"builtin_metrics": ["Completeness", "Faithfulness"],
7276
"custom_metrics_json": json.dumps([CUSTOM_METRIC_DICT]),
73-
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/base-model-fix-test/",
74-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
77+
"s3_output_path": f"s3://sagemaker-{REGION}-{ACCOUNT_ID}/model-customization/eval/base-model-fix-test/",
78+
"mlflow_tracking_server_arn": f"arn:aws:sagemaker:{REGION}:{ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
7579
"evaluate_base_model": True, # This is the key difference - testing base model evaluation
76-
"region": "us-west-2",
80+
"region": REGION,
7781
}
7882

7983

84+
def _get_latest_model_package_arn():
85+
"""Return the ARN of the latest model package, or None."""
86+
sm_client = boto3.client("sagemaker", region_name=REGION)
87+
packages = sm_client.list_model_packages(
88+
ModelPackageGroupName=MODEL_PACKAGE_GROUP,
89+
SortBy="CreationTime",
90+
SortOrder="Descending",
91+
MaxResults=1,
92+
)
93+
summaries = packages.get("ModelPackageSummaryList", [])
94+
if not summaries:
95+
return None
96+
return summaries[0]["ModelPackageArn"]
97+
98+
8099
@pytest.mark.serial
81100
class TestLLMAsJudgeBaseModelFix:
82101
"""Integration test for base model fix in LLMAsJudgeEvaluator"""
83102

84103
def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn):
85104
"""
86105
Test that base model evaluation uses original base model weights.
87-
106+
88107
This test verifies the fix for the bug where base model evaluation
89108
incorrectly used fine-tuned model weights. The test:
90-
109+
91110
1. Creates an evaluator with evaluate_base_model=True
92111
2. Starts the evaluation pipeline
93-
3. Verifies the pipeline has both EvaluateBaseInferenceModel and
112+
3. Verifies the pipeline has both EvaluateBaseInferenceModel and
94113
EvaluateCustomInferenceModel steps
95114
4. Waits for completion
96115
5. Compares results to ensure base and custom models produce different outputs
97-
116+
98117
Expected behavior:
99118
- EvaluateBaseInferenceModel should use only BaseModelArn (no ModelPackageConfig)
100119
- EvaluateCustomInferenceModel should use ModelPackageConfig with SourceModelPackageArn
101120
- Results should show different performance between base and custom models
102121
"""
122+
model_package_arn = _get_latest_model_package_arn()
123+
if not model_package_arn:
124+
pytest.skip(
125+
f"No model packages in group '{MODEL_PACKAGE_GROUP}'. "
126+
"Run SFT/RLVR training first."
127+
)
128+
103129
logger.info("=" * 80)
104130
logger.info("Testing Base Model Fix: evaluate_base_model=True")
105131
logger.info("=" * 80)
106-
132+
107133
# Step 1: Create evaluator with evaluate_base_model=True
108134
logger.info("Creating LLMAsJudgeEvaluator with evaluate_base_model=True")
109-
135+
logger.info(f"Using model package: {model_package_arn}")
136+
110137
evaluator = LLMAsJudgeEvaluator(
111-
model=TEST_CONFIG["model_package_arn"],
138+
model=model_package_arn,
112139
evaluator_model=TEST_CONFIG["evaluator_model"],
113140
dataset=TEST_CONFIG["dataset_s3_uri"],
114141
builtin_metrics=TEST_CONFIG["builtin_metrics"],
@@ -254,19 +281,27 @@ def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn):
254281
def test_base_model_false_still_works(self, mlflow_resource_arn):
255282
"""
256283
Test that evaluate_base_model=False still works correctly (backward compatibility).
257-
284+
258285
This test ensures the fix doesn't break existing functionality when
259286
evaluate_base_model=False (the default behavior).
260287
"""
288+
model_package_arn = _get_latest_model_package_arn()
289+
if not model_package_arn:
290+
pytest.skip(
291+
f"No model packages in group '{MODEL_PACKAGE_GROUP}'. "
292+
"Run SFT/RLVR training first."
293+
)
294+
261295
logger.info("=" * 80)
262296
logger.info("Testing Backward Compatibility: evaluate_base_model=False")
263297
logger.info("=" * 80)
264-
298+
265299
# Create evaluator with evaluate_base_model=False
266300
logger.info("Creating LLMAsJudgeEvaluator with evaluate_base_model=False")
267-
301+
logger.info(f"Using model package: {model_package_arn}")
302+
268303
evaluator = LLMAsJudgeEvaluator(
269-
model=TEST_CONFIG["model_package_arn"],
304+
model=model_package_arn,
270305
evaluator_model=TEST_CONFIG["evaluator_model"],
271306
dataset=TEST_CONFIG["dataset_s3_uri"],
272307
builtin_metrics=TEST_CONFIG["builtin_metrics"],

0 commit comments

Comments
 (0)