@@ -864,3 +864,274 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
864864 # Should still have standard params, just not datamix ones
865865 assert "max_steps" in options ._specs
866866 assert "customer_data_percent" not in options ._specs
867+
868+ def test__create_serverless_config_with_sequence_length (self ):
869+ config = _create_serverless_config ("model-arn" , "SFT" , TrainingType .LORA , accept_eula = True , sequence_length = "8K" )
870+
871+ assert config .sequence_length == "8K"
872+ assert config .base_model_arn == "model-arn"
873+
874+ def test__create_serverless_config_without_sequence_length (self ):
875+ config = _create_serverless_config ("model-arn" , "SFT" , TrainingType .LORA , accept_eula = True )
876+
877+ assert config .sequence_length is None
878+
879+ def test__parse_context_length_with_k_suffix (self ):
880+ assert _parse_context_length ("8K" ) == 8192
881+ assert _parse_context_length ("32K" ) == 32768
882+ assert _parse_context_length ("128K" ) == 131072
883+
884+ def test__parse_context_length_with_lowercase (self ):
885+ assert _parse_context_length ("8k" ) == 8192
886+
887+ def test__parse_context_length_with_integer (self ):
888+ assert _parse_context_length ("4096" ) == 4096
889+
890+ def test__parse_context_length_with_none (self ):
891+ assert _parse_context_length (None ) == 0
892+
893+ def test__parse_context_length_with_empty (self ):
894+ assert _parse_context_length ("" ) == 0
895+
896+ @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
897+ def test__get_fine_tuning_options_filters_by_sequence_length (self , mock_get_hub_content ):
898+ mock_session = Mock ()
899+ mock_session .boto_session .region_name = "us-east-1"
900+ mock_s3 = Mock ()
901+ mock_s3 .get_object .return_value = {
902+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 32768}}' ))
903+ }
904+ mock_session .boto_session .client .return_value = mock_s3
905+
906+ mock_get_hub_content .return_value = {
907+ 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
908+ 'hub_content_document' : {
909+ "GatedBucket" : False ,
910+ "RecipeCollection" : [
911+ {
912+ "CustomizationTechnique" : "SFT" ,
913+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-4k.json" ,
914+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-4k.json" ,
915+ "Peft" : True ,
916+ "SequenceLength" : "4K"
917+ },
918+ {
919+ "CustomizationTechnique" : "SFT" ,
920+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-32k.json" ,
921+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-32k.json" ,
922+ "Peft" : True ,
923+ "SequenceLength" : "32K"
924+ }
925+ ]
926+ }
927+ }
928+
929+ result = _get_fine_tuning_options_and_model_arn ("test-model" , "SFT" , "LORA" , mock_session , sequence_length = "8K" )
930+
931+ if result is not None :
932+ options , model_arn , is_gated_model = result
933+ # Should pick the 32K recipe (smallest >= 8K)
934+ mock_s3 .get_object .assert_called_once ()
935+ call_args = mock_s3 .get_object .call_args [1 ]
936+ assert "params-32k" in call_args ["Key" ]
937+
938+ @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
939+ def test__get_fine_tuning_options_raises_when_no_sufficient_context_length (self , mock_get_hub_content ):
940+ mock_session = Mock ()
941+ mock_session .boto_session .region_name = "us-east-1"
942+
943+ mock_get_hub_content .return_value = {
944+ 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
945+ 'hub_content_document' : {
946+ "GatedBucket" : False ,
947+ "RecipeCollection" : [
948+ {
949+ "CustomizationTechnique" : "SFT" ,
950+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-4k.json" ,
951+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-4k.json" ,
952+ "Peft" : True ,
953+ "SequenceLength" : "4K"
954+ }
955+ ]
956+ }
957+ }
958+
959+ # Requesting 128K but only 4K available — should raise
960+ with pytest .raises (ValueError , match = "No recipes found with SequenceLength >= 128K" ):
961+ _get_fine_tuning_options_and_model_arn (
962+ "test-model" , "SFT" , "LORA" , mock_session , sequence_length = "128K"
963+ )
964+
965+ def test__parse_context_length_with_invalid_k_value (self ):
966+ assert _parse_context_length ("abcK" ) == 0
967+
968+ def test__parse_context_length_with_non_numeric_string (self ):
969+ assert _parse_context_length ("hello" ) == 0
970+
971+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
972+ def test__get_fine_tuning_options_raises_when_no_recipes_have_sequence_length (
973+ self , mock_get_hub_content
974+ ):
975+ mock_session = Mock ()
976+ mock_session .boto_session .region_name = "us-east-1"
977+
978+ mock_get_hub_content .return_value = {
979+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
980+ "hub_content_document" : {
981+ "GatedBucket" : False ,
982+ "RecipeCollection" : [
983+ {
984+ "CustomizationTechnique" : "SFT" ,
985+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template.json" ,
986+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params.json" ,
987+ "Peft" : True ,
988+ }
989+ ],
990+ },
991+ }
992+
993+ with pytest .raises (ValueError , match = "and sequence length" ):
994+ _get_fine_tuning_options_and_model_arn (
995+ "test-model" , "SFT" , "LORA" , mock_session , sequence_length = "8K"
996+ )
997+
998+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
999+ def test__get_fine_tuning_options_filters_by_sequence_length_full_training (
1000+ self , mock_get_hub_content
1001+ ):
1002+ mock_session = Mock ()
1003+ mock_session .boto_session .region_name = "us-east-1"
1004+ mock_s3 = Mock ()
1005+ mock_s3 .get_object .return_value = {
1006+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 8192}}' ))
1007+ }
1008+ mock_session .boto_session .client .return_value = mock_s3
1009+
1010+ mock_get_hub_content .return_value = {
1011+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
1012+ "hub_content_document" : {
1013+ "GatedBucket" : False ,
1014+ "RecipeCollection" : [
1015+ {
1016+ "CustomizationTechnique" : "SFT" ,
1017+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-8k.json" ,
1018+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-8k.json" ,
1019+ "SequenceLength" : "8K" ,
1020+ },
1021+ {
1022+ "CustomizationTechnique" : "SFT" ,
1023+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-32k.json" ,
1024+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-32k.json" ,
1025+ "SequenceLength" : "32K" ,
1026+ },
1027+ ],
1028+ },
1029+ }
1030+
1031+ result = _get_fine_tuning_options_and_model_arn (
1032+ "test-model" , "SFT" , "FULL" , mock_session , sequence_length = "8K"
1033+ )
1034+
1035+ if result is not None :
1036+ options , model_arn , is_gated_model = result
1037+ mock_s3 .get_object .assert_called_once ()
1038+ call_args = mock_s3 .get_object .call_args [1 ]
1039+ assert "params-8k" in call_args ["Key" ]
1040+
1041+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
1042+ def test__get_fine_tuning_options_selects_smallest_sufficient_sequence_length (
1043+ self , mock_get_hub_content
1044+ ):
1045+ mock_session = Mock ()
1046+ mock_session .boto_session .region_name = "us-east-1"
1047+ mock_s3 = Mock ()
1048+ mock_s3 .get_object .return_value = {
1049+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 16384}}' ))
1050+ }
1051+ mock_session .boto_session .client .return_value = mock_s3
1052+
1053+ mock_get_hub_content .return_value = {
1054+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
1055+ "hub_content_document" : {
1056+ "GatedBucket" : False ,
1057+ "RecipeCollection" : [
1058+ {
1059+ "CustomizationTechnique" : "SFT" ,
1060+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-4k.json" ,
1061+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-4k.json" ,
1062+ "Peft" : True ,
1063+ "SequenceLength" : "4K" ,
1064+ },
1065+ {
1066+ "CustomizationTechnique" : "SFT" ,
1067+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-16k.json" ,
1068+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-16k.json" ,
1069+ "Peft" : True ,
1070+ "SequenceLength" : "16K" ,
1071+ },
1072+ {
1073+ "CustomizationTechnique" : "SFT" ,
1074+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-128k.json" ,
1075+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-128k.json" ,
1076+ "Peft" : True ,
1077+ "SequenceLength" : "128K" ,
1078+ },
1079+ ],
1080+ },
1081+ }
1082+
1083+ result = _get_fine_tuning_options_and_model_arn (
1084+ "test-model" , "SFT" , "LORA" , mock_session , sequence_length = "8K"
1085+ )
1086+
1087+ if result is not None :
1088+ options , model_arn , is_gated_model = result
1089+ # Should pick 16K (smallest >= 8K), not 128K
1090+ mock_s3 .get_object .assert_called_once ()
1091+ call_args = mock_s3 .get_object .call_args [1 ]
1092+ assert "params-16k" in call_args ["Key" ]
1093+
1094+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
1095+ def test__get_fine_tuning_options_without_sequence_length_uses_first_recipe (
1096+ self , mock_get_hub_content
1097+ ):
1098+ """Verify that when no sequence_length is provided, existing behavior is unchanged."""
1099+ mock_session = Mock ()
1100+ mock_session .boto_session .region_name = "us-east-1"
1101+ mock_s3 = Mock ()
1102+ mock_s3 .get_object .return_value = {
1103+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 4096}}' ))
1104+ }
1105+ mock_session .boto_session .client .return_value = mock_s3
1106+
1107+ mock_get_hub_content .return_value = {
1108+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
1109+ "hub_content_document" : {
1110+ "GatedBucket" : False ,
1111+ "RecipeCollection" : [
1112+ {
1113+ "CustomizationTechnique" : "SFT" ,
1114+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-first.json" ,
1115+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-first.json" ,
1116+ "Peft" : True ,
1117+ "SequenceLength" : "4K" ,
1118+ },
1119+ {
1120+ "CustomizationTechnique" : "SFT" ,
1121+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-second.json" ,
1122+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-second.json" ,
1123+ "Peft" : True ,
1124+ "SequenceLength" : "32K" ,
1125+ },
1126+ ],
1127+ },
1128+ }
1129+
1130+ result = _get_fine_tuning_options_and_model_arn ("test-model" , "SFT" , "LORA" , mock_session )
1131+
1132+ if result is not None :
1133+ options , model_arn , is_gated_model = result
1134+ # Without sequence_length, should pick the first matching recipe
1135+ mock_s3 .get_object .assert_called_once ()
1136+ call_args = mock_s3 .get_object .call_args [1 ]
1137+ assert "params-first" in call_args ["Key" ]
0 commit comments