1313"""Holds the BedrockModelBuilder class."""
1414from __future__ import absolute_import
1515
16+ import json
1617import time
1718import logging
1819from typing import Optional , Dict , Any , Union
20+ from urllib .parse import urlparse
1921
2022from sagemaker .core .helper .session_helper import Session
2123from sagemaker .core .resources import TrainingJob , ModelPackage
2729logger = logging .getLogger (__name__ )
2830
2931
32+ def _is_nova_model (container ) -> bool :
33+ """Determine whether a model package container represents a Nova model.
34+
35+ Checks both recipe_name and hub_content_name for the "nova" substring.
36+
37+ Args:
38+ container: A container from ModelPackage.inference_specification.containers.
39+
40+ Returns:
41+ True if the container represents a Nova model, False otherwise.
42+ """
43+ base_model = getattr (container , "base_model" , None )
44+ if not base_model :
45+ return False
46+
47+ recipe_name = getattr (base_model , "recipe_name" , None ) or ""
48+ hub_content_name = getattr (base_model , "hub_content_name" , None ) or ""
49+
50+ return "nova" in recipe_name .lower () or "nova" in hub_content_name .lower ()
51+
52+
3053class BedrockModelBuilder :
3154 """Builder class for deploying models to Amazon Bedrock.
32-
55+
3356 This class provides functionality to deploy SageMaker models to Bedrock
3457 using either model import jobs or custom model creation, depending on
3558 the model type (Nova models vs. other models).
@@ -42,7 +65,8 @@ def __init__(self, model: Optional[Union[ModelTrainer, TrainingJob, ModelPackage
4265 """Initialize BedrockModelBuilder with a model instance.
4366
4467 Args:
45- model: The model to deploy. Can be a ModelTrainer, TrainingJob, or ModelPackage instance.
68+ model: The model to deploy. Can be a ModelTrainer, TrainingJob,
69+ or ModelPackage instance.
4670 """
4771 self .model = model
4872 self ._bedrock_client = None
@@ -53,7 +77,7 @@ def __init__(self, model: Optional[Union[ModelTrainer, TrainingJob, ModelPackage
5377
5478 def _get_bedrock_client (self ):
5579 """Get or create Bedrock client singleton.
56-
80+
5781 Returns:
5882 boto3.client: Bedrock client instance.
5983 """
@@ -63,7 +87,7 @@ def _get_bedrock_client(self):
6387
6488 def _get_sagemaker_client (self ):
6589 """Get or create SageMaker client singleton.
66-
90+
6791 Returns:
6892 boto3.client: SageMaker client instance.
6993 """
@@ -73,20 +97,20 @@ def _get_sagemaker_client(self):
7397
7498 @_telemetry_emitter (feature = Feature .MODEL_CUSTOMIZATION , func_name = "BedrockModelBuilder.deploy" )
7599 def deploy (
76- self ,
77- job_name : Optional [str ] = None ,
78- imported_model_name : Optional [str ] = None ,
79- custom_model_name : Optional [str ] = None ,
80- role_arn : Optional [str ] = None ,
81- job_tags : Optional [list ] = None ,
82- imported_model_tags : Optional [list ] = None ,
83- model_tags : Optional [list ] = None ,
84- client_request_token : Optional [str ] = None ,
85- imported_model_kms_key_id : Optional [str ] = None ,
86- deployment_name : Optional [str ] = None ,
100+ self ,
101+ job_name : Optional [str ] = None ,
102+ imported_model_name : Optional [str ] = None ,
103+ custom_model_name : Optional [str ] = None ,
104+ role_arn : Optional [str ] = None ,
105+ job_tags : Optional [list ] = None ,
106+ imported_model_tags : Optional [list ] = None ,
107+ model_tags : Optional [list ] = None ,
108+ client_request_token : Optional [str ] = None ,
109+ imported_model_kms_key_id : Optional [str ] = None ,
110+ deployment_name : Optional [str ] = None ,
87111 ) -> Dict [str , Any ]:
88112 """Deploy the model to Bedrock.
89-
113+
90114 Automatically detects if the model is a Nova model and uses the appropriate
91115 Bedrock API (create_custom_model for Nova, create_model_import_job for others).
92116 For Nova models, also creates a custom model deployment for inference.
@@ -108,19 +132,24 @@ def deploy(
108132 Response from Bedrock API. For Nova models, returns the
109133 create_custom_model_deployment response. For others, returns
110134 the create_model_import_job response.
111-
135+
112136 Raises:
113- ValueError: If required parameters are missing for the detected model type .
137+ ValueError: If model_package is not set or required parameters are missing .
114138 """
139+ if not self .model_package :
140+ raise ValueError (
141+ "model_package is not set. Provide a valid model during initialization."
142+ )
143+
115144 container = self .model_package .inference_specification .containers [0 ]
116- is_nova = (hasattr (container , 'base_model' ) and container .base_model and
117- hasattr (container .base_model , 'recipe_name' ) and container .base_model .recipe_name and
118- "nova" in container .base_model .recipe_name .lower ()) or \
119- (hasattr (container , 'base_model' ) and container .base_model and
120- hasattr (container .base_model , 'hub_content_name' ) and container .base_model .hub_content_name and
121- "nova" in container .base_model .hub_content_name .lower ())
145+ is_nova = _is_nova_model (container )
122146
123147 if is_nova :
148+ if not custom_model_name :
149+ raise ValueError ("custom_model_name is required for Nova model deployment." )
150+ if not role_arn :
151+ raise ValueError ("role_arn is required for Nova model deployment." )
152+
124153 params = {
125154 "modelName" : custom_model_name ,
126155 "modelSourceConfig" : {"s3DataSource" : {"s3Uri" : self .s3_model_artifacts }},
@@ -129,6 +158,8 @@ def deploy(
129158 if model_tags :
130159 params ["modelTags" ] = model_tags
131160 params = {k : v for k , v in params .items () if v is not None }
161+
162+ logger .info ("Creating custom model %s for Nova deployment" , custom_model_name )
132163 create_response = self ._get_bedrock_client ().create_custom_model (** params )
133164
134165 model_arn = create_response .get ("modelArn" )
@@ -147,6 +178,8 @@ def deploy(
147178 "importedModelKmsKeyId" : imported_model_kms_key_id ,
148179 }
149180 params = {k : v for k , v in params .items () if v is not None }
181+
182+ logger .info ("Creating model import job for non-Nova deployment" )
150183 return self ._get_bedrock_client ().create_model_import_job (** params )
151184
152185 def create_deployment (
@@ -174,7 +207,11 @@ def create_deployment(
174207
175208 Raises:
176209 RuntimeError: If the model fails or times out waiting to become Active.
210+ ValueError: If model_arn is not provided.
177211 """
212+ if not model_arn :
213+ raise ValueError ("model_arn is required for create_deployment." )
214+
178215 self ._wait_for_model_active (model_arn , poll_interval = poll_interval , max_wait = max_wait )
179216
180217 params = {
@@ -183,9 +220,13 @@ def create_deployment(
183220 ** {k : v for k , v in kwargs .items () if v is not None },
184221 }
185222 params = {k : v for k , v in params .items () if v is not None }
223+
224+ logger .info ("Creating deployment %s for model %s" , deployment_name , model_arn )
186225 return self ._get_bedrock_client ().create_custom_model_deployment (** params )
187226
188- def _wait_for_model_active (self , model_arn : str , poll_interval : int = 60 , max_wait : int = 3600 ):
227+ def _wait_for_model_active (
228+ self , model_arn : str , poll_interval : int = 60 , max_wait : int = 3600
229+ ):
189230 """Poll Bedrock until the custom model reaches Active status.
190231
191232 Args:
@@ -197,6 +238,7 @@ def _wait_for_model_active(self, model_arn: str, poll_interval: int = 60, max_wa
197238 RuntimeError: If the model status is Failed or the wait times out.
198239 """
199240 elapsed = 0
241+ status = None
200242 while elapsed < max_wait :
201243 resp = self ._get_bedrock_client ().get_custom_model (modelIdentifier = model_arn )
202244 status = resp .get ("modelStatus" )
@@ -214,13 +256,12 @@ def _wait_for_model_active(self, model_arn: str, poll_interval: int = 60, max_wa
214256 f"Last status: { status } "
215257 )
216258
217-
218259 def _fetch_model_package (self ) -> Optional [ModelPackage ]:
219260 """Fetch the ModelPackage from the provided model.
220-
261+
221262 Extracts ModelPackage from ModelTrainer, TrainingJob, or returns
222263 the ModelPackage directly if that's what was provided.
223-
264+
224265 Returns:
225266 ModelPackage instance or None if no model was provided.
226267 """
@@ -229,98 +270,95 @@ def _fetch_model_package(self) -> Optional[ModelPackage]:
229270 if isinstance (self .model , TrainingJob ):
230271 return ModelPackage .get (self .model .output_model_package_arn )
231272 if isinstance (self .model , ModelTrainer ):
232- return ModelPackage .get (self .model ._latest_training_job .output_model_package_arn )
273+ return ModelPackage .get (
274+ self .model ._latest_training_job .output_model_package_arn
275+ )
233276 return None
234277
235278 def _get_s3_artifacts (self ) -> Optional [str ]:
236279 """Extract S3 URI of model artifacts from the model package.
237-
280+
238281 For Nova models, fetches checkpoint URI from manifest.json in training job output.
239282 For other models, returns the model data source S3 URI.
240-
283+
241284 Returns:
242285 S3 URI string of the model artifacts, or None if not available.
243286 """
244287 if not self .model_package :
245288 return None
246-
289+
247290 container = self .model_package .inference_specification .containers [0 ]
248- is_nova = (hasattr (container , 'base_model' ) and container .base_model and
249- hasattr (container .base_model , 'recipe_name' ) and container .base_model .recipe_name and
250- "nova" in container .base_model .recipe_name .lower ()) or \
251- (hasattr (container , 'base_model' ) and container .base_model and
252- hasattr (container .base_model , 'hub_content_name' ) and container .base_model .hub_content_name and
253- "nova" in container .base_model .hub_content_name .lower ())
254-
291+ is_nova = _is_nova_model (container )
292+
255293 if is_nova and isinstance (self .model , TrainingJob ):
256294 return self ._get_checkpoint_uri_from_manifest ()
257-
258- if hasattr (container , 'model_data_source' ) and container .model_data_source :
259- if hasattr (container .model_data_source , 's3_data_source' ) and container .model_data_source .s3_data_source :
260- return container .model_data_source .s3_data_source .s3_uri
295+
296+ if hasattr (container , "model_data_source" ) and container .model_data_source :
297+ data_source = container .model_data_source
298+ if hasattr (data_source , "s3_data_source" ) and data_source .s3_data_source :
299+ return data_source .s3_data_source .s3_uri
261300 return None
262-
301+
263302 def _get_checkpoint_uri_from_manifest (self ) -> Optional [str ]:
264303 """Get checkpoint URI from manifest.json for Nova models.
265-
304+
266305 Steps:
267306 1. Fetch S3 model artifacts from training job
268- 2. Go one level up in directory
269- 3. Find manifest.json
270- 4. Fetch checkpoint_s3_bucket from manifest
271-
307+ 2. Construct path to manifest.json in the output directory
308+ 3. Read and parse manifest.json
309+ 4. Return checkpoint_s3_bucket value
310+
272311 Returns:
273312 Checkpoint URI from manifest.json.
274-
313+
275314 Raises:
276- ValueError: If manifest.json cannot be found or parsed.
315+ ValueError: If manifest.json cannot be found or parsed, or if the
316+ model is not a TrainingJob instance.
277317 """
278- import json
279- from urllib .parse import urlparse
280- import logging
281-
282- logger = logging .getLogger (__name__ )
283-
284318 if not isinstance (self .model , TrainingJob ):
285319 raise ValueError ("Model must be a TrainingJob instance for Nova models" )
286-
287- # Step 1: Get S3 model artifacts from training job
320+
288321 s3_artifacts = self .model .model_artifacts .s3_model_artifacts
289322 if not s3_artifacts :
290323 raise ValueError ("No S3 model artifacts found in training job" )
291-
292- logger .info (f "S3 artifacts path: { s3_artifacts } " )
293-
294- # Step 2: Construct manifest path (same directory as model artifacts)
324+
325+ logger .info ("S3 artifacts path: %s" , s3_artifacts )
326+
327+ # Construct manifest path
295328 # s3://bucket/path/output/model.tar.gz -> s3://bucket/path/output/output/manifest.json
296- parts = s3_artifacts .rstrip ('/' ).rsplit ('/' , 1 )
297- manifest_path = parts [0 ] + '/output/manifest.json'
298-
299- logger .info (f"Manifest path: { manifest_path } " )
300-
301- # Step 3: Find and read manifest.json
329+ parts = s3_artifacts .rstrip ("/" ).rsplit ("/" , 1 )
330+ manifest_path = parts [0 ] + "/output/manifest.json"
331+
332+ logger .info ("Manifest path: %s" , manifest_path )
333+
302334 parsed = urlparse (manifest_path )
303335 bucket = parsed .netloc
304- manifest_key = parsed .path .lstrip ('/' )
305-
306- logger .info (f "Looking for manifest at s3://{ bucket } / { manifest_key } " )
307-
308- s3_client = self .boto_session .client ('s3' )
336+ manifest_key = parsed .path .lstrip ("/" )
337+
338+ logger .info ("Looking for manifest at s3://%s/%s" , bucket , manifest_key )
339+
340+ s3_client = self .boto_session .client ("s3" )
309341 try :
310342 response = s3_client .get_object (Bucket = bucket , Key = manifest_key )
311- manifest = json .loads (response ['Body' ].read ().decode ('utf-8' ))
312- logger .info (f"Manifest content: { manifest } " )
313-
314- # Step 4: Fetch checkpoint_s3_bucket from manifest
315- checkpoint_uri = manifest .get ('checkpoint_s3_bucket' )
343+ manifest = json .loads (response ["Body" ].read ().decode ("utf-8" ))
344+ logger .info ("Manifest content: %s" , manifest )
345+
346+ checkpoint_uri = manifest .get ("checkpoint_s3_bucket" )
316347 if not checkpoint_uri :
317- raise ValueError (f"'checkpoint_s3_bucket' not found in manifest. Available keys: { list (manifest .keys ())} " )
318-
319- logger .info (f"Checkpoint URI: { checkpoint_uri } " )
348+ raise ValueError (
349+ "'checkpoint_s3_bucket' not found in manifest. "
350+ "Available keys: %s" % list (manifest .keys ())
351+ )
352+
353+ logger .info ("Checkpoint URI: %s" , checkpoint_uri )
320354 return checkpoint_uri
321355 except s3_client .exceptions .NoSuchKey :
322- raise ValueError (f"manifest.json not found at s3://{ bucket } /{ manifest_key } " )
356+ raise ValueError (
357+ "manifest.json not found at s3://%s/%s" % (bucket , manifest_key )
358+ )
323359 except json .JSONDecodeError as e :
324- raise ValueError (f"Failed to parse manifest.json: { e } " )
360+ raise ValueError ("Failed to parse manifest.json: %s" % e )
361+ except ValueError :
362+ raise
325363 except Exception as e :
326- raise ValueError (f "Error reading manifest.json: { e } " )
364+ raise ValueError ("Error reading manifest.json: %s" % e )
0 commit comments