Skip to content

Commit 32fab32

Browse files
committed
fix(train): use latest model package in LLM-as-judge base model integ tests
The hardcoded model package version 1 predates the backend's SageMakerPublicHub requirement for serverless training jobs, causing consistent failures across all PRs.
1 parent 7cdd30f commit 32fab32

1 file changed

Lines changed: 52 additions & 16 deletions

File tree

sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"""
1919
from __future__ import absolute_import
2020

21+
import boto3
2122
import json
2223
import time
2324
import pytest
@@ -64,51 +65,78 @@
6465
}
6566

6667
# Test configuration
68+
MODEL_PACKAGE_GROUP = "sdk-test-finetuned-models"
69+
REGION = "us-west-2"
70+
ACCOUNT_ID = "729646638167"
71+
6772
TEST_CONFIG = {
68-
"model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1",
6973
"evaluator_model": "anthropic.claude-3-5-haiku-20241022-v1:0",
70-
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/gen_qa.jsonl",
74+
"dataset_s3_uri": f"s3://sagemaker-{REGION}-{ACCOUNT_ID}/model-customization/eval/gen_qa.jsonl",
7175
"builtin_metrics": ["Completeness", "Faithfulness"],
7276
"custom_metrics_json": json.dumps([CUSTOM_METRIC_DICT]),
73-
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/base-model-fix-test/",
74-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
77+
"s3_output_path": f"s3://sagemaker-{REGION}-{ACCOUNT_ID}/model-customization/eval/base-model-fix-test/",
78+
"mlflow_tracking_server_arn": f"arn:aws:sagemaker:{REGION}:{ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
7579
"evaluate_base_model": True, # This is the key difference - testing base model evaluation
76-
"region": "us-west-2",
80+
"region": REGION,
7781
}
7882

7983

84+
def _get_latest_model_package_arn():
85+
"""Return the ARN of the latest approved model package, or None."""
86+
sm_client = boto3.client("sagemaker", region_name=REGION)
87+
packages = sm_client.list_model_packages(
88+
ModelPackageGroupName=MODEL_PACKAGE_GROUP,
89+
ModelApprovalStatus="Approved",
90+
SortBy="CreationTime",
91+
SortOrder="Descending",
92+
MaxResults=1,
93+
)
94+
summaries = packages.get("ModelPackageSummaryList", [])
95+
if not summaries:
96+
return None
97+
return summaries[0]["ModelPackageArn"]
98+
99+
80100
@pytest.mark.serial
81101
class TestLLMAsJudgeBaseModelFix:
82102
"""Integration test for base model fix in LLMAsJudgeEvaluator"""
83103

84104
def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn):
85105
"""
86106
Test that base model evaluation uses original base model weights.
87-
107+
88108
This test verifies the fix for the bug where base model evaluation
89109
incorrectly used fine-tuned model weights. The test:
90-
110+
91111
1. Creates an evaluator with evaluate_base_model=True
92112
2. Starts the evaluation pipeline
93-
3. Verifies the pipeline has both EvaluateBaseInferenceModel and
113+
3. Verifies the pipeline has both EvaluateBaseInferenceModel and
94114
EvaluateCustomInferenceModel steps
95115
4. Waits for completion
96116
5. Compares results to ensure base and custom models produce different outputs
97-
117+
98118
Expected behavior:
99119
- EvaluateBaseInferenceModel should use only BaseModelArn (no ModelPackageConfig)
100120
- EvaluateCustomInferenceModel should use ModelPackageConfig with SourceModelPackageArn
101121
- Results should show different performance between base and custom models
102122
"""
123+
model_package_arn = _get_latest_model_package_arn()
124+
if not model_package_arn:
125+
pytest.skip(
126+
f"No approved model packages in group '{MODEL_PACKAGE_GROUP}'. "
127+
"Run SFT/RLVR training first."
128+
)
129+
103130
logger.info("=" * 80)
104131
logger.info("Testing Base Model Fix: evaluate_base_model=True")
105132
logger.info("=" * 80)
106-
133+
107134
# Step 1: Create evaluator with evaluate_base_model=True
108135
logger.info("Creating LLMAsJudgeEvaluator with evaluate_base_model=True")
109-
136+
logger.info(f"Using model package: {model_package_arn}")
137+
110138
evaluator = LLMAsJudgeEvaluator(
111-
model=TEST_CONFIG["model_package_arn"],
139+
model=model_package_arn,
112140
evaluator_model=TEST_CONFIG["evaluator_model"],
113141
dataset=TEST_CONFIG["dataset_s3_uri"],
114142
builtin_metrics=TEST_CONFIG["builtin_metrics"],
@@ -254,19 +282,27 @@ def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn):
254282
def test_base_model_false_still_works(self, mlflow_resource_arn):
255283
"""
256284
Test that evaluate_base_model=False still works correctly (backward compatibility).
257-
285+
258286
This test ensures the fix doesn't break existing functionality when
259287
evaluate_base_model=False (the default behavior).
260288
"""
289+
model_package_arn = _get_latest_model_package_arn()
290+
if not model_package_arn:
291+
pytest.skip(
292+
f"No approved model packages in group '{MODEL_PACKAGE_GROUP}'. "
293+
"Run SFT/RLVR training first."
294+
)
295+
261296
logger.info("=" * 80)
262297
logger.info("Testing Backward Compatibility: evaluate_base_model=False")
263298
logger.info("=" * 80)
264-
299+
265300
# Create evaluator with evaluate_base_model=False
266301
logger.info("Creating LLMAsJudgeEvaluator with evaluate_base_model=False")
267-
302+
logger.info(f"Using model package: {model_package_arn}")
303+
268304
evaluator = LLMAsJudgeEvaluator(
269-
model=TEST_CONFIG["model_package_arn"],
305+
model=model_package_arn,
270306
evaluator_model=TEST_CONFIG["evaluator_model"],
271307
dataset=TEST_CONFIG["dataset_s3_uri"],
272308
builtin_metrics=TEST_CONFIG["builtin_metrics"],

0 commit comments

Comments
 (0)