Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.common_utils.finetune_utils import (
_get_fine_tuning_options_and_model_arn,
_validate_and_resolve_model_package_group,
Expand Down Expand Up @@ -96,6 +97,9 @@ class DPOTrainer(BaseTrainer):
The KMS key ID for encrypting training job outputs.
networking (Optional[VpcConfig]):
The VPC configuration for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, defaults to 1 hour max runtime.
"""
def __init__(
self,
Expand All @@ -111,6 +115,7 @@ def __init__(
kms_key_id: Optional[str] = None,
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -128,6 +133,7 @@ def __init__(
self.s3_output_path = s3_output_path
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition

# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
Expand Down Expand Up @@ -198,6 +204,10 @@ def train(self,
current_training_job_name = _get_unique_name(
self.base_job_name or f"{self._model_name}-dpo"
)

stopping_condition = TrainDefaults.get_stopping_condition(
stopping_condition=self.stopping_condition
)

logger.info(f"Training Job Name: {current_training_job_name}")
print(f"Training Job Name: {current_training_job_name}")
Expand Down Expand Up @@ -251,6 +261,7 @@ def train(self,
hyper_parameters=final_hyperparameters,
model_package_config=model_package_config,
vpc_config=vpc_config,
stopping_condition=stopping_condition,
session=sagemaker_session.boto_session,
region=sagemaker_session.boto_session.region_name,
tags=tags,
Expand Down
11 changes: 11 additions & 0 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.common_utils.finetune_utils import (
_get_beta_session,
_get_fine_tuning_options_and_model_arn,
Expand Down Expand Up @@ -110,6 +111,9 @@ class RLAIFTrainer(BaseTrainer):
The KMS key ID for encrypting training job outputs.
networking (Optional[VpcConfig]):
The VPC configuration for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, defaults to 1 hour max runtime.
"""

def __init__(
Expand All @@ -130,6 +134,7 @@ def __init__(
# vpc config
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -150,6 +155,7 @@ def __init__(
self.s3_output_path = s3_output_path
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition

# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
Expand Down Expand Up @@ -215,6 +221,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
current_training_job_name = _get_unique_name(
self.base_job_name or f"{self._model_name}-rlaif"
)

stopping_condition = TrainDefaults.get_stopping_condition(
stopping_condition=self.stopping_condition
)

logger.info(f"Training Job Name: {current_training_job_name}")

Expand Down Expand Up @@ -270,6 +280,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
hyper_parameters=final_hyperparameters,
model_package_config=model_package_config,
vpc_config=vpc_config,
stopping_condition=stopping_condition,
session=sagemaker_session.boto_session,
region=sagemaker_session.boto_session.region_name,
tags=tags,
Expand Down
11 changes: 11 additions & 0 deletions sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.common_utils.finetune_utils import (
_get_fine_tuning_options_and_model_arn,
_validate_and_resolve_model_package_group,
Expand Down Expand Up @@ -102,6 +103,9 @@ class RLVRTrainer(BaseTrainer):
The KMS key ID for encrypting training job outputs.
networking (Optional[VpcConfig]):
The VPC configuration for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, defaults to 1 hour max runtime.
"""

def __init__(
Expand All @@ -121,6 +125,7 @@ def __init__(
# vpc config
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -140,6 +145,7 @@ def __init__(
self.s3_output_path = s3_output_path
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition

# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
Expand Down Expand Up @@ -202,6 +208,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
current_training_job_name = _get_unique_name(
self.base_job_name or f"{self._model_name}-rlvr"
)

stopping_condition = TrainDefaults.get_stopping_condition(
stopping_condition=self.stopping_condition
)

logger.info(f"Training Job Name: {current_training_job_name}")

Expand Down Expand Up @@ -258,6 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
hyper_parameters=final_hyperparameters,
model_package_config=model_package_config,
vpc_config=vpc_config,
stopping_condition=stopping_condition,
session=sagemaker_session.boto_session,
region=sagemaker_session.boto_session.region_name,
tags=tags,
Expand Down
11 changes: 11 additions & 0 deletions sagemaker-train/src/sagemaker/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.common_utils.finetune_utils import (
_get_fine_tuning_options_and_model_arn,
_validate_and_resolve_model_package_group,
Expand Down Expand Up @@ -98,6 +99,9 @@ class SFTTrainer(BaseTrainer):
The KMS key ID for encrypting training job outputs.
networking (Optional[VpcConfig]):
The VPC configuration for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, defaults to 1 hour max runtime.
"""

def __init__(
Expand All @@ -114,6 +118,7 @@ def __init__(
kms_key_id: Optional[str] = None,
networking: Optional[VpcConfig] = None,
accept_eula: Optional[bool] = False,
stopping_condition: Optional[StoppingCondition] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -132,6 +137,7 @@ def __init__(
self.s3_output_path = s3_output_path
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition

# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
Expand Down Expand Up @@ -199,6 +205,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
current_training_job_name = _get_unique_name(
self.base_job_name or f"{self._model_name}-sft"
)

stopping_condition = TrainDefaults.get_stopping_condition(
stopping_condition=self.stopping_condition
)

logger.info(f"Training Job Name: {current_training_job_name}")

Expand Down Expand Up @@ -252,6 +262,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
hyper_parameters=final_hyperparameters,
model_package_config=model_package_config,
vpc_config=vpc_config,
stopping_condition=stopping_condition,
session=sagemaker_session.boto_session,
region=sagemaker_session.boto_session.region_name,
tags=tags,
Expand Down
21 changes: 21 additions & 0 deletions sagemaker-train/tests/unit/train/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,24 @@ def test_process_hyperparameters_with_none_hyperparameters(self):

# Should not raise an exception
trainer._process_hyperparameters()

@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
"""Test DPOTrainer accepts stopping_condition parameter."""
from sagemaker.train.configs import StoppingCondition

mock_validate.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)

stopping_condition = StoppingCondition(max_runtime_in_seconds=14400)
trainer = DPOTrainer(
model="test-model",
model_package_group="test-group",
stopping_condition=stopping_condition
)

assert trainer.stopping_condition == stopping_condition
assert trainer.stopping_condition.max_runtime_in_seconds == 14400
22 changes: 22 additions & 0 deletions sagemaker-train/tests/unit/train/test_rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,25 @@ def test_validate_reward_model_id_none_model(self):

result = trainer._validate_reward_model_id(None)
assert result is None

@patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn')
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
"""Test RLAIFTrainer accepts stopping_condition parameter."""
from sagemaker.train.configs import StoppingCondition

mock_validate.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)

stopping_condition = StoppingCondition(max_runtime_in_seconds=86400)
trainer = RLAIFTrainer(
model="test-model",
model_package_group="test-group",
reward_model_id="openai.gpt-oss-120b-1:0",
stopping_condition=stopping_condition
)

assert trainer.stopping_condition == stopping_condition
assert trainer.stopping_condition.max_runtime_in_seconds == 86400
21 changes: 21 additions & 0 deletions sagemaker-train/tests/unit/train/test_rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,24 @@ def test_process_hyperparameters_with_none_hyperparameters(self):

# Should not raise an exception
trainer._process_hyperparameters()

@patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn')
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
"""Test RLVRTrainer accepts stopping_condition parameter."""
from sagemaker.train.configs import StoppingCondition

mock_validate.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)

stopping_condition = StoppingCondition(max_runtime_in_seconds=259200)
trainer = RLVRTrainer(
model="test-model",
model_package_group="test-group",
stopping_condition=stopping_condition
)

assert trainer.stopping_condition == stopping_condition
assert trainer.stopping_condition.max_runtime_in_seconds == 259200
33 changes: 33 additions & 0 deletions sagemaker-train/tests/unit/train/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,36 @@ def test_process_hyperparameters_with_none_hyperparameters(self):

# Should not raise an exception
trainer._process_hyperparameters()

@patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn')
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
"""Test SFTTrainer accepts stopping_condition parameter."""
from sagemaker.train.configs import StoppingCondition

mock_validate.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)

stopping_condition = StoppingCondition(max_runtime_in_seconds=7200)
trainer = SFTTrainer(
model="test-model",
model_package_group="test-group",
stopping_condition=stopping_condition
)

assert trainer.stopping_condition == stopping_condition
assert trainer.stopping_condition.max_runtime_in_seconds == 7200

@patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn')
def test_default_stopping_condition_is_none(self, mock_finetuning, mock_validate):
"""Test SFTTrainer defaults stopping_condition to None."""
mock_validate.return_value = "test-group"
mock_hyperparams = Mock()
mock_hyperparams.to_dict.return_value = {}
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)

trainer = SFTTrainer(model="test-model", model_package_group="test-group")
assert trainer.stopping_condition is None
Loading