3030 _validate_s3_path_exists ,
3131 _parse_context_length
3232)
33- from sagemaker .core .resources import ModelPackage
34- from sagemaker .core .utils .utils import Unassigned , ModelPackageGroup
33+ from sagemaker .core .resources import ModelPackage , ModelPackageGroup
34+ from sagemaker .core .utils .utils import Unassigned
3535from sagemaker .ai_registry .dataset import DataSet
3636from sagemaker .train .common import TrainingType
3737from sagemaker .train .configs import InputData
@@ -462,7 +462,6 @@ def test__convert_input_data_to_channels(self):
462462 def test__validate_eula_for_gated_model_with_model_package (self ):
463463 """Test EULA validation returns True for ModelPackage input"""
464464 from sagemaker .core .resources import ModelPackage
465- from sagemaker .core .utils .utils import Unassigned
466465 model_package = Mock (spec = ModelPackage )
467466
468467 result = _validate_eula_for_gated_model (model_package , False , True )
@@ -725,10 +724,14 @@ def test__parse_context_length_with_empty(self):
725724 assert _parse_context_length ("" ) == 0
726725
727726 @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
728- @patch ('boto3.client' )
729- def test__get_fine_tuning_options_filters_by_sequence_length (self , mock_boto_client , mock_get_hub_content ):
727+ def test__get_fine_tuning_options_filters_by_sequence_length (self , mock_get_hub_content ):
730728 mock_session = Mock ()
731729 mock_session .boto_session .region_name = "us-east-1"
730+ mock_s3 = Mock ()
731+ mock_s3 .get_object .return_value = {
732+ "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 32768}}' ))
733+ }
734+ mock_session .boto_session .client .return_value = mock_s3
732735
733736 mock_get_hub_content .return_value = {
734737 'hub_content_arn' : "arn:aws:sagemaker:us-east-1:123456789012:model/test-model" ,
@@ -753,19 +756,13 @@ def test__get_fine_tuning_options_filters_by_sequence_length(self, mock_boto_cli
753756 }
754757 }
755758
756- mock_s3_client = Mock ()
757- mock_boto_client .return_value = mock_s3_client
758- mock_s3_client .get_object .return_value = {
759- "Body" : Mock (read = Mock (return_value = b'{"max_length": {"default": 32768}}' ))
760- }
761-
762759 result = _get_fine_tuning_options_and_model_arn ("test-model" , "SFT" , "LORA" , mock_session , sequence_length = "8K" )
763760
764761 if result is not None :
765762 options , model_arn , is_gated_model = result
766763 # Should pick the 32K recipe (smallest >= 8K)
767- mock_s3_client .get_object .assert_called_once ()
768- call_args = mock_s3_client .get_object .call_args [1 ]
764+ mock_s3 .get_object .assert_called_once ()
765+ call_args = mock_s3 .get_object .call_args [1 ]
769766 assert "params-32k" in call_args ["Key" ]
770767
771768 @patch ('sagemaker.train.common_utils.finetune_utils._get_hub_content_metadata' )
0 commit comments