@@ -1013,3 +1013,52 @@ def test_benchmark_evaluator_uses_evaluation_metric_key_for_non_nova(mock_artifa
10131013 assert 'evaluation_metric' in additions
10141014 assert additions ['evaluation_metric' ] == 'accuracy'
10151015 assert 'metric' not in additions
1016+
1017+
1018+ @patch ('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn' )
1019+ @patch ('sagemaker.train.common_utils.recipe_utils._is_nova_model' )
1020+ @patch ('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options' )
1021+ @patch ('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params' )
1022+ @patch ('sagemaker.train.common_utils.model_resolution._resolve_base_model' )
1023+ @patch ('sagemaker.core.resources.Artifact' )
1024+ def test_benchmark_evaluator_custom_hub_name_forwarded (
1025+ mock_artifact , mock_resolve , mock_get_params , mock_extract_options , mock_is_nova , mock_resolve_mlflow
1026+ ):
1027+ """Custom hub_name on BenchMarkEvaluator is forwarded to hub override-params lookup."""
1028+ mock_resolve_mlflow .return_value = DEFAULT_MLFLOW_ARN
1029+ mock_info = Mock ()
1030+ mock_info .base_model_name = DEFAULT_MODEL
1031+ mock_info .base_model_arn = DEFAULT_BASE_MODEL_ARN
1032+ mock_info .source_model_package_arn = None
1033+ mock_resolve .return_value = mock_info
1034+
1035+ mock_artifact .get_all .return_value = iter ([])
1036+ mock_artifact_instance = Mock ()
1037+ mock_artifact_instance .artifact_arn = DEFAULT_ARTIFACT_ARN
1038+ mock_artifact .create .return_value = mock_artifact_instance
1039+
1040+ mock_session = Mock ()
1041+ mock_session .boto_region_name = DEFAULT_REGION
1042+ mock_session .boto_session = Mock ()
1043+ mock_session .get_caller_identity_arn .return_value = DEFAULT_ROLE
1044+ mock_session .sagemaker_config = None
1045+
1046+ mock_is_nova .return_value = False
1047+ mock_get_params .return_value = {'temperature' : 0.7 }
1048+ mock_extract_options .return_value = {'temperature' : {'value' : 0.7 }}
1049+
1050+ evaluator = BenchMarkEvaluator (
1051+ benchmark = _Benchmark .MMLU ,
1052+ model = DEFAULT_MODEL ,
1053+ s3_output_path = DEFAULT_S3_OUTPUT ,
1054+ mlflow_resource_arn = DEFAULT_MLFLOW_ARN ,
1055+ model_package_group = DEFAULT_MODEL_PACKAGE_GROUP_ARN ,
1056+ sagemaker_session = mock_session ,
1057+ hub_name = "MyPrivateHub" ,
1058+ )
1059+
1060+ # Trigger lazy-loaded hyperparameters to hit the hub lookup
1061+ _ = evaluator .hyperparameters
1062+
1063+ assert evaluator .hub_name == "MyPrivateHub"
1064+ assert mock_get_params .call_args .kwargs ["hub_name" ] == "MyPrivateHub"
0 commit comments