Skip to content

Commit 69b7533

Browse files
committed
fix(train): fall back to subscription recipes for models without base recipes
1 parent a15a449 commit 69b7533

1 file changed

Lines changed: 62 additions & 34 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -448,56 +448,84 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
448448
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
449449

450450
# Select recipe based on training type
451-
# Collect override_params from ALL matching recipes (standard + subscription)
451+
# Prefer non-subscription (standard) recipes first, fall back to subscription recipes
452+
# if no standard recipe exists (e.g., Nova Micro v2 only has subscription recipes).
452453
recipe = None
453454
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
454455
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
456+
if not recipe:
457+
recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
455458
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
456459
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
460+
if not recipe:
461+
recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
457462

458463
if not recipe:
459464
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
460465

461-
# Start with the standard recipe's override_params
466+
# Start with the selected recipe's override_params
462467
options_dict = {}
463468
if recipe.get("SmtjOverrideParamsS3Uri"):
464469
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
465470
s3 = sagemaker_session.boto_session.client("s3")
471+
# Handle {customer_id} placeholder (subscription recipes use access point URIs)
472+
if "{customer_id}" in s3_uri:
473+
account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]
474+
s3_uri = s3_uri.replace("{customer_id}", account_id)
466475
uri_path = s3_uri.replace("s3://", "")
467-
bucket, key = uri_path.split("/", 1)
468-
obj = s3.get_object(Bucket=bucket, Key=key)
469-
options_dict = json.loads(obj["Body"].read())
470-
471-
# Auto-detect and merge subscription recipe's override_params if available
472-
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
473-
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
474-
else:
475-
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
476-
477-
if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
476+
# Handle access point ARN URIs
477+
if uri_path.startswith("arn:"):
478+
arn_parts = uri_path.split("/", 2)
479+
bucket = arn_parts[0] + "/" + arn_parts[1]
480+
key = arn_parts[2] if len(arn_parts) > 2 else ""
481+
else:
482+
bucket, key = uri_path.split("/", 1)
478483
try:
479-
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
480-
sub_uri_path = sub_s3_uri.replace("s3://", "")
481-
# Handle access point ARN URIs
482-
if sub_uri_path.startswith("arn:"):
483-
arn_parts = sub_uri_path.split("/", 2)
484-
sub_bucket = arn_parts[0] + "/" + arn_parts[1]
485-
sub_key = arn_parts[2] if len(arn_parts) > 2 else ""
486-
else:
487-
sub_bucket, sub_key = sub_uri_path.split("/", 1)
488-
s3_sub = sagemaker_session.boto_session.client("s3")
489-
sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key)
490-
sub_options = json.loads(sub_obj["Body"].read())
491-
# Merge: subscription params into _specs only (don't set defaults)
492-
# This makes them settable but not serialized unless user explicitly sets them
493-
for k, v in sub_options.items():
494-
if k not in options_dict:
495-
v_copy = v.copy() if isinstance(v, dict) else v
496-
if isinstance(v_copy, dict):
497-
v_copy['default'] = None # No default — won't appear in to_dict() unless set
498-
options_dict[k] = v_copy
484+
obj = s3.get_object(Bucket=bucket, Key=key)
485+
options_dict = json.loads(obj["Body"].read())
499486
except Exception as e:
500-
logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
487+
if recipe.get("IsSubscriptionModel"):
488+
raise ValueError(
489+
f"Could not access subscription recipe for model '{model_name}'. "
490+
f"This model only provides subscription-based recipes. "
491+
f"Please verify that your account has an active Nova Forge subscription. "
492+
f"Refer: https://docs.aws.amazon.com/sagemaker/latest/dg/nova-forge.html#nova-forge-prereq-access"
493+
) from e
494+
else:
495+
raise
496+
497+
# Auto-detect and merge subscription recipe's override_params if available
498+
# (only needed when the primary recipe is NOT a subscription recipe)
499+
if not recipe.get("IsSubscriptionModel"):
500+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
501+
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
502+
else:
503+
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
504+
505+
if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
506+
try:
507+
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
508+
sub_uri_path = sub_s3_uri.replace("s3://", "")
509+
# Handle access point ARN URIs
510+
if sub_uri_path.startswith("arn:"):
511+
arn_parts = sub_uri_path.split("/", 2)
512+
sub_bucket = arn_parts[0] + "/" + arn_parts[1]
513+
sub_key = arn_parts[2] if len(arn_parts) > 2 else ""
514+
else:
515+
sub_bucket, sub_key = sub_uri_path.split("/", 1)
516+
s3_sub = sagemaker_session.boto_session.client("s3")
517+
sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key)
518+
sub_options = json.loads(sub_obj["Body"].read())
519+
# Merge: subscription params into _specs only (don't set defaults)
520+
# This makes them settable but not serialized unless user explicitly sets them
521+
for k, v in sub_options.items():
522+
if k not in options_dict:
523+
v_copy = v.copy() if isinstance(v, dict) else v
524+
if isinstance(v_copy, dict):
525+
v_copy['default'] = None # No default — won't appear in to_dict() unless set
526+
options_dict[k] = v_copy
527+
except Exception as e:
528+
logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
501529

502530
if options_dict:
503531
return FineTuningOptions(options_dict), model_arn, is_gated_model

0 commit comments

Comments
 (0)