|
54 | 54 | "region": "us-west-2", |
55 | 55 | } |
56 | 56 |
|
| 57 | +# Base model only evaluation configuration (uses JumpStart model ID directly, no model package) |
| 58 | +BASE_MODEL_ONLY_CONFIG = { |
| 59 | + "base_model_id": "meta-textgeneration-llama-3-2-1b-instruct", |
| 60 | + "evaluator_arn": "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/eval-lambda-test/0.0.1", |
| 61 | + "dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl", |
| 62 | + "s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/", |
| 63 | + "mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX", |
| 64 | + "region": "us-west-2", |
| 65 | +} |
| 66 | + |
57 | 67 |
|
58 | 68 | # @pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/") |
59 | 69 | @pytest.mark.xdist_group("custom_scorer_evaluator") |
@@ -288,13 +298,125 @@ def test_custom_scorer_with_builtin_metric(self): |
288 | 298 | logger.info("Built-in metric evaluation completed successfully") |
289 | 299 |
|
290 | 300 | # @pytest.mark.skip(reason="Base model only evaluation - not working yet per notebook") |
| 301 | + @pytest.mark.gpu_intensive |
291 | 302 | def test_custom_scorer_base_model_only(self): |
292 | 303 | """ |
293 | 304 | Test custom scorer evaluation with base model only (no fine-tuned model). |
294 | 305 | |
295 | | - Note: Per the notebook, "Evaluation with Base Model Only is yet to be |
296 | | - implemented/tested - Not Working currently". This test is skipped until |
297 | | - that functionality is available. |
| 306 | + This test uses a JumpStart model ID directly instead of a model package ARN, |
| 307 | + which triggers the CUSTOM_SCORER_TEMPLATE_BASE_MODEL_ONLY template path. |
| 308 | + The evaluation runs against only the base model without any fine-tuned weights. |
| 309 | + |
| 310 | + This test covers: |
| 311 | + 1. Creating CustomScorerEvaluator with a JumpStart model ID (base model only) |
| 312 | + 2. Accessing hyperparameters |
| 313 | + 3. Starting evaluation |
| 314 | + 4. Monitoring execution |
| 315 | + 5. Waiting for completion |
| 316 | + 6. Viewing results |
| 317 | + 7. Retrieving execution by ARN |
298 | 318 | """ |
299 | | - logger.info("Base model only evaluation - not yet implemented") |
300 | | - pass |
| 319 | + # Step 1: Create CustomScorerEvaluator with JumpStart model ID |
| 320 | + logger.info("Creating CustomScorerEvaluator with base model only (JumpStart model ID)") |
| 321 | + |
| 322 | + evaluator = CustomScorerEvaluator( |
| 323 | + evaluator=BASE_MODEL_ONLY_CONFIG["evaluator_arn"], |
| 324 | + dataset=BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"], |
| 325 | + model=BASE_MODEL_ONLY_CONFIG["base_model_id"], |
| 326 | + s3_output_path=BASE_MODEL_ONLY_CONFIG["s3_output_path"], |
| 327 | + evaluate_base_model=False, |
| 328 | + ) |
| 329 | + |
| 330 | + # Verify evaluator was created with base model ID |
| 331 | + assert evaluator is not None |
| 332 | + assert evaluator.evaluator == BASE_MODEL_ONLY_CONFIG["evaluator_arn"] |
| 333 | + assert evaluator.model == BASE_MODEL_ONLY_CONFIG["base_model_id"] |
| 334 | + assert evaluator.dataset == BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"] |
| 335 | + |
| 336 | + logger.info(f"Created evaluator with base model: {BASE_MODEL_ONLY_CONFIG['base_model_id']}") |
| 337 | + |
| 338 | + # Step 2: Access hyperparameters |
| 339 | + logger.info("Accessing hyperparameters") |
| 340 | + hyperparams = evaluator.hyperparameters.to_dict() |
| 341 | + |
| 342 | + # Verify hyperparameters structure |
| 343 | + assert isinstance(hyperparams, dict) |
| 344 | + assert "max_new_tokens" in hyperparams |
| 345 | + assert "temperature" in hyperparams |
| 346 | + |
| 347 | + logger.info(f"Hyperparameters: {hyperparams}") |
| 348 | + |
| 349 | + # Step 3: Start evaluation |
| 350 | + logger.info("Starting evaluation execution") |
| 351 | + execution = evaluator.evaluate() |
| 352 | + |
| 353 | + # Verify execution was created |
| 354 | + assert execution is not None |
| 355 | + assert execution.arn is not None |
| 356 | + assert execution.name is not None |
| 357 | + assert execution.eval_type is not None |
| 358 | + |
| 359 | + logger.info(f"Pipeline Execution ARN: {execution.arn}") |
| 360 | + logger.info(f"Initial Status: {execution.status.overall_status}") |
| 361 | + |
| 362 | + # Step 4: Monitor execution |
| 363 | + logger.info("Refreshing execution status") |
| 364 | + execution.refresh() |
| 365 | + |
| 366 | + # Verify status was updated |
| 367 | + assert execution.status.overall_status is not None |
| 368 | + |
| 369 | + # Log step details if available |
| 370 | + if execution.status.step_details: |
| 371 | + logger.info("Step Details:") |
| 372 | + for step in execution.status.step_details: |
| 373 | + logger.info(f" {step.name}: {step.status}") |
| 374 | + |
| 375 | + # Step 5: Wait for completion |
| 376 | + logger.info(f"Waiting for evaluation to complete (timeout: {EVALUATION_TIMEOUT_SECONDS}s / {EVALUATION_TIMEOUT_SECONDS//3600}h)") |
| 377 | + |
| 378 | + try: |
| 379 | + execution.wait(target_status="Succeeded", poll=30, timeout=EVALUATION_TIMEOUT_SECONDS) |
| 380 | + logger.info(f"Final Status: {execution.status.overall_status}") |
| 381 | + |
| 382 | + # Verify completion |
| 383 | + assert execution.status.overall_status == "Succeeded" |
| 384 | + |
| 385 | + # Step 6: View results |
| 386 | + logger.info("Displaying results") |
| 387 | + execution.show_results() |
| 388 | + |
| 389 | + # Verify S3 output path is set |
| 390 | + assert execution.s3_output_path is not None |
| 391 | + logger.info(f"Results stored at: {execution.s3_output_path}") |
| 392 | + |
| 393 | + except Exception as e: |
| 394 | + logger.error(f"Evaluation failed or timed out: {e}") |
| 395 | + logger.error(f"Final status: {execution.status.overall_status}") |
| 396 | + if execution.status.failure_reason: |
| 397 | + logger.error(f"Failure reason: {execution.status.failure_reason}") |
| 398 | + |
| 399 | + # Log step failures |
| 400 | + if execution.status.step_details: |
| 401 | + for step in execution.status.step_details: |
| 402 | + if "failed" in step.status.lower(): |
| 403 | + logger.error(f"Failed step: {step.name}") |
| 404 | + if step.failure_reason: |
| 405 | + logger.error(f" Reason: {step.failure_reason}") |
| 406 | + |
| 407 | + # Re-raise to fail the test |
| 408 | + raise |
| 409 | + |
| 410 | + # Step 7: Retrieve execution by ARN |
| 411 | + logger.info("Retrieving execution by ARN") |
| 412 | + retrieved_execution = EvaluationPipelineExecution.get( |
| 413 | + arn=execution.arn, |
| 414 | + region=BASE_MODEL_ONLY_CONFIG["region"] |
| 415 | + ) |
| 416 | + |
| 417 | + # Verify retrieved execution matches |
| 418 | + assert retrieved_execution.arn == execution.arn |
| 419 | + assert retrieved_execution.status.overall_status == "Succeeded" |
| 420 | + |
| 421 | + logger.info(f"Retrieved execution status: {retrieved_execution.status.overall_status}") |
| 422 | + logger.info("Base model only evaluation completed successfully") |
0 commit comments