File tree Expand file tree Collapse file tree 3 files changed +15
-3
lines changed
src/sagemaker/train/evaluate Expand file tree Collapse file tree 3 files changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -701,10 +701,18 @@ def _get_base_template_context(
701701 Returns:
702702 dict: Base template context dictionary
703703 """
704+ # Generate default mlflow_experiment_name if not provided
705+ # This is required by AWS when ModelPackageGroupArn is not provided in training jobs
706+ mlflow_experiment_name = self .mlflow_experiment_name
707+ if not mlflow_experiment_name and self .mlflow_resource_arn :
708+ # Use pipeline_name as default experiment name
709+ mlflow_experiment_name = '{{ pipeline_name }}'
710+ _logger .info ("No mlflow_experiment_name provided, using pipeline_name as default" )
711+
704712 return {
705713 'role_arn' : role_arn ,
706714 'mlflow_resource_arn' : self .mlflow_resource_arn ,
707- 'mlflow_experiment_name' : self . mlflow_experiment_name ,
715+ 'mlflow_experiment_name' : mlflow_experiment_name ,
708716 'mlflow_run_name' : self .mlflow_run_name ,
709717 'model_package_group_arn' : model_package_group_arn ,
710718 'source_model_package_arn' : self ._source_model_package_arn ,
Original file line number Diff line number Diff line change 337337 "Name": "EvaluateBaseInferenceModel",
338338 "Type": "Training",
339339 "Arguments": {
340- "TrainingJobName": "BaseInference",
340+ "TrainingJobName": "BaseInference",{% if mlflow_experiment_name %}
341+ "MlflowExperimentName": "{{ mlflow_experiment_name }}",{% endif %}
341342 "RoleArn": "{{ role_arn }}",
342343 "ServerlessJobConfig": {
343344 "BaseModelArn": "{{ base_model_arn }}",
10071008 "Name": "EvaluateBaseInferenceModel",
10081009 "Type": "Training",
10091010 "Arguments": {
1010- "TrainingJobName": "BaseInference",
1011+ "TrainingJobName": "BaseInference",{% if mlflow_experiment_name %}
1012+ "MlflowExperimentName": "{{ mlflow_experiment_name }}",{% endif %}
10111013 "RoleArn": "{{ role_arn }}",
10121014 "ServerlessJobConfig": {
10131015 "BaseModelArn": "{{ base_model_arn }}",
Original file line number Diff line number Diff line change @@ -112,6 +112,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
112112 builtin_metrics = TEST_CONFIG ["builtin_metrics" ],
113113 custom_metrics = TEST_CONFIG ["custom_metrics_json" ],
114114 s3_output_path = TEST_CONFIG ["s3_output_path" ],
115+ mlflow_resource_arn = TEST_CONFIG ["mlflow_tracking_server_arn" ],
115116 evaluate_base_model = TEST_CONFIG ["evaluate_base_model" ],
116117 )
117118
@@ -247,6 +248,7 @@ def test_base_model_false_still_works(self):
247248 dataset = TEST_CONFIG ["dataset_s3_uri" ],
248249 builtin_metrics = TEST_CONFIG ["builtin_metrics" ],
249250 s3_output_path = TEST_CONFIG ["s3_output_path" ],
251+ mlflow_resource_arn = TEST_CONFIG ["mlflow_tracking_server_arn" ],
250252 evaluate_base_model = False , # Only evaluate custom model
251253 )
252254
You can’t perform that action at this time.
0 commit comments