Skip to content

Commit 2b6fc64

Browse files
committed
Further fix for the same LLM as judge integ test failure
1 parent 67f3163 commit 2b6fc64

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,18 @@ def _get_base_template_context(
701701
Returns:
702702
dict: Base template context dictionary
703703
"""
704+
# Generate default mlflow_experiment_name if not provided
705+
# This is required by AWS when ModelPackageGroupArn is not provided in training jobs
706+
mlflow_experiment_name = self.mlflow_experiment_name
707+
if not mlflow_experiment_name and self.mlflow_resource_arn:
708+
# Use pipeline_name as default experiment name
709+
mlflow_experiment_name = '{{ pipeline_name }}'
710+
_logger.info("No mlflow_experiment_name provided, using pipeline_name as default")
711+
704712
return {
705713
'role_arn': role_arn,
706714
'mlflow_resource_arn': self.mlflow_resource_arn,
707-
'mlflow_experiment_name': self.mlflow_experiment_name,
715+
'mlflow_experiment_name': mlflow_experiment_name,
708716
'mlflow_run_name': self.mlflow_run_name,
709717
'model_package_group_arn': model_package_group_arn,
710718
'source_model_package_arn': self._source_model_package_arn,

sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@
337337
"Name": "EvaluateBaseInferenceModel",
338338
"Type": "Training",
339339
"Arguments": {
340-
"TrainingJobName": "BaseInference",
340+
"TrainingJobName": "BaseInference",{% if mlflow_experiment_name %}
341+
"MlflowExperimentName": "{{ mlflow_experiment_name }}",{% endif %}
341342
"RoleArn": "{{ role_arn }}",
342343
"ServerlessJobConfig": {
343344
"BaseModelArn": "{{ base_model_arn }}",
@@ -1007,7 +1008,8 @@
10071008
"Name": "EvaluateBaseInferenceModel",
10081009
"Type": "Training",
10091010
"Arguments": {
1010-
"TrainingJobName": "BaseInference",
1011+
"TrainingJobName": "BaseInference",{% if mlflow_experiment_name %}
1012+
"MlflowExperimentName": "{{ mlflow_experiment_name }}",{% endif %}
10111013
"RoleArn": "{{ role_arn }}",
10121014
"ServerlessJobConfig": {
10131015
"BaseModelArn": "{{ base_model_arn }}",

sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
112112
builtin_metrics=TEST_CONFIG["builtin_metrics"],
113113
custom_metrics=TEST_CONFIG["custom_metrics_json"],
114114
s3_output_path=TEST_CONFIG["s3_output_path"],
115+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
115116
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
116117
)
117118

@@ -247,6 +248,7 @@ def test_base_model_false_still_works(self):
247248
dataset=TEST_CONFIG["dataset_s3_uri"],
248249
builtin_metrics=TEST_CONFIG["builtin_metrics"],
249250
s3_output_path=TEST_CONFIG["s3_output_path"],
251+
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
250252
evaluate_base_model=False, # Only evaluate custom model
251253
)
252254

0 commit comments

Comments
 (0)