Skip to content

Commit f0a4618

Browse files
mujtaba1747mollyheamazonnargokul
authored
fix: respect accept_eula in ModelBuilder LoRA deployment path (#5705)
* Update accept_eula to respect user setup * Enable EULA acceptance in model customization tests Set accept_eula to True in model builder to fix tests * fix: add missing model_path attr in TestLoraAcceptEula To fix failing unit tests in PR: #5696 * fix(tests): fix TestLoraAcceptEula missing dataclass attrs and patches --------- Co-authored-by: Molly He <mollyhe@amazon.com> Co-authored-by: Gokul Anantha Narayanan <166456257+nargokul@users.noreply.github.com>
1 parent a29d22c commit f0a4618

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2373,6 +2373,13 @@ def _build_single_modelbuilder(
23732373
"HostingArtifactUri not found in JumpStart hub metadata. "
23742374
"Cannot deploy LORA adapter without base model artifacts."
23752375
)
2376+
accept_eula = getattr(self, "accept_eula", None)
2377+
if not accept_eula:
2378+
raise ValueError(
2379+
"accept_eula must be set to True to deploy this model. "
2380+
"Please set accept_eula=True on the ModelBuilder instance to confirm "
2381+
"you have read and accepted the end-user license agreement for this model."
2382+
)
23762383
container_def = ContainerDefinition(
23772384
image=self.image_uri,
23782385
environment=self.env_vars,
@@ -2381,7 +2388,7 @@ def _build_single_modelbuilder(
23812388
"s3_uri": hosting_artifact_uri,
23822389
"s3_data_type": "S3Prefix",
23832390
"compression_type": "None",
2384-
"model_access_config": {"accept_eula": True},
2391+
"model_access_config": {"accept_eula": accept_eula},
23852392
}
23862393
},
23872394
)

sagemaker-serve/tests/integ/test_model_customization_deployment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def test_build_from_training_job(self, training_job_name):
119119

120120
training_job = TrainingJob.get(training_job_name=training_job_name)
121121
model_builder = ModelBuilder(model=training_job)
122+
model_builder.accept_eula = True
122123
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
123124

124125
assert model is not None
@@ -139,6 +140,7 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
139140

140141
training_job = TrainingJob.get(training_job_name=training_job_name)
141142
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge")
143+
model_builder.accept_eula = True
142144
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
143145

144146
peft_type = model_builder._fetch_peft()
@@ -187,6 +189,7 @@ def test_build_from_model_package(self, model_package_arn):
187189

188190
model_package = ModelPackage.get(model_package_name=model_package_arn)
189191
model_builder = ModelBuilder(model=model_package)
192+
model_builder.accept_eula = True
190193
model = model_builder.build()
191194

192195
assert model is not None
@@ -201,6 +204,7 @@ def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints):
201204
model_package = ModelPackage.get(model_package_name=model_package_arn)
202205
endpoint_name = f"e2e-{int(time.time())}-{random.randint(100, 10000)}"
203206
model_builder = ModelBuilder(model=model_package)
207+
model_builder.accept_eula = True
204208
model_builder.build()
205209
endpoint = model_builder.deploy(endpoint_name=endpoint_name)
206210

@@ -220,6 +224,7 @@ def test_instance_type_from_recipe(self, training_job_name):
220224

221225
training_job = TrainingJob.get(training_job_name=training_job_name)
222226
model_builder = ModelBuilder(model=training_job)
227+
model_builder.accept_eula = True
223228
model_builder.build()
224229

225230
assert model_builder.instance_type is not None

sagemaker-serve/tests/unit/test_model_builder.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,3 +715,96 @@ def test_deploy_passes_inference_config_to_model_customization(self):
715715
call_kwargs = mock_deploy_mc.call_args[1]
716716
self.assertEqual(call_kwargs['inference_config'], inference_config)
717717
self.assertEqual(result, mock_endpoint)
718+
719+
720+
class TestLoraAcceptEula(unittest.TestCase):
721+
"""Tests for accept_eula handling in the LoRA deployment path."""
722+
723+
def _make_mb(self, accept_eula=None):
724+
mb = ModelBuilder.__new__(ModelBuilder)
725+
mb.accept_eula = accept_eula
726+
mb.image_uri = "some-image-uri"
727+
mb.env_vars = {}
728+
mb.model_name = None
729+
mb.model_path = "/tmp/fake-model-path"
730+
mb.role_arn = "arn:aws:iam::123456789012:role/role"
731+
mb.model = MagicMock()
732+
mb._adapter_s3_uri = None
733+
mb.shared_libs = []
734+
mb.dependencies = {"auto": True}
735+
mb.image_config = None
736+
mb.inference_spec = None
737+
mb.schema_builder = None
738+
mb.modelbuilder_list = None
739+
mb.sagemaker_session = None
740+
mb.s3_model_data_url = None
741+
mb.source_code = None
742+
mb.model_server = None
743+
mb.model_metadata = None
744+
mb.log_level = None
745+
mb.content_type = None
746+
mb.accept_type = None
747+
mb.compute = None
748+
mb.network = None
749+
mb.instance_type = None
750+
mb.mode = None
751+
return mb
752+
753+
def _patch_lora_deps(self, mb, hosting_uri="s3://bucket/hosting/"):
754+
"""Patch all dependencies needed to reach the LoRA ContainerDefinition block."""
755+
patches = [
756+
patch.object(mb, "_get_serve_setting", return_value=MagicMock()),
757+
patch.object(mb, "_is_model_customization", return_value=True),
758+
patch.object(mb, "_fetch_model_package", return_value=MagicMock()),
759+
patch.object(mb, "_fetch_and_cache_recipe_config"),
760+
patch.object(mb, "_is_nova_model", return_value=False),
761+
patch.object(mb, "_fetch_peft", return_value="LORA"),
762+
patch.object(mb, "_fetch_hub_document_for_custom_model",
763+
return_value={"HostingArtifactUri": hosting_uri}),
764+
]
765+
return patches
766+
767+
def test_lora_build_raises_when_accept_eula_false(self):
768+
mb = self._make_mb(accept_eula=False)
769+
patches = self._patch_lora_deps(mb)
770+
for p in patches:
771+
p.start()
772+
try:
773+
with self.assertRaises(ValueError) as ctx:
774+
mb._build_single_modelbuilder()
775+
self.assertIn("accept_eula", str(ctx.exception))
776+
finally:
777+
for p in patches:
778+
p.stop()
779+
780+
def test_lora_build_raises_when_accept_eula_not_set(self):
781+
mb = self._make_mb(accept_eula=None)
782+
patches = self._patch_lora_deps(mb)
783+
for p in patches:
784+
p.start()
785+
try:
786+
with self.assertRaises(ValueError) as ctx:
787+
mb._build_single_modelbuilder()
788+
self.assertIn("accept_eula", str(ctx.exception))
789+
finally:
790+
for p in patches:
791+
p.stop()
792+
793+
@patch("sagemaker.serve.model_builder.ContainerDefinition")
794+
@patch("sagemaker.serve.model_builder.Model")
795+
def test_lora_build_passes_accept_eula_true(self, mock_model, mock_container_def):
796+
mb = self._make_mb(accept_eula=True)
797+
mock_model.create.return_value = MagicMock()
798+
patches = self._patch_lora_deps(mb)
799+
for p in patches:
800+
p.start()
801+
try:
802+
mb._build_single_modelbuilder()
803+
call_kwargs = mock_container_def.call_args[1]
804+
eula_val = (
805+
call_kwargs["model_data_source"]["s3_data_source"]["model_access_config"]["accept_eula"]
806+
)
807+
self.assertTrue(eula_val)
808+
finally:
809+
for p in patches:
810+
p.stop()

0 commit comments

Comments
 (0)