Skip to content

Commit 973b62d

Browse files
committed
test: add more unit tests for sequence_length coverage
- _parse_context_length with invalid K value and non-numeric string - sequence_length filter when no recipes have SequenceLength field - sequence_length filter with FULL training type - verify smallest sufficient sequence_length is selected - verify no sequence_length uses first recipe (backward compat)
1 parent 303fa79 commit 973b62d

1 file changed

Lines changed: 174 additions & 0 deletions

File tree

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,3 +852,177 @@ def test__get_fine_tuning_options_raises_when_no_sufficient_context_length(
852852
_get_fine_tuning_options_and_model_arn(
853853
"test-model", "SFT", "LORA", mock_session, sequence_length="128K"
854854
)
855+
856+
def test__parse_context_length_with_invalid_k_value(self):
857+
assert _parse_context_length("abcK") == 0
858+
859+
def test__parse_context_length_with_non_numeric_string(self):
860+
assert _parse_context_length("hello") == 0
861+
862+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
863+
def test__get_fine_tuning_options_raises_when_no_recipes_have_sequence_length(
864+
self, mock_get_hub_content
865+
):
866+
mock_session = Mock()
867+
mock_session.boto_session.region_name = "us-east-1"
868+
869+
mock_get_hub_content.return_value = {
870+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
871+
"hub_content_document": {
872+
"GatedBucket": False,
873+
"RecipeCollection": [
874+
{
875+
"CustomizationTechnique": "SFT",
876+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template.json",
877+
"SmtjOverrideParamsS3Uri": "s3://bucket/params.json",
878+
"Peft": True,
879+
}
880+
],
881+
},
882+
}
883+
884+
with pytest.raises(ValueError, match="and sequence length"):
885+
_get_fine_tuning_options_and_model_arn(
886+
"test-model", "SFT", "LORA", mock_session, sequence_length="8K"
887+
)
888+
889+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
890+
def test__get_fine_tuning_options_filters_by_sequence_length_full_training(
891+
self, mock_get_hub_content
892+
):
893+
mock_session = Mock()
894+
mock_session.boto_session.region_name = "us-east-1"
895+
mock_s3 = Mock()
896+
mock_s3.get_object.return_value = {
897+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 8192}}'))
898+
}
899+
mock_session.boto_session.client.return_value = mock_s3
900+
901+
mock_get_hub_content.return_value = {
902+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
903+
"hub_content_document": {
904+
"GatedBucket": False,
905+
"RecipeCollection": [
906+
{
907+
"CustomizationTechnique": "SFT",
908+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-8k.json",
909+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-8k.json",
910+
"SequenceLength": "8K",
911+
},
912+
{
913+
"CustomizationTechnique": "SFT",
914+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-32k.json",
915+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-32k.json",
916+
"SequenceLength": "32K",
917+
},
918+
],
919+
},
920+
}
921+
922+
result = _get_fine_tuning_options_and_model_arn(
923+
"test-model", "SFT", "FULL", mock_session, sequence_length="8K"
924+
)
925+
926+
if result is not None:
927+
options, model_arn, is_gated_model = result
928+
mock_s3.get_object.assert_called_once()
929+
call_args = mock_s3.get_object.call_args[1]
930+
assert "params-8k" in call_args["Key"]
931+
932+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
933+
def test__get_fine_tuning_options_selects_smallest_sufficient_sequence_length(
934+
self, mock_get_hub_content
935+
):
936+
mock_session = Mock()
937+
mock_session.boto_session.region_name = "us-east-1"
938+
mock_s3 = Mock()
939+
mock_s3.get_object.return_value = {
940+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 16384}}'))
941+
}
942+
mock_session.boto_session.client.return_value = mock_s3
943+
944+
mock_get_hub_content.return_value = {
945+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
946+
"hub_content_document": {
947+
"GatedBucket": False,
948+
"RecipeCollection": [
949+
{
950+
"CustomizationTechnique": "SFT",
951+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-4k.json",
952+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-4k.json",
953+
"Peft": True,
954+
"SequenceLength": "4K",
955+
},
956+
{
957+
"CustomizationTechnique": "SFT",
958+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-16k.json",
959+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-16k.json",
960+
"Peft": True,
961+
"SequenceLength": "16K",
962+
},
963+
{
964+
"CustomizationTechnique": "SFT",
965+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-128k.json",
966+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-128k.json",
967+
"Peft": True,
968+
"SequenceLength": "128K",
969+
},
970+
],
971+
},
972+
}
973+
974+
result = _get_fine_tuning_options_and_model_arn(
975+
"test-model", "SFT", "LORA", mock_session, sequence_length="8K"
976+
)
977+
978+
if result is not None:
979+
options, model_arn, is_gated_model = result
980+
# Should pick 16K (smallest >= 8K), not 128K
981+
mock_s3.get_object.assert_called_once()
982+
call_args = mock_s3.get_object.call_args[1]
983+
assert "params-16k" in call_args["Key"]
984+
985+
@patch("sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata")
986+
def test__get_fine_tuning_options_without_sequence_length_uses_first_recipe(
987+
self, mock_get_hub_content
988+
):
989+
"""Verify that when no sequence_length is provided, existing behavior is unchanged."""
990+
mock_session = Mock()
991+
mock_session.boto_session.region_name = "us-east-1"
992+
mock_s3 = Mock()
993+
mock_s3.get_object.return_value = {
994+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 4096}}'))
995+
}
996+
mock_session.boto_session.client.return_value = mock_s3
997+
998+
mock_get_hub_content.return_value = {
999+
"hub_content_arn": "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
1000+
"hub_content_document": {
1001+
"GatedBucket": False,
1002+
"RecipeCollection": [
1003+
{
1004+
"CustomizationTechnique": "SFT",
1005+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-first.json",
1006+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-first.json",
1007+
"Peft": True,
1008+
"SequenceLength": "4K",
1009+
},
1010+
{
1011+
"CustomizationTechnique": "SFT",
1012+
"SmtjRecipeTemplateS3Uri": "s3://bucket/template-second.json",
1013+
"SmtjOverrideParamsS3Uri": "s3://bucket/params-second.json",
1014+
"Peft": True,
1015+
"SequenceLength": "32K",
1016+
},
1017+
],
1018+
},
1019+
}
1020+
1021+
result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session)
1022+
1023+
if result is not None:
1024+
options, model_arn, is_gated_model = result
1025+
# Without sequence_length, should pick the first matching recipe
1026+
mock_s3.get_object.assert_called_once()
1027+
call_args = mock_s3.get_object.call_args[1]
1028+
assert "params-first" in call_args["Key"]

0 commit comments

Comments
 (0)