From 0a3ccb513244f47c4231989c16d094ad7b7ec2e1 Mon Sep 17 00:00:00 2001 From: guanweim Date: Mon, 4 May 2026 19:03:03 +0000 Subject: [PATCH 1/9] feat(train): Add SequenceLength support for SFT, DPO, RLVR, RLAIF trainers Add optional sequence_length parameter to all four trainers that enables customers to specify their desired context length for serverless training jobs. The parameter is passed in ServerlessJobConfig for recipe filtering. During trainer initialization, _get_fine_tuning_options_and_model_arn filters recipes by SequenceLength field, picking the smallest recipe with context length >= the requested value. Raises ValueError if no sufficient recipe exists or if recipes lack SequenceLength metadata. Changes: - ServerlessJobConfig: add sequence_length field - _parse_context_length: parse values like '8K' to integers - _get_fine_tuning_options_and_model_arn: filter by SequenceLength - _create_serverless_config: conditionally include sequence_length - SFTTrainer, DPOTrainer, RLVRTrainer, RLAIFTrainer: accept and thread sequence_length through init and train methods - Unit tests for all new functionality --- .../src/sagemaker/core/shapes/shapes.py | 3 +- .../train/common_utils/finetune_utils.py | 88 +++++++++++++-- .../src/sagemaker/train/dpo_trainer.py | 36 +++--- .../src/sagemaker/train/rlaif_trainer.py | 35 +++--- .../src/sagemaker/train/rlvr_trainer.py | 37 +++--- .../src/sagemaker/train/sft_trainer.py | 35 +++--- .../train/common_utils/test_finetune_utils.py | 105 +++++++++++++++++- .../tests/unit/train/test_dpo_trainer.py | 65 ++++++++++- .../tests/unit/train/test_rlaif_trainer.py | 68 +++++++++++- .../tests/unit/train/test_rlvr_trainer.py | 65 ++++++++++- .../tests/unit/train/test_sft_trainer.py | 65 ++++++++++- 11 files changed, 530 insertions(+), 72 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index ce25c890dd..5e8217f463 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9717,6 +9717,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. + sequence_length: The sequence length for the training job. Valid values are "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". """ base_model_arn: StrPipeVar @@ -9726,7 +9727,7 @@ class ServerlessJobConfig(Base): peft: Optional[StrPipeVar] = Unassigned() evaluation_type: Optional[StrPipeVar] = Unassigned() evaluator_arn: Optional[StrPipeVar] = Unassigned() - + sequence_length: Optional[StrPipeVar] = Unassigned() class MlflowConfig(Base): """ diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 242370964d..0f8ec7746d 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -407,10 +407,44 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: return None -def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: Optional[str] = None) -> tuple: +def _parse_context_length(value) -> int: + """Parse a context length value like '8K', '32K', '128K' into an integer (e.g., 8192). + + Returns 0 if value is None or unparseable. + """ + if not value: + return 0 + value = str(value).strip().upper() + if value.endswith("K"): + try: + return int(value[:-1]) * 1024 + except ValueError: + return 0 + try: + return int(value) + except ValueError: + return 0 + + +def _get_fine_tuning_options_and_model_arn( + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: str = "SageMakerPublicHub" +) -> tuple: """Get fine-tuning options and model ARN for given customization technique. + Args: + model_name: Name of the model in the hub. + customization_technique: Technique (e.g., "SFT", "DPO", "RLVR", "RLAIF"). + training_type: TrainingType enum or string ("LORA", "FULL"). + sagemaker_session: SageMaker session for API calls. + sequence_length: Optional sequence length (e.g., "8K"). When provided, filters + recipes by MaxContextLength >= the requested value. + hub_name: Hub name (default: "SageMakerPublicHub"). + Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ @@ -452,12 +486,40 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni # if no standard recipe exists (some models only have subscription recipes). recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) - if not recipe: - recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) + candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) - if not recipe: + candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] + else: + candidates = [] + + # Filter by SequenceLength if sequence_length is provided + if sequence_length and candidates: + requested = _parse_context_length(sequence_length) + candidates_with_context = [r for r in candidates if r.get("SequenceLength")] + if candidates_with_context: + filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] + if filtered: + filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) + recipe = filtered[0] + else: + available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) + raise ValueError( + f"No recipes found with SequenceLength >= {sequence_length}. " + f"Available sequence lengths: {available}" + ) + else: + raise ValueError( + f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " + f"and sequence length:{sequence_length}" + ) + elif candidates: + recipe = candidates[0] + + # Fall back to subscription recipes if no standard recipe found + if not recipe: + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + recipe = next((r for r in recipes_with_template if r.get("Peft") and r.get("IsSubscriptionModel")), None) + elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": recipe = next((r for r in recipes_with_template if not r.get("Peft") and r.get("IsSubscriptionModel")), None) if not recipe: @@ -636,7 +698,8 @@ def _resolve_model_and_name(model, sagemaker_session=None): def _create_serverless_config(model_arn, customization_technique, - training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: + training_type, accept_eula, evaluator_arn=None, + sequence_length=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: """Create serverless job configuration for fine-tuning. Args: @@ -645,6 +708,7 @@ def _create_serverless_config(model_arn, customization_technique, training_type: Training type (TrainingType enum or string) accept_eula: Boolean indicating if EULA is accepted evaluator_arn: Optional evaluator ARN for RLVR/RLAIF + sequence_length: Optional sequence length enum value (e.g., "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K") job_type: Type of job (default: "FineTuning") Returns: @@ -654,14 +718,18 @@ def _create_serverless_config(model_arn, customization_technique, else (training_type.value if isinstance(training_type, TrainingType) else training_type) # Create ServerlessJobConfig using shapes - serverless_config = ServerlessJobConfig( + config_kwargs = dict( job_type=job_type, base_model_arn=model_arn, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, - accept_eula=accept_eula + accept_eula=accept_eula, ) + if sequence_length is not None: + config_kwargs["sequence_length"] = sequence_length + + serverless_config = ServerlessJobConfig(**config_kwargs) return serverless_config diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index bd5d9a11bd..8e3bc17d5e 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -100,6 +100,10 @@ class DPOTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( self, @@ -116,6 +120,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -134,16 +139,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.DPO.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.DPO.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -227,12 +233,14 @@ def train(self, kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.DPO.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.DPO.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index f2d8460989..5d782d8fa3 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -114,6 +114,10 @@ class RLAIFTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -135,6 +139,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -156,14 +161,16 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLAIF.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLAIF.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) @@ -242,13 +249,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) evaluator_arn = getattr(self, '_evaluator_arn', None) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLAIF.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLAIF.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 333a93fc55..53029155f2 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -106,6 +106,10 @@ class RLVRTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -126,6 +130,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -146,15 +151,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.RLVR.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.RLVR.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Remove constructor-handled hyperparameters self._process_hyperparameters() @@ -233,13 +240,15 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, # Extract and validate evaluator ARN evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.RLVR.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - evaluator_arn=evaluator_arn, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.RLVR.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + evaluator_arn=evaluator_arn, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 233f169d0f..e2193f0b9b 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -102,6 +102,10 @@ class SFTTrainer(BaseTrainer): stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). + sequence_length (Optional[str]): + The sequence length for the training job. Valid values are + "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + If not specified, the service will use default recipe selection behavior. """ def __init__( @@ -119,6 +123,7 @@ def __init__( networking: Optional[VpcConfig] = None, accept_eula: Optional[bool] = False, stopping_condition: Optional[StoppingCondition] = None, + sequence_length: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) @@ -138,15 +143,17 @@ def __init__( self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition + self.sequence_length = sequence_length # Initialize fine-tuning options with beta session fallback - self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, - CustomizationTechnique.SFT.value, - self.training_type, - self.sagemaker_session or TrainDefaults.get_sagemaker_session( - sagemaker_session=self.sagemaker_session - )) - + self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn( + self._model_name, + CustomizationTechnique.SFT.value, + self.training_type, + self.sagemaker_session or TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session), + sequence_length=self.sequence_length + ) + # Process hyperparameters self._process_hyperparameters() @@ -225,12 +232,14 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati kms_key_id=self.kms_key_id ) - serverless_config = _create_serverless_config(model_arn=self._model_arn, - customization_technique=CustomizationTechnique.SFT.value, - training_type=self.training_type, - accept_eula=self.accept_eula, - job_type=JOB_TYPE - ) + serverless_config = _create_serverless_config( + model_arn=self._model_arn, + customization_technique=CustomizationTechnique.SFT.value, + training_type=self.training_type, + accept_eula=self.accept_eula, + sequence_length=self.sequence_length, + job_type=JOB_TYPE + ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 44089e9eb2..8f5e8bbba0 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -27,9 +27,11 @@ _create_mlflow_config, _validate_eula_for_gated_model, _validate_model_region_availability, - _validate_s3_path_exists + _validate_s3_path_exists, + _parse_context_length ) -from sagemaker.core.resources import ModelPackage, ModelPackageGroup +from sagemaker.core.resources import ModelPackage +from sagemaker.core.utils.utils import Unassigned, ModelPackageGroup from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -464,7 +466,6 @@ def test__convert_input_data_to_channels(self): def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input""" - from sagemaker.core.resources import ModelPackage model_package = Mock(spec=ModelPackage) result = _validate_eula_for_gated_model(model_package, False, True) @@ -865,6 +866,104 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self, assert "max_steps" in options._specs assert "customer_data_percent" not in options._specs + def test__create_serverless_config_with_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K") + + assert config.sequence_length == "8K" + assert config.base_model_arn == "model-arn" + + def test__create_serverless_config_without_sequence_length(self): + config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) + + # sequence_length should remain Unassigned (not set), not None + assert isinstance(config.sequence_length, Unassigned) + + def test__parse_context_length_with_k_suffix(self): + assert _parse_context_length("8K") == 8192 + assert _parse_context_length("32K") == 32768 + assert _parse_context_length("128K") == 131072 + + def test__parse_context_length_with_lowercase(self): + assert _parse_context_length("8k") == 8192 + + def test__parse_context_length_with_integer(self): + assert _parse_context_length("4096") == 4096 + + def test__parse_context_length_with_none(self): + assert _parse_context_length(None) == 0 + + def test__parse_context_length_with_empty(self): + assert _parse_context_length("") == 0 + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + @patch('boto3.client') + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_client, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + }, + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json", + "Peft": True, + "SequenceLength": "32K" + } + ] + } + } + + mock_s3_client = Mock() + mock_boto_client.return_value = mock_s3_client + mock_s3_client.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + + result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") + + if result is not None: + options, model_arn, is_gated_model = result + # Should pick the 32K recipe (smallest >= 8K) + mock_s3_client.get_object.assert_called_once() + call_args = mock_s3_client.get_object.call_args[1] + assert "params-32k" in call_args["Key"] + + @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') + def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(self, mock_get_hub_content): + mock_session = Mock() + mock_session.boto_session.region_name = "us-east-1" + + mock_get_hub_content.return_value = { + 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", + 'hub_content_document': { + "GatedBucket": False, + "RecipeCollection": [ + { + "CustomizationTechnique": "SFT", + "SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json", + "SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json", + "Peft": True, + "SequenceLength": "4K" + } + ] + } + } + + # Requesting 128K but only 4K available — should raise + with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"): + _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="128K") + class TestSubscriptionOnlyModelFallback: """Tests for models that only have subscription recipes.""" diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 1b70e0bf89..7648b46e35 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -506,4 +506,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = DPOTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.dpo_trainer._resolve_model_and_name') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.dpo_trainer._get_unique_name') + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._create_input_data_config') + @patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.dpo_trainer._create_output_config') + @patch('sagemaker.train.dpo_trainer._create_serverless_config') + @patch('sagemaker.train.dpo_trainer._create_mlflow_config') + @patch('sagemaker.train.dpo_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = DPOTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index e5666883e8..6811c45540 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -682,4 +682,70 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_hyperparams._specs = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", sequence_length="128K") + assert trainer.sequence_length == "128K" + + @patch('sagemaker.train.rlaif_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlaif_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlaif_trainer._get_unique_name') + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._create_input_data_config') + @patch('sagemaker.train.rlaif_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlaif_trainer._create_output_config') + @patch('sagemaker.train.rlaif_trainer._create_serverless_config') + @patch('sagemaker.train.rlaif_trainer._create_mlflow_config') + @patch('sagemaker.train.rlaif_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_fine_tuning_options._specs = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLAIFTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="64K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "64K" diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index 320b81555d..b4c01385e2 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -509,4 +509,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = RLVRTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", sequence_length="32K") + assert trainer.sequence_length == "32K" + + @patch('sagemaker.train.rlvr_trainer._resolve_model_and_name') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.rlvr_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.rlvr_trainer._get_unique_name') + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._create_input_data_config') + @patch('sagemaker.train.rlvr_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.rlvr_trainer._create_output_config') + @patch('sagemaker.train.rlvr_trainer._create_serverless_config') + @patch('sagemaker.train.rlvr_trainer._create_mlflow_config') + @patch('sagemaker.train.rlvr_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = RLVRTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="4K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "4K" diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 108990f839..01fc21f4bd 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -520,4 +520,67 @@ def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_ trainer = SFTTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train") trainer.train(wait=False, wait_timeout=600) - mock_wait.assert_not_called() \ No newline at end of file + mock_wait.assert_not_called() + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_sequence_length_default_none(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group") + assert trainer.sequence_length is None + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_init_with_sequence_length(self, mock_finetuning_options, mock_validate_group): + mock_validate_group.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning_options.return_value = (mock_hyperparams, "model-arn", False) + trainer = SFTTrainer(model="test-model", model_package_group="test-group", sequence_length="8K") + assert trainer.sequence_length == "8K" + + @patch('sagemaker.train.sft_trainer._resolve_model_and_name') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_role') + @patch('sagemaker.train.sft_trainer.TrainDefaults.get_sagemaker_session') + @patch('sagemaker.train.sft_trainer._get_unique_name') + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._create_input_data_config') + @patch('sagemaker.train.sft_trainer._convert_input_data_to_channels') + @patch('sagemaker.train.sft_trainer._create_output_config') + @patch('sagemaker.train.sft_trainer._create_serverless_config') + @patch('sagemaker.train.sft_trainer._create_mlflow_config') + @patch('sagemaker.train.sft_trainer._create_model_package_config') + @patch('sagemaker.core.resources.TrainingJob.create') + def test_train_passes_sequence_length_to_serverless_config(self, mock_training_job_create, + mock_model_package_config, mock_mlflow_config, mock_serverless_config, + mock_output_config, mock_convert_channels, mock_input_config, + mock_validate_group, mock_unique_name, mock_get_sagemaker_session, + mock_get_role, mock_get_options, mock_resolve_model): + mock_validate_group.return_value = "test-group" + mock_resolve_model.return_value = ("test-model", "test-model") + mock_get_sagemaker_session.return_value = Mock() + mock_fine_tuning_options = Mock() + mock_fine_tuning_options.to_dict.return_value = {} + mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False) + mock_get_role.return_value = "test-role" + mock_unique_name.return_value = "test-job-name" + mock_input_config.return_value = [Mock()] + mock_convert_channels.return_value = [Mock()] + mock_output_config.return_value = Mock() + mock_serverless_config.return_value = Mock() + mock_mlflow_config.return_value = Mock() + mock_model_package_config.return_value = Mock() + mock_training_job = Mock() + mock_training_job_create.return_value = mock_training_job + + trainer = SFTTrainer(model="test-model", model_package_group="test-group", + training_dataset="s3://bucket/train", sequence_length="16K") + trainer.train(wait=False) + + mock_serverless_config.assert_called_once() + call_kwargs = mock_serverless_config.call_args[1] + assert call_kwargs["sequence_length"] == "16K" From 20d07558d140478db3e4db409e0d21fd3d65d682 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 00:45:54 +0000 Subject: [PATCH 2/9] fix: use codegen for SequenceLength shape instead of manual shapes.py edit Add SequenceLength to service-2.json and regenerate shapes.py via codegen (python -m sagemaker.core.tools.codegen) instead of editing shapes.py manually. --- .../sample/sagemaker/2017-07-24/service-2.json | 17 +++++++++++++++++ .../src/sagemaker/core/shapes/shapes.py | 3 ++- .../core/utils/code_injection/shape_dag.py | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json index ceb4f316dc..6b551c5fd6 100644 --- a/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json +++ b/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json @@ -44132,6 +44132,10 @@ "EvaluatorArn":{ "shape":"EvaluatorArn", "documentation":"

The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt.

" + }, + "SequenceLength":{ + "shape":"SequenceLength", + "documentation":"

The sequence length for the training job.

" } }, "documentation":"

The configuration for the serverless training job.

" @@ -44143,6 +44147,19 @@ "Evaluation" ] }, + "SequenceLength":{ + "type":"string", + "enum":[ + "1K", + "2K", + "4K", + "8K", + "16K", + "32K", + "64K", + "128K" + ] + }, "ServerlessMaxConcurrency":{ "type":"integer", "box":true, diff --git a/sagemaker-core/src/sagemaker/core/shapes/shapes.py b/sagemaker-core/src/sagemaker/core/shapes/shapes.py index 5e8217f463..2aa5f2afe8 100644 --- a/sagemaker-core/src/sagemaker/core/shapes/shapes.py +++ b/sagemaker-core/src/sagemaker/core/shapes/shapes.py @@ -9717,7 +9717,7 @@ class ServerlessJobConfig(Base): peft: The parameter-efficient fine-tuning configuration. evaluation_type: The evaluation job type. Required when serverless job type is Evaluation. evaluator_arn: The evaluator Amazon Resource Name (ARN) used as reward function or reward prompt. - sequence_length: The sequence length for the training job. Valid values are "1K", "2K", "4K", "8K", "16K", "32K", "64K", "128K". + sequence_length: The sequence length for the training job. """ base_model_arn: StrPipeVar @@ -9729,6 +9729,7 @@ class ServerlessJobConfig(Base): evaluator_arn: Optional[StrPipeVar] = Unassigned() sequence_length: Optional[StrPipeVar] = Unassigned() + class MlflowConfig(Base): """ MlflowConfig diff --git a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py index 5d0de63efd..977fe6889f 100644 --- a/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py +++ b/sagemaker-core/src/sagemaker/core/utils/code_injection/shape_dag.py @@ -16206,6 +16206,7 @@ {"name": "Peft", "shape": "Peft", "type": "string"}, {"name": "EvaluationType", "shape": "EvaluationType", "type": "string"}, {"name": "EvaluatorArn", "shape": "EvaluatorArn", "type": "string"}, + {"name": "SequenceLength", "shape": "SequenceLength", "type": "string"}, ], "type": "structure", }, From c4f9a964984caed99af3ada96fbcde79c1389cf1 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 17:14:21 +0000 Subject: [PATCH 3/9] refactor: preserve original recipe selection path when sequence_length not provided Keep the existing `next(...)` logic untouched for the default case (no sequence_length). Only build the candidates list and filter when sequence_length is explicitly requested, ensuring zero behavioral change for existing callers. --- .../train/common_utils/finetune_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 0f8ec7746d..aaf203b6d3 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -486,14 +486,19 @@ def _get_fine_tuning_options_and_model_arn( # if no standard recipe exists (some models only have subscription recipes). recipe = None if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] + recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] - else: - candidates = [] + recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + + # Override recipe selection when sequence_length is explicitly requested + if sequence_length: + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] + elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": + candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] + else: + candidates = [] - # Filter by SequenceLength if sequence_length is provided - if sequence_length and candidates: requested = _parse_context_length(sequence_length) candidates_with_context = [r for r in candidates if r.get("SequenceLength")] if candidates_with_context: @@ -512,8 +517,6 @@ def _get_fine_tuning_options_and_model_arn( f"No recipes found with Smtj for technique: {customization_technique},training_type:{training_type}, " f"and sequence length:{sequence_length}" ) - elif candidates: - recipe = candidates[0] # Fall back to subscription recipes if no standard recipe found if not recipe: From a352a00af59973d90ebbb61fe85161f43b8fdf6e Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 17:31:36 +0000 Subject: [PATCH 4/9] fix: correct test imports and mock setup for sequence_length tests --- .../train/common_utils/test_finetune_utils.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 8f5e8bbba0..c80f0a4c53 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -30,8 +30,8 @@ _validate_s3_path_exists, _parse_context_length ) -from sagemaker.core.resources import ModelPackage -from sagemaker.core.utils.utils import Unassigned, ModelPackageGroup +from sagemaker.core.resources import ModelPackage, ModelPackageGroup +from sagemaker.core.utils.utils import Unassigned from sagemaker.ai_registry.dataset import DataSet from sagemaker.train.common import TrainingType from sagemaker.train.configs import InputData @@ -896,10 +896,14 @@ def test__parse_context_length_with_empty(self): assert _parse_context_length("") == 0 @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') - @patch('boto3.client') - def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_client, mock_get_hub_content): + def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" + mock_s3 = Mock() + mock_s3.get_object.return_value = { + "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) + } + mock_session.boto_session.client.return_value = mock_s3 mock_get_hub_content.return_value = { 'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model", @@ -924,19 +928,13 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_cli } } - mock_s3_client = Mock() - mock_boto_client.return_value = mock_s3_client - mock_s3_client.get_object.return_value = { - "Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}')) - } - result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K") if result is not None: options, model_arn, is_gated_model = result # Should pick the 32K recipe (smallest >= 8K) - mock_s3_client.get_object.assert_called_once() - call_args = mock_s3_client.get_object.call_args[1] + mock_s3.get_object.assert_called_once() + call_args = mock_s3.get_object.call_args[1] assert "params-32k" in call_args["Key"] @patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata') From 0ee458891019141d42a97100859baf302554ce1f Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 21:59:20 +0000 Subject: [PATCH 5/9] address PR review: sequence_length as recipe pre-filter and simplify config - Move sequence_length filtering above recipe selection to reduce recipes_with_template before existing logic runs - Always pass sequence_length to ServerlessJobConfig (no None guard) --- .../train/common_utils/finetune_utils.py | 38 +++++++------------ .../train/common_utils/test_finetune_utils.py | 3 +- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index aaf203b6d3..e469d2c815 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -481,31 +481,15 @@ def _get_fine_tuning_options_and_model_arn( if not recipes_with_template: raise ValueError(f"No recipes found with Smtj for technique: {customization_technique}") - # Select recipe based on training type - # Prefer non-subscription (standard) recipes first, fall back to subscription recipes - # if no standard recipe exists (some models only have subscription recipes). - recipe = None - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) - - # Override recipe selection when sequence_length is explicitly requested + # Filter by SequenceLength before recipe selection if sequence_length is requested if sequence_length: - if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": - candidates = [r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")] - elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": - candidates = [r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")] - else: - candidates = [] - requested = _parse_context_length(sequence_length) - candidates_with_context = [r for r in candidates if r.get("SequenceLength")] + candidates_with_context = [r for r in recipes_with_template if r.get("SequenceLength")] if candidates_with_context: filtered = [r for r in candidates_with_context if _parse_context_length(r.get("SequenceLength")) >= requested] if filtered: filtered.sort(key=lambda r: _parse_context_length(r.get("SequenceLength"))) - recipe = filtered[0] + recipes_with_template = filtered else: available = sorted(set(r.get("SequenceLength") for r in candidates_with_context)) raise ValueError( @@ -518,6 +502,15 @@ def _get_fine_tuning_options_and_model_arn( f"and sequence length:{sequence_length}" ) + # Select recipe based on training type + # Prefer non-subscription (standard) recipes first, fall back to subscription recipes + # if no standard recipe exists (some models only have subscription recipes). + recipe = None + if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": + recipe = next((r for r in recipes_with_template if r.get("Peft") and not r.get("IsSubscriptionModel")), None) + elif (isinstance(training_type, TrainingType) and training_type == TrainingType.FULL) or training_type == "FULL": + recipe = next((r for r in recipes_with_template if not r.get("Peft") and not r.get("IsSubscriptionModel")), None) + # Fall back to subscription recipes if no standard recipe found if not recipe: if (isinstance(training_type, TrainingType) and training_type == TrainingType.LORA) or training_type == "LORA": @@ -721,18 +714,15 @@ def _create_serverless_config(model_arn, customization_technique, else (training_type.value if isinstance(training_type, TrainingType) else training_type) # Create ServerlessJobConfig using shapes - config_kwargs = dict( + serverless_config = ServerlessJobConfig( job_type=job_type, base_model_arn=model_arn, customization_technique=customization_technique, peft=peft, evaluator_arn=evaluator_arn, accept_eula=accept_eula, + sequence_length=sequence_length, ) - if sequence_length is not None: - config_kwargs["sequence_length"] = sequence_length - - serverless_config = ServerlessJobConfig(**config_kwargs) return serverless_config diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c80f0a4c53..4ed28e6d18 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -875,8 +875,7 @@ def test__create_serverless_config_with_sequence_length(self): def test__create_serverless_config_without_sequence_length(self): config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True) - # sequence_length should remain Unassigned (not set), not None - assert isinstance(config.sequence_length, Unassigned) + assert config.sequence_length is None def test__parse_context_length_with_k_suffix(self): assert _parse_context_length("8K") == 8192 From 6998099fb8bbbe591dbcb9e3ab7684c4cece3306 Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:02:51 +0000 Subject: [PATCH 6/9] fix: change hub_name default to None for consistency Use Optional[str] = None instead of hardcoded "SageMakerPublicHub" default, letting get_sagemaker_hub_name() resolve it at runtime. --- .../sagemaker/train/common_utils/finetune_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index e469d2c815..ff67567b6a 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -427,12 +427,12 @@ def _parse_context_length(value) -> int: def _get_fine_tuning_options_and_model_arn( - model_name: str, - customization_technique: str, - training_type, - sagemaker_session, - sequence_length=None, - hub_name: str = "SageMakerPublicHub" + model_name: str, + customization_technique: str, + training_type, + sagemaker_session, + sequence_length=None, + hub_name: Optional[str] = None ) -> tuple: """Get fine-tuning options and model ARN for given customization technique. From afb68ac86d6899c2c99716ff323ea09bdb1fe64e Mon Sep 17 00:00:00 2001 From: guanweim Date: Thu, 28 May 2026 22:05:18 +0000 Subject: [PATCH 7/9] test: add integration test for SFT trainer with sequence_length --- .../train/test_sft_trainer_integration.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index 68446991c4..39bf702025 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -135,3 +135,39 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): assert training_job.training_job_status == "Completed" assert hasattr(training_job, 'output_model_package_arn') assert training_job.output_model_package_arn is not None + + +@pytest.mark.gpu_intensive +def test_sft_trainer_lora_with_sequence_length(sagemaker_session): + """Test SFT training workflow with LORA and sequence_length specified.""" + unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}" + + sft_trainer = SFTTrainer( + model="meta-textgeneration-llama-3-2-1b-instruct", + training_type=TrainingType.LORA, + model_package_group="arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models", + training_dataset="s3://mc-flows-sdk-testing/input_data/sft/sample_data_256_final.jsonl", + s3_output_path="s3://mc-flows-sdk-testing/output/", + accept_eula=True, + sequence_length="8K", + base_job_name=f"sft-seqlen-integ-{unique_id}", + ) + + training_job = sft_trainer.train(wait=False) + + max_wait_time = 3600 + poll_interval = 30 + start_time = time.time() + + while time.time() - start_time < max_wait_time: + training_job.refresh() + status = training_job.training_job_status + + if status in ["Completed", "Failed", "Stopped"]: + break + + time.sleep(poll_interval) + + assert training_job.training_job_status == "Completed" + assert hasattr(training_job, 'output_model_package_arn') + assert training_job.output_model_package_arn is not None From ee08245d9d55d31cd3f482132d5f97f8250c0264 Mon Sep 17 00:00:00 2001 From: guanweim Date: Tue, 16 Jun 2026 23:09:41 +0000 Subject: [PATCH 8/9] Regenerate code via codegen and remove hardcoded path in data_extractor.py --- sagemaker-core/src/sagemaker/core/resources.py | 1 - sagemaker-core/src/sagemaker/core/tools/data_extractor.py | 1 - 2 files changed, 2 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/resources.py b/sagemaker-core/src/sagemaker/core/resources.py index 0eb0b78f06..452f820d27 100644 --- a/sagemaker-core/src/sagemaker/core/resources.py +++ b/sagemaker-core/src/sagemaker/core/resources.py @@ -46,7 +46,6 @@ from sagemaker.core.serializers.base import BaseSerializer from sagemaker.core.deserializers.base import BaseDeserializer - logger = get_textual_rich_logger(__name__) diff --git a/sagemaker-core/src/sagemaker/core/tools/data_extractor.py b/sagemaker-core/src/sagemaker/core/tools/data_extractor.py index aa0eb7cdbd..21e3134756 100644 --- a/sagemaker-core/src/sagemaker/core/tools/data_extractor.py +++ b/sagemaker-core/src/sagemaker/core/tools/data_extractor.py @@ -11,7 +11,6 @@ RUNTIME_SERVICE_JSON_FILE_PATH, ) -SERVICE_JSON_FILE_PATH = "/Users/rsareddy/workplace/pysdk-v3-313-release/sagemaker-python-sdk/sagemaker-core/sample/sagemaker/2017-07-24/service-2.json" class ServiceJsonData(BaseModel): sagemaker: dict From 9decd3b97cc6164b7b9211fbdeaa52f88e2b0ce8 Mon Sep 17 00:00:00 2001 From: guanweim Date: Tue, 16 Jun 2026 23:44:19 +0000 Subject: [PATCH 9/9] comment out gpu testing flag --- .../tests/integ/train/test_sft_trainer_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py index 39bf702025..61d9ff3122 100644 --- a/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_sft_trainer_integration.py @@ -137,7 +137,7 @@ def test_sft_trainer_nova_workflow(sagemaker_session_us_east_1): assert training_job.output_model_package_arn is not None -@pytest.mark.gpu_intensive +# @pytest.mark.gpu_intensive def test_sft_trainer_lora_with_sequence_length(sagemaker_session): """Test SFT training workflow with LORA and sequence_length specified.""" unique_id = f"{int(time.time())}-{random.randint(1000, 9999)}"