From 764d81e3ddb72a5bb5c633d921d15256e1c12fad Mon Sep 17 00:00:00 2001 From: Molly He Date: Mon, 20 Apr 2026 16:58:14 -0700 Subject: [PATCH] feat: private hub support for MC trainer and evaluator --- .../src/sagemaker/train/base_trainer.py | 8 +++ .../src/sagemaker/train/dpo_trainer.py | 5 +- .../train/evaluate/base_evaluator.py | 4 +- .../train/evaluate/benchmark_evaluator.py | 2 +- .../train/evaluate/custom_scorer_evaluator.py | 4 +- .../src/sagemaker/train/rlaif_trainer.py | 7 +-- .../src/sagemaker/train/rlvr_trainer.py | 5 +- .../src/sagemaker/train/sft_trainer.py | 5 +- .../evaluate/test_benchmark_evaluator.py | 49 +++++++++++++++++ .../evaluate/test_custom_scorer_evaluator.py | 48 +++++++++++++++++ .../tests/unit/train/test_dpo_trainer.py | 32 ++++++++++++ .../tests/unit/train/test_rlaif_trainer.py | 52 +++++++++++++++++++ .../tests/unit/train/test_rlvr_trainer.py | 32 ++++++++++++ .../tests/unit/train/test_sft_trainer.py | 32 ++++++++++++ 14 files changed, 272 insertions(+), 13 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/base_trainer.py b/sagemaker-train/src/sagemaker/train/base_trainer.py index a422dc3240..a3fbe5f0d2 100644 --- a/sagemaker-train/src/sagemaker/train/base_trainer.py +++ b/sagemaker-train/src/sagemaker/train/base_trainer.py @@ -4,6 +4,7 @@ from sagemaker.core.training.configs import Tag, Networking, InputData, Channel from sagemaker.core.shapes import shapes from sagemaker.core.resources import TrainingJob +from sagemaker.train.constants import HUB_NAME class BaseTrainer(ABC): @@ -36,6 +37,10 @@ class BaseTrainer(ABC): Can include training and validation datasets. environment (Optional[Dict[str, str]]): Environment variables to set in the training container. + hub_name (Optional[str]): + Name of the SageMaker Hub to pull model recipes and metadata from. + Defaults to ``"SageMakerPublicHub"``. Set to a private hub name to test + pre-release recipes (e.g., during development or E2E testing). """ # Class-level attributes with default values @@ -48,6 +53,7 @@ class BaseTrainer(ABC): input_data_config: Optional[List[Union[Channel, InputData]]] = None environment: Optional[Dict[str, str]] = None latest_training_job: Optional[TrainingJob] = None + hub_name: str = HUB_NAME def __init__( self, @@ -59,6 +65,7 @@ def __init__( output_data_config: Optional[shapes.OutputDataConfig] = None, input_data_config: Optional[List[Union[Channel, InputData]]] = None, environment: Optional[Dict[str, str]] = None, + hub_name: Optional[str] = None, ): self.sagemaker_session = sagemaker_session self.role = role @@ -68,6 +75,7 @@ def __init__( self.output_data_config = output_data_config self.input_data_config = input_data_config self.environment = environment or {} + self.hub_name = hub_name or HUB_NAME def _is_nova_model_for_telemetry(self) -> bool: """Check if the model is a Nova model for telemetry tracking.""" diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 75a450c3c8..525ce58974 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -142,7 +142,8 @@ def __init__( self.sagemaker_session or TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session - )) + ), + hub_name=self.hub_name) # Process hyperparameters self._process_hyperparameters() @@ -244,7 +245,7 @@ def train(self, ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, self.hub_name) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 4bf718b050..62106b1eca 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -104,6 +104,7 @@ class BaseEvaluator(BaseModel): networking: Optional[VpcConfig] = None kms_key_id: Optional[str] = None model_package_group: Optional[Union[str, ModelPackageGroup]] = None + hub_name: Optional[str] = None class Config: arbitrary_types_allowed = True @@ -315,7 +316,8 @@ def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: d # Resolve model information model_info = _resolve_base_model( base_model=v, - sagemaker_session=session + sagemaker_session=session, + hub_name=values.get('hub_name') ) # If model is a ModelPackage object or ARN (has source_model_package_arn), diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index d6bad422c6..d9ce1dd952 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -466,7 +466,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=self.hub_name or "SageMakerPublicHub", evaluation_type=evaluation_type, region=region, session=boto_session diff --git a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py index 78d297006c..f7533c5f47 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py @@ -240,7 +240,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=self.hub_name or "SageMakerPublicHub", evaluation_type="DeterministicEvaluation", region=region, session=boto_session @@ -365,7 +365,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict: _logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}") override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=self.hub_name or "SageMakerPublicHub", evaluation_type="DeterministicEvaluation", region=region, session=session diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index db19a5e1d9..d22f348b0a 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -163,7 +163,8 @@ def __init__( self.training_type, self.sagemaker_session or TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session - )) + ), + hub_name=self.hub_name) # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -263,7 +264,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, self.hub_name) # Build TrainingJob.create() arguments create_args = { @@ -358,7 +359,7 @@ def _process_non_builtin_reward_prompt(self): sagemaker_session=self.sagemaker_session ) hub_content = _get_hub_content_metadata( - hub_name=HUB_NAME, + hub_name=self.hub_name, hub_content_type="JsonDoc", hub_content_name=self.reward_prompt, session=session.boto_session, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 8a11cfb0d8..4c06b197ef 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -153,7 +153,8 @@ def __init__( self.training_type, self.sagemaker_session or TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session - )) + ), + hub_name=self.hub_name) # Remove constructor-handled hyperparameters self._process_hyperparameters() @@ -251,7 +252,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, self.hub_name) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 80465c061d..12635ad2a2 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -145,7 +145,8 @@ def __init__( self.training_type, self.sagemaker_session or TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session - )) + ), + hub_name=self.hub_name) # Process hyperparameters self._process_hyperparameters() @@ -245,7 +246,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, self.hub_name) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py index d87a435ba0..f13621b950 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py @@ -1013,3 +1013,52 @@ def test_benchmark_evaluator_uses_evaluation_metric_key_for_non_nova(mock_artifa assert 'evaluation_metric' in additions assert additions['evaluation_metric'] == 'accuracy' assert 'metric' not in additions + + +@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn') +@patch('sagemaker.train.common_utils.recipe_utils._is_nova_model') +@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options') +@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params') +@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') +@patch('sagemaker.core.resources.Artifact') +def test_benchmark_evaluator_custom_hub_name_forwarded( + mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_is_nova, mock_resolve_mlflow +): + """Custom hub_name on BenchMarkEvaluator is forwarded to hub override-params lookup.""" + mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN + mock_info = Mock() + mock_info.base_model_name = DEFAULT_MODEL + mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN + mock_info.source_model_package_arn = None + mock_resolve.return_value = mock_info + + mock_artifact.get_all.return_value = iter([]) + mock_artifact_instance = Mock() + mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN + mock_artifact.create.return_value = mock_artifact_instance + + mock_session = Mock() + mock_session.boto_region_name = DEFAULT_REGION + mock_session.boto_session = Mock() + mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE + mock_session.sagemaker_config = None + + mock_is_nova.return_value = False + mock_get_params.return_value = {'temperature': 0.7} + mock_extract_options.return_value = {'temperature': {'value': 0.7}} + + evaluator = BenchMarkEvaluator( + benchmark=_Benchmark.MMLU, + model=DEFAULT_MODEL, + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + hub_name="MyPrivateHub", + ) + + # Trigger lazy-loaded hyperparameters to hit the hub lookup + _ = evaluator.hyperparameters + + assert evaluator.hub_name == "MyPrivateHub" + assert mock_get_params.call_args.kwargs["hub_name"] == "MyPrivateHub" diff --git a/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py index 9267cc7f73..4ea99f8c49 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py @@ -1129,3 +1129,51 @@ def test_custom_scorer_evaluator_no_lambda_type_for_non_nova_models( assert 'evaluation_metric' in additions assert additions['evaluation_metric'] == 'all' assert 'metric' not in additions + + +@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn') +@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options') +@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params') +@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') +@patch('sagemaker.core.resources.Artifact') +def test_custom_scorer_evaluator_custom_hub_name_forwarded( + mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_resolve_mlflow +): + """Custom hub_name on CustomScorerEvaluator is forwarded to hub override-params lookup.""" + mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN + mock_info = Mock() + mock_info.base_model_name = DEFAULT_MODEL + mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN + mock_info.source_model_package_arn = None + mock_resolve.return_value = mock_info + + mock_artifact.get_all.return_value = iter([]) + mock_artifact_instance = Mock() + mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN + mock_artifact.create.return_value = mock_artifact_instance + + mock_session = Mock() + mock_session.boto_region_name = DEFAULT_REGION + mock_session.boto_session = Mock() + mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE + mock_session.sagemaker_config = None + + mock_get_params.return_value = {'temperature': 0.5} + mock_extract_options.return_value = {'temperature': {'value': 0.5}} + + evaluator = CustomScorerEvaluator( + evaluator=DEFAULT_EVALUATOR_ARN, + dataset=DEFAULT_DATASET, + model=DEFAULT_MODEL, + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + hub_name="MyPrivateHub", + ) + + # Trigger lazy-loaded hyperparameters to hit the hub lookup + _ = evaluator.hyperparameters + + assert evaluator.hub_name == "MyPrivateHub" + assert mock_get_params.call_args.kwargs["hub_name"] == "MyPrivateHub" diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 93a4b18fa9..2ca89a331f 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -378,3 +378,35 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): assert trainer.stopping_condition == stopping_condition assert trainer.stopping_condition.max_runtime_in_seconds == 14400 + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_hub_name_defaults_to_public_hub(self, mock_finetuning_options, mock_validate_group): + """hub_name defaults to 'SageMakerPublicHub' and is forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = DPOTrainer(model="test-model", model_package_group="test-group") + + assert trainer.hub_name == "SageMakerPublicHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "SageMakerPublicHub" + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_custom_hub_name_forwarded(self, mock_finetuning_options, mock_validate_group): + """Custom hub_name is stored on the trainer and forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = DPOTrainer( + model="test-model", + model_package_group="test-group", + hub_name="MyPrivateHub", + ) + + assert trainer.hub_name == "MyPrivateHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "MyPrivateHub" diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index be8b9b96b6..e788c965d8 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -554,3 +554,55 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): assert trainer.stopping_condition == stopping_condition assert trainer.stopping_condition.max_runtime_in_seconds == 86400 + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_hub_name_defaults_to_public_hub(self, mock_finetuning_options, mock_validate_group): + """hub_name defaults to 'SageMakerPublicHub' and is forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group") + + assert trainer.hub_name == "SageMakerPublicHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "SageMakerPublicHub" + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_custom_hub_name_forwarded(self, mock_finetuning_options, mock_validate_group): + """Custom hub_name is stored on the trainer and forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = RLAIFTrainer( + model="test-model", + model_package_group="test-group", + hub_name="MyPrivateHub", + ) + + assert trainer.hub_name == "MyPrivateHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "MyPrivateHub" + + def test_process_non_builtin_reward_prompt_uses_custom_hub_name(self): + """Non-builtin reward prompt lookup uses trainer's hub_name.""" + mock_hyperparams = Mock() + mock_hyperparams._specs = {} + + trainer = RLAIFTrainer.__new__(RLAIFTrainer) + trainer.hyperparameters = mock_hyperparams + trainer.reward_prompt = "custom-prompt-name" + trainer.sagemaker_session = None + trainer.hub_name = "MyPrivateHub" + + with patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') as mock_session, \ + patch('sagemaker.train.rlaif_trainer._get_hub_content_metadata') as mock_hub: + mock_session.return_value = Mock(boto_session=Mock(region_name="us-west-2")) + mock_hub.return_value = Mock(hub_content_arn="hub-content-arn") + + trainer._process_non_builtin_reward_prompt() + + assert mock_hub.call_args.kwargs["hub_name"] == "MyPrivateHub" diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 4ee785285e..be92f1946c 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -381,3 +381,35 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): assert trainer.stopping_condition == stopping_condition assert trainer.stopping_condition.max_runtime_in_seconds == 259200 + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_hub_name_defaults_to_public_hub(self, mock_finetuning_options, mock_validate_group): + """hub_name defaults to 'SageMakerPublicHub' and is forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group") + + assert trainer.hub_name == "SageMakerPublicHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "SageMakerPublicHub" + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_custom_hub_name_forwarded(self, mock_finetuning_options, mock_validate_group): + """Custom hub_name is stored on the trainer and forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = RLVRTrainer( + model="test-model", + model_package_group="test-group", + hub_name="MyPrivateHub", + ) + + assert trainer.hub_name == "MyPrivateHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "MyPrivateHub" diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 6af829e1a7..c3494f35fa 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -392,3 +392,35 @@ def test_default_stopping_condition_is_none(self, mock_finetuning, mock_validate trainer = SFTTrainer(model="test-model", model_package_group="test-group") assert trainer.stopping_condition is None + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_hub_name_defaults_to_public_hub(self, mock_finetuning_options, mock_validate_group): + """hub_name defaults to 'SageMakerPublicHub' and is forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = SFTTrainer(model="test-model", model_package_group="test-group") + + assert trainer.hub_name == "SageMakerPublicHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "SageMakerPublicHub" + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_custom_hub_name_forwarded(self, mock_finetuning_options, mock_validate_group): + """Custom hub_name is stored on the trainer and forwarded to fine-tuning options lookup.""" + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + + trainer = SFTTrainer( + model="test-model", + model_package_group="test-group", + hub_name="MyPrivateHub", + ) + + assert trainer.hub_name == "MyPrivateHub" + assert mock_finetuning_options.call_args.kwargs["hub_name"] == "MyPrivateHub"