Skip to content

Commit f26e1ae

Browse files
committed
Updating unit tests
1 parent 7edc551 commit f26e1ae

2 files changed

Lines changed: 18 additions & 18 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
449449

450450
# Select recipe based on training type
451451
# Prefer non-subscription (standard) recipes first, fall back to subscription recipes
452-
# if no standard recipe exists (e.g., Nova Micro v2 only has subscription recipes).
452+
# if no standard recipe exists (some models only have subscription recipes).
453453
recipe = None
454454
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
455455
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
867867

868868

869869
class TestSubscriptionOnlyModelFallback:
870-
"""Tests for models that only have subscription recipes (e.g., Nova Micro v2)."""
870+
"""Tests for models that only have subscription recipes."""
871871

872872
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
873873
def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content):
@@ -879,9 +879,9 @@ def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content):
879879
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
880880
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
881881

882-
# Only subscription recipes (like Nova Micro v2)
882+
# Only subscription recipes exist for this model
883883
mock_get_hub_content.return_value = {
884-
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
884+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
885885
'hub_content_document': {
886886
"GatedBucket": False,
887887
"RecipeCollection": [
@@ -891,15 +891,15 @@ def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content):
891891
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
892892
"Peft": True,
893893
"IsSubscriptionModel": True,
894-
"Name": "nova_micro_v2_sft_lora"
894+
"Name": "subscription_model_sft_lora"
895895
},
896896
{
897897
"CustomizationTechnique": "SFT",
898898
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml",
899899
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json",
900900
"Peft": False,
901901
"IsSubscriptionModel": True,
902-
"Name": "nova_micro_v2_sft_full"
902+
"Name": "subscription_model_sft_full"
903903
}
904904
]
905905
}
@@ -909,10 +909,10 @@ def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content):
909909
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
910910

911911
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
912-
"nova-textgeneration-micro-v2", "SFT", "LORA", mock_session,
912+
"subscription-only-model", "SFT", "LORA", mock_session,
913913
)
914914

915-
assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2"
915+
assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model"
916916
assert "max_steps" in options._specs
917917
assert is_gated is False
918918

@@ -927,7 +927,7 @@ def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content):
927927
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
928928

929929
mock_get_hub_content.return_value = {
930-
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
930+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
931931
'hub_content_document': {
932932
"GatedBucket": False,
933933
"RecipeCollection": [
@@ -937,7 +937,7 @@ def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content):
937937
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json",
938938
"Peft": False,
939939
"IsSubscriptionModel": True,
940-
"Name": "nova_micro_v2_sft_full"
940+
"Name": "subscription_model_sft_full"
941941
}
942942
]
943943
}
@@ -947,7 +947,7 @@ def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content):
947947
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
948948

949949
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
950-
"nova-textgeneration-micro-v2", "SFT", "FULL", mock_session,
950+
"subscription-only-model", "SFT", "FULL", mock_session,
951951
)
952952

953953
assert "learning_rate" in options._specs
@@ -963,7 +963,7 @@ def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get
963963
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
964964

965965
mock_get_hub_content.return_value = {
966-
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
966+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
967967
'hub_content_document': {
968968
"GatedBucket": False,
969969
"RecipeCollection": [
@@ -973,7 +973,7 @@ def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get
973973
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
974974
"Peft": True,
975975
"IsSubscriptionModel": True,
976-
"Name": "nova_micro_v2_sft_lora"
976+
"Name": "subscription_model_sft_lora"
977977
}
978978
]
979979
}
@@ -988,12 +988,12 @@ def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get
988988

989989
with pytest.raises(ValueError) as exc_info:
990990
_get_fine_tuning_options_and_model_arn(
991-
"nova-textgeneration-micro-v2", "SFT", "LORA", mock_session,
991+
"subscription-only-model", "SFT", "LORA", mock_session,
992992
)
993993

994994
error_msg = str(exc_info.value)
995995
assert "subscription" in error_msg.lower()
996-
assert "nova-textgeneration-micro-v2" in error_msg
996+
assert "subscription-only-model" in error_msg
997997

998998
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
999999
def test_non_subscription_recipe_preferred_over_subscription(self, mock_get_hub_content):
@@ -1058,7 +1058,7 @@ def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content):
10581058
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
10591059

10601060
mock_get_hub_content.return_value = {
1061-
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2",
1061+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model",
10621062
'hub_content_document': {
10631063
"GatedBucket": False,
10641064
"RecipeCollection": [
@@ -1068,7 +1068,7 @@ def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content):
10681068
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json",
10691069
"Peft": True,
10701070
"IsSubscriptionModel": True,
1071-
"Name": "nova_micro_v2_sft_lora"
1071+
"Name": "subscription_model_sft_lora"
10721072
}
10731073
]
10741074
}
@@ -1078,7 +1078,7 @@ def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content):
10781078
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))}
10791079

10801080
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
1081-
"nova-textgeneration-micro-v2", "SFT", "LORA", mock_session,
1081+
"subscription-only-model", "SFT", "LORA", mock_session,
10821082
)
10831083

10841084
# S3 get_object should only be called once (primary recipe), not twice (no overlay merge)

0 commit comments

Comments
 (0)