Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/strands/vended_plugins/context_offloader/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
max_result_tokens: Offload results whose estimated token count exceeds this threshold.
preview_tokens: Number of tokens to keep as a text preview in context.
include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` tool.
Defaults to False.
Defaults to True.

Example:
```python
Expand All @@ -103,13 +103,13 @@

name = "context_offloader"

def __init__(

Check warning on line 106 in src/strands/vended_plugins/context_offloader/plugin.py

View workflow job for this annotation

GitHub Actions / check-api

ContextOffloader.__init__(include_retrieval_tool)

Parameter default was changed: `False` -> `True`
self,
storage: Storage,
max_result_tokens: int = _DEFAULT_MAX_RESULT_TOKENS,
preview_tokens: int = _DEFAULT_PREVIEW_TOKENS,
*,
include_retrieval_tool: bool = False,
include_retrieval_tool: bool = True,
) -> None:
"""Initialize the ContextOffloader plugin.

Expand All @@ -121,7 +121,7 @@
Uses tiktoken for exact slicing when available, falls back to
chars/4 heuristic. Defaults to ``_DEFAULT_PREVIEW_TOKENS`` (1,000).
include_retrieval_tool: Whether to register the ``retrieve_offloaded_content``
tool so the agent can fetch offloaded content. Defaults to False.
tool so the agent can fetch offloaded content. Defaults to True.

Raises:
ValueError: If max_result_tokens is not positive, preview_tokens is negative,
Expand Down Expand Up @@ -273,7 +273,10 @@
"Use your available tools to selectively access the data you need."
)
if self._include_retrieval_tool:
guidance += "\nYou can also use retrieve_offloaded_content with a reference to get the full content."
guidance += (
"\nOnly use retrieve_offloaded_content as a fallback"
" if the data cannot be accessed using your existing tools."
Comment thread
lizradway marked this conversation as resolved.
Outdated
)

preview_text = (
f"[Offloaded: {len(content)} blocks, ~{token_count:,} tokens]\n"
Expand Down
45 changes: 33 additions & 12 deletions src/strands/vended_plugins/context_offloader/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,19 @@ def _extension_for(content_type: str) -> str:
return f".{content_type.split('/')[-1]}"

def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str:
"""Store content as a file and return the filename as reference.
"""Store content as a file and return the path as reference.

The returned path preserves the form of ``artifact_dir`` passed to
the constructor: a relative ``artifact_dir`` yields a relative
reference, an absolute one yields an absolute reference.

Args:
key: A unique key for this content block.
content: The raw content bytes to store.
content_type: MIME type of the content.

Returns:
The filename (not full path) used as the reference.
The file path (e.g., ``./artifacts/1234_1_key.txt``).
"""
self._artifact_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -156,26 +160,34 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s
file_path = self._artifact_dir / filename
file_path.write_bytes(content)

return filename
return str(file_path)

def retrieve(self, reference: str) -> tuple[bytes, str]:
"""Retrieve content from a stored file.

Accepts both full paths (as returned by ``store()``) and bare
filenames for backward compatibility.

Args:
reference: The filename reference returned by store().
reference: The file path or filename returned by store().

Returns:
A tuple of (content bytes, content type).

Raises:
KeyError: If the file does not exist.
"""
file_path = (self._artifact_dir / reference).resolve()
if not file_path.is_relative_to(self._artifact_dir.resolve()):
resolved_dir = self._artifact_dir.resolve()
ref_path = Path(reference)
file_path = ref_path.resolve() if len(ref_path.parts) > 1 else (self._artifact_dir / reference).resolve()
Comment thread
lizradway marked this conversation as resolved.
if not file_path.is_relative_to(resolved_dir):
file_path = (self._artifact_dir / reference).resolve()
if not file_path.is_relative_to(resolved_dir):
raise KeyError(f"Reference not found: {reference}")
if not file_path.is_file():
raise KeyError(f"Reference not found: {reference}")
content_type = self._content_types.get(reference, "application/octet-stream")
filename = file_path.name
content_type = self._content_types.get(filename, "application/octet-stream")
return file_path.read_bytes(), content_type

def _load_metadata(self) -> dict[str, str]:
Expand Down Expand Up @@ -320,15 +332,15 @@ def __init__(
self._lock = threading.Lock()

def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str:
"""Store content as an S3 object and return the object key as reference.
"""Store content as an S3 object and return an ``s3://`` URI as reference.

Args:
key: A unique key for this content block.
content: The raw content bytes to store.
content_type: MIME type of the content.

Returns:
The S3 object key used as the reference.
An S3 URI (e.g., ``s3://bucket/prefix/1234_1_key``).

Raises:
botocore.exceptions.ClientError: If the S3 operation fails (e.g., bucket
Expand All @@ -348,22 +360,31 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s
ContentType=content_type,
)

return s3_key
return f"s3://{self._bucket}/{s3_key}"

def retrieve(self, reference: str) -> tuple[bytes, str]:
"""Retrieve content from an S3 object.

Accepts both ``s3://`` URIs (as returned by ``store()``) and raw
S3 keys for backward compatibility.

Args:
reference: The S3 object key returned by store().
reference: The S3 URI or object key returned by store().

Returns:
A tuple of (content bytes, content type).

Raises:
KeyError: If the object does not exist.
"""
s3_key = reference
if reference.startswith("s3://"):
expected_prefix = f"s3://{self._bucket}/"
if not reference.startswith(expected_prefix):
raise KeyError(f"Reference not found: {reference}")
s3_key = reference[len(expected_prefix) :]
try:
response = self._client.get_object(Bucket=self._bucket, Key=reference)
response = self._client.get_object(Bucket=self._bucket, Key=s3_key)
content: bytes = response["Body"].read()
content_type: str = response.get("ContentType", "application/octet-stream")
return content, content_type
Expand Down
56 changes: 54 additions & 2 deletions tests/strands/vended_plugins/context_offloader/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from strands.types.tools import ToolContext, ToolUse
from strands.vended_plugins.context_offloader import (
ContextOffloader,
FileStorage,
InMemoryStorage,
)

Expand All @@ -26,6 +27,7 @@ def plugin(storage):
storage=storage,
max_result_tokens=25,
preview_tokens=10,
include_retrieval_tool=False,
)


Expand Down Expand Up @@ -466,10 +468,16 @@ def test_retrieval_tool_registered_when_enabled(self, plugin):
tool_names = [t.tool_name for t in plugin.tools]
assert "retrieve_offloaded_content" in tool_names

def test_retrieval_tool_not_registered_by_default(self):
def test_retrieval_tool_registered_by_default(self):
plugin = ContextOffloader(storage=InMemoryStorage())
plugin.init_agent(MagicMock())
tool_names = [t.tool_name for t in plugin.tools]
assert "retrieve_offloaded_content" in tool_names

def test_retrieval_tool_not_registered_when_disabled(self):
plugin = ContextOffloader(storage=InMemoryStorage(), include_retrieval_tool=False)
plugin.init_agent(MagicMock())
tool_names = [t.tool_name for t in plugin.tools]
assert "retrieve_offloaded_content" not in tool_names

def test_retrieve_text_content(self, plugin, storage, tool_context):
Expand Down Expand Up @@ -531,9 +539,53 @@ async def test_guidance_mentions_retrieval_tool_when_enabled(self, storage, mock

@pytest.mark.asyncio
async def test_guidance_does_not_mention_retrieval_tool_when_disabled(self, storage, mock_agent):
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
plugin = ContextOffloader(
storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=False
)
event = _make_event(mock_agent, "x" * 200)
await plugin._handle_tool_result(event)
result_text = event.result["content"][0]["text"]
assert "retrieve_offloaded_content" not in result_text
assert "available tools" in result_text


class TestActionableReferences:
"""Tests that storage-specific references appear in the offloaded preview."""

@pytest.mark.asyncio
async def test_file_storage_path_in_preview(self, tmp_path, mock_agent):
storage = FileStorage(artifact_dir=str(tmp_path / "artifacts"))
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
event = _make_event(mock_agent, "a" * 200)

await plugin._handle_tool_result(event)

result_text = event.result["content"][0]["text"]
assert str(tmp_path / "artifacts") in result_text

@pytest.mark.asyncio
async def test_file_storage_image_placeholder_has_path(self, tmp_path, mock_agent):
storage = FileStorage(artifact_dir=str(tmp_path / "artifacts"))
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
img_bytes = b"\x89PNG" + b"\x00" * 100
content = [
{"text": "x" * 200},
{"image": {"format": "png", "source": {"bytes": img_bytes}}},
]
event = _make_event(mock_agent, content)

await plugin._handle_tool_result(event)

placeholder = event.result["content"][1]["text"]
assert str(tmp_path / "artifacts") in placeholder

@pytest.mark.asyncio
async def test_inmemory_storage_opaque_reference_in_preview(self, mock_agent):
storage = InMemoryStorage()
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
event = _make_event(mock_agent, "a" * 200)

await plugin._handle_tool_result(event)

result_text = event.result["content"][0]["text"]
assert "mem_" in result_text
43 changes: 39 additions & 4 deletions tests/strands/vended_plugins/context_offloader/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for offload storage backends."""

import threading
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -147,7 +148,30 @@ def test_sanitizes_path_traversal(self, tmp_path):
storage = FileStorage(artifact_dir=str(tmp_path))
ref = storage.store("../../etc/passwd", b"content")
assert ".." not in ref
assert "/" not in ref
assert "/" not in Path(ref).name

def test_reference_includes_artifact_dir(self, tmp_path):
artifact_dir = str(tmp_path / "artifacts")
storage = FileStorage(artifact_dir=artifact_dir)
ref = storage.store("key_1", b"content")
assert Path(ref).parent == Path(artifact_dir)

def test_relative_artifact_dir_gives_relative_reference(self, tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
storage = FileStorage(artifact_dir="./artifacts")
ref = storage.store("key_1", b"content")
assert Path(ref).parent == Path("artifacts")
content, content_type = storage.retrieve(ref)
assert content == b"content"
assert content_type == "text/plain"

def test_retrieve_accepts_bare_filename(self, tmp_path):
storage = FileStorage(artifact_dir=str(tmp_path))
ref = storage.store("key_1", b"hello world")
filename = Path(ref).name
content, content_type = storage.retrieve(filename)
assert content == b"hello world"
assert content_type == "text/plain"

def test_metadata_survives_across_instances(self, tmp_path):
artifact_dir = str(tmp_path / "artifacts")
Expand Down Expand Up @@ -233,9 +257,9 @@ def test_unique_references(self, storage):
assert storage.retrieve(ref1)[0] == b"content a"
assert storage.retrieve(ref2)[0] == b"content b"

def test_reference_includes_prefix(self, storage):
def test_reference_is_s3_uri(self, storage):
ref = storage.store("tool_abc", b"content")
assert ref.startswith("artifacts/")
assert ref.startswith("s3://test-bucket/artifacts/")

def test_empty_prefix(self, mock_s3_client):
with patch("boto3.Session") as mock_session_cls:
Expand All @@ -245,9 +269,20 @@ def test_empty_prefix(self, mock_s3_client):
storage = S3Storage(bucket="test-bucket", prefix="")

ref = storage.store("tool_abc", b"content")
assert not ref.startswith("/")
assert ref.startswith("s3://test-bucket/")
assert storage.retrieve(ref)[0] == b"content"

def test_retrieve_accepts_raw_key(self, storage, mock_s3_client):
ref = storage.store("key_1", b"hello world")
raw_key = ref.removeprefix("s3://test-bucket/")
content, content_type = storage.retrieve(raw_key)
assert content == b"hello world"
assert content_type == "text/plain"

def test_retrieve_rejects_wrong_bucket_uri(self, storage):
with pytest.raises(KeyError, match="Reference not found"):
storage.retrieve("s3://wrong-bucket/artifacts/some_key")

def test_put_object_called_with_correct_params(self, storage, mock_s3_client):
storage.store("key_1", b"test content", "application/json")

Expand Down
Loading