Skip to content

Commit 5ed6bbc

Browse files
committed
fix: pass hub_arn to get_init_kwargs in _build_for_jumpstart
The _build_for_jumpstart method was not forwarding hub_arn to get_init_kwargs, causing private hub models to fail during model creation. Without hub_arn, the SDK could not resolve the correct account ID in the hub content ARN, resulting in 'hub content not found' errors for private hubs. Public hub models are unaffected since hub_arn is None for them.
1 parent 6134e57 commit 5ed6bbc

2 files changed

Lines changed: 63 additions & 0 deletions

File tree

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,7 @@ def _build_for_jumpstart(self) -> Model:
867867
model_version=self.model_version or "*",
868868
region=self.region,
869869
instance_type=self.instance_type,
870+
hub_arn=self.hub_arn,
870871
tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None),
871872
tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None),
872873
config_name=getattr(self, 'config_name', None),

sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,68 @@ def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_d
393393
call_kwargs = mock_get_kwargs.call_args
394394
self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized")
395395

396+
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
397+
@patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart')
398+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
399+
def test_build_for_jumpstart_passes_hub_arn(self, mock_prepare, mock_build_djl, mock_get_kwargs):
400+
"""Test that hub_arn is forwarded to get_init_kwargs for private hub models."""
401+
mock_init_kwargs = Mock()
402+
mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
403+
mock_init_kwargs.env = {}
404+
mock_get_kwargs.return_value = mock_init_kwargs
405+
406+
mock_model = Mock(spec=Model)
407+
mock_build_djl.return_value = mock_model
408+
409+
private_hub_arn = "arn:aws:sagemaker:us-west-2:052150106756:hub/MyPrivateHub2"
410+
411+
builder = ModelBuilder(
412+
model="huggingface-vlm-qwen3-5-27b",
413+
role_arn=MOCK_ROLE_ARN,
414+
sagemaker_session=self.mock_session,
415+
mode=Mode.SAGEMAKER_ENDPOINT
416+
)
417+
builder._optimizing = False
418+
builder.hub_arn = private_hub_arn
419+
420+
builder._build_for_jumpstart()
421+
422+
mock_get_kwargs.assert_called_once()
423+
call_kwargs = mock_get_kwargs.call_args
424+
self.assertEqual(
425+
call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn"),
426+
private_hub_arn,
427+
)
428+
429+
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
430+
@patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart')
431+
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')
432+
def test_build_for_jumpstart_hub_arn_none_for_public_hub(self, mock_prepare, mock_build_djl, mock_get_kwargs):
433+
"""Test that hub_arn is None for public hub models (no regression)."""
434+
mock_init_kwargs = Mock()
435+
mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
436+
mock_init_kwargs.env = {}
437+
mock_get_kwargs.return_value = mock_init_kwargs
438+
439+
mock_model = Mock(spec=Model)
440+
mock_build_djl.return_value = mock_model
441+
442+
builder = ModelBuilder(
443+
model="meta-textgeneration-llama-3-3-70b-instruct",
444+
role_arn=MOCK_ROLE_ARN,
445+
sagemaker_session=self.mock_session,
446+
mode=Mode.SAGEMAKER_ENDPOINT
447+
)
448+
builder._optimizing = False
449+
450+
builder._build_for_jumpstart()
451+
452+
mock_get_kwargs.assert_called_once()
453+
call_kwargs = mock_get_kwargs.call_args
454+
self.assertIsNone(
455+
call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn"),
456+
)
457+
396458
@patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs')
397459
@patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart')
398460
@patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode')

0 commit comments

Comments
 (0)