Skip to content

Commit 5489954

Browse files
committed
feat(train): Auto-detect subscription recipe hyperparameters in SFTTrainer
When a model has subscription-gated recipes (IsSubscriptionModel: true in RecipeCollection), automatically attempt to fetch the recipe's override_params from the S3 access point and merge additional hyperparameter keys into the trainer's _specs schema. This allows subscribed users to natively set datamix hyperparameters (e.g. customer_data_percent, nova_*_percent) via trainer.hyperparameters without any explicit flag or workaround. For non-subscribed users, the fetch fails silently (AccessDenied) and only standard recipe hyperparameters are available. The extra latency only occurs when subscription recipes exist in the hub metadata. Changes: - After loading standard override_params, check if any recipe has IsSubscriptionModel: true - If found: resolve {customer_id} placeholder with caller's account ID, download override_params from access point, merge extra keys - Handle S3 access point ARN URI format for GetObject - Silent fallback on failure (non-subscribed users unaffected) - Add unit tests for positive, negative, and fallback cases
1 parent a3a20c7 commit 5489954

2 files changed

Lines changed: 182 additions & 5 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,21 +358,58 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
358358
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
359359

360360
# Select recipe based on training type
361+
# Collect override_params from ALL matching recipes (standard + subscription)
361362
recipe = None
362363
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
363-
recipe = next((r for r in recipes_with_template if r.get("Peft")), None)
364+
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
364365
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
365-
recipe = next((r for r in recipes_with_template if not r.get("Peft")), None)
366+
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
366367

367368
if not recipe:
368369
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
369370

370-
elif recipe and recipe.get("SmtjOverrideParamsS3Uri"):
371+
# Start with the standard recipe's override_params
372+
options_dict = {}
373+
if recipe.get("SmtjOverrideParamsS3Uri"):
371374
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
372-
s3 = boto3.client("s3")
373-
bucket, key = s3_uri.replace("s3://", "").split("/", 1)
375+
s3 = sagemaker_session.boto_session.client("s3")
376+
uri_path = s3_uri.replace("s3://", "")
377+
bucket, key = uri_path.split("/", 1)
374378
obj = s3.get_object(Bucket=bucket, Key=key)
375379
options_dict = json.loads(obj["Body"].read())
380+
381+
# Auto-detect and merge subscription recipe's override_params if available
382+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
383+
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
384+
else:
385+
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
386+
387+
if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
388+
try:
389+
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", sagemaker_session.boto_session.client("sts").get_caller_identity()["Account"])
390+
sub_uri_path = sub_s3_uri.replace("s3://", "")
391+
# Handle access point ARN URIs
392+
if sub_uri_path.startswith("arn:"):
393+
arn_parts = sub_uri_path.split("/", 2)
394+
sub_bucket = arn_parts[0] + "/" + arn_parts[1]
395+
sub_key = arn_parts[2] if len(arn_parts) > 2 else ""
396+
else:
397+
sub_bucket, sub_key = sub_uri_path.split("/", 1)
398+
s3_sub = sagemaker_session.boto_session.client("s3")
399+
sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key)
400+
sub_options = json.loads(sub_obj["Body"].read())
401+
# Merge: subscription params into _specs only (don't set defaults)
402+
# This makes them settable but not serialized unless user explicitly sets them
403+
for k, v in sub_options.items():
404+
if k not in options_dict:
405+
v_copy = v.copy() if isinstance(v, dict) else v
406+
if isinstance(v_copy, dict):
407+
v_copy['default'] = None # No default — won't appear in to_dict() unless set
408+
options_dict[k] = v_copy
409+
except Exception as e:
410+
logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
411+
412+
if options_dict:
376413
return FineTuningOptions(options_dict), model_arn, is_gated_model
377414
else:
378415
return FineTuningOptions({}), model_arn, is_gated_model

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import pytest
23
from unittest.mock import Mock, patch, MagicMock
34
from sagemaker.train.common_utils.finetune_utils import (
@@ -304,6 +305,8 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get
304305
mock_s3_client.get_object.return_value = {
305306
"Body": Mock(read=Mock(return_value=b'{"learning_rate": 0.001}'))
306307
}
308+
mock_session.boto_session.client.return_value = mock_s3_client
309+
mock_session.boto_session.client.return_value = mock_s3_client
307310

308311
result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session)
309312

@@ -551,3 +554,140 @@ def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client)
551554
mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'')
552555

553556

557+
558+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
559+
def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content):
560+
"""When and user is subscribed, datamix HPs are available."""
561+
mock_session = Mock()
562+
mock_session.boto_session.region_name = "us-east-1"
563+
mock_s3 = Mock()
564+
mock_sts = Mock()
565+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
566+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
567+
568+
mock_get_hub_content.return_value = {
569+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
570+
'hub_content_document': {
571+
"GatedBucket": False,
572+
"RecipeCollection": [
573+
{
574+
"CustomizationTechnique": "SFT",
575+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
576+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
577+
"Name": "standard_sft"
578+
},
579+
{
580+
"CustomizationTechnique": "SFT",
581+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml",
582+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
583+
"Name": "datamix_sft",
584+
"IsSubscriptionModel": True
585+
}
586+
]
587+
}
588+
}
589+
590+
# Standard recipe returns base params
591+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
592+
# Subscription recipe returns datamix params
593+
datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}})
594+
595+
mock_s3.get_object.side_effect = [
596+
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
597+
{"Body": Mock(read=Mock(return_value=datamix_params.encode()))},
598+
]
599+
600+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
601+
"test-model", "SFT", "FULL", mock_session,
602+
)
603+
604+
assert "max_steps" in options._specs
605+
assert "customer_data_percent" in options._specs
606+
assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set
607+
608+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
609+
def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content):
610+
"""When (default), datamix HPs are NOT available."""
611+
mock_session = Mock()
612+
mock_session.boto_session.region_name = "us-east-1"
613+
mock_s3 = Mock()
614+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3
615+
616+
mock_get_hub_content.return_value = {
617+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
618+
'hub_content_document': {
619+
"GatedBucket": False,
620+
"RecipeCollection": [
621+
{
622+
"CustomizationTechnique": "SFT",
623+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
624+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
625+
"Name": "standard_sft"
626+
},
627+
{
628+
"CustomizationTechnique": "SFT",
629+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
630+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
631+
"Name": "datamix_sft",
632+
"IsSubscriptionModel": True
633+
}
634+
]
635+
}
636+
}
637+
638+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
639+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))}
640+
641+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
642+
"test-model", "SFT", "FULL", mock_session,
643+
)
644+
645+
assert "max_steps" in options._specs
646+
assert "customer_data_percent" not in options._specs
647+
648+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
649+
def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content):
650+
"""When but user is NOT subscribed, falls back gracefully."""
651+
mock_session = Mock()
652+
mock_session.boto_session.region_name = "us-east-1"
653+
mock_s3 = Mock()
654+
mock_sts = Mock()
655+
mock_sts.get_caller_identity.return_value = {"Account": "999999999999"}
656+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
657+
658+
mock_get_hub_content.return_value = {
659+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
660+
'hub_content_document': {
661+
"GatedBucket": False,
662+
"RecipeCollection": [
663+
{
664+
"CustomizationTechnique": "SFT",
665+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
666+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
667+
"Name": "standard_sft"
668+
},
669+
{
670+
"CustomizationTechnique": "SFT",
671+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
672+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
673+
"Name": "datamix_sft",
674+
"IsSubscriptionModel": True
675+
}
676+
]
677+
}
678+
}
679+
680+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
681+
# First call succeeds (standard recipe), second call fails (access denied)
682+
mock_s3.get_object.side_effect = [
683+
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
684+
Exception("Access Denied"),
685+
]
686+
687+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
688+
"test-model", "SFT", "FULL", mock_session,
689+
)
690+
691+
# Should still have standard params, just not datamix ones
692+
assert "max_steps" in options._specs
693+
assert "customer_data_percent" not in options._specs

0 commit comments

Comments
 (0)