88from sagemaker .train .defaults import TrainDefaults
99from sagemaker .train .utils import _get_unique_name , _get_studio_tags
1010from sagemaker .ai_registry .dataset import DataSet
11+ from sagemaker .train .configs import StoppingCondition
1112from sagemaker .train .common_utils .finetune_utils import (
1213 _get_fine_tuning_options_and_model_arn ,
1314 _validate_and_resolve_model_package_group ,
@@ -98,6 +99,9 @@ class SFTTrainer(BaseTrainer):
9899 The KMS key ID for encrypting training job outputs.
99100 networking (Optional[VpcConfig]):
100101 The VPC configuration for the training job.
102+ stopping_condition (Optional[StoppingCondition]):
103+ The stopping condition to override training runtime limit.
104+ If not specified, defaults to 1 hour max runtime.
101105 """
102106
103107 def __init__ (
@@ -114,6 +118,7 @@ def __init__(
114118 kms_key_id : Optional [str ] = None ,
115119 networking : Optional [VpcConfig ] = None ,
116120 accept_eula : Optional [bool ] = False ,
121+ stopping_condition : Optional [StoppingCondition ] = None ,
117122 ** kwargs ,
118123 ):
119124 super ().__init__ (** kwargs )
@@ -132,6 +137,7 @@ def __init__(
132137 self .s3_output_path = s3_output_path
133138 self .kms_key_id = kms_key_id
134139 self .networking = networking
140+ self .stopping_condition = stopping_condition
135141
136142 # Initialize fine-tuning options with beta session fallback
137143 self .hyperparameters , self ._model_arn , is_gated_model = _get_fine_tuning_options_and_model_arn (self ._model_name ,
@@ -199,6 +205,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
199205 current_training_job_name = _get_unique_name (
200206 self .base_job_name or f"{ self ._model_name } -sft"
201207 )
208+
209+ stopping_condition = TrainDefaults .get_stopping_condition (
210+ stopping_condition = self .stopping_condition
211+ )
202212
203213 logger .info (f"Training Job Name: { current_training_job_name } " )
204214
@@ -252,6 +262,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
252262 hyper_parameters = final_hyperparameters ,
253263 model_package_config = model_package_config ,
254264 vpc_config = vpc_config ,
265+ stopping_condition = stopping_condition ,
255266 session = sagemaker_session .boto_session ,
256267 region = sagemaker_session .boto_session .region_name ,
257268 tags = tags ,
0 commit comments