Skip to content

Commit eade044

Browse files
committed
test: add more unit tests for sequence_length coverage
- _parse_context_length with invalid K value and non-numeric string - sequence_length filter when no recipes have SequenceLength field - sequence_length filter with FULL training type - verify smallest sufficient sequence_length is selected - verify no sequence_length uses first recipe (backward compat)
1 parent 7c6b02f commit eade044

1 file changed

Lines changed: 271 additions & 0 deletions

File tree

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,274 @@ def test__get_fine_tuning_options_subscription_enabled_but_not_subscribed(self,
864864
# Should still have standard params, just not datamix ones
865865
assert "max_steps" in options._specs
866866
assert "customer_data_percent" not in options._specs
867+
868+
def test__create_serverless_config_with_sequence_length(self):
869+
config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True, sequence_length="8K")
870+
871+
assert config.sequence_length == "8K"
872+
assert config.base_model_arn == "model-arn"
873+
874+
def test__create_serverless_config_without_sequence_length(self):
875+
config = _create_serverless_config("model-arn", "SFT", TrainingType.LORA, accept_eula=True)
876+
877+
assert config.sequence_length is None
878+
879+
def test__parse_context_length_with_k_suffix(self):
880+
assert _parse_context_length("8K") == 8192
881+
assert _parse_context_length("32K") == 32768
882+
assert _parse_context_length("128K") == 131072
883+
884+
def test__parse_context_length_with_lowercase(self):
885+
assert _parse_context_length("8k") == 8192
886+
887+
def test__parse_context_length_with_integer(self):
888+
assert _parse_context_length("4096") == 4096
889+
890+
def test__parse_context_length_with_none(self):
891+
assert _parse_context_length(None) == 0
892+
893+
def test__parse_context_length_with_empty(self):
894+
assert _parse_context_length("") == 0
895+
896+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
897+
def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content):
898+
mock_session = Mock()
899+
mock_session.boto_session.region_name = "us-east-1"
900+
mock_s3 = Mock()
901+
mock_s3.get_object.return_value = {
902+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}'))
903+
}
904+
mock_session.boto_session.client.return_value = mock_s3
905+
906+
mock_get_hub_content.return_value = {
907+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
908+
'hub_content_document': {
909+
"GatedBucket": False,
910+
"RecipeCollection": [
911+
{
912+
"CustomizationTechnique": "SFT",
913+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json",
914+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json",
915+
"Peft": True,
916+
"SequenceLength": "4K"
917+
},
918+
{
919+
"CustomizationTechnique": "SFT",
920+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json",
921+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json",
922+
"Peft": True,
923+
"SequenceLength": "32K"
924+
}
925+
]
926+
}
927+
}
928+
929+
result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K")
930+
931+
if result is not None:
932+
options, model_arn, is_gated_model = result
933+
# Should pick the 32K recipe (smallest >= 8K)
934+
mock_s3.get_object.assert_called_once()
935+
call_args = mock_s3.get_object.call_args[1]
936+
assert "params-32k" in call_args["Key"]
937+
938+
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
939+
def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(self, mock_get_hub_content):
940+
mock_session = Mock()
941+
mock_session.boto_session.region_name = "us-east-1"
942+
943+
mock_get_hub_content.return_value = {
944+
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
945+
'hub_content_document': {
946+
"GatedBucket": False,
947+
"RecipeCollection": [
948+
{
949+
"CustomizationTechnique": "SFT",
950+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json",
951+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json",
952+
"Peft": True,
953+
"SequenceLength": "4K"
954+
}
955+
]
956+
}
957+
}
958+
959+
# Requesting 128K but only 4K available — should raise
960+
with pytest.raises(ValueError, match="No recipes found with SequenceLength >= 128K"):
961+
_get_fine_tuning_options_and_model_arn(
962+
"test-model", "SFT", "LORA", mock_session, sequence_length="128K"
963+
)
964+
965+
def test__parse_context_length_with_invalid_k_value(self):
966+
assert _parse_context_length("abcK") == 0
967+
968+
def test__parse_context_length_with_non_numeric_string(self):
969+
assert _parse_context_length("hello") == 0
970+
971+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
972+
def test__get_fine_tuning_options_raises_when_no_recipes_have_sequence_length(
973+
self, mock_get_hub_content
974+
):
975+
mock_session = Mock()
976+
mock_session.boto_session.region_name = "us-east-1"
977+
978+
mock_get_hub_content.return_value = {
979+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
980+
"hub_content_document": {
981+
"GatedBucket": False,
982+
"RecipeCollection": [
983+
{
984+
"CustomizationTechnique": "SFT",
985+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.json",
986+
"SmtjOverrideParamsS3Uri": "s3://bucket/params.json",
987+
"Peft": True,
988+
}
989+
],
990+
},
991+
}
992+
993+
with pytest.raises(ValueError, match="and sequence length"):
994+
_get_fine_tuning_options_and_model_arn(
995+
"test-model", "SFT", "LORA", mock_session, sequence_length="8K"
996+
)
997+
998+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
999+
def test__get_fine_tuning_options_filters_by_sequence_length_full_training(
1000+
self, mock_get_hub_content
1001+
):
1002+
mock_session = Mock()
1003+
mock_session.boto_session.region_name = "us-east-1"
1004+
mock_s3 = Mock()
1005+
mock_s3.get_object.return_value = {
1006+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 8192}}'))
1007+
}
1008+
mock_session.boto_session.client.return_value = mock_s3
1009+
1010+
mock_get_hub_content.return_value = {
1011+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
1012+
"hub_content_document": {
1013+
"GatedBucket": False,
1014+
"RecipeCollection": [
1015+
{
1016+
"CustomizationTechnique": "SFT",
1017+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-8k.json",
1018+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-8k.json",
1019+
"SequenceLength": "8K",
1020+
},
1021+
{
1022+
"CustomizationTechnique": "SFT",
1023+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json",
1024+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json",
1025+
"SequenceLength": "32K",
1026+
},
1027+
],
1028+
},
1029+
}
1030+
1031+
result = _get_fine_tuning_options_and_model_arn(
1032+
"test-model", "SFT", "FULL", mock_session, sequence_length="8K"
1033+
)
1034+
1035+
if result is not None:
1036+
options, model_arn, is_gated_model = result
1037+
mock_s3.get_object.assert_called_once()
1038+
call_args = mock_s3.get_object.call_args[1]
1039+
assert "params-8k" in call_args["Key"]
1040+
1041+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
1042+
def test__get_fine_tuning_options_selects_smallest_sufficient_sequence_length(
1043+
self, mock_get_hub_content
1044+
):
1045+
mock_session = Mock()
1046+
mock_session.boto_session.region_name = "us-east-1"
1047+
mock_s3 = Mock()
1048+
mock_s3.get_object.return_value = {
1049+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 16384}}'))
1050+
}
1051+
mock_session.boto_session.client.return_value = mock_s3
1052+
1053+
mock_get_hub_content.return_value = {
1054+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
1055+
"hub_content_document": {
1056+
"GatedBucket": False,
1057+
"RecipeCollection": [
1058+
{
1059+
"CustomizationTechnique": "SFT",
1060+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json",
1061+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json",
1062+
"Peft": True,
1063+
"SequenceLength": "4K",
1064+
},
1065+
{
1066+
"CustomizationTechnique": "SFT",
1067+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-16k.json",
1068+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-16k.json",
1069+
"Peft": True,
1070+
"SequenceLength": "16K",
1071+
},
1072+
{
1073+
"CustomizationTechnique": "SFT",
1074+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-128k.json",
1075+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-128k.json",
1076+
"Peft": True,
1077+
"SequenceLength": "128K",
1078+
},
1079+
],
1080+
},
1081+
}
1082+
1083+
result = _get_fine_tuning_options_and_model_arn(
1084+
"test-model", "SFT", "LORA", mock_session, sequence_length="8K"
1085+
)
1086+
1087+
if result is not None:
1088+
options, model_arn, is_gated_model = result
1089+
# Should pick 16K (smallest >= 8K), not 128K
1090+
mock_s3.get_object.assert_called_once()
1091+
call_args = mock_s3.get_object.call_args[1]
1092+
assert "params-16k" in call_args["Key"]
1093+
1094+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
1095+
def test__get_fine_tuning_options_without_sequence_length_uses_first_recipe(
1096+
self, mock_get_hub_content
1097+
):
1098+
"""Verify that when no sequence_length is provided, existing behavior is unchanged."""
1099+
mock_session = Mock()
1100+
mock_session.boto_session.region_name = "us-east-1"
1101+
mock_s3 = Mock()
1102+
mock_s3.get_object.return_value = {
1103+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 4096}}'))
1104+
}
1105+
mock_session.boto_session.client.return_value = mock_s3
1106+
1107+
mock_get_hub_content.return_value = {
1108+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
1109+
"hub_content_document": {
1110+
"GatedBucket": False,
1111+
"RecipeCollection": [
1112+
{
1113+
"CustomizationTechnique": "SFT",
1114+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-first.json",
1115+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-first.json",
1116+
"Peft": True,
1117+
"SequenceLength": "4K",
1118+
},
1119+
{
1120+
"CustomizationTechnique": "SFT",
1121+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-second.json",
1122+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-second.json",
1123+
"Peft": True,
1124+
"SequenceLength": "32K",
1125+
},
1126+
],
1127+
},
1128+
}
1129+
1130+
result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session)
1131+
1132+
if result is not None:
1133+
options, model_arn, is_gated_model = result
1134+
# Without sequence_length, should pick the first matching recipe
1135+
mock_s3.get_object.assert_called_once()
1136+
call_args = mock_s3.get_object.call_args[1]
1137+
assert "params-first" in call_args["Key"]

0 commit comments

Comments
 (0)