Skip to content

Commit 6f73fbf

Browse files
committed
Update accept_eula to respect user setup
1 parent ee420cc commit 6f73fbf

2 files changed

Lines changed: 79 additions & 1 deletion

File tree

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2373,6 +2373,13 @@ def _build_single_modelbuilder(
23732373
"HostingArtifactUri not found in JumpStart hub metadata. "
23742374
"Cannot deploy LORA adapter without base model artifacts."
23752375
)
2376+
accept_eula = getattr(self, "accept_eula", None)
2377+
if not accept_eula:
2378+
raise ValueError(
2379+
"accept_eula must be set to True to deploy this model. "
2380+
"Please set accept_eula=True on the ModelBuilder instance to confirm "
2381+
"you have read and accepted the end-user license agreement for this model."
2382+
)
23762383
container_def = ContainerDefinition(
23772384
image=self.image_uri,
23782385
environment=self.env_vars,
@@ -2381,7 +2388,7 @@ def _build_single_modelbuilder(
23812388
"s3_uri": hosting_artifact_uri,
23822389
"s3_data_type": "S3Prefix",
23832390
"compression_type": "None",
2384-
"model_access_config": {"accept_eula": True},
2391+
"model_access_config": {"accept_eula": accept_eula},
23852392
}
23862393
},
23872394
)

sagemaker-serve/tests/unit/test_model_builder.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,74 @@ def test_deploy_passes_inference_config_to_model_customization(self):
715715
call_kwargs = mock_deploy_mc.call_args[1]
716716
self.assertEqual(call_kwargs['inference_config'], inference_config)
717717
self.assertEqual(result, mock_endpoint)
718+
719+
720+
class TestLoraAcceptEula(unittest.TestCase):
721+
"""Tests for accept_eula handling in the LoRA deployment path."""
722+
723+
def _make_mb(self, accept_eula=None):
724+
mb = ModelBuilder.__new__(ModelBuilder)
725+
mb.accept_eula = accept_eula
726+
mb.image_uri = "some-image-uri"
727+
mb.env_vars = {}
728+
mb.model_name = None
729+
mb.role_arn = "arn:aws:iam::123456789012:role/role"
730+
mb.model = MagicMock()
731+
mb._adapter_s3_uri = None
732+
return mb
733+
734+
def _patch_lora_deps(self, mb, hosting_uri="s3://bucket/hosting/"):
735+
"""Patch all dependencies needed to reach the LoRA ContainerDefinition block."""
736+
patches = [
737+
patch.object(mb, "_fetch_peft", return_value="LORA"),
738+
patch.object(mb, "_fetch_hub_document_for_custom_model",
739+
return_value={"HostingArtifactUri": hosting_uri}),
740+
patch.object(mb, "_get_model_package_for_training_job",
741+
return_value=MagicMock()),
742+
]
743+
return patches
744+
745+
def test_lora_build_raises_when_accept_eula_false(self):
746+
mb = self._make_mb(accept_eula=False)
747+
patches = self._patch_lora_deps(mb)
748+
for p in patches:
749+
p.start()
750+
try:
751+
with self.assertRaises(ValueError) as ctx:
752+
mb._build_single_modelbuilder()
753+
self.assertIn("accept_eula", str(ctx.exception))
754+
finally:
755+
for p in patches:
756+
p.stop()
757+
758+
def test_lora_build_raises_when_accept_eula_not_set(self):
759+
mb = self._make_mb(accept_eula=None)
760+
patches = self._patch_lora_deps(mb)
761+
for p in patches:
762+
p.start()
763+
try:
764+
with self.assertRaises(ValueError) as ctx:
765+
mb._build_single_modelbuilder()
766+
self.assertIn("accept_eula", str(ctx.exception))
767+
finally:
768+
for p in patches:
769+
p.stop()
770+
771+
@patch("sagemaker.serve.model_builder.ContainerDefinition")
772+
@patch("sagemaker.serve.model_builder.Model")
773+
def test_lora_build_passes_accept_eula_true(self, mock_model, mock_container_def):
774+
mb = self._make_mb(accept_eula=True)
775+
mock_model.create.return_value = MagicMock()
776+
patches = self._patch_lora_deps(mb)
777+
for p in patches:
778+
p.start()
779+
try:
780+
mb._build_single_modelbuilder()
781+
call_kwargs = mock_container_def.call_args[1]
782+
eula_val = (
783+
call_kwargs["model_data_source"]["s3_data_source"]["model_access_config"]["accept_eula"]
784+
)
785+
self.assertTrue(eula_val)
786+
finally:
787+
for p in patches:
788+
p.stop()

0 commit comments

Comments
 (0)