Skip to content

Commit b71bc7e

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
fix:unskip and add mock test setup for mtrl integ tests (#5937)
* fix:unskip and mark MTRL tests as gpu intensive * Add Mock setup for mtrl integ tests * Add Mock setup for mtrl integ tests * Add Mock setup for mtrl integ tests --------- Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com>
1 parent f7bb4d5 commit b71bc7e

5 files changed

Lines changed: 63 additions & 125 deletions

File tree

sagemaker-train/tests/integ/train/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
DEFAULT_REGION = "us-west-2"
2323

2424

25+
@pytest.fixture(autouse=True, scope="session")
26+
def use_private_hub():
27+
os.environ["SAGEMAKER_HUB_NAME"] = "sdktest"
28+
yield
29+
del os.environ["SAGEMAKER_HUB_NAME"]
30+
31+
2532
@pytest.fixture(scope="module")
2633
def sagemaker_session():
2734
region = os.environ.get("AWS_DEFAULT_REGION")

sagemaker-train/tests/integ/train/test_mtrl_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _get_test_config():
4141
boto_session = boto3.Session(region_name=_REGION)
4242
account_id = boto_session.client("sts").get_caller_identity()["Account"]
4343
return {
44-
"base_model": "openai-reasoning-gpt-oss-20b",
44+
"base_model": "mock-oss-test",
4545
"agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{account_id}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
4646
"dataset": f"s3://sagemaker-rft-{account_id}/prompts/gsm8k_small/prompts.parquet",
4747
"s3_output_path": f"s3://sagemaker-{_REGION}-{account_id}/model-evaluation/output-artifacts/",
@@ -142,7 +142,7 @@ def mtrl_trainer(sagemaker_session_mtrl, test_config):
142142

143143
trainer = object.__new__(MultiTurnRLTrainer)
144144
trainer._model_name = test_config["base_model"]
145-
trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:aws:hub-content/SageMakerPublicHub/Model/{test_config['base_model']}/1.0.0"
145+
trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:{test_config['account_id']}:hub-content/sdktest/Model/{test_config['base_model']}/0.0.1"
146146
trainer.agent_env = test_config["agent_arn"]
147147
trainer.bedrock_agentcore_qualifier = "DEFAULT"
148148
trainer.output_model_package_group = test_config["model_package_group"]

sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py

Lines changed: 35 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _get_3p_test_config():
148148
boto_session = boto3.Session(region_name=_REGION)
149149
account_id = boto_session.client("sts").get_caller_identity()["Account"]
150150
return {
151-
"base_model": "openai-reasoning-gpt-oss-20b",
151+
"base_model": "mock-oss-test",
152152
"dataset": os.environ.get(
153153
"MTRL_3P_DATASET",
154154
f"s3://sagemaker-rft-{account_id}/prompts/gsm8k_small/prompts.parquet",
@@ -240,21 +240,23 @@ def lambda_agent_arn(test_config):
240240
return _ensure_lambda_exists(test_config["account_id"])
241241

242242

243+
@pytest.mark.gpu_intensive
244+
@pytest.mark.serial
243245
class TestMTRLEvaluator3PAgentIntegration:
244246
"""Integration tests for MultiTurnRLEvaluator with Lambda-based 3P agent."""
245247

246-
def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn, test_config):
247-
"""Test evaluating a base model using a Lambda ARN as agent_config.
248+
def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn, test_config):
249+
"""Test full end-to-end: start evaluation, wait for completion, and verify discoverability.
248250
249-
This is the primary 3P integration pattern: customer provides a
250-
Lambda function that wraps their agent (LangChain, Strands, etc.)
251-
and the evaluator runs rollouts against it.
251+
This test validates the complete lifecycle including wait() using
252+
the standard sagemaker-core pipeline execution path, and verifies
253+
the evaluation is discoverable via get_all().
252254
"""
253255
evaluator = MultiTurnRLEvaluator(
254256
model=test_config["base_model"],
255257
dataset=test_config["dataset"],
256258
agent_config=lambda_agent_arn,
257-
s3_output_path=f'{test_config["s3_output_path"]}lambda-base-model/',
259+
s3_output_path=f'{test_config["s3_output_path"]}lambda-e2e/',
258260
mlflow_resource_arn=test_config["mlflow_resource_arn"],
259261
role=test_config["role"],
260262
region=test_config["region"],
@@ -267,9 +269,27 @@ def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn, test_conf
267269
assert execution.arn is not None
268270
assert "pipeline" in execution.arn.lower()
269271
logger.info(f"Started 3P agent base model evaluation: {execution.arn}")
270-
logger.info(f"Status: {execution.status.overall_status}")
271272

272-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
273+
execution.wait(timeout=EVALUATION_TIMEOUT_SECONDS)
274+
assert execution.status.overall_status in ("Succeeded", "Failed", "Stopped")
275+
logger.info(f"Execution completed: {execution.status.overall_status}")
276+
277+
if execution.status.overall_status == "Failed":
278+
logger.error(f"Failure reason: {execution.status.failure_reason}")
279+
280+
# Verify it's discoverable via get_all
281+
found = False
282+
for ex in MultiTurnRLEvaluator.get_all(region=test_config["region"]):
283+
if ex.arn == execution.arn:
284+
found = True
285+
break
286+
287+
assert found, (
288+
f"Evaluation {execution.arn} not found via get_all(). "
289+
"Pipeline tagging may not be working correctly."
290+
)
291+
logger.info(f"Successfully discovered evaluation via get_all: {execution.arn}")
292+
273293
def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn, test_config):
274294
"""Test evaluating using an CustomAgentLambda object as agent_config.
275295
@@ -295,83 +315,15 @@ def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn, te
295315
assert execution.arn is not None
296316
logger.info(f"Started CustomAgentLambda object evaluation: {execution.arn}")
297317

298-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
299-
def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn, test_config):
300-
"""Test full end-to-end: start evaluation and wait for completion.
301-
302-
This test validates the complete lifecycle including wait() using
303-
the standard sagemaker-core pipeline execution path.
304-
"""
305-
evaluator = MultiTurnRLEvaluator(
306-
model=test_config["base_model"],
307-
dataset=test_config["dataset"],
308-
agent_config=lambda_agent_arn,
309-
s3_output_path=f'{test_config["s3_output_path"]}lambda-e2e/',
310-
mlflow_resource_arn=test_config["mlflow_resource_arn"],
311-
role=test_config["role"],
312-
region=test_config["region"],
313-
accept_eula=True,
314-
)
315-
316-
execution = evaluator.evaluate()
317-
assert execution is not None
318-
319-
logger.info(f"Waiting for execution: {execution.arn}")
320-
execution.wait()
321-
322-
assert execution.status.overall_status in ("Succeeded", "Failed", "Stopped")
323-
logger.info(f"Execution completed: {execution.status.overall_status}")
324-
325-
if execution.status.overall_status == "Failed":
326-
logger.error(f"Failure reason: {execution.status.failure_reason}")
327-
328-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
329-
def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn, test_config):
330-
"""Test that 3P agent evaluations are discoverable via get_all.
331-
332-
Validates that evaluations started with Lambda agents show up in
333-
the standard get_all() discovery path (pipeline tagging works).
334-
"""
335-
evaluator = MultiTurnRLEvaluator(
336-
model=test_config["base_model"],
337-
dataset=test_config["dataset"],
338-
agent_config=lambda_agent_arn,
339-
s3_output_path=f'{test_config["s3_output_path"]}lambda-discovery/',
340-
mlflow_resource_arn=test_config["mlflow_resource_arn"],
341-
role=test_config["role"],
342-
region=test_config["region"],
343-
accept_eula=True,
344-
)
345-
346-
execution = evaluator.evaluate()
347-
assert execution is not None
348-
started_arn = execution.arn
349-
350-
# Give pipeline time to register
351-
time.sleep(10)
352-
353-
# Verify it's discoverable via get_all
354-
found = False
355-
for ex in MultiTurnRLEvaluator.get_all(region=test_config["region"]):
356-
if ex.arn == started_arn:
357-
found = True
358-
break
359-
360-
assert found, (
361-
f"Evaluation {started_arn} not found via get_all(). "
362-
"Pipeline tagging may not be working correctly."
363-
)
364-
logger.info(f"Successfully discovered evaluation via get_all: {started_arn}")
365-
318+
execution.wait(timeout=EVALUATION_TIMEOUT_SECONDS)
319+
assert execution.status.overall_status == "Succeeded"
366320

367-
368-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
369321
def test_evaluate_with_attached_trainer(self, lambda_agent_arn, test_config):
370322
"""Test evaluating a fine-tuned model by attaching to an existing training job."""
371323
from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer
372324

373325
attached_job = MultiTurnRLTrainer.attach(
374-
"openai-reasoning-gpt-oss-20b-mtrl-20260602164546", session=boto3.Session(region_name=_REGION)
326+
"mock-oss-test-mtrl-20260615143910", session=boto3.Session(region_name=_REGION)
375327
)
376328

377329
evaluator = MultiTurnRLEvaluator(
@@ -390,3 +342,6 @@ def test_evaluate_with_attached_trainer(self, lambda_agent_arn, test_config):
390342
assert execution is not None
391343
assert execution.arn is not None
392344
logger.info(f"Started attached trainer evaluation: {execution.arn}")
345+
346+
execution.wait(timeout=EVALUATION_TIMEOUT_SECONDS)
347+
assert execution.status.overall_status == "Succeeded"

sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ def _get_account_id():
5252
# PROD — Main account (729646638167)
5353
"729646638167": {
5454
"env_name": "PROD",
55-
"existing_job_name": "openai-reasoning-gpt-oss-20b-mtrl-20260602215955",
56-
"base_model": "openai-reasoning-gpt-oss-20b",
55+
#"existing_job_name": "mock-oss-test-mtrl-20260611170946",
56+
"existing_job_name": "mock-oss-test-mtrl-20260615143910",
57+
"base_model": "mock-oss-test",
5758
"agent_core_arn": "arn:aws:bedrock-agentcore:us-west-2:729646638167:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
5859
"dataset": "s3://sagemaker-rft-729646638167/prompts/gsm8k_small/prompts.parquet",
5960
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/mtrl-integ/eval-output/",
@@ -65,7 +66,7 @@ def _get_account_id():
6566
"391266019386": {
6667
"env_name": "PREPROD",
6768
"existing_job_name": "mtrl-integ-gpt-oss-agentcore-1779143704358",
68-
"base_model": "openai-reasoning-gpt-oss-20b",
69+
"base_model": "mock-oss-test",
6970
"agent_core_arn": "arn:aws:bedrock-agentcore:us-west-2:391266019386:runtime/mtrl_integ_gsm8k_streaming-bIz4H5Echk",
7071
"dataset": "s3://sagemaker-rft-beta-391266019386/prompts/gsm8k_small/prompts.parquet",
7172
"s3_output_path": "s3://sagemaker-us-west-2-391266019386/mtrl-integ/eval-output/",
@@ -77,7 +78,7 @@ def _get_account_id():
7778
"742774200982": {
7879
"env_name": "BETA",
7980
"existing_job_name": "openai-reasoning-gpt-oss-20b-mtrl-20260601114439",
80-
"base_model": "openai-reasoning-gpt-oss-20b",
81+
"base_model": "mock-oss-test",
8182
"agent_core_arn": "arn:aws:bedrock-agentcore:us-west-2:742774200982:runtime/sagemaker_rft_prod_gsm8k_streaming-UwSB6LEfEq",
8283
"dataset": "s3://sagemaker-rft-beta-742774200982/prompts/gsm8k_small/prompts.parquet",
8384
"s3_output_path": "s3://sagemaker-us-west-2-742774200982/mtrl-integ/eval-output/",
@@ -139,6 +140,8 @@ def attached_trainer(config):
139140
return trainer
140141

141142

143+
@pytest.mark.gpu_intensive
144+
@pytest.mark.serial
142145
class TestMTRLEvalIntegration:
143146
"""Integration tests for MTRL evaluation: attach → evaluate → wait for success."""
144147

@@ -184,7 +187,6 @@ def test_evaluate_finetuned_model(self, attached_trainer, config):
184187
f"reason: {execution.status.failure_reason}"
185188
)
186189

187-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
188190
def test_evaluate_base_model(self, config):
189191
"""Evaluate the base model only — submit and wait for completion."""
190192
evaluator = MultiTurnRLEvaluator(
@@ -245,35 +247,3 @@ def test_evaluate_comparison(self, attached_trainer, config):
245247
f"[{config['env_name']}] Comparison eval failed with status: {status}, "
246248
f"reason: {execution.status.failure_reason}"
247249
)
248-
249-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
250-
def test_evaluate_with_hyperparam_override(self, attached_trainer, config):
251-
"""Test that hyperparameter overrides are passed through to the eval job."""
252-
evaluator = MultiTurnRLEvaluator(
253-
model=attached_trainer,
254-
dataset=config["dataset"],
255-
s3_output_path=f'{config["s3_output_path"]}hyperparam-override/',
256-
mlflow_resource_arn=config["mlflow_resource_arn"],
257-
role=config["role"],
258-
region=_REGION,
259-
)
260-
261-
# Override MTRL-specific hyperparams
262-
evaluator.hyperparameters.sampling_max_tokens = 1024
263-
evaluator.hyperparameters.eval_group_size = 4
264-
265-
execution = evaluator.evaluate()
266-
267-
assert execution is not None
268-
assert execution.arn is not None
269-
logger.info(f"[{config['env_name']}] Started hyperparam override eval: {execution.arn}")
270-
271-
execution.wait(timeout=EVAL_TIMEOUT)
272-
273-
status = execution.status.overall_status
274-
logger.info(f"[{config['env_name']}] Hyperparam override eval completed: {status}")
275-
276-
assert status == "Succeeded", (
277-
f"[{config['env_name']}] Hyperparam override eval failed with status: {status}, "
278-
f"reason: {execution.status.failure_reason}"
279-
)

sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def _get_account_id():
3939
return _ACCOUNT_ID
4040

4141
AGENT_RUNTIME_ID = "sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS"
42-
BASE_MODEL = "openai-reasoning-gpt-oss-20b"
43-
EXISTING_JOB_NAME = "openai-reasoning-gpt-oss-20b-mtrl-20260602005937"
42+
#BASE_MODEL = "openai-reasoning-gpt-oss-20b"
43+
BASE_MODEL = "mock-oss-test"
44+
EXISTING_JOB_NAME = "openai-reasoning-gpt-oss-20b-mtrl-20260602215955"
4445

4546

4647
@pytest.fixture(scope="module")
@@ -63,7 +64,8 @@ def test_resources():
6364
}
6465

6566

66-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
67+
@pytest.mark.gpu_intensive
68+
@pytest.mark.serial
6769
class TestMultiTurnRLTrainerBedrockAgent:
6870
"""Test MTRL training with Bedrock AgentCore runtime."""
6971

@@ -116,7 +118,8 @@ def test_train_and_stop(self, sagemaker_session, test_resources):
116118
assert job.job_status in ("Stopping", "Stopped")
117119

118120

119-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
121+
@pytest.mark.gpu_intensive
122+
@pytest.mark.serial
120123
class TestMultiTurnRLTrainerLambdaAgent:
121124
"""Test MTRL training with Lambda agent."""
122125

@@ -145,7 +148,7 @@ def test_train_with_lambda_arn(self, sagemaker_session, test_resources):
145148
assert job.output_model_package_arn is not None
146149

147150

148-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
151+
149152
class TestMultiTurnRLTrainerAttach:
150153
"""Test attaching to existing MTRL jobs."""
151154

@@ -162,6 +165,7 @@ def test_attach_and_get_properties(self, sagemaker_session):
162165
assert attached_job.output_model_package_arn is not None
163166
assert attached_job.s3_output_path is not None
164167

168+
@pytest.mark.skip(reason="GPU resource intensive — run manually")
165169
def test_get_all_jobs(self, sagemaker_session):
166170
"""Test listing all MTRL jobs."""
167171
jobs = list(AgentRFTJob.get_all(
@@ -172,7 +176,6 @@ def test_get_all_jobs(self, sagemaker_session):
172176
assert all(j.job_status == "Completed" for j in jobs)
173177

174178

175-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
176179
class TestMultiTurnRLTrainerListModels:
177180
"""Test listing supported models (requires API access)."""
178181

@@ -190,3 +193,6 @@ def test_list_bedrock_agentcore_runtimes(self, sagemaker_session):
190193
session=sagemaker_session.boto_session
191194
)
192195
assert isinstance(runtimes, list)
196+
197+
198+

0 commit comments

Comments
 (0)