Skip to content

Commit 1752e7a

Browse files
Fix: Remove default for stopping condition for MC trainer (#5586)
1 parent 2396310 commit 1752e7a

File tree

4 files changed

+88
-80
lines changed

4 files changed

+88
-80
lines changed

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class DPOTrainer(BaseTrainer):
9999
The VPC configuration for the training job.
100100
stopping_condition (Optional[StoppingCondition]):
101101
The stopping condition to override training runtime limit.
102-
If not specified, defaults to 1 hour max runtime.
102+
If not specified, uses SageMaker service default (24 hours for serverless training).
103103
"""
104104
def __init__(
105105
self,
@@ -204,10 +204,6 @@ def train(self,
204204
current_training_job_name = _get_unique_name(
205205
self.base_job_name or f"{self._model_name}-dpo"
206206
)
207-
208-
stopping_condition = TrainDefaults.get_stopping_condition(
209-
stopping_condition=self.stopping_condition
210-
)
211207

212208
logger.info(f"Training Job Name: {current_training_job_name}")
213209
print(f"Training Job Name: {current_training_job_name}")
@@ -250,22 +246,28 @@ def train(self,
250246
vpc_config = self.networking if self.networking else None
251247
tags = _get_studio_tags(self._model_name, HUB_NAME)
252248

249+
# Build TrainingJob.create() arguments
250+
create_args = {
251+
"training_job_name": current_training_job_name,
252+
"role_arn": role,
253+
"input_data_config": channels,
254+
"output_data_config": output_config,
255+
"serverless_job_config": serverless_config,
256+
"mlflow_config": mlflow_config,
257+
"hyper_parameters": final_hyperparameters,
258+
"model_package_config": model_package_config,
259+
"vpc_config": vpc_config,
260+
"session": sagemaker_session.boto_session,
261+
"region": sagemaker_session.boto_session.region_name,
262+
"tags": tags,
263+
}
264+
265+
# Only pass stopping_condition if explicitly provided by user
266+
if self.stopping_condition is not None:
267+
create_args["stopping_condition"] = self.stopping_condition
268+
253269
try:
254-
training_job = TrainingJob.create(
255-
training_job_name=current_training_job_name,
256-
role_arn=role,
257-
input_data_config=channels,
258-
output_data_config=output_config,
259-
serverless_job_config=serverless_config,
260-
mlflow_config=mlflow_config,
261-
hyper_parameters=final_hyperparameters,
262-
model_package_config=model_package_config,
263-
vpc_config=vpc_config,
264-
stopping_condition=stopping_condition,
265-
session=sagemaker_session.boto_session,
266-
region=sagemaker_session.boto_session.region_name,
267-
tags=tags,
268-
)
270+
training_job = TrainingJob.create(**create_args)
269271
except Exception as e:
270272
logger.error("Error: %s", e)
271273
raise e

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class RLAIFTrainer(BaseTrainer):
113113
The VPC configuration for the training job.
114114
stopping_condition (Optional[StoppingCondition]):
115115
The stopping condition to override training runtime limit.
116-
If not specified, defaults to 1 hour max runtime.
116+
If not specified, uses SageMaker service default (24 hours for serverless training).
117117
"""
118118

119119
def __init__(
@@ -221,10 +221,6 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
221221
current_training_job_name = _get_unique_name(
222222
self.base_job_name or f"{self._model_name}-rlaif"
223223
)
224-
225-
stopping_condition = TrainDefaults.get_stopping_condition(
226-
stopping_condition=self.stopping_condition
227-
)
228224

229225
logger.info(f"Training Job Name: {current_training_job_name}")
230226

@@ -269,22 +265,28 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
269265
vpc_config = self.networking if self.networking else None
270266
tags = _get_studio_tags(self._model_name, HUB_NAME)
271267

268+
# Build TrainingJob.create() arguments
269+
create_args = {
270+
"training_job_name": current_training_job_name,
271+
"role_arn": role,
272+
"input_data_config": channels,
273+
"output_data_config": output_config,
274+
"serverless_job_config": serverless_config,
275+
"mlflow_config": mlflow_config,
276+
"hyper_parameters": final_hyperparameters,
277+
"model_package_config": model_package_config,
278+
"vpc_config": vpc_config,
279+
"session": sagemaker_session.boto_session,
280+
"region": sagemaker_session.boto_session.region_name,
281+
"tags": tags,
282+
}
283+
284+
# Only pass stopping_condition if explicitly provided by user
285+
if self.stopping_condition is not None:
286+
create_args["stopping_condition"] = self.stopping_condition
287+
272288
try:
273-
training_job = TrainingJob.create(
274-
training_job_name=current_training_job_name,
275-
role_arn=role,
276-
input_data_config=channels,
277-
output_data_config=output_config,
278-
serverless_job_config=serverless_config,
279-
mlflow_config=mlflow_config,
280-
hyper_parameters=final_hyperparameters,
281-
model_package_config=model_package_config,
282-
vpc_config=vpc_config,
283-
stopping_condition=stopping_condition,
284-
session=sagemaker_session.boto_session,
285-
region=sagemaker_session.boto_session.region_name,
286-
tags=tags,
287-
)
289+
training_job = TrainingJob.create(**create_args)
288290
except Exception as e:
289291
logger.error("Error: %s", e)
290292
raise e

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class RLVRTrainer(BaseTrainer):
105105
The VPC configuration for the training job.
106106
stopping_condition (Optional[StoppingCondition]):
107107
The stopping condition to override training runtime limit.
108-
If not specified, defaults to 1 hour max runtime.
108+
If not specified, uses SageMaker service default (24 hours for serverless training).
109109
"""
110110

111111
def __init__(
@@ -208,10 +208,6 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
208208
current_training_job_name = _get_unique_name(
209209
self.base_job_name or f"{self._model_name}-rlvr"
210210
)
211-
212-
stopping_condition = TrainDefaults.get_stopping_condition(
213-
stopping_condition=self.stopping_condition
214-
)
215211

216212
logger.info(f"Training Job Name: {current_training_job_name}")
217213

@@ -257,22 +253,28 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
257253
vpc_config = self.networking if self.networking else None
258254
tags = _get_studio_tags(self._model_name, HUB_NAME)
259255

256+
# Build TrainingJob.create() arguments
257+
create_args = {
258+
"training_job_name": current_training_job_name,
259+
"role_arn": role,
260+
"input_data_config": channels,
261+
"output_data_config": output_config,
262+
"serverless_job_config": serverless_config,
263+
"mlflow_config": mlflow_config,
264+
"hyper_parameters": final_hyperparameters,
265+
"model_package_config": model_package_config,
266+
"vpc_config": vpc_config,
267+
"session": sagemaker_session.boto_session,
268+
"region": sagemaker_session.boto_session.region_name,
269+
"tags": tags,
270+
}
271+
272+
# Only pass stopping_condition if explicitly provided by user
273+
if self.stopping_condition is not None:
274+
create_args["stopping_condition"] = self.stopping_condition
275+
260276
try:
261-
training_job = TrainingJob.create(
262-
training_job_name=current_training_job_name,
263-
role_arn=role,
264-
input_data_config=channels,
265-
output_data_config=output_config,
266-
serverless_job_config=serverless_config,
267-
mlflow_config=mlflow_config,
268-
hyper_parameters=final_hyperparameters,
269-
model_package_config=model_package_config,
270-
vpc_config=vpc_config,
271-
stopping_condition=stopping_condition,
272-
session=sagemaker_session.boto_session,
273-
region=sagemaker_session.boto_session.region_name,
274-
tags=tags,
275-
)
277+
training_job = TrainingJob.create(**create_args)
276278
except Exception as e:
277279
logger.error("Error: %s", e)
278280
raise e

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class SFTTrainer(BaseTrainer):
101101
The VPC configuration for the training job.
102102
stopping_condition (Optional[StoppingCondition]):
103103
The stopping condition to override training runtime limit.
104-
If not specified, defaults to 1 hour max runtime.
104+
If not specified, uses SageMaker service default (24 hours for serverless training).
105105
"""
106106

107107
def __init__(
@@ -205,10 +205,6 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
205205
current_training_job_name = _get_unique_name(
206206
self.base_job_name or f"{self._model_name}-sft"
207207
)
208-
209-
stopping_condition = TrainDefaults.get_stopping_condition(
210-
stopping_condition=self.stopping_condition
211-
)
212208

213209
logger.info(f"Training Job Name: {current_training_job_name}")
214210

@@ -251,22 +247,28 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
251247
vpc_config = self.networking if self.networking else None
252248
tags = _get_studio_tags(self._model_name, HUB_NAME)
253249

250+
# Build TrainingJob.create() arguments
251+
create_args = {
252+
"training_job_name": current_training_job_name,
253+
"role_arn": role,
254+
"input_data_config": channels,
255+
"output_data_config": output_config,
256+
"serverless_job_config": serverless_config,
257+
"mlflow_config": mlflow_config,
258+
"hyper_parameters": final_hyperparameters,
259+
"model_package_config": model_package_config,
260+
"vpc_config": vpc_config,
261+
"session": sagemaker_session.boto_session,
262+
"region": sagemaker_session.boto_session.region_name,
263+
"tags": tags,
264+
}
265+
266+
# Only pass stopping_condition if explicitly provided by user
267+
if self.stopping_condition is not None:
268+
create_args["stopping_condition"] = self.stopping_condition
269+
254270
try:
255-
training_job = TrainingJob.create(
256-
training_job_name=current_training_job_name,
257-
role_arn=role,
258-
input_data_config=channels,
259-
output_data_config=output_config,
260-
serverless_job_config=serverless_config,
261-
mlflow_config=mlflow_config,
262-
hyper_parameters=final_hyperparameters,
263-
model_package_config=model_package_config,
264-
vpc_config=vpc_config,
265-
stopping_condition=stopping_condition,
266-
session=sagemaker_session.boto_session,
267-
region=sagemaker_session.boto_session.region_name,
268-
tags=tags,
269-
)
271+
training_job = TrainingJob.create(**create_args)
270272
except Exception as e:
271273
logger.error("Error: %s", e)
272274
raise e

0 commit comments

Comments
 (0)