Skip to content

Commit 764d81e

Browse files
committed
feat: private hub support for MC trainer and evaluator
1 parent 4c184d4 commit 764d81e

14 files changed

Lines changed: 272 additions & 13 deletions

sagemaker-train/src/sagemaker/train/base_trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from sagemaker.core.training.configs import Tag, Networking, InputData, Channel
55
from sagemaker.core.shapes import shapes
66
from sagemaker.core.resources import TrainingJob
7+
from sagemaker.train.constants import HUB_NAME
78

89

910
class BaseTrainer(ABC):
@@ -36,6 +37,10 @@ class BaseTrainer(ABC):
3637
Can include training and validation datasets.
3738
environment (Optional[Dict[str, str]]):
3839
Environment variables to set in the training container.
40+
hub_name (Optional[str]):
41+
Name of the SageMaker Hub to pull model recipes and metadata from.
42+
Defaults to ``"SageMakerPublicHub"``. Set to a private hub name to test
43+
pre-release recipes (e.g., during development or E2E testing).
3944
"""
4045

4146
# Class-level attributes with default values
@@ -48,6 +53,7 @@ class BaseTrainer(ABC):
4853
input_data_config: Optional[List[Union[Channel, InputData]]] = None
4954
environment: Optional[Dict[str, str]] = None
5055
latest_training_job: Optional[TrainingJob] = None
56+
hub_name: str = HUB_NAME
5157

5258
def __init__(
5359
self,
@@ -59,6 +65,7 @@ def __init__(
5965
output_data_config: Optional[shapes.OutputDataConfig] = None,
6066
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
6167
environment: Optional[Dict[str, str]] = None,
68+
hub_name: Optional[str] = None,
6269
):
6370
self.sagemaker_session = sagemaker_session
6471
self.role = role
@@ -68,6 +75,7 @@ def __init__(
6875
self.output_data_config = output_data_config
6976
self.input_data_config = input_data_config
7077
self.environment = environment or {}
78+
self.hub_name = hub_name or HUB_NAME
7179

7280
def _is_nova_model_for_telemetry(self) -> bool:
7381
"""Check if the model is a Nova model for telemetry tracking."""

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def __init__(
142142
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
143143
sagemaker_session=self.sagemaker_session
144144

145-
))
145+
),
146+
hub_name=self.hub_name)
146147

147148
# Process hyperparameters
148149
self._process_hyperparameters()
@@ -244,7 +245,7 @@ def train(self,
244245
)
245246

246247
vpc_config = self.networking if self.networking else None
247-
tags = _get_studio_tags(self._model_name, HUB_NAME)
248+
tags = _get_studio_tags(self._model_name, self.hub_name)
248249

249250
# Build TrainingJob.create() arguments
250251
create_args = {

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class BaseEvaluator(BaseModel):
104104
networking: Optional[VpcConfig] = None
105105
kms_key_id: Optional[str] = None
106106
model_package_group: Optional[Union[str, ModelPackageGroup]] = None
107+
hub_name: Optional[str] = None
107108

108109
class Config:
109110
arbitrary_types_allowed = True
@@ -315,7 +316,8 @@ def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: d
315316
# Resolve model information
316317
model_info = _resolve_base_model(
317318
base_model=v,
318-
sagemaker_session=session
319+
sagemaker_session=session,
320+
hub_name=values.get('hub_name')
319321
)
320322

321323
# If model is a ModelPackage object or ARN (has source_model_package_arn),

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def hyperparameters(self):
466466

467467
override_params = _get_evaluation_override_params(
468468
hub_content_name=hub_content_name,
469-
hub_name="SageMakerPublicHub",
469+
hub_name=self.hub_name or "SageMakerPublicHub",
470470
evaluation_type=evaluation_type,
471471
region=region,
472472
session=boto_session

sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def hyperparameters(self):
240240

241241
override_params = _get_evaluation_override_params(
242242
hub_content_name=hub_content_name,
243-
hub_name="SageMakerPublicHub",
243+
hub_name=self.hub_name or "SageMakerPublicHub",
244244
evaluation_type="DeterministicEvaluation",
245245
region=region,
246246
session=boto_session
@@ -365,7 +365,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict:
365365
_logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}")
366366
override_params = _get_evaluation_override_params(
367367
hub_content_name=hub_content_name,
368-
hub_name="SageMakerPublicHub",
368+
hub_name=self.hub_name or "SageMakerPublicHub",
369369
evaluation_type="DeterministicEvaluation",
370370
region=region,
371371
session=session

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def __init__(
163163
self.training_type,
164164
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
165165
sagemaker_session=self.sagemaker_session
166-
))
166+
),
167+
hub_name=self.hub_name)
167168

168169
# Validate and set EULA acceptance
169170
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
@@ -263,7 +264,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
263264
)
264265

265266
vpc_config = self.networking if self.networking else None
266-
tags = _get_studio_tags(self._model_name, HUB_NAME)
267+
tags = _get_studio_tags(self._model_name, self.hub_name)
267268

268269
# Build TrainingJob.create() arguments
269270
create_args = {
@@ -358,7 +359,7 @@ def _process_non_builtin_reward_prompt(self):
358359
sagemaker_session=self.sagemaker_session
359360
)
360361
hub_content = _get_hub_content_metadata(
361-
hub_name=HUB_NAME,
362+
hub_name=self.hub_name,
362363
hub_content_type="JsonDoc",
363364
hub_content_name=self.reward_prompt,
364365
session=session.boto_session,

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ def __init__(
153153
self.training_type,
154154
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
155155
sagemaker_session=self.sagemaker_session
156-
))
156+
),
157+
hub_name=self.hub_name)
157158

158159
# Remove constructor-handled hyperparameters
159160
self._process_hyperparameters()
@@ -251,7 +252,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
251252
)
252253

253254
vpc_config = self.networking if self.networking else None
254-
tags = _get_studio_tags(self._model_name, HUB_NAME)
255+
tags = _get_studio_tags(self._model_name, self.hub_name)
255256

256257
# Build TrainingJob.create() arguments
257258
create_args = {

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def __init__(
145145
self.training_type,
146146
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
147147
sagemaker_session=self.sagemaker_session
148-
))
148+
),
149+
hub_name=self.hub_name)
149150

150151
# Process hyperparameters
151152
self._process_hyperparameters()
@@ -245,7 +246,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
245246
)
246247

247248
vpc_config = self.networking if self.networking else None
248-
tags = _get_studio_tags(self._model_name, HUB_NAME)
249+
tags = _get_studio_tags(self._model_name, self.hub_name)
249250

250251
# Build TrainingJob.create() arguments
251252
create_args = {

sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,3 +1013,52 @@ def test_benchmark_evaluator_uses_evaluation_metric_key_for_non_nova(mock_artifa
10131013
assert 'evaluation_metric' in additions
10141014
assert additions['evaluation_metric'] == 'accuracy'
10151015
assert 'metric' not in additions
1016+
1017+
1018+
@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn')
1019+
@patch('sagemaker.train.common_utils.recipe_utils._is_nova_model')
1020+
@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options')
1021+
@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params')
1022+
@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model')
1023+
@patch('sagemaker.core.resources.Artifact')
1024+
def test_benchmark_evaluator_custom_hub_name_forwarded(
1025+
mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_is_nova, mock_resolve_mlflow
1026+
):
1027+
"""Custom hub_name on BenchMarkEvaluator is forwarded to hub override-params lookup."""
1028+
mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN
1029+
mock_info = Mock()
1030+
mock_info.base_model_name = DEFAULT_MODEL
1031+
mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN
1032+
mock_info.source_model_package_arn = None
1033+
mock_resolve.return_value = mock_info
1034+
1035+
mock_artifact.get_all.return_value = iter([])
1036+
mock_artifact_instance = Mock()
1037+
mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN
1038+
mock_artifact.create.return_value = mock_artifact_instance
1039+
1040+
mock_session = Mock()
1041+
mock_session.boto_region_name = DEFAULT_REGION
1042+
mock_session.boto_session = Mock()
1043+
mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE
1044+
mock_session.sagemaker_config = None
1045+
1046+
mock_is_nova.return_value = False
1047+
mock_get_params.return_value = {'temperature': 0.7}
1048+
mock_extract_options.return_value = {'temperature': {'value': 0.7}}
1049+
1050+
evaluator = BenchMarkEvaluator(
1051+
benchmark=_Benchmark.MMLU,
1052+
model=DEFAULT_MODEL,
1053+
s3_output_path=DEFAULT_S3_OUTPUT,
1054+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1055+
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
1056+
sagemaker_session=mock_session,
1057+
hub_name="MyPrivateHub",
1058+
)
1059+
1060+
# Trigger lazy-loaded hyperparameters to hit the hub lookup
1061+
_ = evaluator.hyperparameters
1062+
1063+
assert evaluator.hub_name == "MyPrivateHub"
1064+
assert mock_get_params.call_args.kwargs["hub_name"] == "MyPrivateHub"

sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,3 +1129,51 @@ def test_custom_scorer_evaluator_no_lambda_type_for_non_nova_models(
11291129
assert 'evaluation_metric' in additions
11301130
assert additions['evaluation_metric'] == 'all'
11311131
assert 'metric' not in additions
1132+
1133+
1134+
@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn')
1135+
@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options')
1136+
@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params')
1137+
@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model')
1138+
@patch('sagemaker.core.resources.Artifact')
1139+
def test_custom_scorer_evaluator_custom_hub_name_forwarded(
1140+
mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_resolve_mlflow
1141+
):
1142+
"""Custom hub_name on CustomScorerEvaluator is forwarded to hub override-params lookup."""
1143+
mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN
1144+
mock_info = Mock()
1145+
mock_info.base_model_name = DEFAULT_MODEL
1146+
mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN
1147+
mock_info.source_model_package_arn = None
1148+
mock_resolve.return_value = mock_info
1149+
1150+
mock_artifact.get_all.return_value = iter([])
1151+
mock_artifact_instance = Mock()
1152+
mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN
1153+
mock_artifact.create.return_value = mock_artifact_instance
1154+
1155+
mock_session = Mock()
1156+
mock_session.boto_region_name = DEFAULT_REGION
1157+
mock_session.boto_session = Mock()
1158+
mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE
1159+
mock_session.sagemaker_config = None
1160+
1161+
mock_get_params.return_value = {'temperature': 0.5}
1162+
mock_extract_options.return_value = {'temperature': {'value': 0.5}}
1163+
1164+
evaluator = CustomScorerEvaluator(
1165+
evaluator=DEFAULT_EVALUATOR_ARN,
1166+
dataset=DEFAULT_DATASET,
1167+
model=DEFAULT_MODEL,
1168+
s3_output_path=DEFAULT_S3_OUTPUT,
1169+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1170+
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
1171+
sagemaker_session=mock_session,
1172+
hub_name="MyPrivateHub",
1173+
)
1174+
1175+
# Trigger lazy-loaded hyperparameters to hit the hub lookup
1176+
_ = evaluator.hyperparameters
1177+
1178+
assert evaluator.hub_name == "MyPrivateHub"
1179+
assert mock_get_params.call_args.kwargs["hub_name"] == "MyPrivateHub"

0 commit comments

Comments
 (0)