1212from sagemaker .serve .mode .function_pointers import Mode
1313from sagemaker .serve .utils .types import ModelServer
1414from sagemaker .core .training .configs import Compute , Networking
15+ from sagemaker .core .jumpstart .configs import JumpStartConfig
16+ from sagemaker .core .inference_config import AsyncInferenceConfig
17+ from botocore .exceptions import ClientError
1518
1619
1720class TestModelBuilderInit (unittest .TestCase ):
@@ -189,7 +192,6 @@ class TestBuildDefaultAsyncInferenceConfig(unittest.TestCase):
189192
190193 def test_build_default_async_config (self ):
191194 """Test building default async inference config."""
192- from sagemaker .core .inference_config import AsyncInferenceConfig
193195
194196 mb = ModelBuilder (model = Mock ())
195197 mb .model_name = "test-model"
@@ -256,7 +258,6 @@ def test_does_ic_exist_true(self):
256258
257259 def test_does_ic_exist_false (self ):
258260 """Test IC doesn't exist."""
259- from botocore .exceptions import ClientError
260261
261262 mb = ModelBuilder (model = Mock ())
262263 mb .sagemaker_session = Mock ()
@@ -366,7 +367,6 @@ class TestFromJumpStartConfig(unittest.TestCase):
366367
367368 def test_from_jumpstart_config_basic (self ):
368369 """Test creating ModelBuilder from JumpStart config."""
369- from sagemaker .core .jumpstart .configs import JumpStartConfig
370370
371371 js_config = JumpStartConfig (
372372 model_id = "test-model" ,
@@ -384,8 +384,6 @@ def test_from_jumpstart_config_basic(self):
384384 @patch ("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs" )
385385 def test_from_jumpstart_config_applies_network_isolation (self , mock_deploy_kwargs ):
386386 """Test that enable_network_isolation from deploy kwargs is applied."""
387- from sagemaker .core .jumpstart .configs import JumpStartConfig
388- from sagemaker .core .training .configs import Compute
389387
390388 mock_deploy_kwargs .return_value = {
391389 "model_data_download_timeout" : 600 ,
@@ -409,6 +407,132 @@ def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwarg
409407
410408 self .assertTrue (mb ._enable_network_isolation )
411409
410+ @patch ("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs" )
411+ def test_from_jumpstart_config_applies_volume_size (self , mock_deploy_kwargs ):
412+ """Test that volume_size from deploy kwargs is applied."""
413+
414+ mock_deploy_kwargs .return_value = {
415+ "model_data_download_timeout" : 600 ,
416+ "volume_size" : 256 ,
417+ }
418+
419+ js_config = JumpStartConfig (
420+ model_id = "meta-textgenerationneuron-llama-2-7b" ,
421+ model_version = "1.0.0"
422+ )
423+
424+ mock_session = Mock ()
425+ mock_session .boto_region_name = "us-west-2"
426+
427+ mb = ModelBuilder .from_jumpstart_config (
428+ jumpstart_config = js_config ,
429+ role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" ,
430+ compute = Compute (instance_type = "ml.inf2.xlarge" ),
431+ sagemaker_session = mock_session ,
432+ )
433+
434+ self .assertEqual (mb .volume_size , 256 )
435+
436+ @patch ("sagemaker.serve.model_builder.Endpoint.get" )
437+ @patch ("sagemaker.serve.model_builder.session_helper.production_variant" )
438+ @patch ("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs" )
439+ def test_deploy_passes_volume_size_to_production_variant (
440+ self , mock_deploy_kwargs , mock_prod_variant , mock_endpoint_get
441+ ):
442+ """Test that volume_size kwarg passed to deploy() reaches production_variant."""
443+
444+ mock_deploy_kwargs .return_value = {"volume_size" : 256 }
445+ mock_prod_variant .return_value = {"VariantName" : "AllTraffic" }
446+ mock_endpoint_get .return_value = Mock ()
447+
448+ js_config = JumpStartConfig (
449+ model_id = "meta-textgenerationneuron-llama-2-7b" ,
450+ model_version = "1.0.0" ,
451+ )
452+
453+ mock_session = Mock ()
454+ mock_session .boto_region_name = "us-west-2"
455+ mock_session .endpoint_in_service_or_not = Mock (return_value = False )
456+ mock_session .endpoint_from_production_variants = Mock ()
457+ mock_session .sagemaker_config = {}
458+ mock_session .settings = Mock ()
459+ mock_session .settings .include_jumpstart_tags = False
460+ mock_session ._append_sagemaker_config_tags = Mock (return_value = [])
461+
462+ mb = ModelBuilder .from_jumpstart_config (
463+ jumpstart_config = js_config ,
464+ role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" ,
465+ compute = Compute (instance_type = "ml.inf2.xlarge" ),
466+ sagemaker_session = mock_session ,
467+ )
468+ mb .built_model = Mock ()
469+ mb .built_model .model_name = "test-model"
470+ mb .model_server = None
471+ mb .mode = Mode .SAGEMAKER_ENDPOINT
472+
473+ # Deploy with explicit volume_size=512 overriding spec's 256
474+ mb .deploy (
475+ endpoint_name = "test-ep" ,
476+ instance_type = "ml.inf2.xlarge" ,
477+ initial_instance_count = 1 ,
478+ volume_size = 512 ,
479+ wait = False ,
480+ )
481+
482+ # Verify production_variant was called with user's 512, not spec's 256
483+ mock_prod_variant .assert_called_once ()
484+ call_kwargs = mock_prod_variant .call_args [1 ]
485+ self .assertEqual (call_kwargs ["volume_size" ], 512 )
486+
487+ @patch ("sagemaker.serve.model_builder.Endpoint.get" )
488+ @patch ("sagemaker.serve.model_builder.session_helper.production_variant" )
489+ @patch ("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs" )
490+ def test_deploy_uses_spec_volume_size_when_not_passed (
491+ self , mock_deploy_kwargs , mock_prod_variant , mock_endpoint_get
492+ ):
493+ """Test that volume_size from spec is used when customer doesn't pass it."""
494+
495+ mock_deploy_kwargs .return_value = {"volume_size" : 256 }
496+ mock_prod_variant .return_value = {"VariantName" : "AllTraffic" }
497+ mock_endpoint_get .return_value = Mock ()
498+
499+ js_config = JumpStartConfig (
500+ model_id = "meta-textgenerationneuron-llama-2-7b" ,
501+ model_version = "1.0.0" ,
502+ )
503+
504+ mock_session = Mock ()
505+ mock_session .boto_region_name = "us-west-2"
506+ mock_session .endpoint_in_service_or_not = Mock (return_value = False )
507+ mock_session .endpoint_from_production_variants = Mock ()
508+ mock_session .sagemaker_config = {}
509+ mock_session .settings = Mock ()
510+ mock_session .settings .include_jumpstart_tags = False
511+ mock_session ._append_sagemaker_config_tags = Mock (return_value = [])
512+
513+ mb = ModelBuilder .from_jumpstart_config (
514+ jumpstart_config = js_config ,
515+ role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" ,
516+ compute = Compute (instance_type = "ml.inf2.xlarge" ),
517+ sagemaker_session = mock_session ,
518+ )
519+ mb .built_model = Mock ()
520+ mb .built_model .model_name = "test-model"
521+ mb .model_server = None
522+ mb .mode = Mode .SAGEMAKER_ENDPOINT
523+
524+ # Deploy WITHOUT passing volume_size — should use spec's 256
525+ mb .deploy (
526+ endpoint_name = "test-ep" ,
527+ instance_type = "ml.inf2.xlarge" ,
528+ initial_instance_count = 1 ,
529+ wait = False ,
530+ )
531+
532+ mock_prod_variant .assert_called_once ()
533+ call_kwargs = mock_prod_variant .call_args [1 ]
534+ self .assertEqual (call_kwargs ["volume_size" ], 256 )
535+
412536
413537if __name__ == "__main__" :
414538 unittest .main ()
0 commit comments