Skip to content

Commit c90ff8e

Browse files
committed
address PR review: sequence_length as recipe pre-filter and simplify config
- Move sequence_length filtering above recipe selection to reduce recipes_with_template before existing logic runs - Always pass sequence_length to ServerlessJobConfig (no None guard)
1 parent 85299be commit c90ff8e

2 files changed

Lines changed: 14 additions & 25 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -481,30 +481,15 @@ def _get_fine_tuning_options_and_model_arn(
481481
if not recipes_with_template:
482482
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
483483

484-
# Select recipe based on training type
485-
# Collect override_params from ALL matching recipes (standard + subscription)
486-
recipe = None
487-
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
488-
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
489-
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
490-
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
491-
492-
# Override recipe selection when sequence_length is explicitly requested
484+
# Filter by SequenceLength before recipe selection if sequence_length is requested
493485
if sequence_length:
494-
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
495-
candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")]
496-
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
497-
candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")]
498-
else:
499-
candidates = []
500-
501486
requested = _parse_context_length(sequence_length)
502-
candidates_with_context = [r for r in candidates if r.get("SequenceLength")]
487+
candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")]
503488
if candidates_with_context:
504489
filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested]
505490
if filtered:
506491
filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength")))
507-
recipe = filtered[0]
492+
recipes_with_template = filtered
508493
else:
509494
available = sorted(set(r.get("SequenceLength") for r in candidates_with_context))
510495
raise ValueError(
@@ -517,6 +502,14 @@ def _get_fine_tuning_options_and_model_arn(
517502
f"and sequence length:{sequence_length}"
518503
)
519504

505+
# Select recipe based on training type
506+
# Collect override_params from ALL matching recipes (standard + subscription)
507+
recipe = None
508+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
509+
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
510+
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
511+
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
512+
520513
if not recipe:
521514
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
522515

@@ -690,18 +683,15 @@ def _create_serverless_config(model_arn, customization_technique,
690683
else (training_type.value if isinstance(training_type, TrainingType) else training_type)
691684

692685
# Create ServerlessJobConfig using shapes
693-
config_kwargs = dict(
686+
serverless_config = ServerlessJobConfig(
694687
job_type=job_type,
695688
base_model_arn=model_arn,
696689
customization_technique=customization_technique,
697690
peft=peft,
698691
evaluator_arn=evaluator_arn,
699692
accept_eula=accept_eula,
693+
sequence_length=sequence_length,
700694
)
701-
if sequence_length is not None:
702-
config_kwargs["sequence_length"] = sequence_length
703-
704-
serverless_config = ServerlessJobConfig(**config_kwargs)
705695

706696
return serverless_config
707697

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,7 @@ def test__create_serverless_config_with_sequence_length(self):
876876
def test__create_serverless_config_without_sequence_length(self):
877877
config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True)
878878

879-
# sequence_length should remain Unassigned (not set), not None
880-
assert isinstance(config.sequence_length, Unassigned)
879+
assert config.sequence_length is None
881880

882881
def test__parse_context_length_with_k_suffix(self):
883882
assert _parse_context_length("8K") == 8192

0 commit comments

Comments
 (0)