From 5ed6bbc01fbaac27ebd88215a771504e80d0f4ed Mon Sep 17 00:00:00 2001 From: Varun Morishetty Date: Mon, 13 Apr 2026 16:18:20 +0000 Subject: [PATCH] fix: pass hub_arn to get_init_kwargs in _build_for_jumpstart The _build_for_jumpstart method was not forwarding hub_arn to get_init_kwargs, causing private hub models to fail during model creation. Without hub_arn, the SDK could not resolve the correct account ID in the hub content ARN, resulting in 'hub content not found' errors for private hubs. Public hub models are unaffected since hub_arn is None for them. --- .../sagemaker/serve/model_builder_servers.py | 1 + .../test_model_builder_servers_coverage.py | 62 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 6619608395..6c689a7825 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -867,6 +867,7 @@ def _build_for_jumpstart(self) -> Model: model_version=self.model_version or "*", region=self.region, instance_type=self.instance_type, + hub_arn=self.hub_arn, tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None), config_name=getattr(self, 'config_name', None), diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py index 02b0962feb..7fbe6d23f3 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py @@ -393,6 +393,68 @@ def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_d call_kwargs = mock_get_kwargs.call_args self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized") + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + def test_build_for_jumpstart_passes_hub_arn(self, mock_prepare, mock_build_djl, mock_get_kwargs): + """Test that hub_arn is forwarded to get_init_kwargs for private hub models.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.env = {} + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock(spec=Model) + mock_build_djl.return_value = mock_model + + private_hub_arn = "arn:aws:sagemaker:us-west-2:052150106756:hub/MyPrivateHub2" + + builder = ModelBuilder( + model="huggingface-vlm-qwen3-5-27b", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT + ) + builder._optimizing = False + builder.hub_arn = private_hub_arn + + builder._build_for_jumpstart() + + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + self.assertEqual( + call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn"), + private_hub_arn, + ) + + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + def test_build_for_jumpstart_hub_arn_none_for_public_hub(self, mock_prepare, mock_build_djl, mock_get_kwargs): + """Test that hub_arn is None for public hub models (no regression).""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.env = {} + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock(spec=Model) + mock_build_djl.return_value = mock_model + + builder = ModelBuilder( + model="meta-textgeneration-llama-3-3-70b-instruct", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT + ) + builder._optimizing = False + + builder._build_for_jumpstart() + + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + self.assertIsNone( + call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn"), + ) + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart') @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')