Skip to content

Commit 97a4afe

Browse files
authored
feat: bedrock-oss-provisioned-throughput-polling (#5893)
* feat: add import job polling and provisioned throughput for Bedrock OSS deployments - deploy() for non-Nova models now waits for import job completion and returns job details (model ready for on-demand inference). - New public method: create_provisioned_throughput() with polling. - New private methods: _wait_for_import_job_complete(), _wait_for_provisioned_throughput_in_service(). - Added unit tests and integ tests (serial to avoid concurrent quota issues). - Mark bedrock integ tests as serial to avoid concurrent import job quota issues. X-AI-Prompt: add import polling and PT for bedrock OSS deployments X-AI-Tool: kiro-cli * imporve docstring * feat: make model_id optional in create_provisioned_throughput - Store imported model ID after deploy() completes - create_provisioned_throughput() now falls back to the stored model ID if model_id is not explicitly passed - Added unit tests for fallback and explicit override behavior * docs: update example notebooks for new OSS deploy polling behavior - bedrock-modelbuilder-deployment.ipynb: deploy() now waits for import completion, removed manual polling cell, added PT usage example - mtrl_finetuning_example_notebook_v3_prod.ipynb: removed manual polling loop, updated description to reflect automatic waiting, added optional create_provisioned_throughput() example
1 parent ec3da03 commit 97a4afe

7 files changed

Lines changed: 795 additions & 161 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ sagemaker_train/src/**/container_drivers/distributed.json
4444
docs/api/generated/
4545
.hypothesis
4646
.kiro
47+
bedrock_api_logs/

sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py

Lines changed: 169 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
self.model = model
102102
self._bedrock_client = None
103103
self._sagemaker_client = None
104+
self._imported_model_id = None
104105
self.boto_session = Session().boto_session
105106
self.model_package = self._fetch_model_package() if model else None
106107
self._is_rmp = is_restricted_model_package(self.model_package)
@@ -153,29 +154,32 @@ def deploy(
153154
"""Deploy the model to Bedrock.
154155
155156
Automatically detects if the model is a Nova model and uses the appropriate
156-
Bedrock API (create_custom_model for Nova, create_model_import_job for others).
157-
For Nova models, also creates a custom model deployment for inference.
157+
Bedrock API (create_custom_model for Nova, create_model_import_job for OSS).
158+
For Nova models, creates a custom model deployment and polls until active.
159+
For OSS models, creates a model import job and polls until complete. Once
160+
deploy() returns, the model is ready for on-demand inference. For provisioned
161+
throughput, use the separate create_provisioned_throughput() method.
158162
159163
Args:
160-
job_name: Name for the model import job (non-Nova models only).
161-
imported_model_name: Name for the imported model (non-Nova models only).
164+
job_name: Name for the model import job (OSS models only).
165+
imported_model_name: Name for the imported model (OSS models only).
162166
custom_model_name: Name for the custom model (Nova models only).
163167
role_arn: IAM role ARN with permissions for Bedrock operations.
164-
job_tags: Tags for the import job (non-Nova models only).
165-
imported_model_tags: Tags for the imported model (non-Nova models only).
168+
job_tags: Tags for the import job (OSS models only).
169+
imported_model_tags: Tags for the imported model (OSS models only).
166170
model_tags: Tags for the custom model (Nova models only).
167-
client_request_token: Unique token for idempotency (non-Nova models only).
168-
imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only).
171+
client_request_token: Unique token for idempotency (OSS models only).
172+
imported_model_kms_key_id: KMS key ID for encryption (OSS models only).
169173
deployment_name: Name for the deployment (Nova models only). If not provided,
170174
defaults to custom_model_name suffixed with '-deployment'.
171175
172176
Returns:
173-
Response from Bedrock API. For Nova models, returns the
174-
create_custom_model_deployment response. For others, returns
175-
the create_model_import_job response.
177+
For Nova models: the create_custom_model_deployment response.
178+
For OSS models: the completed get_model_import_job response.
176179
177180
Raises:
178181
ValueError: If model_package is not set or required parameters are missing.
182+
RuntimeError: If the import job or deployment fails or times out.
179183
"""
180184
if not self.model_package:
181185
raise ValueError(
@@ -238,15 +242,24 @@ def deploy(
238242
}
239243
params = {k: v for k, v in params.items() if v is not None}
240244

241-
logger.info("Creating model import job for non-Nova deployment")
245+
logger.info("Creating model import job for OSS model deployment")
242246
print(f"[BedrockModelBuilder] Resolved S3 artifacts path: {self.s3_model_artifacts}")
243247
print(f"[BedrockModelBuilder] create_model_import_job params: {params}")
244-
response = self._get_bedrock_client().create_model_import_job(**params)
248+
import_response = self._get_bedrock_client().create_model_import_job(**params)
245249
logger.warning(
246-
"Bedrock create_model_import_job request: %s, response: %s", params, response
250+
"Bedrock create_model_import_job request: %s, response: %s", params, import_response
247251
)
248-
_log_bedrock_api_call("create_model_import_job", params, response)
249-
return response
252+
_log_bedrock_api_call("create_model_import_job", params, import_response)
253+
254+
job_arn = import_response.get("jobArn")
255+
self._wait_for_import_job_complete(job_arn)
256+
257+
# Return the completed job details and store imported model ID
258+
job_details = self._get_bedrock_client().get_model_import_job(
259+
jobIdentifier=job_arn
260+
)
261+
self._imported_model_id = job_details.get("importedModelName")
262+
return job_details
250263

251264
def create_deployment(
252265
self,
@@ -303,6 +316,146 @@ def create_deployment(
303316

304317
return response
305318

319+
def create_provisioned_throughput(
320+
self,
321+
model_id: Optional[str] = None,
322+
provisioned_model_name: str = None,
323+
model_units: int = 1,
324+
commitment_duration: Optional[str] = None,
325+
tags: Optional[list] = None,
326+
poll_interval: int = 60,
327+
max_wait: int = 3600,
328+
) -> Dict[str, Any]:
329+
"""Create provisioned throughput for an imported model on Bedrock.
330+
331+
Calls CreateProvisionedModelThroughput and polls until the provisioned
332+
throughput reaches InService status.
333+
334+
Args:
335+
model_id: ARN or name of the model. If not provided, uses the model
336+
ID from the most recent deploy() call.
337+
provisioned_model_name: Name for the provisioned throughput resource.
338+
model_units: Number of model units to provision. Defaults to 1.
339+
commitment_duration: Commitment duration. Valid values: 'OneMonth',
340+
'SixMonths'. If not provided, no commitment is set (on-demand).
341+
tags: Tags for the provisioned throughput resource.
342+
poll_interval: Seconds between status checks. Defaults to 60.
343+
max_wait: Maximum seconds to wait. Defaults to 3600.
344+
345+
Returns:
346+
Response from Bedrock create_provisioned_model_throughput API.
347+
348+
Raises:
349+
RuntimeError: If the provisioned throughput fails or times out.
350+
ValueError: If model_id cannot be determined or provisioned_model_name
351+
is not provided.
352+
"""
353+
resolved_model_id = model_id or self._imported_model_id
354+
if not resolved_model_id:
355+
raise ValueError(
356+
"model_id is required for create_provisioned_throughput. "
357+
"Either pass it explicitly or call deploy() first."
358+
)
359+
if not provisioned_model_name:
360+
raise ValueError(
361+
"provisioned_model_name is required for create_provisioned_throughput."
362+
)
363+
364+
params = {
365+
"modelId": resolved_model_id,
366+
"provisionedModelName": provisioned_model_name,
367+
"modelUnits": model_units,
368+
}
369+
if commitment_duration:
370+
params["commitmentDuration"] = commitment_duration
371+
if tags:
372+
params["tags"] = tags
373+
374+
logger.info(
375+
"Creating provisioned throughput '%s' for model %s with %d model units",
376+
provisioned_model_name,
377+
resolved_model_id,
378+
model_units,
379+
)
380+
response = self._get_bedrock_client().create_provisioned_model_throughput(**params)
381+
382+
provisioned_model_arn = response.get("provisionedModelArn")
383+
if provisioned_model_arn:
384+
self._wait_for_provisioned_throughput_in_service(
385+
provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait
386+
)
387+
388+
return response
389+
390+
def _wait_for_import_job_complete(
391+
self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600
392+
):
393+
"""Poll Bedrock until the model import job reaches Completed status.
394+
395+
Args:
396+
job_arn: ARN of the model import job.
397+
poll_interval: Seconds between status checks. Defaults to 60.
398+
max_wait: Maximum seconds to wait. Defaults to 3600.
399+
400+
Raises:
401+
RuntimeError: If the import job fails or times out.
402+
"""
403+
elapsed = 0
404+
status = None
405+
while elapsed < max_wait:
406+
resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn)
407+
status = resp.get("status")
408+
logger.info("Import job status: %s (elapsed %ds)", status, elapsed)
409+
if status == "Completed":
410+
return
411+
if status == "Failed":
412+
failure_reason = resp.get("failureMessage", "Unknown")
413+
raise RuntimeError(
414+
f"Model import job {job_arn} failed. Reason: {failure_reason}"
415+
)
416+
time.sleep(poll_interval)
417+
elapsed += poll_interval
418+
raise RuntimeError(
419+
f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. "
420+
f"Last status: {status}"
421+
)
422+
423+
def _wait_for_provisioned_throughput_in_service(
424+
self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600
425+
):
426+
"""Poll Bedrock until provisioned throughput reaches InService status.
427+
428+
Args:
429+
provisioned_model_arn: ARN of the provisioned model throughput.
430+
poll_interval: Seconds between status checks. Defaults to 60.
431+
max_wait: Maximum seconds to wait. Defaults to 3600.
432+
433+
Raises:
434+
RuntimeError: If the provisioned throughput fails or times out.
435+
"""
436+
elapsed = 0
437+
status = None
438+
while elapsed < max_wait:
439+
resp = self._get_bedrock_client().get_provisioned_model_throughput(
440+
provisionedModelId=provisioned_model_arn
441+
)
442+
status = resp.get("status")
443+
logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed)
444+
if status == "InService":
445+
return
446+
if status == "Failed":
447+
failure_reason = resp.get("failureMessage", "Unknown")
448+
raise RuntimeError(
449+
f"Provisioned throughput {provisioned_model_arn} failed. "
450+
f"Reason: {failure_reason}"
451+
)
452+
time.sleep(poll_interval)
453+
elapsed += poll_interval
454+
raise RuntimeError(
455+
f"Timed out after {max_wait}s waiting for provisioned throughput "
456+
f"{provisioned_model_arn} to become InService. Last status: {status}"
457+
)
458+
306459
def _wait_for_model_active(
307460
self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600
308461
):

0 commit comments

Comments
 (0)