Skip to content

Commit f593ee3

Browse files
committed
feat: add Bedrock integration (BIP-0042) as separate PR
1 parent 5ca9739 commit f593ee3

5 files changed

Lines changed: 500 additions & 41 deletions

File tree

burr/integrations/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,14 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
18+
19+
def __getattr__(name: str):
20+
"""Lazy load Bedrock integration to avoid requiring boto3 unless used."""
21+
if name == "BedrockAction":
22+
from burr.integrations.bedrock import BedrockAction
23+
return BedrockAction
24+
if name == "BedrockStreamingAction":
25+
from burr.integrations.bedrock import BedrockStreamingAction
26+
return BedrockStreamingAction
27+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

burr/integrations/bedrock.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Amazon Bedrock integration for Burr.
19+
20+
This module provides Action classes for invoking Amazon Bedrock models
21+
within Burr applications.
22+
23+
Example usage:
24+
from burr.integrations.bedrock import BedrockAction
25+
26+
def prompt_mapper(state):
27+
return {
28+
"messages": [{"role": "user", "content": state["user_input"]}],
29+
"system": [{"text": "You are a helpful assistant."}],
30+
}
31+
32+
# With default client (created lazily on first use):
33+
action = BedrockAction(
34+
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
35+
input_mapper=prompt_mapper,
36+
reads=["user_input"],
37+
writes=["response"],
38+
)
39+
40+
# With injected client (for tests or distributed execution):
41+
# client = boto3.client("bedrock-runtime", region_name="us-east-1")
42+
# action = BedrockAction(..., client=client)
43+
"""
44+
45+
import logging
46+
from typing import Any, Generator, Optional, Protocol
47+
48+
from burr.core.action import SingleStepAction, StreamingAction
49+
from burr.core.state import State
50+
from burr.integrations.base import require_plugin
51+
52+
logger = logging.getLogger(__name__)
53+
54+
# Type for injected Bedrock client (avoids boto3 import at type-check time)
55+
BedrockClient = Any
56+
57+
try:
58+
import boto3
59+
from botocore.config import Config
60+
from botocore.exceptions import ClientError
61+
except ImportError as e:
62+
require_plugin(e, "bedrock")
63+
64+
65+
class StateToPromptMapper(Protocol):
66+
"""Protocol for mapping Burr state to Bedrock prompt format."""
67+
68+
def __call__(self, state: State) -> dict[str, Any]:
69+
...
70+
71+
72+
class BedrockAction(SingleStepAction):
73+
"""Action that invokes Amazon Bedrock models using the Converse API."""
74+
75+
def __init__(
76+
self,
77+
model_id: str,
78+
input_mapper: StateToPromptMapper,
79+
reads: list[str],
80+
writes: list[str],
81+
name: str = "bedrock_invoke",
82+
region: Optional[str] = None,
83+
guardrail_id: Optional[str] = None,
84+
guardrail_version: Optional[str] = None,
85+
inference_config: Optional[dict[str, Any]] = None,
86+
max_retries: int = 3,
87+
client: Optional[BedrockClient] = None,
88+
):
89+
super().__init__()
90+
self._model_id = model_id
91+
self._input_mapper = input_mapper
92+
self._reads = reads
93+
self._writes = writes
94+
self._name = name
95+
self._region = region
96+
self._guardrail_id = guardrail_id
97+
self._guardrail_version = guardrail_version or "DRAFT"
98+
self._inference_config = inference_config or {"maxTokens": 4096}
99+
self._max_retries = max_retries
100+
self._client = client
101+
102+
def _get_client(self) -> BedrockClient:
103+
"""Return the Bedrock client, creating lazily if not injected."""
104+
if self._client is not None:
105+
return self._client
106+
config = Config(
107+
retries={"max_attempts": self._max_retries, "mode": "adaptive"}
108+
)
109+
self._client = boto3.client(
110+
"bedrock-runtime", region_name=self._region, config=config
111+
)
112+
return self._client
113+
114+
@property
115+
def reads(self) -> list[str]:
116+
return self._reads
117+
118+
@property
119+
def writes(self) -> list[str]:
120+
return self._writes
121+
122+
@property
123+
def name(self) -> str:
124+
return self._name
125+
126+
def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
127+
prompt = self._input_mapper(state)
128+
129+
request: dict[str, Any] = {
130+
"modelId": self._model_id,
131+
"messages": prompt["messages"],
132+
"inferenceConfig": self._inference_config,
133+
}
134+
135+
if "system" in prompt:
136+
request["system"] = prompt["system"]
137+
138+
if self._guardrail_id:
139+
request["guardrailConfig"] = {
140+
"guardrailIdentifier": self._guardrail_id,
141+
"guardrailVersion": self._guardrail_version,
142+
}
143+
144+
try:
145+
response = self._get_client().converse(**request)
146+
except ClientError as e:
147+
logger.error("Bedrock API error: %s", e)
148+
raise
149+
150+
output_message = response["output"]["message"]
151+
content_blocks = output_message.get("content", [])
152+
text = content_blocks[0]["text"] if content_blocks else ""
153+
154+
result: dict[str, Any] = {
155+
"response": text,
156+
"usage": response.get("usage", {}),
157+
"stop_reason": response.get("stopReason"),
158+
}
159+
160+
updates = {key: result[key] for key in self._writes if key in result}
161+
new_state = state.update(**updates)
162+
163+
return result, new_state
164+
165+
166+
class BedrockStreamingAction(StreamingAction):
167+
"""Streaming variant of BedrockAction using Converse Stream API."""
168+
169+
def __init__(
170+
self,
171+
model_id: str,
172+
input_mapper: StateToPromptMapper,
173+
reads: list[str],
174+
writes: list[str],
175+
name: str = "bedrock_stream",
176+
region: Optional[str] = None,
177+
guardrail_id: Optional[str] = None,
178+
guardrail_version: Optional[str] = None,
179+
inference_config: Optional[dict[str, Any]] = None,
180+
max_retries: int = 3,
181+
client: Optional[BedrockClient] = None,
182+
):
183+
super().__init__()
184+
self._model_id = model_id
185+
self._input_mapper = input_mapper
186+
self._reads = reads
187+
self._writes = writes
188+
self._name = name
189+
self._region = region
190+
self._guardrail_id = guardrail_id
191+
self._guardrail_version = guardrail_version or "DRAFT"
192+
self._inference_config = inference_config or {"maxTokens": 4096}
193+
self._max_retries = max_retries
194+
self._client = client
195+
196+
def _get_client(self) -> BedrockClient:
197+
"""Return the Bedrock client, creating lazily if not injected."""
198+
if self._client is not None:
199+
return self._client
200+
config = Config(
201+
retries={"max_attempts": self._max_retries, "mode": "adaptive"}
202+
)
203+
self._client = boto3.client(
204+
"bedrock-runtime", region_name=self._region, config=config
205+
)
206+
return self._client
207+
208+
@property
209+
def reads(self) -> list[str]:
210+
return self._reads
211+
212+
@property
213+
def writes(self) -> list[str]:
214+
return self._writes
215+
216+
@property
217+
def name(self) -> str:
218+
return self._name
219+
220+
def stream_run(
221+
self, state: State, **run_kwargs
222+
) -> Generator[dict, None, None]:
223+
prompt = self._input_mapper(state)
224+
225+
request: dict[str, Any] = {
226+
"modelId": self._model_id,
227+
"messages": prompt["messages"],
228+
"inferenceConfig": self._inference_config,
229+
}
230+
231+
if "system" in prompt:
232+
request["system"] = prompt["system"]
233+
234+
if self._guardrail_id:
235+
request["guardrailConfig"] = {
236+
"guardrailIdentifier": self._guardrail_id,
237+
"guardrailVersion": self._guardrail_version,
238+
}
239+
240+
try:
241+
response = self._get_client().converse_stream(**request)
242+
except ClientError as e:
243+
logger.error("Bedrock streaming API error: %s", e)
244+
raise
245+
246+
full_response = ""
247+
stream = response.get("stream", [])
248+
for event in stream:
249+
if "contentBlockDelta" in event:
250+
chunk = event["contentBlockDelta"]["delta"].get("text", "")
251+
full_response += chunk
252+
yield {"chunk": chunk, "response": full_response}
253+
254+
yield {"chunk": "", "response": full_response, "complete": True}
255+
256+
def update(self, result: dict, state: State) -> State:
257+
if result.get("complete"):
258+
updates = {"response": result.get("response", "")}
259+
filtered = {k: v for k, v in updates.items() if k in self._writes}
260+
return state.update(**filtered)
261+
return state

0 commit comments

Comments
 (0)