@@ -448,56 +448,84 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
448448 raise ValueError (f"No recipes found with Smtj for technique: { customization_technique } " )
449449
450450 # Select recipe based on training type
451- # Collect override_params from ALL matching recipes (standard + subscription)
451+ # Prefer non-subscription (standard) recipes first, fall back to subscription recipes
452+ # if no standard recipe exists (e.g., Nova Micro v2 only has subscription recipes).
452453 recipe = None
453454 if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
454455 recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
456+ if not recipe :
457+ recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and r .get ("IsSubscriptionModel" )), None )
455458 elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
456459 recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
460+ if not recipe :
461+ recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and r .get ("IsSubscriptionModel" )), None )
457462
458463 if not recipe :
459464 raise ValueError (f"No recipes found with Smtj for technique: { customization_technique } ,training_type:{ training_type } " )
460465
461- # Start with the standard recipe's override_params
466+ # Start with the selected recipe's override_params
462467 options_dict = {}
463468 if recipe .get ("SmtjOverrideParamsS3Uri" ):
464469 s3_uri = recipe ["SmtjOverrideParamsS3Uri" ]
465470 s3 = sagemaker_session .boto_session .client ("s3" )
471+ # Handle {customer_id} placeholder (subscription recipes use access point URIs)
472+ if "{customer_id}" in s3_uri :
473+ account_id = sagemaker_session .boto_session .client ("sts" ).get_caller_identity ()["Account" ]
474+ s3_uri = s3_uri .replace ("{customer_id}" , account_id )
466475 uri_path = s3_uri .replace ("s3://" , "" )
467- bucket , key = uri_path .split ("/" , 1 )
468- obj = s3 .get_object (Bucket = bucket , Key = key )
469- options_dict = json .loads (obj ["Body" ].read ())
470-
471- # Auto-detect and merge subscription recipe's override_params if available
472- if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
473- sub_recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and r .get ("IsSubscriptionModel" )), None )
474- else :
475- sub_recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and r .get ("IsSubscriptionModel" )), None )
476-
477- if sub_recipe and sub_recipe .get ("SmtjOverrideParamsS3Uri" ):
476+ # Handle access point ARN URIs
477+ if uri_path .startswith ("arn:" ):
478+ arn_parts = uri_path .split ("/" , 2 )
479+ bucket = arn_parts [0 ] + "/" + arn_parts [1 ]
480+ key = arn_parts [2 ] if len (arn_parts ) > 2 else ""
481+ else :
482+ bucket , key = uri_path .split ("/" , 1 )
478483 try :
479- sub_s3_uri = sub_recipe ["SmtjOverrideParamsS3Uri" ].replace ("{customer_id}" , sagemaker_session .boto_session .client ("sts" ).get_caller_identity ()["Account" ])
480- sub_uri_path = sub_s3_uri .replace ("s3://" , "" )
481- # Handle access point ARN URIs
482- if sub_uri_path .startswith ("arn:" ):
483- arn_parts = sub_uri_path .split ("/" , 2 )
484- sub_bucket = arn_parts [0 ] + "/" + arn_parts [1 ]
485- sub_key = arn_parts [2 ] if len (arn_parts ) > 2 else ""
486- else :
487- sub_bucket , sub_key = sub_uri_path .split ("/" , 1 )
488- s3_sub = sagemaker_session .boto_session .client ("s3" )
489- sub_obj = s3_sub .get_object (Bucket = sub_bucket , Key = sub_key )
490- sub_options = json .loads (sub_obj ["Body" ].read ())
491- # Merge: subscription params into _specs only (don't set defaults)
492- # This makes them settable but not serialized unless user explicitly sets them
493- for k , v in sub_options .items ():
494- if k not in options_dict :
495- v_copy = v .copy () if isinstance (v , dict ) else v
496- if isinstance (v_copy , dict ):
497- v_copy ['default' ] = None # No default — won't appear in to_dict() unless set
498- options_dict [k ] = v_copy
484+ obj = s3 .get_object (Bucket = bucket , Key = key )
485+ options_dict = json .loads (obj ["Body" ].read ())
499486 except Exception as e :
500- logger .debug (f"Could not fetch subscription recipe override_params: { type (e ).__name__ } : { e } " )
487+ if recipe .get ("IsSubscriptionModel" ):
488+ raise ValueError (
489+ f"Could not access subscription recipe for model '{ model_name } '. "
490+ f"This model only provides subscription-based recipes. "
491+ f"Please verify that your account has an active Nova Forge subscription. "
492+ f"Refer: https://docs.aws.amazon.com/sagemaker/latest/dg/nova-forge.html#nova-forge-prereq-access"
493+ ) from e
494+ else :
495+ raise
496+
497+ # Auto-detect and merge subscription recipe's override_params if available
498+ # (only needed when the primary recipe is NOT a subscription recipe)
499+ if not recipe .get ("IsSubscriptionModel" ):
500+ if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
501+ sub_recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and r .get ("IsSubscriptionModel" )), None )
502+ else :
503+ sub_recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and r .get ("IsSubscriptionModel" )), None )
504+
505+ if sub_recipe and sub_recipe .get ("SmtjOverrideParamsS3Uri" ):
506+ try :
507+ sub_s3_uri = sub_recipe ["SmtjOverrideParamsS3Uri" ].replace ("{customer_id}" , sagemaker_session .boto_session .client ("sts" ).get_caller_identity ()["Account" ])
508+ sub_uri_path = sub_s3_uri .replace ("s3://" , "" )
509+ # Handle access point ARN URIs
510+ if sub_uri_path .startswith ("arn:" ):
511+ arn_parts = sub_uri_path .split ("/" , 2 )
512+ sub_bucket = arn_parts [0 ] + "/" + arn_parts [1 ]
513+ sub_key = arn_parts [2 ] if len (arn_parts ) > 2 else ""
514+ else :
515+ sub_bucket , sub_key = sub_uri_path .split ("/" , 1 )
516+ s3_sub = sagemaker_session .boto_session .client ("s3" )
517+ sub_obj = s3_sub .get_object (Bucket = sub_bucket , Key = sub_key )
518+ sub_options = json .loads (sub_obj ["Body" ].read ())
519+ # Merge: subscription params into _specs only (don't set defaults)
520+ # This makes them settable but not serialized unless user explicitly sets them
521+ for k , v in sub_options .items ():
522+ if k not in options_dict :
523+ v_copy = v .copy () if isinstance (v , dict ) else v
524+ if isinstance (v_copy , dict ):
525+ v_copy ['default' ] = None # No default — won't appear in to_dict() unless set
526+ options_dict [k ] = v_copy
527+ except Exception as e :
528+ logger .debug (f"Could not fetch subscription recipe override_params: { type (e ).__name__ } : { e } " )
501529
502530 if options_dict :
503531 return FineTuningOptions (options_dict ), model_arn , is_gated_model
0 commit comments