Skip to content

Commit 667723b

Browse files
committed
Adding unit testing for nova subscription recipe fix
1 parent 69b7533 commit 667723b

1 file changed

Lines changed: 219 additions & 0 deletions

File tree

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,222 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
864864
# Should still have standard params, just not datamix ones
865865
assert "max_steps" in options._specs
866866
assert "customer_data_percent" not in options._specs
867+
868+
869+
class TestSubscriptionOnlyModelFallback:
870+
"""Tests for models that only have subscription recipes (e.g., Nova Micro v2)."""
871+
872+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
873+
def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content):
874+
"""When no non-subscription LORA recipe exists, falls back to subscription recipe."""
875+
mock_session = Mock()
876+
mock_session.boto_session.region_name = "us-east-1"
877+
mock_s3 = Mock()
878+
mock_sts = Mock()
879+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
880+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
881+
882+
# Only subscription recipes (like Nova Micro v2)
883+
mock_get_hub_content.return_value = {
884+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
885+
'hub_content_document': {
886+
"GatedBucket": False,
887+
"RecipeCollection": [
888+
{
889+
"CustomizationTechnique": "SFT",
890+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
891+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
892+
"Peft": True,
893+
"IsSubscriptionModel": True,
894+
"Name": "nova_micro_v2_sft_lora"
895+
},
896+
{
897+
"CustomizationTechnique": "SFT",
898+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
899+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json",
900+
"Peft": False,
901+
"IsSubscriptionModel": True,
902+
"Name": "nova_micro_v2_sft_full"
903+
}
904+
]
905+
}
906+
}
907+
908+
sub_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
909+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
910+
911+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
912+
"nova-textgeneration-micro-v2", "SFT", "LORA", mock_session,
913+
)
914+
915+
assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2"
916+
assert "max_steps" in options._specs
917+
assert is_gated is False
918+
919+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
920+
def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content):
921+
"""When no non-subscription FULL recipe exists, falls back to subscription recipe."""
922+
mock_session = Mock()
923+
mock_session.boto_session.region_name = "us-east-1"
924+
mock_s3 = Mock()
925+
mock_sts = Mock()
926+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
927+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
928+
929+
mock_get_hub_content.return_value = {
930+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
931+
'hub_content_document': {
932+
"GatedBucket": False,
933+
"RecipeCollection": [
934+
{
935+
"CustomizationTechnique": "SFT",
936+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
937+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json",
938+
"Peft": False,
939+
"IsSubscriptionModel": True,
940+
"Name": "nova_micro_v2_sft_full"
941+
}
942+
]
943+
}
944+
}
945+
946+
sub_params = json.dumps({"learning_rate": {"type": "float", "required": True, "default": 5e-6}})
947+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
948+
949+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
950+
"nova-textgeneration-micro-v2", "SFT", "FULL", mock_session,
951+
)
952+
953+
assert "learning_rate" in options._specs
954+
955+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
956+
def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get_hub_content):
957+
"""When subscription recipe download fails (AccessDenied), raises actionable ValueError."""
958+
mock_session = Mock()
959+
mock_session.boto_session.region_name = "us-east-1"
960+
mock_s3 = Mock()
961+
mock_sts = Mock()
962+
mock_sts.get_caller_identity.return_value = {"Account": "999999999999"}
963+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
964+
965+
mock_get_hub_content.return_value = {
966+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
967+
'hub_content_document': {
968+
"GatedBucket": False,
969+
"RecipeCollection": [
970+
{
971+
"CustomizationTechnique": "SFT",
972+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
973+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
974+
"Peft": True,
975+
"IsSubscriptionModel": True,
976+
"Name": "nova_micro_v2_sft_lora"
977+
}
978+
]
979+
}
980+
}
981+
982+
# Simulate AccessDenied from S3 access point
983+
from botocore.exceptions import ClientError
984+
mock_s3.get_object.side_effect = ClientError(
985+
{"Error": {"Code": "AccessDenied", "Message": "Could not access through this access point"}},
986+
"GetObject"
987+
)
988+
989+
with pytest.raises(ValueError) as exc_info:
990+
_get_fine_tuning_options_and_model_arn(
991+
"nova-textgeneration-micro-v2", "SFT", "LORA", mock_session,
992+
)
993+
994+
error_msg = str(exc_info.value)
995+
assert "subscription" in error_msg.lower()
996+
assert "nova-textgeneration-micro-v2" in error_msg
997+
998+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
999+
def test_non_subscription_recipe_preferred_over_subscription(self, mock_get_hub_content):
1000+
"""When both standard and subscription recipes exist, standard is selected as primary."""
1001+
mock_session = Mock()
1002+
mock_session.boto_session.region_name = "us-east-1"
1003+
mock_s3 = Mock()
1004+
mock_sts = Mock()
1005+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
1006+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
1007+
1008+
mock_get_hub_content.return_value = {
1009+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-lite-v2",
1010+
'hub_content_document': {
1011+
"GatedBucket": False,
1012+
"RecipeCollection": [
1013+
{
1014+
"CustomizationTechnique": "SFT",
1015+
"SmtjRecipeTemplateS3Uri": "s3://bucket/standard_template.yaml",
1016+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
1017+
"Peft": True,
1018+
"Name": "nova_lite_v2_sft_lora"
1019+
},
1020+
{
1021+
"CustomizationTechnique": "SFT",
1022+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
1023+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
1024+
"Peft": True,
1025+
"IsSubscriptionModel": True,
1026+
"Name": "nova_lite_v2_sft_lora_datamix"
1027+
}
1028+
]
1029+
}
1030+
}
1031+
1032+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
1033+
datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}})
1034+
mock_s3.get_object.side_effect = [
1035+
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
1036+
{"Body": Mock(read=Mock(return_value=datamix_params.encode()))},
1037+
]
1038+
1039+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
1040+
"nova-textgeneration-lite-v2", "SFT", "LORA", mock_session,
1041+
)
1042+
1043+
# Standard recipe's params should be loaded as primary
1044+
assert "max_steps" in options._specs
1045+
assert options._specs["max_steps"]["default"] == 100
1046+
# Subscription params merged with None defaults
1047+
assert "customer_data_percent" in options._specs
1048+
assert options._specs["customer_data_percent"]["default"] is None
1049+
1050+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
1051+
def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content):
1052+
"""When primary recipe IS a subscription recipe, overlay merge is skipped."""
1053+
mock_session = Mock()
1054+
mock_session.boto_session.region_name = "us-east-1"
1055+
mock_s3 = Mock()
1056+
mock_sts = Mock()
1057+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
1058+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
1059+
1060+
mock_get_hub_content.return_value = {
1061+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
1062+
'hub_content_document': {
1063+
"GatedBucket": False,
1064+
"RecipeCollection": [
1065+
{
1066+
"CustomizationTechnique": "SFT",
1067+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
1068+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
1069+
"Peft": True,
1070+
"IsSubscriptionModel": True,
1071+
"Name": "nova_micro_v2_sft_lora"
1072+
}
1073+
]
1074+
}
1075+
}
1076+
1077+
sub_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
1078+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
1079+
1080+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
1081+
"nova-textgeneration-micro-v2", "SFT", "LORA", mock_session,
1082+
)
1083+
1084+
# S3 get_object should only be called once (primary recipe), not twice (no overlay merge)
1085+
assert mock_s3.get_object.call_count == 1

0 commit comments

Comments
 (0)