@@ -101,6 +101,7 @@ def __init__(
101101 self .model = model
102102 self ._bedrock_client = None
103103 self ._sagemaker_client = None
104+ self ._imported_model_id = None
104105 self .boto_session = Session ().boto_session
105106 self .model_package = self ._fetch_model_package () if model else None
106107 self ._is_rmp = is_restricted_model_package (self .model_package )
@@ -153,29 +154,32 @@ def deploy(
153154 """Deploy the model to Bedrock.
154155
155156 Automatically detects if the model is a Nova model and uses the appropriate
156- Bedrock API (create_custom_model for Nova, create_model_import_job for others).
157- For Nova models, also creates a custom model deployment for inference.
157+ Bedrock API (create_custom_model for Nova, create_model_import_job for OSS).
158+ For Nova models, creates a custom model deployment and polls until active.
159+ For OSS models, creates a model import job and polls until complete. Once
160+ deploy() returns, the model is ready for on-demand inference. For provisioned
161+ throughput, use the separate create_provisioned_throughput() method.
158162
159163 Args:
160- job_name: Name for the model import job (non-Nova models only).
161- imported_model_name: Name for the imported model (non-Nova models only).
164+ job_name: Name for the model import job (OSS models only).
165+ imported_model_name: Name for the imported model (OSS models only).
162166 custom_model_name: Name for the custom model (Nova models only).
163167 role_arn: IAM role ARN with permissions for Bedrock operations.
164- job_tags: Tags for the import job (non-Nova models only).
165- imported_model_tags: Tags for the imported model (non-Nova models only).
168+ job_tags: Tags for the import job (OSS models only).
169+ imported_model_tags: Tags for the imported model (OSS models only).
166170 model_tags: Tags for the custom model (Nova models only).
167- client_request_token: Unique token for idempotency (non-Nova models only).
168- imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only).
171+ client_request_token: Unique token for idempotency (OSS models only).
172+ imported_model_kms_key_id: KMS key ID for encryption (OSS models only).
169173 deployment_name: Name for the deployment (Nova models only). If not provided,
170174 defaults to custom_model_name suffixed with '-deployment'.
171175
172176 Returns:
173- Response from Bedrock API. For Nova models, returns the
174- create_custom_model_deployment response. For others, returns
175- the create_model_import_job response.
177+ For Nova models: the create_custom_model_deployment response.
178+ For OSS models: the completed get_model_import_job response.
176179
177180 Raises:
178181 ValueError: If model_package is not set or required parameters are missing.
182+ RuntimeError: If the import job or deployment fails or times out.
179183 """
180184 if not self .model_package :
181185 raise ValueError (
@@ -238,15 +242,24 @@ def deploy(
238242 }
239243 params = {k : v for k , v in params .items () if v is not None }
240244
241- logger .info ("Creating model import job for non-Nova deployment" )
245+ logger .info ("Creating model import job for OSS model deployment" )
242246 print (f"[BedrockModelBuilder] Resolved S3 artifacts path: { self .s3_model_artifacts } " )
243247 print (f"[BedrockModelBuilder] create_model_import_job params: { params } " )
244- response = self ._get_bedrock_client ().create_model_import_job (** params )
248+ import_response = self ._get_bedrock_client ().create_model_import_job (** params )
245249 logger .warning (
246- "Bedrock create_model_import_job request: %s, response: %s" , params , response
250+ "Bedrock create_model_import_job request: %s, response: %s" , params , import_response
247251 )
248- _log_bedrock_api_call ("create_model_import_job" , params , response )
249- return response
252+ _log_bedrock_api_call ("create_model_import_job" , params , import_response )
253+
254+ job_arn = import_response .get ("jobArn" )
255+ self ._wait_for_import_job_complete (job_arn )
256+
257+ # Return the completed job details and store imported model ID
258+ job_details = self ._get_bedrock_client ().get_model_import_job (
259+ jobIdentifier = job_arn
260+ )
261+ self ._imported_model_id = job_details .get ("importedModelName" )
262+ return job_details
250263
251264 def create_deployment (
252265 self ,
@@ -303,6 +316,146 @@ def create_deployment(
303316
304317 return response
305318
319+ def create_provisioned_throughput (
320+ self ,
321+ model_id : Optional [str ] = None ,
322+ provisioned_model_name : str = None ,
323+ model_units : int = 1 ,
324+ commitment_duration : Optional [str ] = None ,
325+ tags : Optional [list ] = None ,
326+ poll_interval : int = 60 ,
327+ max_wait : int = 3600 ,
328+ ) -> Dict [str , Any ]:
329+ """Create provisioned throughput for an imported model on Bedrock.
330+
331+ Calls CreateProvisionedModelThroughput and polls until the provisioned
332+ throughput reaches InService status.
333+
334+ Args:
335+ model_id: ARN or name of the model. If not provided, uses the model
336+ ID from the most recent deploy() call.
337+ provisioned_model_name: Name for the provisioned throughput resource.
338+ model_units: Number of model units to provision. Defaults to 1.
339+ commitment_duration: Commitment duration. Valid values: 'OneMonth',
340+ 'SixMonths'. If not provided, no commitment is set (on-demand).
341+ tags: Tags for the provisioned throughput resource.
342+ poll_interval: Seconds between status checks. Defaults to 60.
343+ max_wait: Maximum seconds to wait. Defaults to 3600.
344+
345+ Returns:
346+ Response from Bedrock create_provisioned_model_throughput API.
347+
348+ Raises:
349+ RuntimeError: If the provisioned throughput fails or times out.
350+ ValueError: If model_id cannot be determined or provisioned_model_name
351+ is not provided.
352+ """
353+ resolved_model_id = model_id or self ._imported_model_id
354+ if not resolved_model_id :
355+ raise ValueError (
356+ "model_id is required for create_provisioned_throughput. "
357+ "Either pass it explicitly or call deploy() first."
358+ )
359+ if not provisioned_model_name :
360+ raise ValueError (
361+ "provisioned_model_name is required for create_provisioned_throughput."
362+ )
363+
364+ params = {
365+ "modelId" : resolved_model_id ,
366+ "provisionedModelName" : provisioned_model_name ,
367+ "modelUnits" : model_units ,
368+ }
369+ if commitment_duration :
370+ params ["commitmentDuration" ] = commitment_duration
371+ if tags :
372+ params ["tags" ] = tags
373+
374+ logger .info (
375+ "Creating provisioned throughput '%s' for model %s with %d model units" ,
376+ provisioned_model_name ,
377+ resolved_model_id ,
378+ model_units ,
379+ )
380+ response = self ._get_bedrock_client ().create_provisioned_model_throughput (** params )
381+
382+ provisioned_model_arn = response .get ("provisionedModelArn" )
383+ if provisioned_model_arn :
384+ self ._wait_for_provisioned_throughput_in_service (
385+ provisioned_model_arn , poll_interval = poll_interval , max_wait = max_wait
386+ )
387+
388+ return response
389+
390+ def _wait_for_import_job_complete (
391+ self , job_arn : str , poll_interval : int = 60 , max_wait : int = 3600
392+ ):
393+ """Poll Bedrock until the model import job reaches Completed status.
394+
395+ Args:
396+ job_arn: ARN of the model import job.
397+ poll_interval: Seconds between status checks. Defaults to 60.
398+ max_wait: Maximum seconds to wait. Defaults to 3600.
399+
400+ Raises:
401+ RuntimeError: If the import job fails or times out.
402+ """
403+ elapsed = 0
404+ status = None
405+ while elapsed < max_wait :
406+ resp = self ._get_bedrock_client ().get_model_import_job (jobIdentifier = job_arn )
407+ status = resp .get ("status" )
408+ logger .info ("Import job status: %s (elapsed %ds)" , status , elapsed )
409+ if status == "Completed" :
410+ return
411+ if status == "Failed" :
412+ failure_reason = resp .get ("failureMessage" , "Unknown" )
413+ raise RuntimeError (
414+ f"Model import job { job_arn } failed. Reason: { failure_reason } "
415+ )
416+ time .sleep (poll_interval )
417+ elapsed += poll_interval
418+ raise RuntimeError (
419+ f"Timed out after { max_wait } s waiting for import job { job_arn } to complete. "
420+ f"Last status: { status } "
421+ )
422+
423+ def _wait_for_provisioned_throughput_in_service (
424+ self , provisioned_model_arn : str , poll_interval : int = 60 , max_wait : int = 3600
425+ ):
426+ """Poll Bedrock until provisioned throughput reaches InService status.
427+
428+ Args:
429+ provisioned_model_arn: ARN of the provisioned model throughput.
430+ poll_interval: Seconds between status checks. Defaults to 60.
431+ max_wait: Maximum seconds to wait. Defaults to 3600.
432+
433+ Raises:
434+ RuntimeError: If the provisioned throughput fails or times out.
435+ """
436+ elapsed = 0
437+ status = None
438+ while elapsed < max_wait :
439+ resp = self ._get_bedrock_client ().get_provisioned_model_throughput (
440+ provisionedModelId = provisioned_model_arn
441+ )
442+ status = resp .get ("status" )
443+ logger .info ("Provisioned throughput status: %s (elapsed %ds)" , status , elapsed )
444+ if status == "InService" :
445+ return
446+ if status == "Failed" :
447+ failure_reason = resp .get ("failureMessage" , "Unknown" )
448+ raise RuntimeError (
449+ f"Provisioned throughput { provisioned_model_arn } failed. "
450+ f"Reason: { failure_reason } "
451+ )
452+ time .sleep (poll_interval )
453+ elapsed += poll_interval
454+ raise RuntimeError (
455+ f"Timed out after { max_wait } s waiting for provisioned throughput "
456+ f"{ provisioned_model_arn } to become InService. Last status: { status } "
457+ )
458+
306459 def _wait_for_model_active (
307460 self , model_arn : str , poll_interval : int = 60 , max_wait : int = 3600
308461 ):
0 commit comments