@@ -864,3 +864,222 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
864864 # Should still have standard params, just not datamix ones
865865 assert "max_steps" in options ._specs
866866 assert "customer_data_percent" not in options ._specs
867+
868+
869+ class TestSubscriptionOnlyModelFallback :
870+ """Tests for models that only have subscription recipes (e.g., Nova Micro v2)."""
871+
872+ @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
873+ def test_fallback_to_subscription_recipe_lora (self , mock_get_hub_content ):
874+ """When no non-subscription LORA recipe exists, falls back to subscription recipe."""
875+ mock_session = Mock ()
876+ mock_session .boto_session .region_name = "us-east-1"
877+ mock_s3 = Mock ()
878+ mock_sts = Mock ()
879+ mock_sts .get_caller_identity .return_value = {"Account" : "123456789012" }
880+ mock_session .boto_session .client .side_effect = lambda service , ** kwargs : mock_s3 if service == "s3" else mock_sts
881+
882+ # Only subscription recipes (like Nova Micro v2)
883+ mock_get_hub_content .return_value = {
884+ 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2" ,
885+ 'hub_content_document' : {
886+ "GatedBucket" : False ,
887+ "RecipeCollection" : [
888+ {
889+ "CustomizationTechnique" : "SFT" ,
890+ "SmtjRecipeTemplateS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml" ,
891+ "SmtjOverrideParamsS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json" ,
892+ "Peft" : True ,
893+ "IsSubscriptionModel" : True ,
894+ "Name" : "nova_micro_v2_sft_lora"
895+ },
896+ {
897+ "CustomizationTechnique" : "SFT" ,
898+ "SmtjRecipeTemplateS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml" ,
899+ "SmtjOverrideParamsS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json" ,
900+ "Peft" : False ,
901+ "IsSubscriptionModel" : True ,
902+ "Name" : "nova_micro_v2_sft_full"
903+ }
904+ ]
905+ }
906+ }
907+
908+ sub_params = json .dumps ({"max_steps" : {"type" : "integer" , "required" : True , "default" : 100 }})
909+ mock_s3 .get_object .return_value = {"Body" : Mock (read = Mock (return_value = sub_params .encode ()))}
910+
911+ options , model_arn , is_gated = _get_fine_tuning_options_and_model_arn (
912+ "nova-textgeneration-micro-v2" , "SFT" , "LORA" , mock_session ,
913+ )
914+
915+ assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2"
916+ assert "max_steps" in options ._specs
917+ assert is_gated is False
918+
919+ @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
920+ def test_fallback_to_subscription_recipe_full (self , mock_get_hub_content ):
921+ """When no non-subscription FULL recipe exists, falls back to subscription recipe."""
922+ mock_session = Mock ()
923+ mock_session .boto_session .region_name = "us-east-1"
924+ mock_s3 = Mock ()
925+ mock_sts = Mock ()
926+ mock_sts .get_caller_identity .return_value = {"Account" : "123456789012" }
927+ mock_session .boto_session .client .side_effect = lambda service , ** kwargs : mock_s3 if service == "s3" else mock_sts
928+
929+ mock_get_hub_content .return_value = {
930+ 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2" ,
931+ 'hub_content_document' : {
932+ "GatedBucket" : False ,
933+ "RecipeCollection" : [
934+ {
935+ "CustomizationTechnique" : "SFT" ,
936+ "SmtjRecipeTemplateS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml" ,
937+ "SmtjOverrideParamsS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json" ,
938+ "Peft" : False ,
939+ "IsSubscriptionModel" : True ,
940+ "Name" : "nova_micro_v2_sft_full"
941+ }
942+ ]
943+ }
944+ }
945+
946+ sub_params = json .dumps ({"learning_rate" : {"type" : "float" , "required" : True , "default" : 5e-6 }})
947+ mock_s3 .get_object .return_value = {"Body" : Mock (read = Mock (return_value = sub_params .encode ()))}
948+
949+ options , model_arn , is_gated = _get_fine_tuning_options_and_model_arn (
950+ "nova-textgeneration-micro-v2" , "SFT" , "FULL" , mock_session ,
951+ )
952+
953+ assert "learning_rate" in options ._specs
954+
955+ @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
956+ def test_subscription_only_model_access_denied_raises_clear_error (self , mock_get_hub_content ):
957+ """When subscription recipe download fails (AccessDenied), raises actionable ValueError."""
958+ mock_session = Mock ()
959+ mock_session .boto_session .region_name = "us-east-1"
960+ mock_s3 = Mock ()
961+ mock_sts = Mock ()
962+ mock_sts .get_caller_identity .return_value = {"Account" : "999999999999" }
963+ mock_session .boto_session .client .side_effect = lambda service , ** kwargs : mock_s3 if service == "s3" else mock_sts
964+
965+ mock_get_hub_content .return_value = {
966+ 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2" ,
967+ 'hub_content_document' : {
968+ "GatedBucket" : False ,
969+ "RecipeCollection" : [
970+ {
971+ "CustomizationTechnique" : "SFT" ,
972+ "SmtjRecipeTemplateS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml" ,
973+ "SmtjOverrideParamsS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json" ,
974+ "Peft" : True ,
975+ "IsSubscriptionModel" : True ,
976+ "Name" : "nova_micro_v2_sft_lora"
977+ }
978+ ]
979+ }
980+ }
981+
982+ # Simulate AccessDenied from S3 access point
983+ from botocore .exceptions import ClientError
984+ mock_s3 .get_object .side_effect = ClientError (
985+ {"Error" : {"Code" : "AccessDenied" , "Message" : "Could not access through this access point" }},
986+ "GetObject"
987+ )
988+
989+ with pytest .raises (ValueError ) as exc_info :
990+ _get_fine_tuning_options_and_model_arn (
991+ "nova-textgeneration-micro-v2" , "SFT" , "LORA" , mock_session ,
992+ )
993+
994+ error_msg = str (exc_info .value )
995+ assert "subscription" in error_msg .lower ()
996+ assert "nova-textgeneration-micro-v2" in error_msg
997+
998+ @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
999+ def test_non_subscription_recipe_preferred_over_subscription (self , mock_get_hub_content ):
1000+ """When both standard and subscription recipes exist, standard is selected as primary."""
1001+ mock_session = Mock ()
1002+ mock_session .boto_session .region_name = "us-east-1"
1003+ mock_s3 = Mock ()
1004+ mock_sts = Mock ()
1005+ mock_sts .get_caller_identity .return_value = {"Account" : "123456789012" }
1006+ mock_session .boto_session .client .side_effect = lambda service , ** kwargs : mock_s3 if service == "s3" else mock_sts
1007+
1008+ mock_get_hub_content .return_value = {
1009+ 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/nova-lite-v2" ,
1010+ 'hub_content_document' : {
1011+ "GatedBucket" : False ,
1012+ "RecipeCollection" : [
1013+ {
1014+ "CustomizationTechnique" : "SFT" ,
1015+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/standard_template.yaml" ,
1016+ "SmtjOverrideParamsS3Uri" : "s3://bucket/standard_params.json" ,
1017+ "Peft" : True ,
1018+ "Name" : "nova_lite_v2_sft_lora"
1019+ },
1020+ {
1021+ "CustomizationTechnique" : "SFT" ,
1022+ "SmtjRecipeTemplateS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml" ,
1023+ "SmtjOverrideParamsS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json" ,
1024+ "Peft" : True ,
1025+ "IsSubscriptionModel" : True ,
1026+ "Name" : "nova_lite_v2_sft_lora_datamix"
1027+ }
1028+ ]
1029+ }
1030+ }
1031+
1032+ standard_params = json .dumps ({"max_steps" : {"type" : "integer" , "required" : True , "default" : 100 }})
1033+ datamix_params = json .dumps ({"customer_data_percent" : {"type" : "integer" , "required" : False , "default" : 50 }})
1034+ mock_s3 .get_object .side_effect = [
1035+ {"Body" : Mock (read = Mock (return_value = standard_params .encode ()))},
1036+ {"Body" : Mock (read = Mock (return_value = datamix_params .encode ()))},
1037+ ]
1038+
1039+ options , model_arn , is_gated = _get_fine_tuning_options_and_model_arn (
1040+ "nova-textgeneration-lite-v2" , "SFT" , "LORA" , mock_session ,
1041+ )
1042+
1043+ # Standard recipe's params should be loaded as primary
1044+ assert "max_steps" in options ._specs
1045+ assert options ._specs ["max_steps" ]["default" ] == 100
1046+ # Subscription params merged with None defaults
1047+ assert "customer_data_percent" in options ._specs
1048+ assert options ._specs ["customer_data_percent" ]["default" ] is None
1049+
1050+ @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
1051+ def test_subscription_only_skips_overlay_merge (self , mock_get_hub_content ):
1052+ """When primary recipe IS a subscription recipe, overlay merge is skipped."""
1053+ mock_session = Mock ()
1054+ mock_session .boto_session .region_name = "us-east-1"
1055+ mock_s3 = Mock ()
1056+ mock_sts = Mock ()
1057+ mock_sts .get_caller_identity .return_value = {"Account" : "123456789012" }
1058+ mock_session .boto_session .client .side_effect = lambda service , ** kwargs : mock_s3 if service == "s3" else mock_sts
1059+
1060+ mock_get_hub_content .return_value = {
1061+ 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2" ,
1062+ 'hub_content_document' : {
1063+ "GatedBucket" : False ,
1064+ "RecipeCollection" : [
1065+ {
1066+ "CustomizationTechnique" : "SFT" ,
1067+ "SmtjRecipeTemplateS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml" ,
1068+ "SmtjOverrideParamsS3Uri" : "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json" ,
1069+ "Peft" : True ,
1070+ "IsSubscriptionModel" : True ,
1071+ "Name" : "nova_micro_v2_sft_lora"
1072+ }
1073+ ]
1074+ }
1075+ }
1076+
1077+ sub_params = json .dumps ({"max_steps" : {"type" : "integer" , "required" : True , "default" : 100 }})
1078+ mock_s3 .get_object .return_value = {"Body" : Mock (read = Mock (return_value = sub_params .encode ()))}
1079+
1080+ options , model_arn , is_gated = _get_fine_tuning_options_and_model_arn (
1081+ "nova-textgeneration-micro-v2" , "SFT" , "LORA" , mock_session ,
1082+ )
1083+
1084+ # S3 get_object should only be called once (primary recipe), not twice (no overlay merge)
1085+ assert mock_s3 .get_object .call_count == 1
0 commit comments