Skip to content

Commit ec3da03

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
Master mtrl eval issue fix (#5923)
* Fix:MTRL eval override hyperparams issue * Fix:MTRL eval override hyperparams issue * Fix:MTRL eval override hyperparams issue * Fix: Pass mlflow_resource_arn in LLM-as-Judge base model fix tests * Fix: Pass mlflow_resource_arn in benchmark, custom scorer, and LLM-as-Judge integ tests --------- Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com>
1 parent fda2565 commit ec3da03

9 files changed

Lines changed: 64 additions & 23 deletions

sagemaker-train/src/sagemaker/train/evaluate/multi_turn_rl_evaluator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def hyperparameters(self):
317317
override_params = _get_evaluation_override_params(
318318
hub_content_name=hub_content_name,
319319
hub_name="SageMakerPublicHub",
320-
evaluation_type="AgentRFTEvaluation",
320+
evaluation_type="MTRLEvaluation",
321321
region=self.region,
322322
session=boto_session,
323323
)
@@ -328,7 +328,9 @@ def hyperparameters(self):
328328
f"JumpStart hub."
329329
)
330330

331-
spec = _extract_eval_override_options(override_params, return_full_spec=True)
331+
spec = _extract_eval_override_options(
332+
override_params, param_names=list(override_params.keys()), return_full_spec=True
333+
)
332334
self._hyperparameters = FineTuningOptions(spec)
333335
return self._hyperparameters
334336

sagemaker-train/tests/integ/train/test_benchmark_evaluator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
"model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1",
4848
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
4949
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
50-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
50+
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
5151
"model_package_group_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
5252
"region": "us-west-2",
5353
}
@@ -57,7 +57,7 @@
5757
"base_model_id": "meta-textgeneration-llama-3-2-1b-instruct",
5858
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
5959
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
60-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
60+
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
6161
"region": "us-west-2",
6262
}
6363

@@ -124,7 +124,7 @@ def test_benchmark_evaluation_full_flow(self):
124124
benchmark=Benchmark.MMLU,
125125
model=TEST_CONFIG["model_package_arn"],
126126
s3_output_path=TEST_CONFIG["s3_output_path"],
127-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
127+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
128128
model_package_group=TEST_CONFIG["model_package_group_arn"],
129129
base_eval_name="integ-test-gen-qa-eval",
130130
)
@@ -242,7 +242,7 @@ def test_benchmark_evaluator_validation(self):
242242
benchmark="invalid_benchmark",
243243
model=TEST_CONFIG["model_package_arn"],
244244
s3_output_path=TEST_CONFIG["s3_output_path"],
245-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
245+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
246246
)
247247

248248
# Test invalid MLflow ARN format
@@ -265,7 +265,7 @@ def test_benchmark_subtasks_validation(self):
265265
benchmark=Benchmark.MMLU,
266266
model=TEST_CONFIG["model_package_arn"],
267267
s3_output_path=TEST_CONFIG["s3_output_path"],
268-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
268+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
269269
subtasks="abstract_algebra",
270270
model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/test",
271271
)
@@ -277,7 +277,7 @@ def test_benchmark_subtasks_validation(self):
277277
benchmark=Benchmark.MMLU,
278278
model=TEST_CONFIG["model_package_arn"],
279279
s3_output_path=TEST_CONFIG["s3_output_path"],
280-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
280+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
281281
subtasks=["invalid"],
282282
model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/test",
283283
)

sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1",
4949
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
5050
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
51-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
51+
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
5252
"model_package_group_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
5353
"evaluate_base_model": False,
5454
"region": "us-west-2",
@@ -60,7 +60,7 @@
6060
"evaluator_arn": "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/eval-lambda-test/0.0.1",
6161
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
6262
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
63-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
63+
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
6464
"region": "us-west-2",
6565
}
6666

@@ -111,7 +111,7 @@ def test_custom_scorer_evaluation_full_flow(self):
111111
dataset=TEST_CONFIG["dataset_s3_uri"],
112112
model=TEST_CONFIG["model_package_arn"],
113113
s3_output_path=TEST_CONFIG["s3_output_path"],
114-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
114+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
115115
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
116116
)
117117

@@ -228,7 +228,7 @@ def test_custom_scorer_evaluator_validation(self):
228228
evaluator=123, # Invalid type (not string, enum, or object)
229229
model=TEST_CONFIG["model_package_arn"],
230230
s3_output_path=TEST_CONFIG["s3_output_path"],
231-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
231+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
232232
dataset=TEST_CONFIG["dataset_s3_uri"],
233233
)
234234

@@ -268,7 +268,7 @@ def test_custom_scorer_with_builtin_metric(self):
268268
dataset=TEST_CONFIG["dataset_s3_uri"],
269269
model=TEST_CONFIG["model_package_arn"],
270270
s3_output_path=TEST_CONFIG["s3_output_path"],
271-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
271+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
272272
evaluate_base_model=False,
273273
)
274274

sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
"builtin_metrics": ["Completeness", "Faithfulness"],
7272
"custom_metrics_json": json.dumps([CUSTOM_METRIC_DICT]),
7373
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/base-model-fix-test/",
74-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
74+
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
7575
"evaluate_base_model": True, # This is the key difference - testing base model evaluation
7676
"region": "us-west-2",
7777
}
@@ -115,6 +115,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
115115
custom_metrics=TEST_CONFIG["custom_metrics_json"],
116116
s3_output_path=TEST_CONFIG["s3_output_path"],
117117
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
118+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
118119
)
119120

120121
# Verify evaluator configuration
@@ -271,6 +272,7 @@ def test_base_model_false_still_works(self):
271272
builtin_metrics=TEST_CONFIG["builtin_metrics"],
272273
s3_output_path=TEST_CONFIG["s3_output_path"],
273274
evaluate_base_model=False, # Only evaluate custom model
275+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
274276
)
275277

276278
# Verify evaluator configuration

sagemaker-train/tests/integ/train/test_llm_as_judge_evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
"builtin_metrics": ["Completeness", "Faithfulness"],
7878
"custom_metrics_json": json.dumps([CUSTOM_METRIC_DICT]),
7979
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
80-
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
80+
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
8181
# "model_package_group_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
8282
"evaluate_base_model": False,
8383
"region": "us-west-2",
@@ -113,7 +113,7 @@ def test_llm_as_judge_evaluation_full_flow(self):
113113
dataset=TEST_CONFIG["dataset_s3_uri"],
114114
builtin_metrics=TEST_CONFIG["builtin_metrics"],
115115
custom_metrics=TEST_CONFIG["custom_metrics_json"],
116-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
116+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
117117
s3_output_path=TEST_CONFIG["s3_output_path"],
118118
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
119119
)
@@ -236,7 +236,7 @@ def test_llm_as_judge_builtin_metrics_prefix_handling(self):
236236
evaluator_model=TEST_CONFIG["evaluator_model"],
237237
dataset=TEST_CONFIG["dataset_s3_uri"],
238238
s3_output_path=TEST_CONFIG["s3_output_path"],
239-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
239+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
240240
builtin_metrics=["Builtin.Correctness", "Builtin.Helpfulness"],
241241
)
242242
assert evaluator_with_prefix.builtin_metrics == ["Builtin.Correctness", "Builtin.Helpfulness"]
@@ -247,7 +247,7 @@ def test_llm_as_judge_builtin_metrics_prefix_handling(self):
247247
evaluator_model=TEST_CONFIG["evaluator_model"],
248248
dataset=TEST_CONFIG["dataset_s3_uri"],
249249
s3_output_path=TEST_CONFIG["s3_output_path"],
250-
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
250+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
251251
builtin_metrics=["Correctness", "Helpfulness"],
252252
)
253253
assert evaluator_without_prefix.builtin_metrics == ["Correctness", "Helpfulness"]

sagemaker-train/tests/integ/train/test_mtrl_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
"agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{_ACCOUNT_ID}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
4848
"dataset": f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet",
4949
"s3_output_path": f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/output-artifacts/",
50-
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU",
50+
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
5151
"model_package_group": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
5252
"role": f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin",
5353
"region": _REGION,

sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def handler(event, context):
161161
),
162162
"mlflow_resource_arn": os.environ.get(
163163
"MTRL_3P_MLFLOW_ARN",
164-
f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU",
164+
f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
165165
),
166166
"role": os.environ.get(
167167
"MTRL_3P_ROLE",
@@ -262,6 +262,7 @@ def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn):
262262
logger.info(f"Started 3P agent base model evaluation: {execution.arn}")
263263
logger.info(f"Status: {execution.status.overall_status}")
264264

265+
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
265266
def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn):
266267
"""Test evaluating using an CustomAgentLambda object as agent_config.
267268
@@ -287,6 +288,7 @@ def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn):
287288
assert execution.arn is not None
288289
logger.info(f"Started CustomAgentLambda object evaluation: {execution.arn}")
289290

291+
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
290292
def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn):
291293
"""Test full end-to-end: start evaluation and wait for completion.
292294
@@ -316,6 +318,7 @@ def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn):
316318
if execution.status.overall_status == "Failed":
317319
logger.error(f"Failure reason: {execution.status.failure_reason}")
318320

321+
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
319322
def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn):
320323
"""Test that 3P agent evaluations are discoverable via get_all.
321324
@@ -355,6 +358,7 @@ def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn):
355358

356359

357360

361+
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
358362
def test_evaluate_with_attached_trainer(self, lambda_agent_arn):
359363
"""Test evaluating a fine-tuned model by attaching to an existing training job."""
360364
from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer

sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def _get_account_id():
5555
# PROD — Main account (729646638167)
5656
"729646638167": {
5757
"env_name": "PROD",
58-
"existing_job_name": "openai-reasoning-gpt-oss-20b-mtrl-20260602150414",
58+
"existing_job_name": "openai-reasoning-gpt-oss-20b-mtrl-20260602215955",
5959
"base_model": "openai-reasoning-gpt-oss-20b",
6060
"agent_core_arn": "arn:aws:bedrock-agentcore:us-west-2:729646638167:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
6161
"dataset": "s3://sagemaker-rft-729646638167/prompts/gsm8k_small/prompts.parquet",
6262
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/mtrl-integ/eval-output/",
63-
"mlflow_resource_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-ZG6FYITNGMMU",
63+
"mlflow_resource_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
6464
"model_package_group": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
6565
"role": "arn:aws:iam::729646638167:role/Admin",
6666
},
@@ -187,6 +187,7 @@ def test_evaluate_finetuned_model(self, attached_trainer, config):
187187
f"reason: {execution.status.failure_reason}"
188188
)
189189

190+
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
190191
def test_evaluate_base_model(self, config):
191192
"""Evaluate the base model only — submit and wait for completion."""
192193
evaluator = MultiTurnRLEvaluator(
@@ -247,3 +248,35 @@ def test_evaluate_comparison(self, attached_trainer, config):
247248
f"[{config['env_name']}] Comparison eval failed with status: {status}, "
248249
f"reason: {execution.status.failure_reason}"
249250
)
251+
252+
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
253+
def test_evaluate_with_hyperparam_override(self, attached_trainer, config):
254+
"""Test that hyperparameter overrides are passed through to the eval job."""
255+
evaluator = MultiTurnRLEvaluator(
256+
model=attached_trainer,
257+
dataset=config["dataset"],
258+
s3_output_path=f'{config["s3_output_path"]}hyperparam-override/',
259+
mlflow_resource_arn=config["mlflow_resource_arn"],
260+
role=config["role"],
261+
region=_REGION,
262+
)
263+
264+
# Override MTRL-specific hyperparams
265+
evaluator.hyperparameters.sampling_max_tokens = 1024
266+
evaluator.hyperparameters.eval_group_size = 4
267+
268+
execution = evaluator.evaluate()
269+
270+
assert execution is not None
271+
assert execution.arn is not None
272+
logger.info(f"[{config['env_name']}] Started hyperparam override eval: {execution.arn}")
273+
274+
execution.wait(timeout=EVAL_TIMEOUT)
275+
276+
status = execution.status.overall_status
277+
logger.info(f"[{config['env_name']}] Hyperparam override eval completed: {status}")
278+
279+
assert status == "Succeeded", (
280+
f"[{config['env_name']}] Hyperparam override eval failed with status: {status}, "
281+
f"reason: {execution.status.failure_reason}"
282+
)

sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
AGENT_RUNTIME_ID = "sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS"
3737
ROLE_ARN = f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin"
38-
MLFLOW_ARN = f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU"
38+
MLFLOW_ARN = f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6"
3939
S3_INPUT_PATH = f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet"
4040
S3_OUTPUT_PATH = f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/mtrl-trainer-integ/"
4141
LAMBDA_ARN = f"arn:aws:lambda:{_REGION}:{_ACCOUNT_ID}:function:SageMaker-AgentConnector-Lambda-MTRL-integ-test"

0 commit comments

Comments
 (0)