Skip to content

Commit 7bd5df9

Browse files
committed
feat(train): Add SequenceLength support for SFT, DPO, RLVR, RLAIF trainers
Add optional sequence_length parameter to all four trainers that enables customers to specify their desired context length for serverless training jobs. The parameter is passed in ServerlessJobConfig for recipe filtering. During trainer initialization, _get_fine_tuning_options_and_model_arn filters recipes by SequenceLength field, picking the smallest recipe with context length >= the requested value. Raises ValueError if no sufficient recipe exists or if recipes lack SequenceLength metadata. Changes: - ServerlessJobConfig: add sequence_length field - _parse_context_length: parse values like '8K' to integers - _get_fine_tuning_options_and_model_arn: filter by SequenceLength - _create_serverless_config: conditionally include sequence_length - SFTTrainer, DPOTrainer, RLVRTrainer, RLAIFTrainer: accept and thread sequence_length through init and train methods - Unit tests for all new functionality
1 parent 4374751 commit 7bd5df9

11 files changed

Lines changed: 525 additions & 68 deletions

File tree

sagemaker-core/src/sagemaker/core/shapes/shapes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9588,6 +9588,7 @@ class ServerlessJobConfig(Base):
95889588
peft: The parameter-efficient fine-tuning configuration.
95899589
evaluation_type: The evaluation job type. Required when serverless job type is Evaluation.
95909590
evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.
9591+
sequence_length: The sequence length for the training job. Valid values are "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
95919592
"""
95929593

95939594
base_model_arn: StrPipeVar
@@ -9597,7 +9598,7 @@ class ServerlessJobConfig(Base):
95979598
peft: Optional[StrPipeVar] = Unassigned()
95989599
evaluation_type: Optional[StrPipeVar] = Unassigned()
95999600
evaluator_arn: Optional[StrPipeVar] = Unassigned()
9600-
9601+
sequence_length: Optional[StrPipeVar] = Unassigned()
96019602

96029603
class MlflowConfig(Base):
96039604
"""

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]:
318318
return None
319319

320320

321-
def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session,
322-
hub_name: Optional[str] = None) -> tuple:
321+
def _parse_context_length(value) -> int:
322+
"""Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192).
323+
324+
Returns 0 if value is None or unparseable.
325+
"""
326+
if not value:
327+
return 0
328+
value = str(value).strip().upper()
329+
if value.endswith("K"):
330+
try:
331+
return int(value[:-1]) * 1024
332+
except ValueError:
333+
return 0
334+
try:
335+
return int(value)
336+
except ValueError:
337+
return 0
338+
339+
340+
def _get_fine_tuning_options_and_model_arn(
341+
model_name: str,
342+
customization_technique: str,
343+
training_type,
344+
sagemaker_session,
345+
sequence_length=None,
346+
hub_name: str = "SageMakerPublicHub"
347+
) -> tuple:
323348
"""Get fine-tuning options and model ARN for given customization technique.
324349
350+
Args:
351+
model_name: Name of the model in the hub.
352+
customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF").
353+
training_type: TrainingType enum or string ("LORA", "FULL").
354+
sagemaker_session: SageMaker session for API calls.
355+
sequence_length: Optional sequence length (e.g., "8K"). When provided, filters
356+
recipes by MaxContextLength >= the requested value.
357+
hub_name: Hub name (default: "SageMakerPublicHub").
358+
325359
Returns:
326360
tuple: (FineTuningOptions, model_arn, is_gated_model)
327361
"""
@@ -362,9 +396,34 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
362396
# Collect override_params from ALL matching recipes (standard + subscription)
363397
recipe = None
364398
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
365-
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
399+
candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")]
366400
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
367-
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
401+
candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")]
402+
else:
403+
candidates = []
404+
405+
# Filter by SequenceLength if sequence_length is provided
406+
if sequence_length and candidates:
407+
requested = _parse_context_length(sequence_length)
408+
candidates_with_context = [r for r in candidates if r.get("SequenceLength")]
409+
if candidates_with_context:
410+
filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested]
411+
if filtered:
412+
filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength")))
413+
recipe = filtered[0]
414+
else:
415+
available = sorted(set(r.get("SequenceLength") for r in candidates_with_context))
416+
raise ValueError(
417+
f"No recipes found with SequenceLength >= {sequence_length}. "
418+
f"Available sequence lengths: {available}"
419+
)
420+
else:
421+
raise ValueError(
422+
f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, "
423+
f"and sequence length:{sequence_length}"
424+
)
425+
elif candidates:
426+
recipe = candidates[0]
368427

369428
if not recipe:
370429
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
@@ -519,7 +578,8 @@ def _resolve_model_and_name(model, sagemaker_session=None):
519578

520579

521580
def _create_serverless_config(model_arn, customization_technique,
522-
training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']:
581+
training_type, accept_eula, evaluator_arn=None,
582+
sequence_length=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']:
523583
"""Create serverless job configuration for fine-tuning.
524584
525585
Args:
@@ -528,6 +588,7 @@ def _create_serverless_config(model_arn, customization_technique,
528588
training_type: Training type (TrainingType enum or string)
529589
accept_eula: Boolean indicating if EULA is accepted
530590
evaluator_arn: Optional evaluator ARN for RLVR/RLAIF
591+
sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K")
531592
job_type: Type of job (default: "FineTuning")
532593
533594
Returns:
@@ -537,14 +598,18 @@ def _create_serverless_config(model_arn, customization_technique,
537598
else (training_type.value if isinstance(training_type, TrainingType) else training_type)
538599

539600
# Create ServerlessJobConfig using shapes
540-
serverless_config = ServerlessJobConfig(
601+
config_kwargs = dict(
541602
job_type=job_type,
542603
base_model_arn=model_arn,
543604
customization_technique=customization_technique,
544605
peft=peft,
545606
evaluator_arn=evaluator_arn,
546-
accept_eula=accept_eula
607+
accept_eula=accept_eula,
547608
)
609+
if sequence_length is not None:
610+
config_kwargs["sequence_length"] = sequence_length
611+
612+
serverless_config = ServerlessJobConfig(**config_kwargs)
548613

549614
return serverless_config
550615

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ class DPOTrainer(BaseTrainer):
100100
stopping_condition (Optional[StoppingCondition]):
101101
The stopping condition to override training runtime limit.
102102
If not specified, uses SageMaker service default (24 hours for serverless training).
103+
sequence_length (Optional[str]):
104+
The sequence length for the training job. Valid values are
105+
"1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
106+
If not specified, the service will use default recipe selection behavior.
103107
"""
104108
def __init__(
105109
self,
@@ -116,6 +120,7 @@ def __init__(
116120
networking: Optional[VpcConfig] = None,
117121
accept_eula: bool = False,
118122
stopping_condition: Optional[StoppingCondition] = None,
123+
sequence_length: Optional[str] = None,
119124
**kwargs,
120125
):
121126
super().__init__(**kwargs)
@@ -134,16 +139,17 @@ def __init__(
134139
self.kms_key_id = kms_key_id
135140
self.networking = networking
136141
self.stopping_condition = stopping_condition
142+
self.sequence_length = sequence_length
137143

138144
# Initialize fine-tuning options with beta session fallback
139-
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
140-
CustomizationTechnique.DPO.value,
141-
self.training_type,
142-
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
143-
sagemaker_session=self.sagemaker_session
144-
145-
))
146-
145+
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(
146+
self._model_name,
147+
CustomizationTechnique.DPO.value,
148+
self.training_type,
149+
self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
150+
sequence_length=self.sequence_length
151+
)
152+
147153
# Process hyperparameters
148154
self._process_hyperparameters()
149155

@@ -227,12 +233,14 @@ def train(self,
227233
kms_key_id=self.kms_key_id
228234
)
229235

230-
serverless_config = _create_serverless_config(model_arn=self._model_arn,
231-
customization_technique=CustomizationTechnique.DPO.value,
232-
training_type=self.training_type,
233-
accept_eula=self.accept_eula,
234-
job_type=JOB_TYPE
235-
)
236+
serverless_config = _create_serverless_config(
237+
model_arn=self._model_arn,
238+
customization_technique=CustomizationTechnique.DPO.value,
239+
training_type=self.training_type,
240+
accept_eula=self.accept_eula,
241+
sequence_length=self.sequence_length,
242+
job_type=JOB_TYPE
243+
)
236244

237245
mlflow_config = _create_mlflow_config(
238246
sagemaker_session,

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class RLAIFTrainer(BaseTrainer):
114114
stopping_condition (Optional[StoppingCondition]):
115115
The stopping condition to override training runtime limit.
116116
If not specified, uses SageMaker service default (24 hours for serverless training).
117+
sequence_length (Optional[str]):
118+
The sequence length for the training job. Valid values are
119+
"1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
120+
If not specified, the service will use default recipe selection behavior.
117121
"""
118122

119123
def __init__(
@@ -135,6 +139,7 @@ def __init__(
135139
networking: Optional[VpcConfig] = None,
136140
accept_eula: bool = False,
137141
stopping_condition: Optional[StoppingCondition] = None,
142+
sequence_length: Optional[str] = None,
138143
**kwargs,
139144
):
140145
super().__init__(**kwargs)
@@ -156,14 +161,16 @@ def __init__(
156161
self.kms_key_id = kms_key_id
157162
self.networking = networking
158163
self.stopping_condition = stopping_condition
164+
self.sequence_length = sequence_length
159165

160166
# Initialize fine-tuning options with beta session fallback
161-
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
162-
CustomizationTechnique.RLAIF.value,
163-
self.training_type,
164-
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
165-
sagemaker_session=self.sagemaker_session
166-
))
167+
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(
168+
self._model_name,
169+
CustomizationTechnique.RLAIF.value,
170+
self.training_type,
171+
self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
172+
sequence_length=self.sequence_length
173+
)
167174

168175
# Validate and set EULA acceptance
169176
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
@@ -242,13 +249,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
242249
)
243250

244251
evaluator_arn = getattr(self, '_evaluator_arn', None)
245-
serverless_config = _create_serverless_config(model_arn=self._model_arn,
246-
customization_technique=CustomizationTechnique.RLAIF.value,
247-
training_type=self.training_type,
248-
accept_eula=self.accept_eula,
249-
evaluator_arn=evaluator_arn,
250-
job_type=JOB_TYPE
251-
)
252+
serverless_config = _create_serverless_config(
253+
model_arn=self._model_arn,
254+
customization_technique=CustomizationTechnique.RLAIF.value,
255+
training_type=self.training_type,
256+
accept_eula=self.accept_eula,
257+
evaluator_arn=evaluator_arn,
258+
sequence_length=self.sequence_length,
259+
job_type=JOB_TYPE
260+
)
252261

253262
mlflow_config = _create_mlflow_config(
254263
sagemaker_session,

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ class RLVRTrainer(BaseTrainer):
106106
stopping_condition (Optional[StoppingCondition]):
107107
The stopping condition to override training runtime limit.
108108
If not specified, uses SageMaker service default (24 hours for serverless training).
109+
sequence_length (Optional[str]):
110+
The sequence length for the training job. Valid values are
111+
"1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K".
112+
If not specified, the service will use default recipe selection behavior.
109113
"""
110114

111115
def __init__(
@@ -126,6 +130,7 @@ def __init__(
126130
networking: Optional[VpcConfig] = None,
127131
accept_eula: bool = False,
128132
stopping_condition: Optional[StoppingCondition] = None,
133+
sequence_length: Optional[str] = None,
129134
**kwargs,
130135
):
131136
super().__init__(**kwargs)
@@ -146,15 +151,17 @@ def __init__(
146151
self.kms_key_id = kms_key_id
147152
self.networking = networking
148153
self.stopping_condition = stopping_condition
154+
self.sequence_length = sequence_length
149155

150156
# Initialize fine-tuning options with beta session fallback
151-
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
152-
CustomizationTechnique.RLVR.value,
153-
self.training_type,
154-
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
155-
sagemaker_session=self.sagemaker_session
156-
))
157-
157+
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(
158+
self._model_name,
159+
CustomizationTechnique.RLVR.value,
160+
self.training_type,
161+
self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session),
162+
sequence_length=self.sequence_length
163+
)
164+
158165
# Remove constructor-handled hyperparameters
159166
self._process_hyperparameters()
160167

@@ -233,13 +240,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
233240

234241
# Extract and validate evaluator ARN
235242
evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None
236-
serverless_config = _create_serverless_config(model_arn=self._model_arn,
237-
customization_technique=CustomizationTechnique.RLVR.value,
238-
training_type=self.training_type,
239-
accept_eula=self.accept_eula,
240-
evaluator_arn=evaluator_arn,
241-
job_type=JOB_TYPE
242-
)
243+
serverless_config = _create_serverless_config(
244+
model_arn=self._model_arn,
245+
customization_technique=CustomizationTechnique.RLVR.value,
246+
training_type=self.training_type,
247+
accept_eula=self.accept_eula,
248+
evaluator_arn=evaluator_arn,
249+
sequence_length=self.sequence_length,
250+
job_type=JOB_TYPE
251+
)
243252
mlflow_config = _create_mlflow_config(
244253
sagemaker_session,
245254
mlflow_resource_arn=self.mlflow_resource_arn,

0 commit comments

Comments
 (0)