diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index 6619608395..6c689a7825 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -867,6 +867,7 @@ def _build_for_jumpstart(self) -> Model: model_version=self.model_version or "*", region=self.region, instance_type=self.instance_type, + hub_arn=self.hub_arn, tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None), config_name=getattr(self, 'config_name', None), diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py index 02b0962feb..7fbe6d23f3 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py @@ -393,6 +393,68 @@ def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_d call_kwargs = mock_get_kwargs.call_args self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized") + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + def test_build_for_jumpstart_passes_hub_arn(self, mock_prepare, mock_build_djl, mock_get_kwargs): + """Test that hub_arn is forwarded to get_init_kwargs for private hub models.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.env = {} + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock(spec=Model) + mock_build_djl.return_value = mock_model + + private_hub_arn = "arn:aws:sagemaker:us-west-2:052150106756:hub/MyPrivateHub2" + + builder = ModelBuilder( + model="huggingface-vlm-qwen3-5-27b", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT + ) + builder._optimizing = False + builder.hub_arn = private_hub_arn + + builder._build_for_jumpstart() + + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + self.assertEqual( + call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn"), + private_hub_arn, + ) + + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') + @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') + @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + def test_build_for_jumpstart_hub_arn_none_for_public_hub(self, mock_prepare, mock_build_djl, mock_get_kwargs): + """Test that hub_arn is None for public hub models (no regression).""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.env = {} + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock(spec=Model) + mock_build_djl.return_value = mock_model + + builder = ModelBuilder( + model="meta-textgeneration-llama-3-3-70b-instruct", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT + ) + builder._optimizing = False + + builder._build_for_jumpstart() + + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + self.assertIsNone( + call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn"), + ) + @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart') @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')