Skip to content

Commit 0832239

Browse files
author
Roja Reddy Sareddy
committed
fix:unskip and mark MTRL tests as gpu intensive
1 parent 63ac789 commit 0832239

3 files changed

Lines changed: 45 additions & 118 deletions

File tree

sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py

Lines changed: 28 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -240,21 +240,23 @@ def lambda_agent_arn(test_config):
240240
return _ensure_lambda_exists(test_config["account_id"])
241241

242242

243+
@pytest.mark.gpu_intensive
244+
@pytest.mark.serial
243245
class TestMTRLEvaluator3PAgentIntegration:
244246
"""Integration tests for MultiTurnRLEvaluator with Lambda-based 3P agent."""
245247

246-
def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn, test_config):
247-
"""Test evaluating a base model using a Lambda ARN as agent_config.
248+
def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn, test_config):
249+
"""Test full end-to-end: start evaluation, wait for completion, and verify discoverability.
248250
249-
This is the primary 3P integration pattern: customer provides a
250-
Lambda function that wraps their agent (LangChain, Strands, etc.)
251-
and the evaluator runs rollouts against it.
251+
This test validates the complete lifecycle including wait() using
252+
the standard sagemaker-core pipeline execution path, and verifies
253+
the evaluation is discoverable via get_all().
252254
"""
253255
evaluator = MultiTurnRLEvaluator(
254256
model=test_config["base_model"],
255257
dataset=test_config["dataset"],
256258
agent_config=lambda_agent_arn,
257-
s3_output_path=f'{test_config["s3_output_path"]}lambda-base-model/',
259+
s3_output_path=f'{test_config["s3_output_path"]}lambda-e2e/',
258260
mlflow_resource_arn=test_config["mlflow_resource_arn"],
259261
role=test_config["role"],
260262
region=test_config["region"],
@@ -267,9 +269,27 @@ def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn, test_conf
267269
assert execution.arn is not None
268270
assert "pipeline" in execution.arn.lower()
269271
logger.info(f"Started 3P agent base model evaluation: {execution.arn}")
270-
logger.info(f"Status: {execution.status.overall_status}")
271272

272-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
273+
execution.wait()
274+
assert execution.status.overall_status in ("Succeeded", "Failed", "Stopped")
275+
logger.info(f"Execution completed: {execution.status.overall_status}")
276+
277+
if execution.status.overall_status == "Failed":
278+
logger.error(f"Failure reason: {execution.status.failure_reason}")
279+
280+
# Verify it's discoverable via get_all
281+
found = False
282+
for ex in MultiTurnRLEvaluator.get_all(region=test_config["region"]):
283+
if ex.arn == execution.arn:
284+
found = True
285+
break
286+
287+
assert found, (
288+
f"Evaluation {execution.arn} not found via get_all(). "
289+
"Pipeline tagging may not be working correctly."
290+
)
291+
logger.info(f"Successfully discovered evaluation via get_all: {execution.arn}")
292+
273293
def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn, test_config):
274294
"""Test evaluating using an CustomAgentLambda object as agent_config.
275295
@@ -295,77 +315,6 @@ def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn, te
295315
assert execution.arn is not None
296316
logger.info(f"Started CustomAgentLambda object evaluation: {execution.arn}")
297317

298-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
299-
def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn, test_config):
300-
"""Test full end-to-end: start evaluation and wait for completion.
301-
302-
This test validates the complete lifecycle including wait() using
303-
the standard sagemaker-core pipeline execution path.
304-
"""
305-
evaluator = MultiTurnRLEvaluator(
306-
model=test_config["base_model"],
307-
dataset=test_config["dataset"],
308-
agent_config=lambda_agent_arn,
309-
s3_output_path=f'{test_config["s3_output_path"]}lambda-e2e/',
310-
mlflow_resource_arn=test_config["mlflow_resource_arn"],
311-
role=test_config["role"],
312-
region=test_config["region"],
313-
accept_eula=True,
314-
)
315-
316-
execution = evaluator.evaluate()
317-
assert execution is not None
318-
319-
logger.info(f"Waiting for execution: {execution.arn}")
320-
execution.wait()
321-
322-
assert execution.status.overall_status in ("Succeeded", "Failed", "Stopped")
323-
logger.info(f"Execution completed: {execution.status.overall_status}")
324-
325-
if execution.status.overall_status == "Failed":
326-
logger.error(f"Failure reason: {execution.status.failure_reason}")
327-
328-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
329-
def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn, test_config):
330-
"""Test that 3P agent evaluations are discoverable via get_all.
331-
332-
Validates that evaluations started with Lambda agents show up in
333-
the standard get_all() discovery path (pipeline tagging works).
334-
"""
335-
evaluator = MultiTurnRLEvaluator(
336-
model=test_config["base_model"],
337-
dataset=test_config["dataset"],
338-
agent_config=lambda_agent_arn,
339-
s3_output_path=f'{test_config["s3_output_path"]}lambda-discovery/',
340-
mlflow_resource_arn=test_config["mlflow_resource_arn"],
341-
role=test_config["role"],
342-
region=test_config["region"],
343-
accept_eula=True,
344-
)
345-
346-
execution = evaluator.evaluate()
347-
assert execution is not None
348-
started_arn = execution.arn
349-
350-
# Give pipeline time to register
351-
time.sleep(10)
352-
353-
# Verify it's discoverable via get_all
354-
found = False
355-
for ex in MultiTurnRLEvaluator.get_all(region=test_config["region"]):
356-
if ex.arn == started_arn:
357-
found = True
358-
break
359-
360-
assert found, (
361-
f"Evaluation {started_arn} not found via get_all(). "
362-
"Pipeline tagging may not be working correctly."
363-
)
364-
logger.info(f"Successfully discovered evaluation via get_all: {started_arn}")
365-
366-
367-
368-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
369318
def test_evaluate_with_attached_trainer(self, lambda_agent_arn, test_config):
370319
"""Test evaluating a fine-tuned model by attaching to an existing training job."""
371320
from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer

sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def attached_trainer(config):
139139
return trainer
140140

141141

142+
@pytest.mark.gpu_intensive
143+
@pytest.mark.serial
142144
class TestMTRLEvalIntegration:
143145
"""Integration tests for MTRL evaluation: attach → evaluate → wait for success."""
144146

@@ -156,7 +158,10 @@ def test_attach_to_existing_job(self, config):
156158
logger.info(f"[{config['env_name']}] Output model package: {job.output_model_package_arn}")
157159

158160
def test_evaluate_finetuned_model(self, attached_trainer, config):
159-
"""Evaluate a fine-tuned model from attached trainer — submit and wait for completion."""
161+
"""Evaluate a fine-tuned model from attached trainer — submit and wait for completion.
162+
163+
Also validates hyperparameter overrides are passed through to the eval job.
164+
"""
160165
evaluator = MultiTurnRLEvaluator(
161166
model=attached_trainer,
162167
dataset=config["dataset"],
@@ -166,6 +171,10 @@ def test_evaluate_finetuned_model(self, attached_trainer, config):
166171
region=_REGION,
167172
)
168173

174+
# Override MTRL-specific hyperparams
175+
evaluator.hyperparameters.sampling_max_tokens = 1024
176+
evaluator.hyperparameters.eval_group_size = 4
177+
169178
execution = evaluator.evaluate()
170179

171180
assert execution is not None
@@ -184,7 +193,6 @@ def test_evaluate_finetuned_model(self, attached_trainer, config):
184193
f"reason: {execution.status.failure_reason}"
185194
)
186195

187-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
188196
def test_evaluate_base_model(self, config):
189197
"""Evaluate the base model only — submit and wait for completion."""
190198
evaluator = MultiTurnRLEvaluator(
@@ -245,35 +253,3 @@ def test_evaluate_comparison(self, attached_trainer, config):
245253
f"[{config['env_name']}] Comparison eval failed with status: {status}, "
246254
f"reason: {execution.status.failure_reason}"
247255
)
248-
249-
@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
250-
def test_evaluate_with_hyperparam_override(self, attached_trainer, config):
251-
"""Test that hyperparameter overrides are passed through to the eval job."""
252-
evaluator = MultiTurnRLEvaluator(
253-
model=attached_trainer,
254-
dataset=config["dataset"],
255-
s3_output_path=f'{config["s3_output_path"]}hyperparam-override/',
256-
mlflow_resource_arn=config["mlflow_resource_arn"],
257-
role=config["role"],
258-
region=_REGION,
259-
)
260-
261-
# Override MTRL-specific hyperparams
262-
evaluator.hyperparameters.sampling_max_tokens = 1024
263-
evaluator.hyperparameters.eval_group_size = 4
264-
265-
execution = evaluator.evaluate()
266-
267-
assert execution is not None
268-
assert execution.arn is not None
269-
logger.info(f"[{config['env_name']}] Started hyperparam override eval: {execution.arn}")
270-
271-
execution.wait(timeout=EVAL_TIMEOUT)
272-
273-
status = execution.status.overall_status
274-
logger.info(f"[{config['env_name']}] Hyperparam override eval completed: {status}")
275-
276-
assert status == "Succeeded", (
277-
f"[{config['env_name']}] Hyperparam override eval failed with status: {status}, "
278-
f"reason: {execution.status.failure_reason}"
279-
)

sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _get_account_id():
4040

4141
AGENT_RUNTIME_ID = "sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS"
4242
BASE_MODEL = "openai-reasoning-gpt-oss-20b"
43-
EXISTING_JOB_NAME = "openai-reasoning-gpt-oss-20b-mtrl-20260602005937"
43+
EXISTING_JOB_NAME = "openai-reasoning-gpt-oss-20b-mtrl-20260602215955"
4444

4545

4646
@pytest.fixture(scope="module")
@@ -63,7 +63,8 @@ def test_resources():
6363
}
6464

6565

66-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
66+
@pytest.mark.gpu_intensive
67+
@pytest.mark.serial
6768
class TestMultiTurnRLTrainerBedrockAgent:
6869
"""Test MTRL training with Bedrock AgentCore runtime."""
6970

@@ -116,7 +117,8 @@ def test_train_and_stop(self, sagemaker_session, test_resources):
116117
assert job.job_status in ("Stopping", "Stopped")
117118

118119

119-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
120+
@pytest.mark.gpu_intensive
121+
@pytest.mark.serial
120122
class TestMultiTurnRLTrainerLambdaAgent:
121123
"""Test MTRL training with Lambda agent."""
122124

@@ -145,7 +147,7 @@ def test_train_with_lambda_arn(self, sagemaker_session, test_resources):
145147
assert job.output_model_package_arn is not None
146148

147149

148-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
150+
149151
class TestMultiTurnRLTrainerAttach:
150152
"""Test attaching to existing MTRL jobs."""
151153

@@ -162,6 +164,7 @@ def test_attach_and_get_properties(self, sagemaker_session):
162164
assert attached_job.output_model_package_arn is not None
163165
assert attached_job.s3_output_path is not None
164166

167+
@pytest.mark.skip(reason="GPU resource intensive — run manually")
165168
def test_get_all_jobs(self, sagemaker_session):
166169
"""Test listing all MTRL jobs."""
167170
jobs = list(AgentRFTJob.get_all(
@@ -172,7 +175,6 @@ def test_get_all_jobs(self, sagemaker_session):
172175
assert all(j.job_status == "Completed" for j in jobs)
173176

174177

175-
@pytest.mark.skip(reason="GPU resource intensive — run manually")
176178
class TestMultiTurnRLTrainerListModels:
177179
"""Test listing supported models (requires API access)."""
178180

0 commit comments

Comments
 (0)