Skip to content

Commit e88b276

Browse files
authored
feat(offloader): return explicit paths in preview and auto-enable retrieval (#2222)
1 parent 888c98c commit e88b276

4 files changed

Lines changed: 131 additions & 22 deletions

File tree

src/strands/vended_plugins/context_offloader/plugin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class ContextOffloader(Plugin):
8888
max_result_tokens: Offload results whose estimated token count exceeds this threshold.
8989
preview_tokens: Number of tokens to keep as a text preview in context.
9090
include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` tool.
91-
Defaults to False.
91+
Defaults to True.
9292
9393
Example:
9494
```python
@@ -109,7 +109,7 @@ def __init__(
109109
max_result_tokens: int = _DEFAULT_MAX_RESULT_TOKENS,
110110
preview_tokens: int = _DEFAULT_PREVIEW_TOKENS,
111111
*,
112-
include_retrieval_tool: bool = False,
112+
include_retrieval_tool: bool = True,
113113
) -> None:
114114
"""Initialize the ContextOffloader plugin.
115115
@@ -121,7 +121,7 @@ def __init__(
121121
Uses tiktoken for exact slicing when available, falls back to
122122
chars/4 heuristic. Defaults to ``_DEFAULT_PREVIEW_TOKENS`` (1,000).
123123
include_retrieval_tool: Whether to register the ``retrieve_offloaded_content``
124-
tool so the agent can fetch offloaded content. Defaults to False.
124+
tool so the agent can fetch offloaded content. Defaults to True.
125125
126126
Raises:
127127
ValueError: If max_result_tokens is not positive, preview_tokens is negative,
@@ -155,7 +155,8 @@ def retrieve_offloaded_content(
155155
"""Retrieve offloaded content by reference.
156156
157157
Use this tool when you see a placeholder with a reference (ref: ...)
158-
and need the full content.
158+
and need the full content. Only use this as a fallback if the data
159+
cannot be accessed using your existing tools.
159160
160161
Args:
161162
reference: The reference string from the offload placeholder.

src/strands/vended_plugins/context_offloader/storage.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,19 @@ def _extension_for(content_type: str) -> str:
131131
return f".{content_type.split('/')[-1]}"
132132

133133
def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str:
134-
"""Store content as a file and return the filename as reference.
134+
"""Store content as a file and return the path as reference.
135+
136+
The returned path preserves the form of ``artifact_dir`` passed to
137+
the constructor: a relative ``artifact_dir`` yields a relative
138+
reference, an absolute one yields an absolute reference.
135139
136140
Args:
137141
key: A unique key for this content block.
138142
content: The raw content bytes to store.
139143
content_type: MIME type of the content.
140144
141145
Returns:
142-
The filename (not full path) used as the reference.
146+
The file path (e.g., ``./artifacts/1234_1_key.txt``).
143147
"""
144148
self._artifact_dir.mkdir(parents=True, exist_ok=True)
145149

@@ -156,26 +160,34 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s
156160
file_path = self._artifact_dir / filename
157161
file_path.write_bytes(content)
158162

159-
return filename
163+
return str(file_path)
160164

161165
def retrieve(self, reference: str) -> tuple[bytes, str]:
162166
"""Retrieve content from a stored file.
163167
168+
Accepts both full paths (as returned by ``store()``) and bare
169+
filenames for backward compatibility.
170+
164171
Args:
165-
reference: The filename reference returned by store().
172+
reference: The file path or filename returned by store().
166173
167174
Returns:
168175
A tuple of (content bytes, content type).
169176
170177
Raises:
171178
KeyError: If the file does not exist.
172179
"""
173-
file_path = (self._artifact_dir / reference).resolve()
174-
if not file_path.is_relative_to(self._artifact_dir.resolve()):
180+
resolved_dir = self._artifact_dir.resolve()
181+
ref_path = Path(reference)
182+
file_path = ref_path.resolve() if len(ref_path.parts) > 1 else (self._artifact_dir / reference).resolve()
183+
if not file_path.is_relative_to(resolved_dir):
184+
file_path = (self._artifact_dir / reference).resolve()
185+
if not file_path.is_relative_to(resolved_dir):
175186
raise KeyError(f"Reference not found: {reference}")
176187
if not file_path.is_file():
177188
raise KeyError(f"Reference not found: {reference}")
178-
content_type = self._content_types.get(reference, "application/octet-stream")
189+
filename = file_path.name
190+
content_type = self._content_types.get(filename, "application/octet-stream")
179191
return file_path.read_bytes(), content_type
180192

181193
def _load_metadata(self) -> dict[str, str]:
@@ -320,15 +332,15 @@ def __init__(
320332
self._lock = threading.Lock()
321333

322334
def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str:
323-
"""Store content as an S3 object and return the object key as reference.
335+
"""Store content as an S3 object and return an ``s3://`` URI as reference.
324336
325337
Args:
326338
key: A unique key for this content block.
327339
content: The raw content bytes to store.
328340
content_type: MIME type of the content.
329341
330342
Returns:
331-
The S3 object key used as the reference.
343+
An S3 URI (e.g., ``s3://bucket/prefix/1234_1_key``).
332344
333345
Raises:
334346
botocore.exceptions.ClientError: If the S3 operation fails (e.g., bucket
@@ -348,22 +360,31 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s
348360
ContentType=content_type,
349361
)
350362

351-
return s3_key
363+
return f"s3://{self._bucket}/{s3_key}"
352364

353365
def retrieve(self, reference: str) -> tuple[bytes, str]:
354366
"""Retrieve content from an S3 object.
355367
368+
Accepts both ``s3://`` URIs (as returned by ``store()``) and raw
369+
S3 keys for backward compatibility.
370+
356371
Args:
357-
reference: The S3 object key returned by store().
372+
reference: The S3 URI or object key returned by store().
358373
359374
Returns:
360375
A tuple of (content bytes, content type).
361376
362377
Raises:
363378
KeyError: If the object does not exist.
364379
"""
380+
s3_key = reference
381+
if reference.startswith("s3://"):
382+
expected_prefix = f"s3://{self._bucket}/"
383+
if not reference.startswith(expected_prefix):
384+
raise KeyError(f"Reference not found: {reference}")
385+
s3_key = reference[len(expected_prefix) :]
365386
try:
366-
response = self._client.get_object(Bucket=self._bucket, Key=reference)
387+
response = self._client.get_object(Bucket=self._bucket, Key=s3_key)
367388
content: bytes = response["Body"].read()
368389
content_type: str = response.get("ContentType", "application/octet-stream")
369390
return content, content_type

tests/strands/vended_plugins/context_offloader/test_plugin.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from strands.types.tools import ToolContext, ToolUse
1212
from strands.vended_plugins.context_offloader import (
1313
ContextOffloader,
14+
FileStorage,
1415
InMemoryStorage,
1516
)
1617

@@ -26,6 +27,7 @@ def plugin(storage):
2627
storage=storage,
2728
max_result_tokens=25,
2829
preview_tokens=10,
30+
include_retrieval_tool=False,
2931
)
3032

3133

@@ -466,10 +468,16 @@ def test_retrieval_tool_registered_when_enabled(self, plugin):
466468
tool_names = [t.tool_name for t in plugin.tools]
467469
assert "retrieve_offloaded_content" in tool_names
468470

469-
def test_retrieval_tool_not_registered_by_default(self):
471+
def test_retrieval_tool_registered_by_default(self):
470472
plugin = ContextOffloader(storage=InMemoryStorage())
471473
plugin.init_agent(MagicMock())
472474
tool_names = [t.tool_name for t in plugin.tools]
475+
assert "retrieve_offloaded_content" in tool_names
476+
477+
def test_retrieval_tool_not_registered_when_disabled(self):
478+
plugin = ContextOffloader(storage=InMemoryStorage(), include_retrieval_tool=False)
479+
plugin.init_agent(MagicMock())
480+
tool_names = [t.tool_name for t in plugin.tools]
473481
assert "retrieve_offloaded_content" not in tool_names
474482

475483
def test_retrieve_text_content(self, plugin, storage, tool_context):
@@ -531,9 +539,53 @@ async def test_guidance_mentions_retrieval_tool_when_enabled(self, storage, mock
531539

532540
@pytest.mark.asyncio
533541
async def test_guidance_does_not_mention_retrieval_tool_when_disabled(self, storage, mock_agent):
534-
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
542+
plugin = ContextOffloader(
543+
storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=False
544+
)
535545
event = _make_event(mock_agent, "x" * 200)
536546
await plugin._handle_tool_result(event)
537547
result_text = event.result["content"][0]["text"]
538548
assert "retrieve_offloaded_content" not in result_text
539549
assert "available tools" in result_text
550+
551+
552+
class TestActionableReferences:
553+
"""Tests that storage-specific references appear in the offloaded preview."""
554+
555+
@pytest.mark.asyncio
556+
async def test_file_storage_path_in_preview(self, tmp_path, mock_agent):
557+
storage = FileStorage(artifact_dir=str(tmp_path / "artifacts"))
558+
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
559+
event = _make_event(mock_agent, "a" * 200)
560+
561+
await plugin._handle_tool_result(event)
562+
563+
result_text = event.result["content"][0]["text"]
564+
assert str(tmp_path / "artifacts") in result_text
565+
566+
@pytest.mark.asyncio
567+
async def test_file_storage_image_placeholder_has_path(self, tmp_path, mock_agent):
568+
storage = FileStorage(artifact_dir=str(tmp_path / "artifacts"))
569+
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
570+
img_bytes = b"\x89PNG" + b"\x00" * 100
571+
content = [
572+
{"text": "x" * 200},
573+
{"image": {"format": "png", "source": {"bytes": img_bytes}}},
574+
]
575+
event = _make_event(mock_agent, content)
576+
577+
await plugin._handle_tool_result(event)
578+
579+
placeholder = event.result["content"][1]["text"]
580+
assert str(tmp_path / "artifacts") in placeholder
581+
582+
@pytest.mark.asyncio
583+
async def test_inmemory_storage_opaque_reference_in_preview(self, mock_agent):
584+
storage = InMemoryStorage()
585+
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
586+
event = _make_event(mock_agent, "a" * 200)
587+
588+
await plugin._handle_tool_result(event)
589+
590+
result_text = event.result["content"][0]["text"]
591+
assert "mem_" in result_text

tests/strands/vended_plugins/context_offloader/test_storage.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for offload storage backends."""
22

33
import threading
4+
from pathlib import Path
45
from unittest.mock import MagicMock, patch
56

67
import pytest
@@ -147,7 +148,30 @@ def test_sanitizes_path_traversal(self, tmp_path):
147148
storage = FileStorage(artifact_dir=str(tmp_path))
148149
ref = storage.store("../../etc/passwd", b"content")
149150
assert ".." not in ref
150-
assert "/" not in ref
151+
assert "/" not in Path(ref).name
152+
153+
def test_reference_includes_artifact_dir(self, tmp_path):
154+
artifact_dir = str(tmp_path / "artifacts")
155+
storage = FileStorage(artifact_dir=artifact_dir)
156+
ref = storage.store("key_1", b"content")
157+
assert Path(ref).parent == Path(artifact_dir)
158+
159+
def test_relative_artifact_dir_gives_relative_reference(self, tmp_path, monkeypatch):
160+
monkeypatch.chdir(tmp_path)
161+
storage = FileStorage(artifact_dir="./artifacts")
162+
ref = storage.store("key_1", b"content")
163+
assert Path(ref).parent == Path("artifacts")
164+
content, content_type = storage.retrieve(ref)
165+
assert content == b"content"
166+
assert content_type == "text/plain"
167+
168+
def test_retrieve_accepts_bare_filename(self, tmp_path):
169+
storage = FileStorage(artifact_dir=str(tmp_path))
170+
ref = storage.store("key_1", b"hello world")
171+
filename = Path(ref).name
172+
content, content_type = storage.retrieve(filename)
173+
assert content == b"hello world"
174+
assert content_type == "text/plain"
151175

152176
def test_metadata_survives_across_instances(self, tmp_path):
153177
artifact_dir = str(tmp_path / "artifacts")
@@ -233,9 +257,9 @@ def test_unique_references(self, storage):
233257
assert storage.retrieve(ref1)[0] == b"content a"
234258
assert storage.retrieve(ref2)[0] == b"content b"
235259

236-
def test_reference_includes_prefix(self, storage):
260+
def test_reference_is_s3_uri(self, storage):
237261
ref = storage.store("tool_abc", b"content")
238-
assert ref.startswith("artifacts/")
262+
assert ref.startswith("s3://test-bucket/artifacts/")
239263

240264
def test_empty_prefix(self, mock_s3_client):
241265
with patch("boto3.Session") as mock_session_cls:
@@ -245,9 +269,20 @@ def test_empty_prefix(self, mock_s3_client):
245269
storage = S3Storage(bucket="test-bucket", prefix="")
246270

247271
ref = storage.store("tool_abc", b"content")
248-
assert not ref.startswith("/")
272+
assert ref.startswith("s3://test-bucket/")
249273
assert storage.retrieve(ref)[0] == b"content"
250274

275+
def test_retrieve_accepts_raw_key(self, storage, mock_s3_client):
276+
ref = storage.store("key_1", b"hello world")
277+
raw_key = ref.removeprefix("s3://test-bucket/")
278+
content, content_type = storage.retrieve(raw_key)
279+
assert content == b"hello world"
280+
assert content_type == "text/plain"
281+
282+
def test_retrieve_rejects_wrong_bucket_uri(self, storage):
283+
with pytest.raises(KeyError, match="Reference not found"):
284+
storage.retrieve("s3://wrong-bucket/artifacts/some_key")
285+
251286
def test_put_object_called_with_correct_params(self, storage, mock_s3_client):
252287
storage.store("key_1", b"test content", "application/json")
253288

0 commit comments

Comments
 (0)