Skip to content

Commit daf19b0

Browse files
Fix: Add mlflowconfig to eval base model (#5745)
* Add mlflowconfig to eval * Update unit and integ test
1 parent 6134e57 commit daf19b0

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@
326326
LLMAJ_TEMPLATE_BASE_MODEL_ONLY = """{
327327
"Version": "2020-12-01",
328328
"Metadata": {},
329+
"MlflowConfig": {
330+
"MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %},
331+
"MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}
332+
},
329333
"Parameters": [],
330334
"Steps": [
331335
{
@@ -457,6 +461,10 @@
457461
DETERMINISTIC_TEMPLATE_BASE_MODEL_ONLY = """{
458462
"Version": "2020-12-01",
459463
"Metadata": {},
464+
"MlflowConfig": {
465+
"MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %},
466+
"MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}
467+
},
460468
"Parameters": [],
461469
"Steps": [
462470
{
@@ -843,6 +851,10 @@
843851
CUSTOM_SCORER_TEMPLATE_BASE_MODEL_ONLY = """{
844852
"Version": "2020-12-01",
845853
"Metadata": {},
854+
"MlflowConfig": {
855+
"MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %},
856+
"MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}
857+
},
846858
"Parameters": [],
847859
"Steps": [
848860
{

sagemaker-train/tests/integ/train/test_benchmark_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def test_benchmark_evaluation_base_model_only(self):
307307
benchmark=Benchmark.MMLU,
308308
model=BASE_MODEL_ONLY_CONFIG["base_model_id"],
309309
s3_output_path=BASE_MODEL_ONLY_CONFIG["s3_output_path"],
310-
# mlflow_resource_arn=BASE_MODEL_ONLY_CONFIG["mlflow_tracking_server_arn"],
310+
mlflow_resource_arn=BASE_MODEL_ONLY_CONFIG["mlflow_tracking_server_arn"],
311311
base_eval_name="integ-test-base-model-only",
312312
# Note: model_package_group not needed for JumpStart models
313313
)

sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,9 @@ def test_deterministic_base_model_only_with_all_params(self):
317317

318318
base_model_step = pipeline_def["Steps"][0]
319319

320-
# Verify MLflow config is not present in BASE_MODEL_ONLY template
321-
assert "MlflowConfig" not in pipeline_def
320+
# Verify MLflow config is present in BASE_MODEL_ONLY template
321+
assert "MlflowConfig" in pipeline_def
322+
assert pipeline_def["MlflowConfig"]["MlflowResourceArn"] == BASE_CONTEXT["mlflow_resource_arn"]
322323

323324
# Verify KMS key
324325
assert base_model_step["Arguments"]["OutputDataConfig"]["KmsKeyId"] == context["kms_key_id"]
@@ -403,8 +404,9 @@ def test_custom_scorer_base_model_only_minimal(self):
403404

404405
pipeline_def = json.loads(rendered)
405406

406-
# Verify MLflow config is not present in BASE_MODEL_ONLY template
407-
assert "MlflowConfig" not in pipeline_def
407+
# Verify MLflow config is present in BASE_MODEL_ONLY template
408+
assert "MlflowConfig" in pipeline_def
409+
assert pipeline_def["MlflowConfig"]["MlflowResourceArn"] == BASE_CONTEXT["mlflow_resource_arn"]
408410

409411
# Should have only 1 step
410412
assert len(pipeline_def["Steps"]) == 1
@@ -574,8 +576,9 @@ def test_llmaj_base_model_only_minimal(self):
574576

575577
pipeline_def = json.loads(rendered)
576578

577-
# Verify MLflow config is not present in BASE_MODEL_ONLY template
578-
assert "MlflowConfig" not in pipeline_def
579+
# Verify MLflow config is present in BASE_MODEL_ONLY template
580+
assert "MlflowConfig" in pipeline_def
581+
assert pipeline_def["MlflowConfig"]["MlflowResourceArn"] == BASE_CONTEXT["mlflow_resource_arn"]
579582

580583
# Should have 2 steps: EvaluateBaseInferenceModel and EvaluateBaseModelMetrics
581584
assert len(pipeline_def["Steps"]) == 2

0 commit comments

Comments
 (0)