@@ -501,6 +501,71 @@ def test_build_for_jumpstart_routes_to_mms(self, mock_prepare, mock_create, mock
501501 mock_create .assert_called_once ()
502502
503503
504+ @patch ("sagemaker.core.jumpstart.factory.utils.get_init_kwargs" )
505+ @patch ("sagemaker.serve.model_builder.ModelBuilder._create_model" )
506+ @patch ("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode" )
507+ def test_build_for_jumpstart_applies_network_isolation_from_spec (
508+ self , mock_prepare , mock_create , mock_get_kwargs
509+ ):
510+ """Test that enable_network_isolation from JumpStart model spec is applied."""
511+ mock_init_kwargs = Mock ()
512+ mock_init_kwargs .image_uri = (
513+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
514+ )
515+ mock_init_kwargs .env = {}
516+ mock_init_kwargs .model_data = "s3://jumpstart-cache/models/model.tar.gz"
517+ mock_init_kwargs .enable_network_isolation = True
518+ mock_get_kwargs .return_value = mock_init_kwargs
519+
520+ mock_model = Mock (spec = Model )
521+ mock_create .return_value = mock_model
522+
523+ builder = ModelBuilder (
524+ model = "meta-textgeneration-llama-3-8b" ,
525+ role_arn = MOCK_ROLE_ARN ,
526+ sagemaker_session = self .mock_session ,
527+ mode = Mode .SAGEMAKER_ENDPOINT ,
528+ )
529+ builder ._optimizing = False
530+
531+ builder ._build_for_jumpstart ()
532+
533+ self .assertTrue (builder ._enable_network_isolation )
534+
535+ @patch ("sagemaker.core.jumpstart.factory.utils.get_init_kwargs" )
536+ @patch ("sagemaker.serve.model_builder.ModelBuilder._create_model" )
537+ @patch ("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode" )
538+ def test_build_for_jumpstart_does_not_override_user_network_isolation (
539+ self , mock_prepare , mock_create , mock_get_kwargs
540+ ):
541+ """Test that user-set network isolation is not overridden by spec."""
542+ mock_init_kwargs = Mock ()
543+ mock_init_kwargs .image_uri = (
544+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
545+ )
546+ mock_init_kwargs .env = {}
547+ mock_init_kwargs .model_data = "s3://jumpstart-cache/models/model.tar.gz"
548+ mock_init_kwargs .enable_network_isolation = False
549+ mock_get_kwargs .return_value = mock_init_kwargs
550+
551+ mock_model = Mock (spec = Model )
552+ mock_create .return_value = mock_model
553+
554+ builder = ModelBuilder (
555+ model = "meta-textgeneration-llama-3-8b" ,
556+ role_arn = MOCK_ROLE_ARN ,
557+ sagemaker_session = self .mock_session ,
558+ mode = Mode .SAGEMAKER_ENDPOINT ,
559+ )
560+ builder ._optimizing = False
561+ builder ._enable_network_isolation = True # User explicitly set
562+
563+ builder ._build_for_jumpstart ()
564+
565+ # User's True should not be overridden by spec's False
566+ self .assertTrue (builder ._enable_network_isolation )
567+
568+
504569class TestDeployWrappers (unittest .TestCase ):
505570 """Test deploy wrapper methods."""
506571
0 commit comments