Skip to content

Commit 37a5996

Browse files
committed
refactor: preserve original recipe selection path when sequence_length not provided
Keep the existing `next(...)` logic untouched for the default case (no sequence_length). Only build the candidates list and filter when sequence_length is explicitly requested, ensuring zero behavioral change for existing callers.
1 parent 13e0f49 commit 37a5996

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -396,14 +396,19 @@ def _get_fine_tuning_options_and_model_arn(
396396
# Collect override_params from ALL matching recipes (standard + subscription)
397397
recipe = None
398398
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
399-
candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")]
399+
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
400400
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
401-
candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")]
402-
else:
403-
candidates = []
401+
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
402+
403+
# Override recipe selection when sequence_length is explicitly requested
404+
if sequence_length:
405+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
406+
candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")]
407+
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
408+
candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")]
409+
else:
410+
candidates = []
404411

405-
# Filter by SequenceLength if sequence_length is provided
406-
if sequence_length and candidates:
407412
requested = _parse_context_length(sequence_length)
408413
candidates_with_context = [r for r in candidates if r.get("SequenceLength")]
409414
if candidates_with_context:
@@ -422,8 +427,6 @@ def _get_fine_tuning_options_and_model_arn(
422427
f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, "
423428
f"and sequence length:{sequence_length}"
424429
)
425-
elif candidates:
426-
recipe = candidates[0]
427430

428431
if not recipe:
429432
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")

0 commit comments

Comments
 (0)