Skip to content

Commit 0499dec

Browse files
committed
Add Bedrock integration, tests, docs, and CI
1 parent ab2fb90 commit 0499dec

8 files changed

Lines changed: 753 additions & 1 deletion

File tree

.github/workflows/python-package.yml

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
7474
- name: Run tests
7575
run: |
76-
python -m pytest tests --ignore=tests/integrations/persisters
76+
python -m pytest tests --ignore=tests/integrations/persisters --ignore=tests/integrations/test_bip0042_bedrock.py
7777
7878
test-tracking-server-s3:
7979
runs-on: ubuntu-latest
@@ -98,6 +98,29 @@ jobs:
9898
run: |
9999
python -m pytest tests/tracking/test_bip0042_s3_buffering.py -v
100100
101+
test-bedrock:
102+
runs-on: ubuntu-latest
103+
strategy:
104+
fail-fast: false
105+
matrix:
106+
python-version: ['3.9', '3.10', '3.11', '3.12']
107+
108+
steps:
109+
- uses: actions/checkout@v4
110+
111+
- name: Set up Python ${{ matrix.python-version }}
112+
uses: actions/setup-python@v4
113+
with:
114+
python-version: ${{ matrix.python-version }}
115+
116+
- name: Install dependencies
117+
run: |
118+
python -m pip install -e ".[tests,bedrock]"
119+
120+
- name: Run Bedrock integration tests
121+
run: |
122+
python -m pytest tests/integrations/test_bip0042_bedrock.py -v
123+
101124
validate-examples:
102125
runs-on: ubuntu-latest
103126
steps:

burr/integrations/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,16 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
18+
19+
def __getattr__(name: str):
20+
"""Lazy load Bedrock integration to avoid requiring boto3 unless used."""
21+
if name == "BedrockAction":
22+
from burr.integrations.bedrock import BedrockAction
23+
24+
return BedrockAction
25+
if name == "BedrockStreamingAction":
26+
from burr.integrations.bedrock import BedrockStreamingAction
27+
28+
return BedrockStreamingAction
29+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

burr/integrations/bedrock.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Amazon Bedrock integration for Burr.
19+
20+
This module provides Action classes for invoking Amazon Bedrock models
21+
within Burr applications.
22+
23+
Example usage:
24+
from burr.integrations.bedrock import BedrockAction
25+
26+
def prompt_mapper(state):
27+
return {
28+
"messages": [{"role": "user", "content": state["user_input"]}],
29+
"system": [{"text": "You are a helpful assistant."}],
30+
}
31+
32+
# With default client (created lazily on first use):
33+
action = BedrockAction(
34+
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
35+
input_mapper=prompt_mapper,
36+
reads=["user_input"],
37+
writes=["response"],
38+
)
39+
40+
# With injected client (for tests or distributed execution):
41+
# client = boto3.client("bedrock-runtime", region_name="us-east-1")
42+
# action = BedrockAction(..., client=client)
43+
44+
If ``guardrail_id`` is set, you must pass an explicit ``guardrail_version``
45+
(including the string ``DRAFT`` if you intend to use the unpublished draft).
46+
"""
47+
48+
import logging
49+
from typing import Any, Generator, Optional, Protocol
50+
51+
from burr.core.action import SingleStepAction, StreamingAction
52+
from burr.core.state import State
53+
from burr.integrations.base import require_plugin
54+
55+
logger = logging.getLogger(__name__)
56+
57+
# Type for injected Bedrock client (avoids boto3 import at type-check time)
58+
BedrockClient = Any
59+
60+
try:
61+
import boto3
62+
from botocore.config import Config
63+
from botocore.exceptions import ClientError
64+
except ImportError as e:
65+
require_plugin(e, "bedrock")
66+
67+
68+
class StateToPromptMapper(Protocol):
69+
"""Protocol for mapping Burr state to Bedrock prompt format."""
70+
71+
def __call__(self, state: State) -> dict[str, Any]: ... # noqa: E704
72+
73+
74+
class _BedrockCore:
75+
"""Shared Bedrock client, inference config, and Converse request shape."""
76+
77+
def __init__(
78+
self,
79+
model_id: str,
80+
input_mapper: StateToPromptMapper,
81+
reads: list[str],
82+
writes: list[str],
83+
name: str,
84+
region: Optional[str],
85+
guardrail_id: Optional[str],
86+
guardrail_version: Optional[str],
87+
inference_config: Optional[dict[str, Any]],
88+
max_retries: int,
89+
client: Optional[BedrockClient],
90+
):
91+
if guardrail_id is not None and guardrail_version is None:
92+
raise ValueError(
93+
"guardrail_version is required when guardrail_id is set "
94+
'(pass an explicit published version, or "DRAFT" for the draft).'
95+
)
96+
self._model_id = model_id
97+
self._input_mapper = input_mapper
98+
self._reads = reads
99+
self._writes = writes
100+
self._name = name
101+
self._region = region
102+
self._guardrail_id = guardrail_id
103+
self._guardrail_version = guardrail_version
104+
self._inference_config = (
105+
{"maxTokens": 4096} if inference_config is None else inference_config
106+
)
107+
self._max_retries = max_retries
108+
self._client = client
109+
110+
@property
111+
def reads(self) -> list[str]:
112+
return self._reads
113+
114+
@property
115+
def writes(self) -> list[str]:
116+
return self._writes
117+
118+
@property
119+
def name(self) -> str:
120+
return self._name
121+
122+
def get_client(self) -> BedrockClient:
123+
"""Return the Bedrock runtime client, creating it lazily if not injected."""
124+
if self._client is not None:
125+
return self._client
126+
config = Config(retries={"max_attempts": self._max_retries, "mode": "adaptive"})
127+
self._client = boto3.client("bedrock-runtime", region_name=self._region, config=config)
128+
return self._client
129+
130+
def build_converse_request(self, state: State) -> dict[str, Any]:
131+
"""Build the kwargs dict for ``converse`` / ``converse_stream`` from current state."""
132+
prompt = self._input_mapper(state)
133+
request: dict[str, Any] = {
134+
"modelId": self._model_id,
135+
"messages": prompt["messages"],
136+
"inferenceConfig": self._inference_config,
137+
}
138+
if "system" in prompt:
139+
request["system"] = prompt["system"]
140+
if self._guardrail_id:
141+
request["guardrailConfig"] = {
142+
"guardrailIdentifier": self._guardrail_id,
143+
"guardrailVersion": self._guardrail_version,
144+
}
145+
return request
146+
147+
148+
class BedrockAction(SingleStepAction):
149+
"""Action that invokes Amazon Bedrock models using the Converse API.
150+
151+
:param model_id: Bedrock model identifier (e.g. Anthropic Claude on Bedrock).
152+
:param input_mapper: Callable mapping :class:`~burr.core.state.State` to Bedrock
153+
``messages`` / optional ``system`` keys.
154+
:param reads: State keys this action reads.
155+
:param writes: State keys to update (typically include ``response``).
156+
:param name: Action name for the graph.
157+
:param region: AWS region for the Bedrock runtime client (optional).
158+
:param guardrail_id: If set, ``guardrail_version`` must also be set explicitly.
159+
:param guardrail_version: Guardrail version string (use ``DRAFT`` only when intended).
160+
:param inference_config: Passed as ``inferenceConfig``; if omitted, defaults to
161+
a ``maxTokens`` limit. Pass an empty dict explicitly to send an empty config.
162+
:param max_retries: Botocore retry configuration for the runtime client.
163+
:param client: Optional pre-built ``bedrock-runtime`` client (for tests or injection).
164+
165+
Use :meth:`run_and_update` to run the model and merge outputs into state.
166+
"""
167+
168+
def __init__(
169+
self,
170+
model_id: str,
171+
input_mapper: StateToPromptMapper,
172+
reads: list[str],
173+
writes: list[str],
174+
name: str = "bedrock_invoke",
175+
region: Optional[str] = None,
176+
guardrail_id: Optional[str] = None,
177+
guardrail_version: Optional[str] = None,
178+
inference_config: Optional[dict[str, Any]] = None,
179+
max_retries: int = 3,
180+
client: Optional[BedrockClient] = None,
181+
):
182+
super().__init__()
183+
self._bedrock = _BedrockCore(
184+
model_id=model_id,
185+
input_mapper=input_mapper,
186+
reads=reads,
187+
writes=writes,
188+
name=name,
189+
region=region,
190+
guardrail_id=guardrail_id,
191+
guardrail_version=guardrail_version,
192+
inference_config=inference_config,
193+
max_retries=max_retries,
194+
client=client,
195+
)
196+
197+
@property
198+
def reads(self) -> list[str]:
199+
return self._bedrock.reads
200+
201+
@property
202+
def writes(self) -> list[str]:
203+
return self._bedrock.writes
204+
205+
@property
206+
def name(self) -> str:
207+
return self._bedrock.name
208+
209+
def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
210+
request = self._bedrock.build_converse_request(state)
211+
212+
try:
213+
response = self._bedrock.get_client().converse(**request)
214+
except ClientError as e:
215+
logger.error("Bedrock API error: %s", e)
216+
raise
217+
218+
output_message = response["output"]["message"]
219+
content_blocks = output_message.get("content", [])
220+
text = content_blocks[0]["text"] if content_blocks else ""
221+
222+
result: dict[str, Any] = {
223+
"response": text,
224+
"usage": response.get("usage", {}),
225+
"stop_reason": response.get("stopReason"),
226+
}
227+
228+
updates = {key: result[key] for key in self._bedrock.writes if key in result}
229+
new_state = state.update(**updates)
230+
231+
return result, new_state
232+
233+
234+
class BedrockStreamingAction(StreamingAction):
235+
"""Streaming Bedrock action using the Converse Stream API.
236+
237+
Parameters match :class:`BedrockAction` except the default ``name`` is
238+
``bedrock_stream``. Yields chunk dicts from :meth:`stream_run` and merges the
239+
final response in :meth:`update`.
240+
"""
241+
242+
def __init__(
243+
self,
244+
model_id: str,
245+
input_mapper: StateToPromptMapper,
246+
reads: list[str],
247+
writes: list[str],
248+
name: str = "bedrock_stream",
249+
region: Optional[str] = None,
250+
guardrail_id: Optional[str] = None,
251+
guardrail_version: Optional[str] = None,
252+
inference_config: Optional[dict[str, Any]] = None,
253+
max_retries: int = 3,
254+
client: Optional[BedrockClient] = None,
255+
):
256+
super().__init__()
257+
self._bedrock = _BedrockCore(
258+
model_id=model_id,
259+
input_mapper=input_mapper,
260+
reads=reads,
261+
writes=writes,
262+
name=name,
263+
region=region,
264+
guardrail_id=guardrail_id,
265+
guardrail_version=guardrail_version,
266+
inference_config=inference_config,
267+
max_retries=max_retries,
268+
client=client,
269+
)
270+
271+
@property
272+
def reads(self) -> list[str]:
273+
return self._bedrock.reads
274+
275+
@property
276+
def writes(self) -> list[str]:
277+
return self._bedrock.writes
278+
279+
@property
280+
def name(self) -> str:
281+
return self._bedrock.name
282+
283+
def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]:
284+
request = self._bedrock.build_converse_request(state)
285+
286+
try:
287+
response = self._bedrock.get_client().converse_stream(**request)
288+
except ClientError as e:
289+
logger.error("Bedrock streaming API error: %s", e)
290+
raise
291+
292+
full_response = ""
293+
stream = response.get("stream", [])
294+
for event in stream:
295+
if "contentBlockDelta" in event:
296+
chunk = event["contentBlockDelta"]["delta"].get("text", "")
297+
full_response += chunk
298+
yield {"chunk": chunk, "response": full_response}
299+
300+
yield {"chunk": "", "response": full_response, "complete": True}
301+
302+
def update(self, result: dict, state: State) -> State:
303+
if result.get("complete"):
304+
updates = {"response": result.get("response", "")}
305+
filtered = {k: v for k, v in updates.items() if k in self._bedrock.writes}
306+
return state.update(**filtered)
307+
return state

docs/getting_started/install.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,9 @@ This installs the server dependencies to run the UI and load tracking that was s
174174
pip install "burr[tracking-server]"
175175
176176
This installs the server dependencies for running the UI off a filesystem.
177+
178+
.. code-block:: bash
179+
180+
pip install "burr[bedrock]"
181+
182+
This installs ``boto3`` for the :ref:`Amazon Bedrock integration <bedrock-integration>` (``BedrockAction`` / ``BedrockStreamingAction``).

0 commit comments

Comments
 (0)