Skip to content

Commit 7444009

Browse files
authored
feat: implement test_custom_scorer_base_model_only integration test (#5902)
Add integration test for CustomScorerEvaluator base-model-only evaluation path. This test uses a JumpStart model ID directly (instead of a model package ARN) to exercise the CUSTOM_SCORER_TEMPLATE_BASE_MODEL_ONLY pipeline template. The test covers evaluator creation, hyperparameter access, evaluation execution, result verification, and execution retrieval. Also added BASE_MODEL_ONLY_CONFIG with dedicated test configuration using the meta-textgeneration-llama-3-2-1b-instruct model and existing test account resources (729646638167).
1 parent 65d6c04 commit 7444009

1 file changed

Lines changed: 127 additions & 5 deletions

File tree

sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@
5454
"region": "us-west-2",
5555
}
5656

57+
# Base model only evaluation configuration (uses JumpStart model ID directly, no model package)
58+
BASE_MODEL_ONLY_CONFIG = {
59+
"base_model_id": "meta-textgeneration-llama-3-2-1b-instruct",
60+
"evaluator_arn": "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/eval-lambda-test/0.0.1",
61+
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
62+
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
63+
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
64+
"region": "us-west-2",
65+
}
66+
5767

5868
# @pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
5969
@pytest.mark.xdist_group("custom_scorer_evaluator")
@@ -288,13 +298,125 @@ def test_custom_scorer_with_builtin_metric(self):
288298
logger.info("Built-in metric evaluation completed successfully")
289299

290300
# @pytest.mark.skip(reason="Base model only evaluation - not working yet per notebook")
301+
@pytest.mark.gpu_intensive
291302
def test_custom_scorer_base_model_only(self):
292303
"""
293304
Test custom scorer evaluation with base model only (no fine-tuned model).
294305
295-
Note: Per the notebook, "Evaluation with Base Model Only is yet to be
296-
implemented/tested - Not Working currently". This test is skipped until
297-
that functionality is available.
306+
This test uses a JumpStart model ID directly instead of a model package ARN,
307+
which triggers the CUSTOM_SCORER_TEMPLATE_BASE_MODEL_ONLY template path.
308+
The evaluation runs against only the base model without any fine-tuned weights.
309+
310+
This test covers:
311+
1. Creating CustomScorerEvaluator with a JumpStart model ID (base model only)
312+
2. Accessing hyperparameters
313+
3. Starting evaluation
314+
4. Monitoring execution
315+
5. Waiting for completion
316+
6. Viewing results
317+
7. Retrieving execution by ARN
298318
"""
299-
logger.info("Base model only evaluation - not yet implemented")
300-
pass
319+
# Step 1: Create CustomScorerEvaluator with JumpStart model ID
320+
logger.info("Creating CustomScorerEvaluator with base model only (JumpStart model ID)")
321+
322+
evaluator = CustomScorerEvaluator(
323+
evaluator=BASE_MODEL_ONLY_CONFIG["evaluator_arn"],
324+
dataset=BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"],
325+
model=BASE_MODEL_ONLY_CONFIG["base_model_id"],
326+
s3_output_path=BASE_MODEL_ONLY_CONFIG["s3_output_path"],
327+
evaluate_base_model=False,
328+
)
329+
330+
# Verify evaluator was created with base model ID
331+
assert evaluator is not None
332+
assert evaluator.evaluator == BASE_MODEL_ONLY_CONFIG["evaluator_arn"]
333+
assert evaluator.model == BASE_MODEL_ONLY_CONFIG["base_model_id"]
334+
assert evaluator.dataset == BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"]
335+
336+
logger.info(f"Created evaluator with base model: {BASE_MODEL_ONLY_CONFIG['base_model_id']}")
337+
338+
# Step 2: Access hyperparameters
339+
logger.info("Accessing hyperparameters")
340+
hyperparams = evaluator.hyperparameters.to_dict()
341+
342+
# Verify hyperparameters structure
343+
assert isinstance(hyperparams, dict)
344+
assert "max_new_tokens" in hyperparams
345+
assert "temperature" in hyperparams
346+
347+
logger.info(f"Hyperparameters: {hyperparams}")
348+
349+
# Step 3: Start evaluation
350+
logger.info("Starting evaluation execution")
351+
execution = evaluator.evaluate()
352+
353+
# Verify execution was created
354+
assert execution is not None
355+
assert execution.arn is not None
356+
assert execution.name is not None
357+
assert execution.eval_type is not None
358+
359+
logger.info(f"Pipeline Execution ARN: {execution.arn}")
360+
logger.info(f"Initial Status: {execution.status.overall_status}")
361+
362+
# Step 4: Monitor execution
363+
logger.info("Refreshing execution status")
364+
execution.refresh()
365+
366+
# Verify status was updated
367+
assert execution.status.overall_status is not None
368+
369+
# Log step details if available
370+
if execution.status.step_details:
371+
logger.info("Step Details:")
372+
for step in execution.status.step_details:
373+
logger.info(f" {step.name}: {step.status}")
374+
375+
# Step 5: Wait for completion
376+
logger.info(f"Waiting for evaluation to complete (timeout: {EVALUATION_TIMEOUT_SECONDS}s / {EVALUATION_TIMEOUT_SECONDS//3600}h)")
377+
378+
try:
379+
execution.wait(target_status="Succeeded", poll=30, timeout=EVALUATION_TIMEOUT_SECONDS)
380+
logger.info(f"Final Status: {execution.status.overall_status}")
381+
382+
# Verify completion
383+
assert execution.status.overall_status == "Succeeded"
384+
385+
# Step 6: View results
386+
logger.info("Displaying results")
387+
execution.show_results()
388+
389+
# Verify S3 output path is set
390+
assert execution.s3_output_path is not None
391+
logger.info(f"Results stored at: {execution.s3_output_path}")
392+
393+
except Exception as e:
394+
logger.error(f"Evaluation failed or timed out: {e}")
395+
logger.error(f"Final status: {execution.status.overall_status}")
396+
if execution.status.failure_reason:
397+
logger.error(f"Failure reason: {execution.status.failure_reason}")
398+
399+
# Log step failures
400+
if execution.status.step_details:
401+
for step in execution.status.step_details:
402+
if "failed" in step.status.lower():
403+
logger.error(f"Failed step: {step.name}")
404+
if step.failure_reason:
405+
logger.error(f" Reason: {step.failure_reason}")
406+
407+
# Re-raise to fail the test
408+
raise
409+
410+
# Step 7: Retrieve execution by ARN
411+
logger.info("Retrieving execution by ARN")
412+
retrieved_execution = EvaluationPipelineExecution.get(
413+
arn=execution.arn,
414+
region=BASE_MODEL_ONLY_CONFIG["region"]
415+
)
416+
417+
# Verify retrieved execution matches
418+
assert retrieved_execution.arn == execution.arn
419+
assert retrieved_execution.status.overall_status == "Succeeded"
420+
421+
logger.info(f"Retrieved execution status: {retrieved_execution.status.overall_status}")
422+
logger.info("Base model only evaluation completed successfully")

0 commit comments

Comments
 (0)