@@ -393,6 +393,68 @@ def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_d
393393 call_kwargs = mock_get_kwargs .call_args
394394 self .assertEqual (call_kwargs .kwargs .get ("config_name" ) or call_kwargs [1 ].get ("config_name" ), "lmi-optimized" )
395395
396+ @patch ('sagemaker.core.jumpstart.factory.utils.get_init_kwargs' )
397+ @patch ('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart' )
398+ @patch ('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode' )
399+ def test_build_for_jumpstart_passes_hub_arn (self , mock_prepare , mock_build_djl , mock_get_kwargs ):
400+ """Test that hub_arn is forwarded to get_init_kwargs for private hub models."""
401+ mock_init_kwargs = Mock ()
402+ mock_init_kwargs .image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
403+ mock_init_kwargs .env = {}
404+ mock_get_kwargs .return_value = mock_init_kwargs
405+
406+ mock_model = Mock (spec = Model )
407+ mock_build_djl .return_value = mock_model
408+
409+ private_hub_arn = "arn:aws:sagemaker:us-west-2:052150106756:hub/MyPrivateHub2"
410+
411+ builder = ModelBuilder (
412+ model = "huggingface-vlm-qwen3-5-27b" ,
413+ role_arn = MOCK_ROLE_ARN ,
414+ sagemaker_session = self .mock_session ,
415+ mode = Mode .SAGEMAKER_ENDPOINT
416+ )
417+ builder ._optimizing = False
418+ builder .hub_arn = private_hub_arn
419+
420+ builder ._build_for_jumpstart ()
421+
422+ mock_get_kwargs .assert_called_once ()
423+ call_kwargs = mock_get_kwargs .call_args
424+ self .assertEqual (
425+ call_kwargs .kwargs .get ("hub_arn" ) or call_kwargs [1 ].get ("hub_arn" ),
426+ private_hub_arn ,
427+ )
428+
429+ @patch ('sagemaker.core.jumpstart.factory.utils.get_init_kwargs' )
430+ @patch ('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart' )
431+ @patch ('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode' )
432+ def test_build_for_jumpstart_hub_arn_none_for_public_hub (self , mock_prepare , mock_build_djl , mock_get_kwargs ):
433+ """Test that hub_arn is None for public hub models (no regression)."""
434+ mock_init_kwargs = Mock ()
435+ mock_init_kwargs .image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
436+ mock_init_kwargs .env = {}
437+ mock_get_kwargs .return_value = mock_init_kwargs
438+
439+ mock_model = Mock (spec = Model )
440+ mock_build_djl .return_value = mock_model
441+
442+ builder = ModelBuilder (
443+ model = "meta-textgeneration-llama-3-3-70b-instruct" ,
444+ role_arn = MOCK_ROLE_ARN ,
445+ sagemaker_session = self .mock_session ,
446+ mode = Mode .SAGEMAKER_ENDPOINT
447+ )
448+ builder ._optimizing = False
449+
450+ builder ._build_for_jumpstart ()
451+
452+ mock_get_kwargs .assert_called_once ()
453+ call_kwargs = mock_get_kwargs .call_args
454+ self .assertIsNone (
455+ call_kwargs .kwargs .get ("hub_arn" ) or call_kwargs [1 ].get ("hub_arn" ),
456+ )
457+
396458 @patch ('sagemaker.core.jumpstart.factory.utils.get_init_kwargs' )
397459 @patch ('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart' )
398460 @patch ('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode' )
0 commit comments