|
18 | 18 | """ |
19 | 19 | from __future__ import absolute_import |
20 | 20 |
|
| 21 | +import boto3 |
21 | 22 | import json |
22 | 23 | import time |
23 | 24 | import pytest |
|
64 | 65 | } |
65 | 66 |
|
66 | 67 | # Test configuration |
| 68 | +MODEL_PACKAGE_GROUP = "sdk-test-finetuned-models" |
| 69 | +REGION = "us-west-2" |
| 70 | +ACCOUNT_ID = "729646638167" |
| 71 | + |
67 | 72 | TEST_CONFIG = { |
68 | | - "model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1", |
69 | 73 | "evaluator_model": "anthropic.claude-3-5-haiku-20241022-v1:0", |
70 | | - "dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/gen_qa.jsonl", |
| 74 | + "dataset_s3_uri": f"s3://sagemaker-{REGION}-{ACCOUNT_ID}/model-customization/eval/gen_qa.jsonl", |
71 | 75 | "builtin_metrics": ["Completeness", "Faithfulness"], |
72 | 76 | "custom_metrics_json": json.dumps([CUSTOM_METRIC_DICT]), |
73 | | - "s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/base-model-fix-test/", |
74 | | - "mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6", |
| 77 | + "s3_output_path": f"s3://sagemaker-{REGION}-{ACCOUNT_ID}/model-customization/eval/base-model-fix-test/", |
| 78 | + "mlflow_tracking_server_arn": f"arn:aws:sagemaker:{REGION}:{ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6", |
75 | 79 | "evaluate_base_model": True, # This is the key difference - testing base model evaluation |
76 | | - "region": "us-west-2", |
| 80 | + "region": REGION, |
77 | 81 | } |
78 | 82 |
|
79 | 83 |
|
| 84 | +def _get_latest_model_package_arn(): |
| 85 | + """Return the ARN of the latest approved model package, or None.""" |
| 86 | + sm_client = boto3.client("sagemaker", region_name=REGION) |
| 87 | + packages = sm_client.list_model_packages( |
| 88 | + ModelPackageGroupName=MODEL_PACKAGE_GROUP, |
| 89 | + ModelApprovalStatus="Approved", |
| 90 | + SortBy="CreationTime", |
| 91 | + SortOrder="Descending", |
| 92 | + MaxResults=1, |
| 93 | + ) |
| 94 | + summaries = packages.get("ModelPackageSummaryList", []) |
| 95 | + if not summaries: |
| 96 | + return None |
| 97 | + return summaries[0]["ModelPackageArn"] |
| 98 | + |
| 99 | + |
80 | 100 | @pytest.mark.serial |
81 | 101 | class TestLLMAsJudgeBaseModelFix: |
82 | 102 | """Integration test for base model fix in LLMAsJudgeEvaluator""" |
83 | 103 |
|
84 | 104 | def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn): |
85 | 105 | """ |
86 | 106 | Test that base model evaluation uses original base model weights. |
87 | | - |
| 107 | +
|
88 | 108 | This test verifies the fix for the bug where base model evaluation |
89 | 109 | incorrectly used fine-tuned model weights. The test: |
90 | | - |
| 110 | +
|
91 | 111 | 1. Creates an evaluator with evaluate_base_model=True |
92 | 112 | 2. Starts the evaluation pipeline |
93 | | - 3. Verifies the pipeline has both EvaluateBaseInferenceModel and |
| 113 | + 3. Verifies the pipeline has both EvaluateBaseInferenceModel and |
94 | 114 | EvaluateCustomInferenceModel steps |
95 | 115 | 4. Waits for completion |
96 | 116 | 5. Compares results to ensure base and custom models produce different outputs |
97 | | - |
| 117 | +
|
98 | 118 | Expected behavior: |
99 | 119 | - EvaluateBaseInferenceModel should use only BaseModelArn (no ModelPackageConfig) |
100 | 120 | - EvaluateCustomInferenceModel should use ModelPackageConfig with SourceModelPackageArn |
101 | 121 | - Results should show different performance between base and custom models |
102 | 122 | """ |
| 123 | + model_package_arn = _get_latest_model_package_arn() |
| 124 | + if not model_package_arn: |
| 125 | + pytest.skip( |
| 126 | + f"No approved model packages in group '{MODEL_PACKAGE_GROUP}'. " |
| 127 | + "Run SFT/RLVR training first." |
| 128 | + ) |
| 129 | + |
103 | 130 | logger.info("=" * 80) |
104 | 131 | logger.info("Testing Base Model Fix: evaluate_base_model=True") |
105 | 132 | logger.info("=" * 80) |
106 | | - |
| 133 | + |
107 | 134 | # Step 1: Create evaluator with evaluate_base_model=True |
108 | 135 | logger.info("Creating LLMAsJudgeEvaluator with evaluate_base_model=True") |
109 | | - |
| 136 | + logger.info(f"Using model package: {model_package_arn}") |
| 137 | + |
110 | 138 | evaluator = LLMAsJudgeEvaluator( |
111 | | - model=TEST_CONFIG["model_package_arn"], |
| 139 | + model=model_package_arn, |
112 | 140 | evaluator_model=TEST_CONFIG["evaluator_model"], |
113 | 141 | dataset=TEST_CONFIG["dataset_s3_uri"], |
114 | 142 | builtin_metrics=TEST_CONFIG["builtin_metrics"], |
@@ -254,19 +282,27 @@ def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn): |
254 | 282 | def test_base_model_false_still_works(self, mlflow_resource_arn): |
255 | 283 | """ |
256 | 284 | Test that evaluate_base_model=False still works correctly (backward compatibility). |
257 | | - |
| 285 | +
|
258 | 286 | This test ensures the fix doesn't break existing functionality when |
259 | 287 | evaluate_base_model=False (the default behavior). |
260 | 288 | """ |
| 289 | + model_package_arn = _get_latest_model_package_arn() |
| 290 | + if not model_package_arn: |
| 291 | + pytest.skip( |
| 292 | + f"No approved model packages in group '{MODEL_PACKAGE_GROUP}'. " |
| 293 | + "Run SFT/RLVR training first." |
| 294 | + ) |
| 295 | + |
261 | 296 | logger.info("=" * 80) |
262 | 297 | logger.info("Testing Backward Compatibility: evaluate_base_model=False") |
263 | 298 | logger.info("=" * 80) |
264 | | - |
| 299 | + |
265 | 300 | # Create evaluator with evaluate_base_model=False |
266 | 301 | logger.info("Creating LLMAsJudgeEvaluator with evaluate_base_model=False") |
267 | | - |
| 302 | + logger.info(f"Using model package: {model_package_arn}") |
| 303 | + |
268 | 304 | evaluator = LLMAsJudgeEvaluator( |
269 | | - model=TEST_CONFIG["model_package_arn"], |
| 305 | + model=model_package_arn, |
270 | 306 | evaluator_model=TEST_CONFIG["evaluator_model"], |
271 | 307 | dataset=TEST_CONFIG["dataset_s3_uri"], |
272 | 308 | builtin_metrics=TEST_CONFIG["builtin_metrics"], |
|
0 commit comments