Skip to content

Commit 0789d88

Browse files
committed
test: add unit tests for JumpStart network isolation fix
Tests both code paths: 1. _build_for_jumpstart() - verifies enable_network_isolation is applied from init_kwargs and that user-set values are not overridden 2. from_jumpstart_config() - verifies enable_network_isolation is extracted from deploy_kwargs
1 parent 8b84605 commit 0789d88

2 files changed

Lines changed: 89 additions & 0 deletions

File tree

sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,30 @@ def test_from_jumpstart_config_basic(self):
381381
self.assertEqual(mb.model, "test-model")
382382
self.assertEqual(mb.model_version, "1.0.0")
383383

384+
@patch("sagemaker.core.jumpstart.artifacts.kwargs._retrieve_model_deploy_kwargs")
385+
def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwargs):
386+
"""Test that enable_network_isolation from deploy kwargs is applied."""
387+
from sagemaker.core.jumpstart.configs import JumpStartConfig
388+
from sagemaker.core.training.configs import Compute
389+
390+
mock_deploy_kwargs.return_value = {
391+
"model_data_download_timeout": 600,
392+
"enable_network_isolation": True,
393+
}
394+
395+
js_config = JumpStartConfig(
396+
model_id="test-model",
397+
model_version="1.0.0"
398+
)
399+
400+
mb = ModelBuilder.from_jumpstart_config(
401+
jumpstart_config=js_config,
402+
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
403+
compute=Compute(instance_type="ml.g5.xlarge"),
404+
)
405+
406+
self.assertTrue(mb._enable_network_isolation)
407+
384408

385409
if __name__ == "__main__":
386410
unittest.main()

sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,71 @@ def test_build_for_jumpstart_routes_to_mms(self, mock_prepare, mock_create, mock
501501
mock_create.assert_called_once()
502502

503503

504+
@patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs")
505+
@patch("sagemaker.serve.model_builder.ModelBuilder._create_model")
506+
@patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode")
507+
def test_build_for_jumpstart_applies_network_isolation_from_spec(
508+
self, mock_prepare, mock_create, mock_get_kwargs
509+
):
510+
"""Test that enable_network_isolation from JumpStart model spec is applied."""
511+
mock_init_kwargs = Mock()
512+
mock_init_kwargs.image_uri = (
513+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
514+
)
515+
mock_init_kwargs.env = {}
516+
mock_init_kwargs.model_data = "s3://jumpstart-cache/models/model.tar.gz"
517+
mock_init_kwargs.enable_network_isolation = True
518+
mock_get_kwargs.return_value = mock_init_kwargs
519+
520+
mock_model = Mock(spec=Model)
521+
mock_create.return_value = mock_model
522+
523+
builder = ModelBuilder(
524+
model="meta-textgeneration-llama-3-8b",
525+
role_arn=MOCK_ROLE_ARN,
526+
sagemaker_session=self.mock_session,
527+
mode=Mode.SAGEMAKER_ENDPOINT,
528+
)
529+
builder._optimizing = False
530+
531+
builder._build_for_jumpstart()
532+
533+
self.assertTrue(builder._enable_network_isolation)
534+
535+
@patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs")
536+
@patch("sagemaker.serve.model_builder.ModelBuilder._create_model")
537+
@patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode")
538+
def test_build_for_jumpstart_does_not_override_user_network_isolation(
539+
self, mock_prepare, mock_create, mock_get_kwargs
540+
):
541+
"""Test that user-set network isolation is not overridden by spec."""
542+
mock_init_kwargs = Mock()
543+
mock_init_kwargs.image_uri = (
544+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
545+
)
546+
mock_init_kwargs.env = {}
547+
mock_init_kwargs.model_data = "s3://jumpstart-cache/models/model.tar.gz"
548+
mock_init_kwargs.enable_network_isolation = False
549+
mock_get_kwargs.return_value = mock_init_kwargs
550+
551+
mock_model = Mock(spec=Model)
552+
mock_create.return_value = mock_model
553+
554+
builder = ModelBuilder(
555+
model="meta-textgeneration-llama-3-8b",
556+
role_arn=MOCK_ROLE_ARN,
557+
sagemaker_session=self.mock_session,
558+
mode=Mode.SAGEMAKER_ENDPOINT,
559+
)
560+
builder._optimizing = False
561+
builder._enable_network_isolation = True # User explicitly set
562+
563+
builder._build_for_jumpstart()
564+
565+
# User's True should not be overridden by spec's False
566+
self.assertTrue(builder._enable_network_isolation)
567+
568+
504569
class TestDeployWrappers(unittest.TestCase):
505570
"""Test deploy wrapper methods."""
506571

0 commit comments

Comments
 (0)