Skip to content

Commit 65d1b41

Browse files
committed
Update unit and integ test
1 parent a85f62e commit 65d1b41

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

sagemaker-train/tests/integ/train/test_benchmark_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def test_benchmark_evaluation_base_model_only(self):
307307
benchmark=Benchmark.MMLU,
308308
model=BASE_MODEL_ONLY_CONFIG["base_model_id"],
309309
s3_output_path=BASE_MODEL_ONLY_CONFIG["s3_output_path"],
310-
# mlflow_resource_arn=BASE_MODEL_ONLY_CONFIG["mlflow_tracking_server_arn"],
310+
mlflow_resource_arn=BASE_MODEL_ONLY_CONFIG["mlflow_tracking_server_arn"],
311311
base_eval_name="integ-test-base-model-only",
312312
# Note: model_package_group not needed for JumpStart models
313313
)

sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,9 @@ def test_deterministic_base_model_only_with_all_params(self):
317317

318318
base_model_step = pipeline_def["Steps"][0]
319319

320-
# Verify MLflow config is not present in BASE_MODEL_ONLY template
321-
assert "MlflowConfig" not in pipeline_def
320+
# Verify MLflow config is present in BASE_MODEL_ONLY template
321+
assert "MlflowConfig" in pipeline_def
322+
assert pipeline_def["MlflowConfig"]["MlflowResourceArn"] == BASE_CONTEXT["mlflow_resource_arn"]
322323

323324
# Verify KMS key
324325
assert base_model_step["Arguments"]["OutputDataConfig"]["KmsKeyId"] == context["kms_key_id"]
@@ -403,8 +404,9 @@ def test_custom_scorer_base_model_only_minimal(self):
403404

404405
pipeline_def = json.loads(rendered)
405406

406-
# Verify MLflow config is not present in BASE_MODEL_ONLY template
407-
assert "MlflowConfig" not in pipeline_def
407+
# Verify MLflow config is present in BASE_MODEL_ONLY template
408+
assert "MlflowConfig" in pipeline_def
409+
assert pipeline_def["MlflowConfig"]["MlflowResourceArn"] == BASE_CONTEXT["mlflow_resource_arn"]
408410

409411
# Should have only 1 step
410412
assert len(pipeline_def["Steps"]) == 1
@@ -574,8 +576,9 @@ def test_llmaj_base_model_only_minimal(self):
574576

575577
pipeline_def = json.loads(rendered)
576578

577-
# Verify MLflow config is not present in BASE_MODEL_ONLY template
578-
assert "MlflowConfig" not in pipeline_def
579+
# Verify MLflow config is present in BASE_MODEL_ONLY template
580+
assert "MlflowConfig" in pipeline_def
581+
assert pipeline_def["MlflowConfig"]["MlflowResourceArn"] == BASE_CONTEXT["mlflow_resource_arn"]
579582

580583
# Should have 2 steps: EvaluateBaseInferenceModel and EvaluateBaseModelMetrics
581584
assert len(pipeline_def["Steps"]) == 2

0 commit comments

Comments
 (0)