Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sagemaker-train/src/sagemaker/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sagemaker.core.training.configs import Tag, Networking, InputData, Channel
from sagemaker.core.shapes import shapes
from sagemaker.core.resources import TrainingJob
from sagemaker.train.constants import HUB_NAME


class BaseTrainer(ABC):
Expand Down Expand Up @@ -36,6 +37,10 @@ class BaseTrainer(ABC):
Can include training and validation datasets.
environment (Optional[Dict[str, str]]):
Environment variables to set in the training container.
hub_name (Optional[str]):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a customer accessible feature so we should avoid having this documented where customers can see it

Name of the SageMaker Hub to pull model recipes and metadata from.
Defaults to ``"SageMakerPublicHub"``. Set to a private hub name to test
pre-release recipes (e.g., during development or E2E testing).
"""

# Class-level attributes with default values
Expand All @@ -48,6 +53,7 @@ class BaseTrainer(ABC):
input_data_config: Optional[List[Union[Channel, InputData]]] = None
environment: Optional[Dict[str, str]] = None
latest_training_job: Optional[TrainingJob] = None
hub_name: str = HUB_NAME

def __init__(
self,
Expand All @@ -59,6 +65,7 @@ def __init__(
output_data_config: Optional[shapes.OutputDataConfig] = None,
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
environment: Optional[Dict[str, str]] = None,
hub_name: Optional[str] = None,
):
self.sagemaker_session = sagemaker_session
self.role = role
Expand All @@ -68,6 +75,7 @@ def __init__(
self.output_data_config = output_data_config
self.input_data_config = input_data_config
self.environment = environment or {}
self.hub_name = hub_name or HUB_NAME

def _is_nova_model_for_telemetry(self) -> bool:
"""Check if the model is a Nova model for telemetry tracking."""
Expand Down
5 changes: 3 additions & 2 deletions sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def __init__(
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session

))
),
hub_name=self.hub_name)

# Process hyperparameters
self._process_hyperparameters()
Expand Down Expand Up @@ -244,7 +245,7 @@ def train(self,
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, self.hub_name)

# Build TrainingJob.create() arguments
create_args = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class BaseEvaluator(BaseModel):
networking: Optional[VpcConfig] = None
kms_key_id: Optional[str] = None
model_package_group: Optional[Union[str, ModelPackageGroup]] = None
hub_name: Optional[str] = None

class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -315,7 +316,8 @@ def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: d
# Resolve model information
model_info = _resolve_base_model(
base_model=v,
sagemaker_session=session
sagemaker_session=session,
hub_name=values.get('hub_name')
)

# If model is a ModelPackage object or ARN (has source_model_package_arn),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def hyperparameters(self):

override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
hub_name=self.hub_name or "SageMakerPublicHub",
evaluation_type=evaluation_type,
region=region,
session=boto_session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def hyperparameters(self):

override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
hub_name=self.hub_name or "SageMakerPublicHub",
evaluation_type="DeterministicEvaluation",
region=region,
session=boto_session
Expand Down Expand Up @@ -365,7 +365,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict:
_logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}")
override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
hub_name=self.hub_name or "SageMakerPublicHub",
evaluation_type="DeterministicEvaluation",
region=region,
session=session
Expand Down
7 changes: 4 additions & 3 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def __init__(
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
))
),
hub_name=self.hub_name)

# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
Expand Down Expand Up @@ -263,7 +264,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, self.hub_name)

# Build TrainingJob.create() arguments
create_args = {
Expand Down Expand Up @@ -358,7 +359,7 @@ def _process_non_builtin_reward_prompt(self):
sagemaker_session=self.sagemaker_session
)
hub_content = _get_hub_content_metadata(
hub_name=HUB_NAME,
hub_name=self.hub_name,
hub_content_type="JsonDoc",
hub_content_name=self.reward_prompt,
session=session.boto_session,
Expand Down
5 changes: 3 additions & 2 deletions sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def __init__(
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
))
),
hub_name=self.hub_name)

# Remove constructor-handled hyperparameters
self._process_hyperparameters()
Expand Down Expand Up @@ -251,7 +252,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, self.hub_name)

# Build TrainingJob.create() arguments
create_args = {
Expand Down
5 changes: 3 additions & 2 deletions sagemaker-train/src/sagemaker/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def __init__(
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
))
),
hub_name=self.hub_name)

# Process hyperparameters
self._process_hyperparameters()
Expand Down Expand Up @@ -245,7 +246,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, self.hub_name)

# Build TrainingJob.create() arguments
create_args = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,52 @@ def test_benchmark_evaluator_uses_evaluation_metric_key_for_non_nova(mock_artifa
assert 'evaluation_metric' in additions
assert additions['evaluation_metric'] == 'accuracy'
assert 'metric' not in additions


@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn')
@patch('sagemaker.train.common_utils.recipe_utils._is_nova_model')
@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options')
@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params')
@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model')
@patch('sagemaker.core.resources.Artifact')
def test_benchmark_evaluator_custom_hub_name_forwarded(
mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_is_nova, mock_resolve_mlflow
):
"""Custom hub_name on BenchMarkEvaluator is forwarded to hub override-params lookup."""
mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN
mock_info = Mock()
mock_info.base_model_name = DEFAULT_MODEL
mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN
mock_info.source_model_package_arn = None
mock_resolve.return_value = mock_info

mock_artifact.get_all.return_value = iter([])
mock_artifact_instance = Mock()
mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN
mock_artifact.create.return_value = mock_artifact_instance

mock_session = Mock()
mock_session.boto_region_name = DEFAULT_REGION
mock_session.boto_session = Mock()
mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE
mock_session.sagemaker_config = None

mock_is_nova.return_value = False
mock_get_params.return_value = {'temperature': 0.7}
mock_extract_options.return_value = {'temperature': {'value': 0.7}}

evaluator = BenchMarkEvaluator(
benchmark=_Benchmark.MMLU,
model=DEFAULT_MODEL,
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
hub_name="MyPrivateHub",
)

# Trigger lazy-loaded hyperparameters to hit the hub lookup
_ = evaluator.hyperparameters

assert evaluator.hub_name == "MyPrivateHub"
assert mock_get_params.call_args.kwargs["hub_name"] == "MyPrivateHub"
Original file line number Diff line number Diff line change
Expand Up @@ -1129,3 +1129,51 @@ def test_custom_scorer_evaluator_no_lambda_type_for_non_nova_models(
assert 'evaluation_metric' in additions
assert additions['evaluation_metric'] == 'all'
assert 'metric' not in additions


@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn')
@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options')
@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params')
@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model')
@patch('sagemaker.core.resources.Artifact')
def test_custom_scorer_evaluator_custom_hub_name_forwarded(
mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_resolve_mlflow
):
"""Custom hub_name on CustomScorerEvaluator is forwarded to hub override-params lookup."""
mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN
mock_info = Mock()
mock_info.base_model_name = DEFAULT_MODEL
mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN
mock_info.source_model_package_arn = None
mock_resolve.return_value = mock_info

mock_artifact.get_all.return_value = iter([])
mock_artifact_instance = Mock()
mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN
mock_artifact.create.return_value = mock_artifact_instance

mock_session = Mock()
mock_session.boto_region_name = DEFAULT_REGION
mock_session.boto_session = Mock()
mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE
mock_session.sagemaker_config = None

mock_get_params.return_value = {'temperature': 0.5}
mock_extract_options.return_value = {'temperature': {'value': 0.5}}

evaluator = CustomScorerEvaluator(
evaluator=DEFAULT_EVALUATOR_ARN,
dataset=DEFAULT_DATASET,
model=DEFAULT_MODEL,
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
hub_name="MyPrivateHub",
)

# Trigger lazy-loaded hyperparameters to hit the hub lookup
_ = evaluator.hyperparameters

assert evaluator.hub_name == "MyPrivateHub"
assert mock_get_params.call_args.kwargs["hub_name"] == "MyPrivateHub"
32 changes: 32 additions & 0 deletions sagemaker-train/tests/unit/train/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,35 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):

assert trainer.stopping_condition == stopping_condition
assert trainer.stopping_condition.max_runtime_in_seconds == 14400

@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
def test_hub_name_defaults_to_public_hub(self, mock_finetuning_options, mock_validate_group):
"""hub_name defaults to 'SageMakerPublicHub' and is forwarded to fine-tuning options lookup."""
mock_validate_group.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)

trainer = DPOTrainer(model="test-model", model_package_group="test-group")

assert trainer.hub_name == "SageMakerPublicHub"
assert mock_finetuning_options.call_args.kwargs["hub_name"] == "SageMakerPublicHub"

@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
def test_custom_hub_name_forwarded(self, mock_finetuning_options, mock_validate_group):
"""Custom hub_name is stored on the trainer and forwarded to fine-tuning options lookup."""
mock_validate_group.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)

trainer = DPOTrainer(
model="test-model",
model_package_group="test-group",
hub_name="MyPrivateHub",
)

assert trainer.hub_name == "MyPrivateHub"
assert mock_finetuning_options.call_args.kwargs["hub_name"] == "MyPrivateHub"
52 changes: 52 additions & 0 deletions sagemaker-train/tests/unit/train/test_rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,55 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):

assert trainer.stopping_condition == stopping_condition
assert trainer.stopping_condition.max_runtime_in_seconds == 86400

@patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn')
def test_hub_name_defaults_to_public_hub(self, mock_finetuning_options, mock_validate_group):
"""hub_name defaults to 'SageMakerPublicHub' and is forwarded to fine-tuning options lookup."""
mock_validate_group.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)

trainer = RLAIFTrainer(model="test-model", model_package_group="test-group")

assert trainer.hub_name == "SageMakerPublicHub"
assert mock_finetuning_options.call_args.kwargs["hub_name"] == "SageMakerPublicHub"

@patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn')
def test_custom_hub_name_forwarded(self, mock_finetuning_options, mock_validate_group):
"""Custom hub_name is stored on the trainer and forwarded to fine-tuning options lookup."""
mock_validate_group.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False)

trainer = RLAIFTrainer(
model="test-model",
model_package_group="test-group",
hub_name="MyPrivateHub",
)

assert trainer.hub_name == "MyPrivateHub"
assert mock_finetuning_options.call_args.kwargs["hub_name"] == "MyPrivateHub"

def test_process_non_builtin_reward_prompt_uses_custom_hub_name(self):
"""Non-builtin reward prompt lookup uses trainer's hub_name."""
mock_hyperparams = Mock()
mock_hyperparams._specs = {}

trainer = RLAIFTrainer.__new__(RLAIFTrainer)
trainer.hyperparameters = mock_hyperparams
trainer.reward_prompt = "custom-prompt-name"
trainer.sagemaker_session = None
trainer.hub_name = "MyPrivateHub"

with patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') as mock_session, \
patch('sagemaker.train.rlaif_trainer._get_hub_content_metadata') as mock_hub:
mock_session.return_value = Mock(boto_session=Mock(region_name="us-west-2"))
mock_hub.return_value = Mock(hub_content_arn="hub-content-arn")

trainer._process_non_builtin_reward_prompt()

assert mock_hub.call_args.kwargs["hub_name"] == "MyPrivateHub"
Loading
Loading