@@ -852,3 +852,177 @@ def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(
852852 _get_fine_tuning_options_and_model_arn (
853853 "test-model" , "SFT" , "LORA" , mock_session , sequence_length = "128K"
854854 )
855+
856+ def test__parse_context_length_with_invalid_k_value (self ):
857+ assert _parse_context_length ("abcK" ) == 0
858+
859+ def test__parse_context_length_with_non_numeric_string (self ):
860+ assert _parse_context_length ("hello" ) == 0
861+
862+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
863+ def test__get_fine_tuning_options_raises_when_no_recipes_have_sequence_length (
864+ self , mock_get_hub_content
865+ ):
866+ mock_session = Mock ()
867+ mock_session .boto_session .region_name = "us-east-1"
868+
869+ mock_get_hub_content .return_value = {
870+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
871+ "hub_content_document" : {
872+ "GatedBucket" : False ,
873+ "RecipeCollection" : [
874+ {
875+ "CustomizationTechnique" : "SFT" ,
876+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template.json" ,
877+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params.json" ,
878+ "Peft" : True ,
879+ }
880+ ],
881+ },
882+ }
883+
884+ with pytest .raises (ValueError , match = "and sequence length" ):
885+ _get_fine_tuning_options_and_model_arn (
886+ "test-model" , "SFT" , "LORA" , mock_session , sequence_length = "8K"
887+ )
888+
889+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
890+ def test__get_fine_tuning_options_filters_by_sequence_length_full_training (
891+ self , mock_get_hub_content
892+ ):
893+ mock_session = Mock ()
894+ mock_session .boto_session .region_name = "us-east-1"
895+ mock_s3 = Mock ()
896+ mock_s3 .get_object .return_value = {
897+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 8192}}' ))
898+ }
899+ mock_session .boto_session .client .return_value = mock_s3
900+
901+ mock_get_hub_content .return_value = {
902+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
903+ "hub_content_document" : {
904+ "GatedBucket" : False ,
905+ "RecipeCollection" : [
906+ {
907+ "CustomizationTechnique" : "SFT" ,
908+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-8k.json" ,
909+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-8k.json" ,
910+ "SequenceLength" : "8K" ,
911+ },
912+ {
913+ "CustomizationTechnique" : "SFT" ,
914+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-32k.json" ,
915+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-32k.json" ,
916+ "SequenceLength" : "32K" ,
917+ },
918+ ],
919+ },
920+ }
921+
922+ result = _get_fine_tuning_options_and_model_arn (
923+ "test-model" , "SFT" , "FULL" , mock_session , sequence_length = "8K"
924+ )
925+
926+ if result is not None :
927+ options , model_arn , is_gated_model = result
928+ mock_s3 .get_object .assert_called_once ()
929+ call_args = mock_s3 .get_object .call_args [1 ]
930+ assert "params-8k" in call_args ["Key" ]
931+
932+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
933+ def test__get_fine_tuning_options_selects_smallest_sufficient_sequence_length (
934+ self , mock_get_hub_content
935+ ):
936+ mock_session = Mock ()
937+ mock_session .boto_session .region_name = "us-east-1"
938+ mock_s3 = Mock ()
939+ mock_s3 .get_object .return_value = {
940+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 16384}}' ))
941+ }
942+ mock_session .boto_session .client .return_value = mock_s3
943+
944+ mock_get_hub_content .return_value = {
945+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
946+ "hub_content_document" : {
947+ "GatedBucket" : False ,
948+ "RecipeCollection" : [
949+ {
950+ "CustomizationTechnique" : "SFT" ,
951+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-4k.json" ,
952+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-4k.json" ,
953+ "Peft" : True ,
954+ "SequenceLength" : "4K" ,
955+ },
956+ {
957+ "CustomizationTechnique" : "SFT" ,
958+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-16k.json" ,
959+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-16k.json" ,
960+ "Peft" : True ,
961+ "SequenceLength" : "16K" ,
962+ },
963+ {
964+ "CustomizationTechnique" : "SFT" ,
965+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-128k.json" ,
966+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-128k.json" ,
967+ "Peft" : True ,
968+ "SequenceLength" : "128K" ,
969+ },
970+ ],
971+ },
972+ }
973+
974+ result = _get_fine_tuning_options_and_model_arn (
975+ "test-model" , "SFT" , "LORA" , mock_session , sequence_length = "8K"
976+ )
977+
978+ if result is not None :
979+ options , model_arn , is_gated_model = result
980+ # Should pick 16K (smallest >= 8K), not 128K
981+ mock_s3 .get_object .assert_called_once ()
982+ call_args = mock_s3 .get_object .call_args [1 ]
983+ assert "params-16k" in call_args ["Key" ]
984+
985+ @patch ("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata" )
986+ def test__get_fine_tuning_options_without_sequence_length_uses_first_recipe (
987+ self , mock_get_hub_content
988+ ):
989+ """Verify that when no sequence_length is provided, existing behavior is unchanged."""
990+ mock_session = Mock ()
991+ mock_session .boto_session .region_name = "us-east-1"
992+ mock_s3 = Mock ()
993+ mock_s3 .get_object .return_value = {
994+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 4096}}' ))
995+ }
996+ mock_session .boto_session .client .return_value = mock_s3
997+
998+ mock_get_hub_content .return_value = {
999+ "hub_content_arn" : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
1000+ "hub_content_document" : {
1001+ "GatedBucket" : False ,
1002+ "RecipeCollection" : [
1003+ {
1004+ "CustomizationTechnique" : "SFT" ,
1005+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-first.json" ,
1006+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-first.json" ,
1007+ "Peft" : True ,
1008+ "SequenceLength" : "4K" ,
1009+ },
1010+ {
1011+ "CustomizationTechnique" : "SFT" ,
1012+ "SmtjRecipeTemplateS3Uri" : "s3://bucket/template-second.json" ,
1013+ "SmtjOverrideParamsS3Uri" : "s3://bucket/params-second.json" ,
1014+ "Peft" : True ,
1015+ "SequenceLength" : "32K" ,
1016+ },
1017+ ],
1018+ },
1019+ }
1020+
1021+ result = _get_fine_tuning_options_and_model_arn ("test-model" , "SFT" , "LORA" , mock_session )
1022+
1023+ if result is not None :
1024+ options , model_arn , is_gated_model = result
1025+ # Without sequence_length, should pick the first matching recipe
1026+ mock_s3 .get_object .assert_called_once ()
1027+ call_args = mock_s3 .get_object .call_args [1 ]
1028+ assert "params-first" in call_args ["Key" ]
0 commit comments