Skip to content

Commit ebe542e

Browse files
committed
feat(bedrock): Harden BedrockModelBuilder for production readiness
Extract _is_nova_model() helper to eliminate duplicated Nova detection logic across deploy() and _get_s3_artifacts(). Uses getattr with safe defaults instead of fragile hasattr chains. Add input validation to deploy() and create_deployment(): - Raise ValueError when model_package is not set - Raise ValueError when custom_model_name or role_arn missing for Nova deployments - Raise ValueError when model_arn is empty in create_deployment Move json and urlparse imports to module level (were previously imported inside _get_checkpoint_uri_from_manifest). Replace f-string logging with lazy %s formatting throughout. Initialize status=None before the polling loop in _wait_for_model_active to avoid UnboundLocalError if the loop body never executes. Rewrite unit tests (43 tests) with full coverage: - _is_nova_model: recipe_name, hub_content_name, case insensitivity, missing base_model, None fields - __init__: None model, TrainingJob, ModelPackage - Client singletons: caching, injection - _fetch_model_package: ModelPackage, TrainingJob, ModelTrainer, unknown type - _get_s3_artifacts: None package, non-Nova, Nova delegation, Nova fallback - _get_checkpoint_uri_from_manifest: success, missing key, NoSuchKey, not TrainingJob, no artifacts, invalid JSON - _wait_for_model_active: immediate, polling, Failed, timeout - create_deployment: polling chain, extra kwargs, empty/None ARN - deploy: non-Nova, Nova full chain, hub_content_name detection, default deployment name, tags, missing params, None stripping Add integration tests for Nova E2E deployment: - Training job existence and status verification - Builder creation and Nova detection via _is_nova_model - S3 artifacts checkpoint validation - Full deploy-with-polling flow (marked @pytest.mark.slow) - Timeout behavior on bogus ARN - Validation error paths (no model_package, empty model_arn) - Resource cleanup fixture for deployments and custom models
1 parent c5e17c6 commit ebe542e

3 files changed

Lines changed: 846 additions & 320 deletions

File tree

sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py

Lines changed: 123 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
"""Holds the BedrockModelBuilder class."""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import time
1718
import logging
1819
from typing import Optional, Dict, Any, Union
20+
from urllib.parse import urlparse
1921

2022
from sagemaker.core.helper.session_helper import Session
2123
from sagemaker.core.resources import TrainingJob, ModelPackage
@@ -27,9 +29,30 @@
2729
logger = logging.getLogger(__name__)
2830

2931

32+
def _is_nova_model(container) -> bool:
33+
"""Determine whether a model package container represents a Nova model.
34+
35+
Checks both recipe_name and hub_content_name for the "nova" substring.
36+
37+
Args:
38+
container: A container from ModelPackage.inference_specification.containers.
39+
40+
Returns:
41+
True if the container represents a Nova model, False otherwise.
42+
"""
43+
base_model = getattr(container, "base_model", None)
44+
if not base_model:
45+
return False
46+
47+
recipe_name = getattr(base_model, "recipe_name", None) or ""
48+
hub_content_name = getattr(base_model, "hub_content_name", None) or ""
49+
50+
return "nova" in recipe_name.lower() or "nova" in hub_content_name.lower()
51+
52+
3053
class BedrockModelBuilder:
3154
"""Builder class for deploying models to Amazon Bedrock.
32-
55+
3356
This class provides functionality to deploy SageMaker models to Bedrock
3457
using either model import jobs or custom model creation, depending on
3558
the model type (Nova models vs. other models).
@@ -42,7 +65,8 @@ def __init__(self, model: Optional[Union[ModelTrainer, TrainingJob, ModelPackage
4265
"""Initialize BedrockModelBuilder with a model instance.
4366
4467
Args:
45-
model: The model to deploy. Can be a ModelTrainer, TrainingJob, or ModelPackage instance.
68+
model: The model to deploy. Can be a ModelTrainer, TrainingJob,
69+
or ModelPackage instance.
4670
"""
4771
self.model = model
4872
self._bedrock_client = None
@@ -53,7 +77,7 @@ def __init__(self, model: Optional[Union[ModelTrainer, TrainingJob, ModelPackage
5377

5478
def _get_bedrock_client(self):
5579
"""Get or create Bedrock client singleton.
56-
80+
5781
Returns:
5882
boto3.client: Bedrock client instance.
5983
"""
@@ -63,7 +87,7 @@ def _get_bedrock_client(self):
6387

6488
def _get_sagemaker_client(self):
6589
"""Get or create SageMaker client singleton.
66-
90+
6791
Returns:
6892
boto3.client: SageMaker client instance.
6993
"""
@@ -73,20 +97,20 @@ def _get_sagemaker_client(self):
7397

7498
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="BedrockModelBuilder.deploy")
7599
def deploy(
76-
self,
77-
job_name: Optional[str] = None,
78-
imported_model_name: Optional[str] = None,
79-
custom_model_name: Optional[str] = None,
80-
role_arn: Optional[str] = None,
81-
job_tags: Optional[list] = None,
82-
imported_model_tags: Optional[list] = None,
83-
model_tags: Optional[list] = None,
84-
client_request_token: Optional[str] = None,
85-
imported_model_kms_key_id: Optional[str] = None,
86-
deployment_name: Optional[str] = None,
100+
self,
101+
job_name: Optional[str] = None,
102+
imported_model_name: Optional[str] = None,
103+
custom_model_name: Optional[str] = None,
104+
role_arn: Optional[str] = None,
105+
job_tags: Optional[list] = None,
106+
imported_model_tags: Optional[list] = None,
107+
model_tags: Optional[list] = None,
108+
client_request_token: Optional[str] = None,
109+
imported_model_kms_key_id: Optional[str] = None,
110+
deployment_name: Optional[str] = None,
87111
) -> Dict[str, Any]:
88112
"""Deploy the model to Bedrock.
89-
113+
90114
Automatically detects if the model is a Nova model and uses the appropriate
91115
Bedrock API (create_custom_model for Nova, create_model_import_job for others).
92116
For Nova models, also creates a custom model deployment for inference.
@@ -108,19 +132,24 @@ def deploy(
108132
Response from Bedrock API. For Nova models, returns the
109133
create_custom_model_deployment response. For others, returns
110134
the create_model_import_job response.
111-
135+
112136
Raises:
113-
ValueError: If required parameters are missing for the detected model type.
137+
ValueError: If model_package is not set or required parameters are missing.
114138
"""
139+
if not self.model_package:
140+
raise ValueError(
141+
"model_package is not set. Provide a valid model during initialization."
142+
)
143+
115144
container = self.model_package.inference_specification.containers[0]
116-
is_nova = (hasattr(container, 'base_model') and container.base_model and
117-
hasattr(container.base_model, 'recipe_name') and container.base_model.recipe_name and
118-
"nova" in container.base_model.recipe_name.lower()) or \
119-
(hasattr(container, 'base_model') and container.base_model and
120-
hasattr(container.base_model, 'hub_content_name') and container.base_model.hub_content_name and
121-
"nova" in container.base_model.hub_content_name.lower())
145+
is_nova = _is_nova_model(container)
122146

123147
if is_nova:
148+
if not custom_model_name:
149+
raise ValueError("custom_model_name is required for Nova model deployment.")
150+
if not role_arn:
151+
raise ValueError("role_arn is required for Nova model deployment.")
152+
124153
params = {
125154
"modelName": custom_model_name,
126155
"modelSourceConfig": {"s3DataSource": {"s3Uri": self.s3_model_artifacts}},
@@ -129,6 +158,8 @@ def deploy(
129158
if model_tags:
130159
params["modelTags"] = model_tags
131160
params = {k: v for k, v in params.items() if v is not None}
161+
162+
logger.info("Creating custom model %s for Nova deployment", custom_model_name)
132163
create_response = self._get_bedrock_client().create_custom_model(**params)
133164

134165
model_arn = create_response.get("modelArn")
@@ -147,6 +178,8 @@ def deploy(
147178
"importedModelKmsKeyId": imported_model_kms_key_id,
148179
}
149180
params = {k: v for k, v in params.items() if v is not None}
181+
182+
logger.info("Creating model import job for non-Nova deployment")
150183
return self._get_bedrock_client().create_model_import_job(**params)
151184

152185
def create_deployment(
@@ -174,7 +207,11 @@ def create_deployment(
174207
175208
Raises:
176209
RuntimeError: If the model fails or times out waiting to become Active.
210+
ValueError: If model_arn is not provided.
177211
"""
212+
if not model_arn:
213+
raise ValueError("model_arn is required for create_deployment.")
214+
178215
self._wait_for_model_active(model_arn, poll_interval=poll_interval, max_wait=max_wait)
179216

180217
params = {
@@ -183,9 +220,13 @@ def create_deployment(
183220
**{k: v for k, v in kwargs.items() if v is not None},
184221
}
185222
params = {k: v for k, v in params.items() if v is not None}
223+
224+
logger.info("Creating deployment %s for model %s", deployment_name, model_arn)
186225
return self._get_bedrock_client().create_custom_model_deployment(**params)
187226

188-
def _wait_for_model_active(self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600):
227+
def _wait_for_model_active(
228+
self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600
229+
):
189230
"""Poll Bedrock until the custom model reaches Active status.
190231
191232
Args:
@@ -197,6 +238,7 @@ def _wait_for_model_active(self, model_arn: str, poll_interval: int = 60, max_wa
197238
RuntimeError: If the model status is Failed or the wait times out.
198239
"""
199240
elapsed = 0
241+
status = None
200242
while elapsed < max_wait:
201243
resp = self._get_bedrock_client().get_custom_model(modelIdentifier=model_arn)
202244
status = resp.get("modelStatus")
@@ -214,13 +256,12 @@ def _wait_for_model_active(self, model_arn: str, poll_interval: int = 60, max_wa
214256
f"Last status: {status}"
215257
)
216258

217-
218259
def _fetch_model_package(self) -> Optional[ModelPackage]:
219260
"""Fetch the ModelPackage from the provided model.
220-
261+
221262
Extracts ModelPackage from ModelTrainer, TrainingJob, or returns
222263
the ModelPackage directly if that's what was provided.
223-
264+
224265
Returns:
225266
ModelPackage instance or None if no model was provided.
226267
"""
@@ -229,98 +270,95 @@ def _fetch_model_package(self) -> Optional[ModelPackage]:
229270
if isinstance(self.model, TrainingJob):
230271
return ModelPackage.get(self.model.output_model_package_arn)
231272
if isinstance(self.model, ModelTrainer):
232-
return ModelPackage.get(self.model._latest_training_job.output_model_package_arn)
273+
return ModelPackage.get(
274+
self.model._latest_training_job.output_model_package_arn
275+
)
233276
return None
234277

235278
def _get_s3_artifacts(self) -> Optional[str]:
236279
"""Extract S3 URI of model artifacts from the model package.
237-
280+
238281
For Nova models, fetches checkpoint URI from manifest.json in training job output.
239282
For other models, returns the model data source S3 URI.
240-
283+
241284
Returns:
242285
S3 URI string of the model artifacts, or None if not available.
243286
"""
244287
if not self.model_package:
245288
return None
246-
289+
247290
container = self.model_package.inference_specification.containers[0]
248-
is_nova = (hasattr(container, 'base_model') and container.base_model and
249-
hasattr(container.base_model, 'recipe_name') and container.base_model.recipe_name and
250-
"nova" in container.base_model.recipe_name.lower()) or \
251-
(hasattr(container, 'base_model') and container.base_model and
252-
hasattr(container.base_model, 'hub_content_name') and container.base_model.hub_content_name and
253-
"nova" in container.base_model.hub_content_name.lower())
254-
291+
is_nova = _is_nova_model(container)
292+
255293
if is_nova and isinstance(self.model, TrainingJob):
256294
return self._get_checkpoint_uri_from_manifest()
257-
258-
if hasattr(container, 'model_data_source') and container.model_data_source:
259-
if hasattr(container.model_data_source, 's3_data_source') and container.model_data_source.s3_data_source:
260-
return container.model_data_source.s3_data_source.s3_uri
295+
296+
if hasattr(container, "model_data_source") and container.model_data_source:
297+
data_source = container.model_data_source
298+
if hasattr(data_source, "s3_data_source") and data_source.s3_data_source:
299+
return data_source.s3_data_source.s3_uri
261300
return None
262-
301+
263302
def _get_checkpoint_uri_from_manifest(self) -> Optional[str]:
264303
"""Get checkpoint URI from manifest.json for Nova models.
265-
304+
266305
Steps:
267306
1. Fetch S3 model artifacts from training job
268-
2. Go one level up in directory
269-
3. Find manifest.json
270-
4. Fetch checkpoint_s3_bucket from manifest
271-
307+
2. Construct path to manifest.json in the output directory
308+
3. Read and parse manifest.json
309+
4. Return checkpoint_s3_bucket value
310+
272311
Returns:
273312
Checkpoint URI from manifest.json.
274-
313+
275314
Raises:
276-
ValueError: If manifest.json cannot be found or parsed.
315+
ValueError: If manifest.json cannot be found or parsed, or if the
316+
model is not a TrainingJob instance.
277317
"""
278-
import json
279-
from urllib.parse import urlparse
280-
import logging
281-
282-
logger = logging.getLogger(__name__)
283-
284318
if not isinstance(self.model, TrainingJob):
285319
raise ValueError("Model must be a TrainingJob instance for Nova models")
286-
287-
# Step 1: Get S3 model artifacts from training job
320+
288321
s3_artifacts = self.model.model_artifacts.s3_model_artifacts
289322
if not s3_artifacts:
290323
raise ValueError("No S3 model artifacts found in training job")
291-
292-
logger.info(f"S3 artifacts path: {s3_artifacts}")
293-
294-
# Step 2: Construct manifest path (same directory as model artifacts)
324+
325+
logger.info("S3 artifacts path: %s", s3_artifacts)
326+
327+
# Construct manifest path
295328
# s3://bucket/path/output/model.tar.gz -> s3://bucket/path/output/output/manifest.json
296-
parts = s3_artifacts.rstrip('/').rsplit('/', 1)
297-
manifest_path = parts[0] + '/output/manifest.json'
298-
299-
logger.info(f"Manifest path: {manifest_path}")
300-
301-
# Step 3: Find and read manifest.json
329+
parts = s3_artifacts.rstrip("/").rsplit("/", 1)
330+
manifest_path = parts[0] + "/output/manifest.json"
331+
332+
logger.info("Manifest path: %s", manifest_path)
333+
302334
parsed = urlparse(manifest_path)
303335
bucket = parsed.netloc
304-
manifest_key = parsed.path.lstrip('/')
305-
306-
logger.info(f"Looking for manifest at s3://{bucket}/{manifest_key}")
307-
308-
s3_client = self.boto_session.client('s3')
336+
manifest_key = parsed.path.lstrip("/")
337+
338+
logger.info("Looking for manifest at s3://%s/%s", bucket, manifest_key)
339+
340+
s3_client = self.boto_session.client("s3")
309341
try:
310342
response = s3_client.get_object(Bucket=bucket, Key=manifest_key)
311-
manifest = json.loads(response['Body'].read().decode('utf-8'))
312-
logger.info(f"Manifest content: {manifest}")
313-
314-
# Step 4: Fetch checkpoint_s3_bucket from manifest
315-
checkpoint_uri = manifest.get('checkpoint_s3_bucket')
343+
manifest = json.loads(response["Body"].read().decode("utf-8"))
344+
logger.info("Manifest content: %s", manifest)
345+
346+
checkpoint_uri = manifest.get("checkpoint_s3_bucket")
316347
if not checkpoint_uri:
317-
raise ValueError(f"'checkpoint_s3_bucket' not found in manifest. Available keys: {list(manifest.keys())}")
318-
319-
logger.info(f"Checkpoint URI: {checkpoint_uri}")
348+
raise ValueError(
349+
"'checkpoint_s3_bucket' not found in manifest. "
350+
"Available keys: %s" % list(manifest.keys())
351+
)
352+
353+
logger.info("Checkpoint URI: %s", checkpoint_uri)
320354
return checkpoint_uri
321355
except s3_client.exceptions.NoSuchKey:
322-
raise ValueError(f"manifest.json not found at s3://{bucket}/{manifest_key}")
356+
raise ValueError(
357+
"manifest.json not found at s3://%s/%s" % (bucket, manifest_key)
358+
)
323359
except json.JSONDecodeError as e:
324-
raise ValueError(f"Failed to parse manifest.json: {e}")
360+
raise ValueError("Failed to parse manifest.json: %s" % e)
361+
except ValueError:
362+
raise
325363
except Exception as e:
326-
raise ValueError(f"Error reading manifest.json: {e}")
364+
raise ValueError("Error reading manifest.json: %s" % e)

0 commit comments

Comments
 (0)