Skip to content

Commit 7cdd30f

Browse files
authored
fix(train): fall back to subscription recipes for models without base recipes (#5946)
* fix(train): fall back to subscription recipes for models without base recipes * Adding unit testing for nova subscription recipe fix * Change failed retrieval of subscription recipes log to warning * Updating unit tests
1 parent 9101cef commit 7cdd30f

2 files changed

Lines changed: 281 additions & 34 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -448,56 +448,84 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
448448
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
449449

450450
# Select recipe based on training type
451-
# Collect override_params from ALL matching recipes (standard + subscription)
451+
# Prefer non-subscription (standard) recipes first, fall back to subscription recipes
452+
# if no standard recipe exists (some models only have subscription recipes).
452453
recipe = None
453454
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
454455
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
456+
if not recipe:
457+
recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
455458
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
456459
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
460+
if not recipe:
461+
recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
457462

458463
if not recipe:
459464
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
460465

461-
# Start with the standard recipe's override_params
466+
# Start with the selected recipe's override_params
462467
options_dict = {}
463468
if recipe.get("SmtjOverrideParamsS3Uri"):
464469
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
465470
s3 = sagemaker_session.boto_session.client("s3")
471+
# Handle {customer_id} placeholder (subscription recipes use access point URIs)
472+
if "{customer_id}" in s3_uri:
473+
account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]
474+
s3_uri = s3_uri.replace("{customer_id}", account_id)
466475
uri_path = s3_uri.replace("s3://", "")
467-
bucket, key = uri_path.split("/", 1)
468-
obj = s3.get_object(Bucket=bucket, Key=key)
469-
options_dict = json.loads(obj["Body"].read())
470-
471-
# Auto-detect and merge subscription recipe's override_params if available
472-
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
473-
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
474-
else:
475-
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
476-
477-
if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
476+
# Handle access point ARN URIs
477+
if uri_path.startswith("arn:"):
478+
arn_parts = uri_path.split("/", 2)
479+
bucket = arn_parts[0] + "/" + arn_parts[1]
480+
key = arn_parts[2] if len(arn_parts) > 2 else ""
481+
else:
482+
bucket, key = uri_path.split("/", 1)
478483
try:
479-
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
480-
sub_uri_path = sub_s3_uri.replace("s3://", "")
481-
# Handle access point ARN URIs
482-
if sub_uri_path.startswith("arn:"):
483-
arn_parts = sub_uri_path.split("/", 2)
484-
sub_bucket = arn_parts[0] + "/" + arn_parts[1]
485-
sub_key = arn_parts[2] if len(arn_parts) > 2 else ""
486-
else:
487-
sub_bucket, sub_key = sub_uri_path.split("/", 1)
488-
s3_sub = sagemaker_session.boto_session.client("s3")
489-
sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key)
490-
sub_options = json.loads(sub_obj["Body"].read())
491-
# Merge: subscription params into _specs only (don't set defaults)
492-
# This makes them settable but not serialized unless user explicitly sets them
493-
for k, v in sub_options.items():
494-
if k not in options_dict:
495-
v_copy = v.copy() if isinstance(v, dict) else v
496-
if isinstance(v_copy, dict):
497-
v_copy['default'] = None # No default — won't appear in to_dict() unless set
498-
options_dict[k] = v_copy
484+
obj = s3.get_object(Bucket=bucket, Key=key)
485+
options_dict = json.loads(obj["Body"].read())
499486
except Exception as e:
500-
logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
487+
if recipe.get("IsSubscriptionModel"):
488+
raise ValueError(
489+
f"Could not access subscription recipe for model '{model_name}'. "
490+
f"This model only provides subscription-based recipes. "
491+
f"Please verify that your account has an active Nova Forge subscription. "
492+
f"Refer: https://docs.aws.amazon.com/sagemaker/latest/dg/nova-forge.html#nova-forge-prereq-access"
493+
) from e
494+
else:
495+
raise
496+
497+
# Auto-detect and merge subscription recipe's override_params if available
498+
# (only needed when the primary recipe is NOT a subscription recipe)
499+
if not recipe.get("IsSubscriptionModel"):
500+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
501+
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
502+
else:
503+
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
504+
505+
if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
506+
try:
507+
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
508+
sub_uri_path = sub_s3_uri.replace("s3://", "")
509+
# Handle access point ARN URIs
510+
if sub_uri_path.startswith("arn:"):
511+
arn_parts = sub_uri_path.split("/", 2)
512+
sub_bucket = arn_parts[0] + "/" + arn_parts[1]
513+
sub_key = arn_parts[2] if len(arn_parts) > 2 else ""
514+
else:
515+
sub_bucket, sub_key = sub_uri_path.split("/", 1)
516+
s3_sub = sagemaker_session.boto_session.client("s3")
517+
sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key)
518+
sub_options = json.loads(sub_obj["Body"].read())
519+
# Merge: subscription params into _specs only (don't set defaults)
520+
# This makes them settable but not serialized unless user explicitly sets them
521+
for k, v in sub_options.items():
522+
if k not in options_dict:
523+
v_copy = v.copy() if isinstance(v, dict) else v
524+
if isinstance(v_copy, dict):
525+
v_copy['default'] = None # No default — won't appear in to_dict() unless set
526+
options_dict[k] = v_copy
527+
except Exception as e:
528+
logger.warning(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
501529

502530
if options_dict:
503531
return FineTuningOptions(options_dict), model_arn, is_gated_model

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,222 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
864864
# Should still have standard params, just not datamix ones
865865
assert "max_steps" in options._specs
866866
assert "customer_data_percent" not in options._specs
867+
868+
869+
class TestSubscriptionOnlyModelFallback:
870+
"""Tests for models that only have subscription recipes."""
871+
872+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
873+
def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content):
874+
"""When no non-subscription LORA recipe exists, falls back to subscription recipe."""
875+
mock_session = Mock()
876+
mock_session.boto_session.region_name = "us-east-1"
877+
mock_s3 = Mock()
878+
mock_sts = Mock()
879+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
880+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
881+
882+
# Only subscription recipes exist for this model
883+
mock_get_hub_content.return_value = {
884+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
885+
'hub_content_document': {
886+
"GatedBucket": False,
887+
"RecipeCollection": [
888+
{
889+
"CustomizationTechnique": "SFT",
890+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
891+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
892+
"Peft": True,
893+
"IsSubscriptionModel": True,
894+
"Name": "subscription_model_sft_lora"
895+
},
896+
{
897+
"CustomizationTechnique": "SFT",
898+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
899+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json",
900+
"Peft": False,
901+
"IsSubscriptionModel": True,
902+
"Name": "subscription_model_sft_full"
903+
}
904+
]
905+
}
906+
}
907+
908+
sub_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
909+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
910+
911+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
912+
"subscription-only-model", "SFT", "LORA", mock_session,
913+
)
914+
915+
assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model"
916+
assert "max_steps" in options._specs
917+
assert is_gated is False
918+
919+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
920+
def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content):
921+
"""When no non-subscription FULL recipe exists, falls back to subscription recipe."""
922+
mock_session = Mock()
923+
mock_session.boto_session.region_name = "us-east-1"
924+
mock_s3 = Mock()
925+
mock_sts = Mock()
926+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
927+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
928+
929+
mock_get_hub_content.return_value = {
930+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
931+
'hub_content_document': {
932+
"GatedBucket": False,
933+
"RecipeCollection": [
934+
{
935+
"CustomizationTechnique": "SFT",
936+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
937+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json",
938+
"Peft": False,
939+
"IsSubscriptionModel": True,
940+
"Name": "subscription_model_sft_full"
941+
}
942+
]
943+
}
944+
}
945+
946+
sub_params = json.dumps({"learning_rate": {"type": "float", "required": True, "default": 5e-6}})
947+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
948+
949+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
950+
"subscription-only-model", "SFT", "FULL", mock_session,
951+
)
952+
953+
assert "learning_rate" in options._specs
954+
955+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
956+
def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get_hub_content):
957+
"""When subscription recipe download fails (AccessDenied), raises actionable ValueError."""
958+
mock_session = Mock()
959+
mock_session.boto_session.region_name = "us-east-1"
960+
mock_s3 = Mock()
961+
mock_sts = Mock()
962+
mock_sts.get_caller_identity.return_value = {"Account": "999999999999"}
963+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
964+
965+
mock_get_hub_content.return_value = {
966+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
967+
'hub_content_document': {
968+
"GatedBucket": False,
969+
"RecipeCollection": [
970+
{
971+
"CustomizationTechnique": "SFT",
972+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
973+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
974+
"Peft": True,
975+
"IsSubscriptionModel": True,
976+
"Name": "subscription_model_sft_lora"
977+
}
978+
]
979+
}
980+
}
981+
982+
# Simulate AccessDenied from S3 access point
983+
from botocore.exceptions import ClientError
984+
mock_s3.get_object.side_effect = ClientError(
985+
{"Error": {"Code": "AccessDenied", "Message": "Could not access through this access point"}},
986+
"GetObject"
987+
)
988+
989+
with pytest.raises(ValueError) as exc_info:
990+
_get_fine_tuning_options_and_model_arn(
991+
"subscription-only-model", "SFT", "LORA", mock_session,
992+
)
993+
994+
error_msg = str(exc_info.value)
995+
assert "subscription" in error_msg.lower()
996+
assert "subscription-only-model" in error_msg
997+
998+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
999+
def test_non_subscription_recipe_preferred_over_subscription(self, mock_get_hub_content):
1000+
"""When both standard and subscription recipes exist, standard is selected as primary."""
1001+
mock_session = Mock()
1002+
mock_session.boto_session.region_name = "us-east-1"
1003+
mock_s3 = Mock()
1004+
mock_sts = Mock()
1005+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
1006+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
1007+
1008+
mock_get_hub_content.return_value = {
1009+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-lite-v2",
1010+
'hub_content_document': {
1011+
"GatedBucket": False,
1012+
"RecipeCollection": [
1013+
{
1014+
"CustomizationTechnique": "SFT",
1015+
"SmtjRecipeTemplateS3Uri": "s3://bucket/standard_template.yaml",
1016+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
1017+
"Peft": True,
1018+
"Name": "nova_lite_v2_sft_lora"
1019+
},
1020+
{
1021+
"CustomizationTechnique": "SFT",
1022+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
1023+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
1024+
"Peft": True,
1025+
"IsSubscriptionModel": True,
1026+
"Name": "nova_lite_v2_sft_lora_datamix"
1027+
}
1028+
]
1029+
}
1030+
}
1031+
1032+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
1033+
datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}})
1034+
mock_s3.get_object.side_effect = [
1035+
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
1036+
{"Body": Mock(read=Mock(return_value=datamix_params.encode()))},
1037+
]
1038+
1039+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
1040+
"nova-textgeneration-lite-v2", "SFT", "LORA", mock_session,
1041+
)
1042+
1043+
# Standard recipe's params should be loaded as primary
1044+
assert "max_steps" in options._specs
1045+
assert options._specs["max_steps"]["default"] == 100
1046+
# Subscription params merged with None defaults
1047+
assert "customer_data_percent" in options._specs
1048+
assert options._specs["customer_data_percent"]["default"] is None
1049+
1050+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
1051+
def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content):
1052+
"""When primary recipe IS a subscription recipe, overlay merge is skipped."""
1053+
mock_session = Mock()
1054+
mock_session.boto_session.region_name = "us-east-1"
1055+
mock_s3 = Mock()
1056+
mock_sts = Mock()
1057+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
1058+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
1059+
1060+
mock_get_hub_content.return_value = {
1061+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
1062+
'hub_content_document': {
1063+
"GatedBucket": False,
1064+
"RecipeCollection": [
1065+
{
1066+
"CustomizationTechnique": "SFT",
1067+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
1068+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
1069+
"Peft": True,
1070+
"IsSubscriptionModel": True,
1071+
"Name": "subscription_model_sft_lora"
1072+
}
1073+
]
1074+
}
1075+
}
1076+
1077+
sub_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
1078+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
1079+
1080+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
1081+
"subscription-only-model", "SFT", "LORA", mock_session,
1082+
)
1083+
1084+
# S3 get_object should only be called once (primary recipe), not twice (no overlay merge)
1085+
assert mock_s3.get_object.call_count == 1

0 commit comments

Comments
 (0)