Skip to content

Commit 1c6f0ef

Browse files
committed
fix: correct test imports and mock setup for sequence_length tests
1 parent 37a5996 commit 1c6f0ef

1 file changed

Lines changed: 10 additions & 13 deletions

File tree

sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
_validate_s3_path_exists,
3131
_parse_context_length
3232
)
33-
from sagemaker.core.resources import ModelPackage
34-
from sagemaker.core.utils.utils import Unassigned, ModelPackageGroup
33+
from sagemaker.core.resources import ModelPackage, ModelPackageGroup
34+
from sagemaker.core.utils.utils import Unassigned
3535
from sagemaker.ai_registry.dataset import DataSet
3636
from sagemaker.train.common import TrainingType
3737
from sagemaker.train.configs import InputData
@@ -462,7 +462,6 @@ def test__convert_input_data_to_channels(self):
462462
def test__validate_eula_for_gated_model_with_model_package(self):
463463
"""Test EULA validation returns True for ModelPackage input"""
464464
from sagemaker.core.resources import ModelPackage
465-
from sagemaker.core.utils.utils import Unassigned
466465
model_package = Mock(spec=ModelPackage)
467466

468467
result = _validate_eula_for_gated_model(model_package, False, True)
@@ -725,10 +724,14 @@ def test__parse_context_length_with_empty(self):
725724
assert _parse_context_length("") == 0
726725

727726
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')
728-
@patch('boto3.client')
729-
def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_client, mock_get_hub_content):
727+
def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_get_hub_content):
730728
mock_session = Mock()
731729
mock_session.boto_session.region_name = "us-east-1"
730+
mock_s3 = Mock()
731+
mock_s3.get_object.return_value = {
732+
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}'))
733+
}
734+
mock_session.boto_session.client.return_value = mock_s3
732735

733736
mock_get_hub_content.return_value = {
734737
'hub_content_arn': "arn:aws:sagemaker:us-east-1:123456789012:model/test-model",
@@ -753,19 +756,13 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_cli
753756
}
754757
}
755758

756-
mock_s3_client = Mock()
757-
mock_boto_client.return_value = mock_s3_client
758-
mock_s3_client.get_object.return_value = {
759-
"Body": Mock(read=Mock(return_value=b'{"max_length": {"default": 32768}}'))
760-
}
761-
762759
result = _get_fine_tuning_options_and_model_arn("test-model", "SFT", "LORA", mock_session, sequence_length="8K")
763760

764761
if result is not None:
765762
options, model_arn, is_gated_model = result
766763
# Should pick the 32K recipe (smallest >= 8K)
767-
mock_s3_client.get_object.assert_called_once()
768-
call_args = mock_s3_client.get_object.call_args[1]
764+
mock_s3.get_object.assert_called_once()
765+
call_args = mock_s3.get_object.call_args[1]
769766
assert "params-32k" in call_args["Key"]
770767

771768
@patch('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata')

0 commit comments

Comments
 (0)