Skip to content

Commit 3d36503

Browse files
committed
address PR review: sequence_length as recipe pre-filter and simplify config
- Move sequence_length filtering above recipe selection to reduce recipes_with_template before existing logic runs - Always pass sequence_length to ServerlessJobConfig (no None guard)
1 parent 1c6f0ef commit 3d36503

2 files changed

Lines changed: 14 additions & 25 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -392,30 +392,15 @@ def _get_fine_tuning_options_and_model_arn(
392392
if not recipes_with_template:
393393
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
394394

395-
# Select recipe based on training type
396-
# Collect override_params from ALL matching recipes (standard + subscription)
397-
recipe = None
398-
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
399-
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
400-
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
401-
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
402-
403-
# Override recipe selection when sequence_length is explicitly requested
395+
# Filter by SequenceLength before recipe selection if sequence_length is requested
404396
if sequence_length:
405-
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
406-
candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")]
407-
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
408-
candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")]
409-
else:
410-
candidates = []
411-
412397
requested = _parse_context_length(sequence_length)
413-
candidates_with_context = [r for r in candidates if r.get("SequenceLength")]
398+
candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")]
414399
if candidates_with_context:
415400
filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested]
416401
if filtered:
417402
filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength")))
418-
recipe = filtered[0]
403+
recipes_with_template = filtered
419404
else:
420405
available = sorted(set(r.get("SequenceLength") for r in candidates_with_context))
421406
raise ValueError(
@@ -428,6 +413,14 @@ def _get_fine_tuning_options_and_model_arn(
428413
f"and sequence length:{sequence_length}"
429414
)
430415

416+
# Select recipe based on training type
417+
# Collect override_params from ALL matching recipes (standard + subscription)
418+
recipe = None
419+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
420+
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
421+
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
422+
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
423+
431424
if not recipe:
432425
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
433426

@@ -601,18 +594,15 @@ def _create_serverless_config(model_arn, customization_technique,
601594
else (training_type.value if isinstance(training_type, TrainingType) else training_type)
602595

603596
# Create ServerlessJobConfig using shapes
604-
config_kwargs = dict(
597+
serverless_config = ServerlessJobConfig(
605598
job_type=job_type,
606599
base_model_arn=model_arn,
607600
customization_technique=customization_technique,
608601
peft=peft,
609602
evaluator_arn=evaluator_arn,
610603
accept_eula=accept_eula,
604+
sequence_length=sequence_length,
611605
)
612-
if sequence_length is not None:
613-
config_kwargs["sequence_length"] = sequence_length
614-
615-
serverless_config = ServerlessJobConfig(**config_kwargs)
616606

617607
return serverless_config
618608

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,7 @@ def test__create_serverless_config_with_sequence_length(self):
703703
def test__create_serverless_config_without_sequence_length(self):
704704
config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True)
705705

706-
# sequence_length should remain Unassigned (not set), not None
707-
assert isinstance(config.sequence_length, Unassigned)
706+
assert config.sequence_length is None
708707

709708
def test__parse_context_length_with_k_suffix(self):
710709
assert _parse_context_length("8K") == 8192

0 commit comments

Comments
 (0)