forked from strands-agents/sdk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplugin.py
More file actions
328 lines (275 loc) · 13.4 KB
/
plugin.py
File metadata and controls
328 lines (275 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
"""ContextOffloader plugin for managing large tool outputs.
This module provides the ContextOffloader plugin that intercepts oversized
tool results, persists each content block to a storage backend, and replaces
the in-context result with a truncated preview and per-block references.
Example:
```python
from strands import Agent
from strands.vended_plugins.context_offloader import (
ContextOffloader,
InMemoryStorage,
FileStorage,
)
# In-memory storage
agent = Agent(plugins=[
ContextOffloader(storage=InMemoryStorage())
])
# File storage with custom thresholds and retrieval tool enabled
agent = Agent(plugins=[
ContextOffloader(
storage=FileStorage("./artifacts"),
max_result_tokens=5_000,
preview_tokens=2_000,
include_retrieval_tool=True,
)
])
```
"""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING
from ...hooks.events import AfterToolCallEvent
from ...plugins import Plugin, hook
from ...tools.decorator import tool
from ...types.content import Message
from ...types.tools import ToolContext, ToolResult, ToolResultContent
from .storage import Storage
if TYPE_CHECKING:
from ...agent.agent import Agent
logger = logging.getLogger(__name__)
_DEFAULT_MAX_RESULT_TOKENS = 2_500
"""Default token threshold above which tool results are offloaded."""
_DEFAULT_PREVIEW_TOKENS = 1_000
"""Default number of tokens to keep as a preview in context."""
_CHARS_PER_TOKEN = 4
"""Approximate characters per token, fallback for preview slicing without tiktoken."""
class ContextOffloader(Plugin):
"""Plugin that offloads oversized tool results to reduce context consumption.
When a tool result exceeds the configured token threshold, this plugin
stores each content block individually to a storage backend and replaces
the in-context result with a truncated text preview plus per-block references.
Token estimation uses the agent's model ``count_tokens`` method, which
leverages tiktoken when available and falls back to character-based heuristics.
Content type handling:
- **Text**: stored as ``text/plain``, replaced with a preview
- **JSON**: stored as ``application/json``, replaced with a preview
- **Image**: stored in its native format (e.g., ``image/png``), replaced with a
placeholder showing format and size
- **Document**: stored in its native format (e.g., ``application/pdf``), replaced
with a placeholder showing format, name, and size
- **Unknown types**: passed through unchanged
This operates proactively at tool execution time via ``AfterToolCallEvent``,
before the result enters the conversation — unlike ``SlidingWindowConversationManager``
which truncates reactively after context overflow.
Args:
storage: Backend for storing offloaded content (required).
max_result_tokens: Offload results whose estimated token count exceeds this threshold.
preview_tokens: Number of tokens to keep as a text preview in context.
include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` tool.
Defaults to True.
Example:
```python
from strands import Agent
from strands.vended_plugins.context_offloader import ContextOffloader, InMemoryStorage
agent = Agent(plugins=[
ContextOffloader(storage=InMemoryStorage())
])
```
"""
name = "context_offloader"
def __init__(
self,
storage: Storage,
max_result_tokens: int = _DEFAULT_MAX_RESULT_TOKENS,
preview_tokens: int = _DEFAULT_PREVIEW_TOKENS,
*,
include_retrieval_tool: bool = True,
) -> None:
"""Initialize the ContextOffloader plugin.
Args:
storage: Backend for storing offloaded content.
max_result_tokens: Offload results whose estimated token count exceeds this
threshold. Defaults to ``_DEFAULT_MAX_RESULT_TOKENS`` (2,500).
preview_tokens: Number of tokens to keep as a text preview in context.
Uses tiktoken for exact slicing when available, falls back to
chars/4 heuristic. Defaults to ``_DEFAULT_PREVIEW_TOKENS`` (1,000).
include_retrieval_tool: Whether to register the ``retrieve_offloaded_content``
tool so the agent can fetch offloaded content. Defaults to True.
Raises:
ValueError: If max_result_tokens is not positive, preview_tokens is negative,
or preview_tokens >= max_result_tokens.
"""
if max_result_tokens <= 0:
raise ValueError("max_result_tokens must be positive")
if preview_tokens < 0:
raise ValueError("preview_tokens must be non-negative")
if preview_tokens >= max_result_tokens:
raise ValueError("preview_tokens must be less than max_result_tokens")
self._storage = storage
self._max_result_tokens = max_result_tokens
self._preview_tokens = preview_tokens
self._include_retrieval_tool = include_retrieval_tool
super().__init__()
def init_agent(self, agent: Agent) -> None:
"""Conditionally register the retrieval tool."""
if not self._include_retrieval_tool:
# Remove the auto-discovered retrieval tool
self._tools = [t for t in self._tools if t.tool_name != "retrieve_offloaded_content"]
@tool(context=True)
def retrieve_offloaded_content(
self,
reference: str,
tool_context: ToolContext,
) -> dict | str:
"""Retrieve offloaded content by reference.
Use this tool when you see a placeholder with a reference (ref: ...)
and need the full content. Only use this as a fallback if the data
cannot be accessed using your existing tools.
Args:
reference: The reference string from the offload placeholder.
tool_context: Injected by the framework. Not user-facing.
"""
try:
content_bytes, content_type = self._storage.retrieve(reference)
except KeyError:
return f"Error: reference not found: {reference}"
if content_type.startswith("text/"):
return content_bytes.decode("utf-8")
if content_type == "application/json":
return {"status": "success", "content": [{"json": json.loads(content_bytes)}]}
if content_type.startswith("image/"):
img_format = content_type.split("/")[-1]
return {
"status": "success",
"content": [{"image": {"format": img_format, "source": {"bytes": content_bytes}}}],
}
if content_type.startswith("application/"):
doc_format = content_type.split("/")[-1]
doc_block = {"format": doc_format, "name": reference, "source": {"bytes": content_bytes}}
return {"status": "success", "content": [{"document": doc_block}]}
return content_bytes.decode("utf-8", errors="replace")
@hook
async def _handle_tool_result(self, event: AfterToolCallEvent) -> None:
"""Intercept oversized tool results, offload per-block, and replace with preview."""
if event.cancel_message is not None:
return
if self._include_retrieval_tool and event.tool_use.get("name") == self.retrieve_offloaded_content.tool_name:
return
result = event.result
content = result["content"]
tool_use_id = event.tool_use["toolUseId"]
# Estimate token count by wrapping the tool result as a message for count_tokens
tool_result_message: Message = {"role": "user", "content": [{"toolResult": result}]}
token_count = await event.agent.model.count_tokens([tool_result_message])
if token_count <= self._max_result_tokens:
return
# Build text preview from text+JSON blocks.
# Empty text blocks are intentionally excluded — they add no content value.
text_preview_parts: list[str] = []
for block in content:
if block.get("text"):
text_preview_parts.append(block["text"])
elif "json" in block:
text_preview_parts.append(json.dumps(block["json"], indent=2))
full_text = "\n".join(text_preview_parts) if text_preview_parts else ""
# Store each content block individually
references: list[tuple[str, str, str]] = [] # (ref, content_type, description)
try:
for i, block in enumerate(content):
key = f"{tool_use_id}_{i}"
if block.get("text"):
ref = self._storage.store(key, block["text"].encode("utf-8"), "text/plain")
references.append((ref, "text/plain", f"text, {len(block['text']):,} chars"))
elif "json" in block:
json_bytes = json.dumps(block["json"], indent=2).encode("utf-8")
ref = self._storage.store(key, json_bytes, "application/json")
references.append((ref, "application/json", f"json, {len(json_bytes):,} bytes"))
elif "image" in block:
image = block["image"]
img_format = image.get("format", "unknown")
img_bytes = image.get("source", {}).get("bytes", b"")
if img_bytes:
ref = self._storage.store(key, img_bytes, f"image/{img_format}")
references.append((ref, f"image/{img_format}", f"image/{img_format}, {len(img_bytes):,} bytes"))
else:
references.append(("", f"image/{img_format}", f"image/{img_format}, 0 bytes"))
elif "document" in block:
doc = block["document"]
doc_format = doc.get("format", "unknown")
doc_name = doc.get("name", "unknown")
doc_bytes = doc.get("source", {}).get("bytes", b"")
if doc_bytes:
ref = self._storage.store(key, doc_bytes, f"application/{doc_format}")
references.append((ref, f"application/{doc_format}", f"{doc_name}, {len(doc_bytes):,} bytes"))
else:
references.append(("", f"application/{doc_format}", f"{doc_name}, 0 bytes"))
except Exception:
logger.warning(
"tool_use_id=<%s> | failed to offload tool result, keeping original",
tool_use_id,
exc_info=True,
)
return
logger.debug(
"tool_use_id=<%s>, blocks=<%d>, tokens=<%d> | tool result offloaded",
tool_use_id,
len(references),
token_count,
)
# Build preview text — use tiktoken for exact slicing when available
preview = self._slice_preview(full_text) if full_text else ""
ref_lines = "\n".join(f" {ref} ({desc})" for ref, _, desc in references if ref)
guidance = (
"Tool result was offloaded to external storage due to size.\n"
"Use the preview below to answer if possible.\n"
"Use your available tools to selectively access the data you need."
)
if self._include_retrieval_tool:
guidance += "\nYou can also use retrieve_offloaded_content with a reference to get the full content."
preview_text = (
f"[Offloaded: {len(content)} blocks, ~{token_count:,} tokens]\n"
f"{guidance}\n\n"
f"{preview}\n\n"
f"[Stored references:]\n{ref_lines}"
)
# Build new content with preview + placeholders for non-text blocks
new_content: list[ToolResultContent] = [ToolResultContent(text=preview_text)]
for i, block in enumerate(content):
ref = references[i][0] if i < len(references) else ""
if "text" in block or "json" in block:
continue
elif "image" in block:
image = block["image"]
img_format = image.get("format", "unknown")
img_bytes = image.get("source", {}).get("bytes", b"")
placeholder = f"[image: {img_format}, {len(img_bytes) if img_bytes else 0} bytes"
if ref:
placeholder += f" | ref: {ref}"
placeholder += "]"
new_content.append(ToolResultContent(text=placeholder))
elif "document" in block:
doc = block["document"]
doc_format = doc.get("format", "unknown")
doc_name = doc.get("name", "unknown")
doc_bytes = doc.get("source", {}).get("bytes", b"")
placeholder = f"[document: {doc_format}, {doc_name}, {len(doc_bytes) if doc_bytes else 0} bytes"
if ref:
placeholder += f" | ref: {ref}"
placeholder += "]"
new_content.append(ToolResultContent(text=placeholder))
else:
new_content.append(block)
event.result = ToolResult(
toolUseId=result["toolUseId"],
status=result["status"],
content=new_content,
)
def _slice_preview(self, text: str) -> str:
"""Slice text to approximately preview_tokens using character-based estimation.
Args:
text: The full text to slice.
Returns:
The preview text.
"""
return text[: self._preview_tokens * _CHARS_PER_TOKEN]