From 69b75336449946fe6d83ed2d03525632bf396ed5 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:36:56 -0700 Subject: [PATCH 1/4] fix(train): fall back to subscription recipes for models without base recipes --- .../train/common_utils/finetune_utils.py | 96 ++++++++++++------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 6479e803bd..d08052e265 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -448,56 +448,84 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") # Select recipe based on training type - # Collect override_params from ALL matching recipes (standard + subscription) + # Prefer non-subscription (standard) recipes first, fall back to subscription recipes + # if no standard recipe exists (e.g., Nova Micro v2 only has subscription recipes). recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if not recipe: + recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + if not recipe: + recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) if not recipe: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}") - # Start with the standard recipe's override_params + # Start with the selected recipe's override_params options_dict = {} if recipe.get("SmtjOverrideParamsS3Uri"): s3_uri = recipe["SmtjOverrideParamsS3Uri"] s3 = sagemaker_session.boto_session.client("s3") + # Handle {customer_id} placeholder (subscription recipes use access point URIs) + if "{customer_id}" in s3_uri: + account_id = sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"] + s3_uri = s3_uri.replace("{customer_id}", account_id) uri_path = s3_uri.replace("s3://", "") - bucket, key = uri_path.split("/", 1) - obj = s3.get_object(Bucket=bucket, Key=key) - options_dict = json.loads(obj["Body"].read()) - - # Auto-detect and merge subscription recipe's override_params if available - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) - else: - sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) - - if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"): + # Handle access point ARN URIs + if uri_path.startswith("arn:"): + arn_parts = uri_path.split("/", 2) + bucket = arn_parts[0] + "/" + arn_parts[1] + key = arn_parts[2] if len(arn_parts) > 2 else "" + else: + bucket, key = uri_path.split("/", 1) try: - sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]) - sub_uri_path = sub_s3_uri.replace("s3://", "") - # Handle access point ARN URIs - if sub_uri_path.startswith("arn:"): - arn_parts = sub_uri_path.split("/", 2) - sub_bucket = arn_parts[0] + "/" + arn_parts[1] - sub_key = arn_parts[2] if len(arn_parts) > 2 else "" - else: - sub_bucket, sub_key = sub_uri_path.split("/", 1) - s3_sub = sagemaker_session.boto_session.client("s3") - sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key) - sub_options = json.loads(sub_obj["Body"].read()) - # Merge: subscription params into _specs only (don't set defaults) - # This makes them settable but not serialized unless user explicitly sets them - for k, v in sub_options.items(): - if k not in options_dict: - v_copy = v.copy() if isinstance(v, dict) else v - if isinstance(v_copy, dict): - v_copy['default'] = None # No default — won't appear in to_dict() unless set - options_dict[k] = v_copy + obj = s3.get_object(Bucket=bucket, Key=key) + options_dict = json.loads(obj["Body"].read()) except Exception as e: - logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") + if recipe.get("IsSubscriptionModel"): + raise ValueError( + f"Could not access subscription recipe for model '{model_name}'. " + f"This model only provides subscription-based recipes. " + f"Please verify that your account has an active Nova Forge subscription. " + f"Refer: https://docs.aws.amazon.com/sagemaker/latest/dg/nova-forge.html#nova-forge-prereq-access" + ) from e + else: + raise + + # Auto-detect and merge subscription recipe's override_params if available + # (only needed when the primary recipe is NOT a subscription recipe) + if not recipe.get("IsSubscriptionModel"): + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) + else: + sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) + + if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"): + try: + sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"]) + sub_uri_path = sub_s3_uri.replace("s3://", "") + # Handle access point ARN URIs + if sub_uri_path.startswith("arn:"): + arn_parts = sub_uri_path.split("/", 2) + sub_bucket = arn_parts[0] + "/" + arn_parts[1] + sub_key = arn_parts[2] if len(arn_parts) > 2 else "" + else: + sub_bucket, sub_key = sub_uri_path.split("/", 1) + s3_sub = sagemaker_session.boto_session.client("s3") + sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key) + sub_options = json.loads(sub_obj["Body"].read()) + # Merge: subscription params into _specs only (don't set defaults) + # This makes them settable but not serialized unless user explicitly sets them + for k, v in sub_options.items(): + if k not in options_dict: + v_copy = v.copy() if isinstance(v, dict) else v + if isinstance(v_copy, dict): + v_copy['default'] = None # No default — won't appear in to_dict() unless set + options_dict[k] = v_copy + except Exception as e: + logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") if options_dict: return FineTuningOptions(options_dict), model_arn, is_gated_model From 667723b4fb857cc241aa8151d1810e4ac4ad4be2 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:42:43 -0700 Subject: [PATCH 2/4] Adding unit testing for nova subscription recipe fix --- .../train/common_utils/test_finetune_utils.py | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c98dea477f..8f0540e617 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -864,3 +864,222 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, # Should still have standard params, just not datamix ones assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs + + +class TestSubscriptionOnlyModelFallback: + """Tests for models that only have subscription recipes (e.g., Nova Micro v2).""" + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content): + """When no non-subscription LORA recipe exists, falls back to subscription recipe.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + # Only subscription recipes (like Nova Micro v2) + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json", + "Peft": True, + "IsSubscriptionModel": True, + "Name": "nova_micro_v2_sft_lora" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json", + "Peft": False, + "IsSubscriptionModel": True, + "Name": "nova_micro_v2_sft_full" + } + ] + } + } + + sub_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))} + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "nova-textgeneration-micro-v2", "SFT", "LORA", mock_session, + ) + + assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2" + assert "max_steps" in options._specs + assert is_gated is False + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content): + """When no non-subscription FULL recipe exists, falls back to subscription recipe.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json", + "Peft": False, + "IsSubscriptionModel": True, + "Name": "nova_micro_v2_sft_full" + } + ] + } + } + + sub_params = json.dumps({"learning_rate": {"type": "float", "required": True, "default": 5e-6}}) + mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))} + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "nova-textgeneration-micro-v2", "SFT", "FULL", mock_session, + ) + + assert "learning_rate" in options._specs + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get_hub_content): + """When subscription recipe download fails (AccessDenied), raises actionable ValueError.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "999999999999"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json", + "Peft": True, + "IsSubscriptionModel": True, + "Name": "nova_micro_v2_sft_lora" + } + ] + } + } + + # Simulate AccessDenied from S3 access point + from botocore.exceptions import ClientError + mock_s3.get_object.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Could not access through this access point"}}, + "GetObject" + ) + + with pytest.raises(ValueError) as exc_info: + _get_fine_tuning_options_and_model_arn( + "nova-textgeneration-micro-v2", "SFT", "LORA", mock_session, + ) + + error_msg = str(exc_info.value) + assert "subscription" in error_msg.lower() + assert "nova-textgeneration-micro-v2" in error_msg + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test_non_subscription_recipe_preferred_over_subscription(self, mock_get_hub_content): + """When both standard and subscription recipes exist, standard is selected as primary.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-lite-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/standard_template.yaml", + "SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json", + "Peft": True, + "Name": "nova_lite_v2_sft_lora" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json", + "Peft": True, + "IsSubscriptionModel": True, + "Name": "nova_lite_v2_sft_lora_datamix" + } + ] + } + } + + standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}}) + mock_s3.get_object.side_effect = [ + {"Body": Mock(read=Mock(return_value=standard_params.encode()))}, + {"Body": Mock(read=Mock(return_value=datamix_params.encode()))}, + ] + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "nova-textgeneration-lite-v2", "SFT", "LORA", mock_session, + ) + + # Standard recipe's params should be loaded as primary + assert "max_steps" in options._specs + assert options._specs["max_steps"]["default"] == 100 + # Subscription params merged with None defaults + assert "customer_data_percent" in options._specs + assert options._specs["customer_data_percent"]["default"] is None + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content): + """When primary recipe IS a subscription recipe, overlay merge is skipped.""" + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_sts = Mock() + mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} + mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/template.yaml", + "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json", + "Peft": True, + "IsSubscriptionModel": True, + "Name": "nova_micro_v2_sft_lora" + } + ] + } + } + + sub_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}}) + mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))} + + options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( + "nova-textgeneration-micro-v2", "SFT", "LORA", mock_session, + ) + + # S3 get_object should only be called once (primary recipe), not twice (no overlay merge) + assert mock_s3.get_object.call_count == 1 From 7edc5515f1489c26d4d5c3d76bcbec5f3a4e3d6c Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:29:31 -0700 Subject: [PATCH 3/4] Change failed retrieval of subscription recipes log to warning --- .../src/sagemaker/train/common_utils/finetune_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index d08052e265..2dfee058ab 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -525,7 +525,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni v_copy['default'] = None # No default — won't appear in to_dict() unless set options_dict[k] = v_copy except Exception as e: - logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") + logger.warning(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") if options_dict: return FineTuningOptions(options_dict), model_arn, is_gated_model From f26e1ae0baaa12bc1888ae3454faf4175aec3852 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:40:03 -0700 Subject: [PATCH 4/4] Updating unit tests --- .../train/common_utils/finetune_utils.py | 2 +- .../train/common_utils/test_finetune_utils.py | 34 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 2dfee058ab..242370964d 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -449,7 +449,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni # Select recipe based on training type # Prefer non-subscription (standard) recipes first, fall back to subscription recipes - # if no standard recipe exists (e.g., Nova Micro v2 only has subscription recipes). + # if no standard recipe exists (some models only have subscription recipes). recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 8f0540e617..44089e9eb2 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -867,7 +867,7 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, class TestSubscriptionOnlyModelFallback: - """Tests for models that only have subscription recipes (e.g., Nova Micro v2).""" + """Tests for models that only have subscription recipes.""" @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content): @@ -879,9 +879,9 @@ def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content): mock_sts.get_caller_identity.return_value = {"Account": "123456789012"} mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts - # Only subscription recipes (like Nova Micro v2) + # Only subscription recipes exist for this model mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model", 'hub_content_document': { "GatedBucket": False, "RecipeCollection": [ @@ -891,7 +891,7 @@ def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content): "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json", "Peft": True, "IsSubscriptionModel": True, - "Name": "nova_micro_v2_sft_lora" + "Name": "subscription_model_sft_lora" }, { "CustomizationTechnique": "SFT", @@ -899,7 +899,7 @@ def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content): "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json", "Peft": False, "IsSubscriptionModel": True, - "Name": "nova_micro_v2_sft_full" + "Name": "subscription_model_sft_full" } ] } @@ -909,10 +909,10 @@ def test_fallback_to_subscription_recipe_lora(self, mock_get_hub_content): mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))} options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "nova-textgeneration-micro-v2", "SFT", "LORA", mock_session, + "subscription-only-model", "SFT", "LORA", mock_session, ) - assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2" + assert model_arn == "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model" assert "max_steps" in options._specs assert is_gated is False @@ -927,7 +927,7 @@ def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content): mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model", 'hub_content_document': { "GatedBucket": False, "RecipeCollection": [ @@ -937,7 +937,7 @@ def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content): "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/full_params.json", "Peft": False, "IsSubscriptionModel": True, - "Name": "nova_micro_v2_sft_full" + "Name": "subscription_model_sft_full" } ] } @@ -947,7 +947,7 @@ def test_fallback_to_subscription_recipe_full(self, mock_get_hub_content): mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))} options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "nova-textgeneration-micro-v2", "SFT", "FULL", mock_session, + "subscription-only-model", "SFT", "FULL", mock_session, ) assert "learning_rate" in options._specs @@ -963,7 +963,7 @@ def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model", 'hub_content_document': { "GatedBucket": False, "RecipeCollection": [ @@ -973,7 +973,7 @@ def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json", "Peft": True, "IsSubscriptionModel": True, - "Name": "nova_micro_v2_sft_lora" + "Name": "subscription_model_sft_lora" } ] } @@ -988,12 +988,12 @@ def test_subscription_only_model_access_denied_raises_clear_error(self, mock_get with pytest.raises(ValueError) as exc_info: _get_fine_tuning_options_and_model_arn( - "nova-textgeneration-micro-v2", "SFT", "LORA", mock_session, + "subscription-only-model", "SFT", "LORA", mock_session, ) error_msg = str(exc_info.value) assert "subscription" in error_msg.lower() - assert "nova-textgeneration-micro-v2" in error_msg + assert "subscription-only-model" in error_msg @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') def test_non_subscription_recipe_preferred_over_subscription(self, mock_get_hub_content): @@ -1058,7 +1058,7 @@ def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content): mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts mock_get_hub_content.return_value = { - 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/nova-micro-v2", + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/subscription-only-model", 'hub_content_document': { "GatedBucket": False, "RecipeCollection": [ @@ -1068,7 +1068,7 @@ def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content): "SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/lora_params.json", "Peft": True, "IsSubscriptionModel": True, - "Name": "nova_micro_v2_sft_lora" + "Name": "subscription_model_sft_lora" } ] } @@ -1078,7 +1078,7 @@ def test_subscription_only_skips_overlay_merge(self, mock_get_hub_content): mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=sub_params.encode()))} options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn( - "nova-textgeneration-micro-v2", "SFT", "LORA", mock_session, + "subscription-only-model", "SFT", "LORA", mock_session, ) # S3 get_object should only be called once (primary recipe), not twice (no overlay merge)