Skip to content

Commit bba6d97

Browse files
committed
feat(train): Auto-detect subscription recipe hyperparameters in SFTTrainer
When a model has subscription-gated recipes (IsSubscriptionModel: true in RecipeCollection), automatically attempt to fetch the recipe's override_params from the S3 access point and merge additional hyperparameter keys into the trainer's _specs schema. This allows subscribed users to natively set datamix hyperparameters (e.g. customer_data_percent, nova_*_percent) via trainer.hyperparameters without any explicit flag or workaround. For non-subscribed users, the fetch fails silently (AccessDenied) and only standard recipe hyperparameters are available. The extra latency only occurs when subscription recipes exist in the hub metadata. Changes: - After loading standard override_params, check if any recipe has IsSubscriptionModel: true - If found: resolve {customer_id} placeholder with caller's account ID, download override_params from access point, merge extra keys - Handle S3 access point ARN URI format for GetObject - Silent fallback on failure (non-subscribed users unaffected) - Add unit tests for positive, negative, and fallback cases
1 parent a3a20c7 commit bba6d97

2 files changed

Lines changed: 180 additions & 4 deletions

File tree

sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,21 +358,58 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
358358
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}")
359359

360360
# Select recipe based on training type
361+
# Collect override_params from ALL matching recipes (standard + subscription)
361362
recipe = None
362363
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
363-
recipe = next((r for r in recipes_with_template if r.get("Peft")), None)
364+
recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None)
364365
elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL":
365-
recipe = next((r for r in recipes_with_template if not r.get("Peft")), None)
366+
recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None)
366367

367368
if not recipe:
368369
raise ValueError(f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}")
369370

370-
elif recipe and recipe.get("SmtjOverrideParamsS3Uri"):
371+
# Start with the standard recipe's override_params
372+
options_dict = {}
373+
if recipe.get("SmtjOverrideParamsS3Uri"):
371374
s3_uri = recipe["SmtjOverrideParamsS3Uri"]
372375
s3 = boto3.client("s3")
373-
bucket, key = s3_uri.replace("s3://", "").split("/", 1)
376+
uri_path = s3_uri.replace("s3://", "")
377+
bucket, key = uri_path.split("/", 1)
374378
obj = s3.get_object(Bucket=bucket, Key=key)
375379
options_dict = json.loads(obj["Body"].read())
380+
381+
# Auto-detect and merge subscription recipe's override_params if available
382+
if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA":
383+
sub_recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None)
384+
else:
385+
sub_recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None)
386+
387+
if sub_recipe and sub_recipe.get("SmtjOverrideParamsS3Uri"):
388+
try:
389+
sub_s3_uri = sub_recipe["SmtjOverrideParamsS3Uri"].replace("{customer_id}", boto3.client("sts").get_caller_identity()["Account"])
390+
sub_uri_path = sub_s3_uri.replace("s3://", "")
391+
# Handle access point ARN URIs
392+
if sub_uri_path.startswith("arn:"):
393+
arn_parts = sub_uri_path.split("/", 2)
394+
sub_bucket = arn_parts[0] + "/" + arn_parts[1]
395+
sub_key = arn_parts[2] if len(arn_parts) > 2 else ""
396+
else:
397+
sub_bucket, sub_key = sub_uri_path.split("/", 1)
398+
s3_sub = boto3.client("s3")
399+
sub_obj = s3_sub.get_object(Bucket=sub_bucket, Key=sub_key)
400+
sub_options = json.loads(sub_obj["Body"].read())
401+
# Merge: subscription params into _specs only (don't set defaults)
402+
# This makes them settable but not serialized unless user explicitly sets them
403+
for k, v in sub_options.items():
404+
if k not in options_dict:
405+
v_copy = v.copy() if isinstance(v, dict) else v
406+
if isinstance(v_copy, dict):
407+
v_copy['default'] = None # No default — won't appear in to_dict() unless set
408+
options_dict[k] = v_copy
409+
except Exception as e:
410+
logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
411+
412+
if options_dict:
376413
return FineTuningOptions(options_dict), model_arn, is_gated_model
377414
else:
378415
return FineTuningOptions({}), model_arn, is_gated_model

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import pytest
23
from unittest.mock import Mock, patch, MagicMock
34
from sagemaker.train.common_utils.finetune_utils import (
@@ -304,6 +305,7 @@ def test__get_fine_tuning_options_and_model_arn(self, mock_boto_client, mock_get
304305
mock_s3_client.get_object.return_value = {
305306
"Body": Mock(read=Mock(return_value=b'{"learning_rate": 0.001}'))
306307
}
308+
mock_session.boto_session.client.return_value = mock_s3_client
307309

308310
result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session)
309311

@@ -551,3 +553,140 @@ def test__validate_s3_path_exists_with_prefix_not_exists(self, mock_boto_client)
551553
mock_s3_client.put_object.assert_called_once_with(Bucket="test-bucket", Key="prefix/", Body=b'')
552554

553555

556+
557+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
558+
def test__get_fine_tuning_options_with_subscription_recipe_enabled(self, mock_get_hub_content):
559+
"""When and user is subscribed, datamix HPs are available."""
560+
mock_session = Mock()
561+
mock_session.boto_session.region_name = "us-east-1"
562+
mock_s3 = Mock()
563+
mock_sts = Mock()
564+
mock_sts.get_caller_identity.return_value = {"Account": "123456789012"}
565+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
566+
567+
mock_get_hub_content.return_value = {
568+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
569+
'hub_content_document': {
570+
"GatedBucket": False,
571+
"RecipeCollection": [
572+
{
573+
"CustomizationTechnique": "SFT",
574+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
575+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
576+
"Name": "standard_sft"
577+
},
578+
{
579+
"CustomizationTechnique": "SFT",
580+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-123456789012/source/template.yaml",
581+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
582+
"Name": "datamix_sft",
583+
"IsSubscriptionModel": True
584+
}
585+
]
586+
}
587+
}
588+
589+
# Standard recipe returns base params
590+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
591+
# Subscription recipe returns datamix params
592+
datamix_params = json.dumps({"customer_data_percent": {"type": "integer", "required": False, "default": 50}})
593+
594+
mock_s3.get_object.side_effect = [
595+
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
596+
{"Body": Mock(read=Mock(return_value=datamix_params.encode()))},
597+
]
598+
599+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
600+
"test-model", "SFT", "FULL", mock_session,
601+
)
602+
603+
assert "max_steps" in options._specs
604+
assert "customer_data_percent" in options._specs
605+
assert options._specs["customer_data_percent"]["default"] is None # defaults are None so they dont serialize unless explicitly set
606+
607+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
608+
def test__get_fine_tuning_options_subscription_disabled_no_datamix_hps(self, mock_get_hub_content):
609+
"""When (default), datamix HPs are NOT available."""
610+
mock_session = Mock()
611+
mock_session.boto_session.region_name = "us-east-1"
612+
mock_s3 = Mock()
613+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3
614+
615+
mock_get_hub_content.return_value = {
616+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
617+
'hub_content_document': {
618+
"GatedBucket": False,
619+
"RecipeCollection": [
620+
{
621+
"CustomizationTechnique": "SFT",
622+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
623+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
624+
"Name": "standard_sft"
625+
},
626+
{
627+
"CustomizationTechnique": "SFT",
628+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
629+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
630+
"Name": "datamix_sft",
631+
"IsSubscriptionModel": True
632+
}
633+
]
634+
}
635+
}
636+
637+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
638+
mock_s3.get_object.return_value = {"Body": Mock(read=Mock(return_value=standard_params.encode()))}
639+
640+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
641+
"test-model", "SFT", "FULL", mock_session,
642+
)
643+
644+
assert "max_steps" in options._specs
645+
assert "customer_data_percent" not in options._specs
646+
647+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
648+
def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, mock_get_hub_content):
649+
"""When but user is NOT subscribed, falls back gracefully."""
650+
mock_session = Mock()
651+
mock_session.boto_session.region_name = "us-east-1"
652+
mock_s3 = Mock()
653+
mock_sts = Mock()
654+
mock_sts.get_caller_identity.return_value = {"Account": "999999999999"}
655+
mock_session.boto_session.client.side_effect = lambda service, **kwargs: mock_s3 if service == "s3" else mock_sts
656+
657+
mock_get_hub_content.return_value = {
658+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
659+
'hub_content_document': {
660+
"GatedBucket": False,
661+
"RecipeCollection": [
662+
{
663+
"CustomizationTechnique": "SFT",
664+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.yaml",
665+
"SmtjOverrideParamsS3Uri": "s3://bucket/standard_params.json",
666+
"Name": "standard_sft"
667+
},
668+
{
669+
"CustomizationTechnique": "SFT",
670+
"SmtjRecipeTemplateS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/template.yaml",
671+
"SmtjOverrideParamsS3Uri": "s3://arn:aws:s3:us-east-1:334772094012:accesspoint/recipes-{customer_id}/source/params.json",
672+
"Name": "datamix_sft",
673+
"IsSubscriptionModel": True
674+
}
675+
]
676+
}
677+
}
678+
679+
standard_params = json.dumps({"max_steps": {"type": "integer", "required": True, "default": 100}})
680+
# First call succeeds (standard recipe), second call fails (access denied)
681+
mock_s3.get_object.side_effect = [
682+
{"Body": Mock(read=Mock(return_value=standard_params.encode()))},
683+
Exception("Access Denied"),
684+
]
685+
686+
options, model_arn, is_gated = _get_fine_tuning_options_and_model_arn(
687+
"test-model", "SFT", "FULL", mock_session,
688+
)
689+
690+
# Should still have standard params, just not datamix ones
691+
assert "max_steps" in options._specs
692+
assert "customer_data_percent" not in options._specs

0 commit comments

Comments
 (0)