Skip to content

Commit 9fc50b5

Browse files
Fix VolumeSizeInGB missing from v3 deploy for models with inference_v… (#5847)
* Fix VolumeSizeInGB missing from v3 deploy for models with inference_volume_size from_jumpstart_config() retrieves deploy kwargs (which includes volume_size) but never persisted it on the ModelBuilder instance. Then _deploy_core_endpoint read volume_size from kwargs with no self.* fallback, so it was always None. Two-line fix in model_builder.py: 1. from_jumpstart_config(): persist volume_size from deploy_kwargs (same as model_data_download_timeout and container_startup_health_check_timeout) 2. _deploy_core_endpoint(): fall back to self.volume_size when not in kwargs Same pattern as EnableNetworkIsolation fix in v3.10.1. * Fix integ test: add accept_eula=True for Llama model --------- Co-authored-by: Liam Neal Reilly <lnealrei@amazon.com>
1 parent 1572b32 commit 9fc50b5

3 files changed

Lines changed: 181 additions & 6 deletions

File tree

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2774,7 +2774,7 @@ def _deploy_core_endpoint(self, **kwargs):
27742774
self._deserializer = deserializer
27752775

27762776
data_capture_config = kwargs.get("data_capture_config", None)
2777-
volume_size = kwargs.get("volume_size", None)
2777+
volume_size = kwargs.get("volume_size", getattr(self, "volume_size", None))
27782778
inference_recommendation_id = kwargs.get("inference_recommendation_id", None)
27792779
explainer_config = kwargs.get("explainer_config", None)
27802780
endpoint_logging = kwargs.get("endpoint_logging", False)
@@ -3620,6 +3620,7 @@ def from_jumpstart_config(
36203620
"container_startup_health_check_timeout"
36213621
)
36223622
mb_instance.inference_ami_version = deploy_kwargs.get("inference_ami_version")
3623+
mb_instance.volume_size = deploy_kwargs.get("volume_size")
36233624

36243625
# Apply network isolation from JumpStart model spec if not set by user via network param
36253626
if not mb_instance._enable_network_isolation and deploy_kwargs.get(

sagemaker-serve/tests/integ/test_jumpstart_network_isolation.py renamed to sagemaker-serve/tests/integ/test_jumpstart_deploy_parity.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,53 @@ def test_jumpstart_build_enables_network_isolation():
6565
finally:
6666
core_model.delete()
6767
logger.info("Model deleted.")
68+
69+
70+
VOLUME_SIZE_MODEL_ID = "meta-textgenerationneuron-llama-2-7b"
71+
VOLUME_SIZE_INSTANCE_TYPE = "ml.inf2.xlarge"
72+
73+
74+
@pytest.mark.slow_test
75+
def test_jumpstart_build_sets_volume_size():
76+
"""Integration test verifying volume_size from model specs is propagated.
77+
78+
JumpStart model specs define inference_volume_size for models that need
79+
large EBS volumes for model weights. This test validates that ModelBuilder
80+
propagates volume_size through both from_jumpstart_config() and build() paths,
81+
matching v2 behavior where VolumeSizeInGB appears in CreateEndpointConfig.
82+
"""
83+
logger.info("Starting JumpStart volume_size integration test...")
84+
85+
# Test from_jumpstart_config path
86+
compute = Compute(instance_type=VOLUME_SIZE_INSTANCE_TYPE)
87+
jumpstart_config = JumpStartConfig(model_id=VOLUME_SIZE_MODEL_ID, accept_eula=True)
88+
model_builder = ModelBuilder.from_jumpstart_config(
89+
jumpstart_config=jumpstart_config, compute=compute
90+
)
91+
92+
assert getattr(model_builder, "volume_size", None) is not None, (
93+
f"ModelBuilder.volume_size should be set after from_jumpstart_config() "
94+
f"for model {VOLUME_SIZE_MODEL_ID} on {VOLUME_SIZE_INSTANCE_TYPE}, got None"
95+
)
96+
logger.info(f"from_jumpstart_config set volume_size={model_builder.volume_size}")
97+
98+
# Test build path (also sets volume_size via _build_for_jumpstart)
99+
unique_id = str(uuid.uuid4())[:8]
100+
core_model = model_builder.build(model_name=f"js-volsize-test-{unique_id}")
101+
logger.info(f"Model created: {core_model.model_name}")
102+
103+
try:
104+
assert getattr(model_builder, "volume_size", None) is not None, (
105+
f"ModelBuilder.volume_size should persist after build() "
106+
f"for model {VOLUME_SIZE_MODEL_ID}, got None"
107+
)
108+
assert model_builder.volume_size >= 256, (
109+
f"volume_size should be >= 256, "
110+
f"got {model_builder.volume_size}"
111+
)
112+
logger.info(
113+
f"✅ volume_size={model_builder.volume_size} correctly set"
114+
)
115+
finally:
116+
core_model.delete()
117+
logger.info("Model deleted.")

sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from sagemaker.serve.mode.function_pointers import Mode
1313
from sagemaker.serve.utils.types import ModelServer
1414
from sagemaker.core.training.configs import Compute, Networking
15+
from sagemaker.core.jumpstart.configs import JumpStartConfig
16+
from sagemaker.core.inference_config import AsyncInferenceConfig
17+
from botocore.exceptions import ClientError
1518

1619

1720
class TestModelBuilderInit(unittest.TestCase):
@@ -189,7 +192,6 @@ class TestBuildDefaultAsyncInferenceConfig(unittest.TestCase):
189192

190193
def test_build_default_async_config(self):
191194
"""Test building default async inference config."""
192-
from sagemaker.core.inference_config import AsyncInferenceConfig
193195

194196
mb = ModelBuilder(model=Mock())
195197
mb.model_name = "test-model"
@@ -256,7 +258,6 @@ def test_does_ic_exist_true(self):
256258

257259
def test_does_ic_exist_false(self):
258260
"""Test IC doesn't exist."""
259-
from botocore.exceptions import ClientError
260261

261262
mb = ModelBuilder(model=Mock())
262263
mb.sagemaker_session = Mock()
@@ -366,7 +367,6 @@ class TestFromJumpStartConfig(unittest.TestCase):
366367

367368
def test_from_jumpstart_config_basic(self):
368369
"""Test creating ModelBuilder from JumpStart config."""
369-
from sagemaker.core.jumpstart.configs import JumpStartConfig
370370

371371
js_config = JumpStartConfig(
372372
model_id="test-model",
@@ -384,8 +384,6 @@ def test_from_jumpstart_config_basic(self):
384384
@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
385385
def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwargs):
386386
"""Test that enable_network_isolation from deploy kwargs is applied."""
387-
from sagemaker.core.jumpstart.configs import JumpStartConfig
388-
from sagemaker.core.training.configs import Compute
389387

390388
mock_deploy_kwargs.return_value = {
391389
"model_data_download_timeout": 600,
@@ -409,6 +407,132 @@ def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwarg
409407

410408
self.assertTrue(mb._enable_network_isolation)
411409

410+
@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
411+
def test_from_jumpstart_config_applies_volume_size(self, mock_deploy_kwargs):
412+
"""Test that volume_size from deploy kwargs is applied."""
413+
414+
mock_deploy_kwargs.return_value = {
415+
"model_data_download_timeout": 600,
416+
"volume_size": 256,
417+
}
418+
419+
js_config = JumpStartConfig(
420+
model_id="meta-textgenerationneuron-llama-2-7b",
421+
model_version="1.0.0"
422+
)
423+
424+
mock_session = Mock()
425+
mock_session.boto_region_name = "us-west-2"
426+
427+
mb = ModelBuilder.from_jumpstart_config(
428+
jumpstart_config=js_config,
429+
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
430+
compute=Compute(instance_type="ml.inf2.xlarge"),
431+
sagemaker_session=mock_session,
432+
)
433+
434+
self.assertEqual(mb.volume_size, 256)
435+
436+
@patch("sagemaker.serve.model_builder.Endpoint.get")
437+
@patch("sagemaker.serve.model_builder.session_helper.production_variant")
438+
@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
439+
def test_deploy_passes_volume_size_to_production_variant(
440+
self, mock_deploy_kwargs, mock_prod_variant, mock_endpoint_get
441+
):
442+
"""Test that volume_size kwarg passed to deploy() reaches production_variant."""
443+
444+
mock_deploy_kwargs.return_value = {"volume_size": 256}
445+
mock_prod_variant.return_value = {"VariantName": "AllTraffic"}
446+
mock_endpoint_get.return_value = Mock()
447+
448+
js_config = JumpStartConfig(
449+
model_id="meta-textgenerationneuron-llama-2-7b",
450+
model_version="1.0.0",
451+
)
452+
453+
mock_session = Mock()
454+
mock_session.boto_region_name = "us-west-2"
455+
mock_session.endpoint_in_service_or_not = Mock(return_value=False)
456+
mock_session.endpoint_from_production_variants = Mock()
457+
mock_session.sagemaker_config = {}
458+
mock_session.settings = Mock()
459+
mock_session.settings.include_jumpstart_tags = False
460+
mock_session._append_sagemaker_config_tags = Mock(return_value=[])
461+
462+
mb = ModelBuilder.from_jumpstart_config(
463+
jumpstart_config=js_config,
464+
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
465+
compute=Compute(instance_type="ml.inf2.xlarge"),
466+
sagemaker_session=mock_session,
467+
)
468+
mb.built_model = Mock()
469+
mb.built_model.model_name = "test-model"
470+
mb.model_server = None
471+
mb.mode = Mode.SAGEMAKER_ENDPOINT
472+
473+
# Deploy with explicit volume_size=512 overriding spec's 256
474+
mb.deploy(
475+
endpoint_name="test-ep",
476+
instance_type="ml.inf2.xlarge",
477+
initial_instance_count=1,
478+
volume_size=512,
479+
wait=False,
480+
)
481+
482+
# Verify production_variant was called with user's 512, not spec's 256
483+
mock_prod_variant.assert_called_once()
484+
call_kwargs = mock_prod_variant.call_args[1]
485+
self.assertEqual(call_kwargs["volume_size"], 512)
486+
487+
@patch("sagemaker.serve.model_builder.Endpoint.get")
488+
@patch("sagemaker.serve.model_builder.session_helper.production_variant")
489+
@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
490+
def test_deploy_uses_spec_volume_size_when_not_passed(
491+
self, mock_deploy_kwargs, mock_prod_variant, mock_endpoint_get
492+
):
493+
"""Test that volume_size from spec is used when customer doesn't pass it."""
494+
495+
mock_deploy_kwargs.return_value = {"volume_size": 256}
496+
mock_prod_variant.return_value = {"VariantName": "AllTraffic"}
497+
mock_endpoint_get.return_value = Mock()
498+
499+
js_config = JumpStartConfig(
500+
model_id="meta-textgenerationneuron-llama-2-7b",
501+
model_version="1.0.0",
502+
)
503+
504+
mock_session = Mock()
505+
mock_session.boto_region_name = "us-west-2"
506+
mock_session.endpoint_in_service_or_not = Mock(return_value=False)
507+
mock_session.endpoint_from_production_variants = Mock()
508+
mock_session.sagemaker_config = {}
509+
mock_session.settings = Mock()
510+
mock_session.settings.include_jumpstart_tags = False
511+
mock_session._append_sagemaker_config_tags = Mock(return_value=[])
512+
513+
mb = ModelBuilder.from_jumpstart_config(
514+
jumpstart_config=js_config,
515+
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
516+
compute=Compute(instance_type="ml.inf2.xlarge"),
517+
sagemaker_session=mock_session,
518+
)
519+
mb.built_model = Mock()
520+
mb.built_model.model_name = "test-model"
521+
mb.model_server = None
522+
mb.mode = Mode.SAGEMAKER_ENDPOINT
523+
524+
# Deploy WITHOUT passing volume_size — should use spec's 256
525+
mb.deploy(
526+
endpoint_name="test-ep",
527+
instance_type="ml.inf2.xlarge",
528+
initial_instance_count=1,
529+
wait=False,
530+
)
531+
532+
mock_prod_variant.assert_called_once()
533+
call_kwargs = mock_prod_variant.call_args[1]
534+
self.assertEqual(call_kwargs["volume_size"], 256)
535+
412536

413537
if __name__ == "__main__":
414538
unittest.main()

0 commit comments

Comments
 (0)