Skip to content

Commit 9267183

Browse files
committed
feat(offloader): return relevant path type from storage
1 parent 52cdb9d commit 9267183

3 files changed

Lines changed: 113 additions & 15 deletions

File tree

src/strands/vended_plugins/context_offloader/storage.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,19 @@ def _extension_for(content_type: str) -> str:
131131
return f".{content_type.split('/')[-1]}"
132132

133133
def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str:
134-
"""Store content as a file and return the filename as reference.
134+
"""Store content as a file and return the path as reference.
135+
136+
The returned path preserves the form of ``artifact_dir`` passed to
137+
the constructor: a relative ``artifact_dir`` yields a relative
138+
reference, an absolute one yields an absolute reference.
135139
136140
Args:
137141
key: A unique key for this content block.
138142
content: The raw content bytes to store.
139143
content_type: MIME type of the content.
140144
141145
Returns:
142-
The filename (not full path) used as the reference.
146+
The file path (e.g., ``./artifacts/1234_1_key.txt``).
143147
"""
144148
self._artifact_dir.mkdir(parents=True, exist_ok=True)
145149

@@ -156,26 +160,33 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s
156160
file_path = self._artifact_dir / filename
157161
file_path.write_bytes(content)
158162

159-
return filename
163+
return str(file_path)
160164

161165
def retrieve(self, reference: str) -> tuple[bytes, str]:
162166
"""Retrieve content from a stored file.
163167
168+
Accepts both full paths (as returned by ``store()``) and bare
169+
filenames for backward compatibility.
170+
164171
Args:
165-
reference: The filename reference returned by store().
172+
reference: The file path or filename returned by store().
166173
167174
Returns:
168175
A tuple of (content bytes, content type).
169176
170177
Raises:
171178
KeyError: If the file does not exist.
172179
"""
173-
file_path = (self._artifact_dir / reference).resolve()
180+
ref_path = Path(reference)
181+
file_path = ref_path.resolve() if len(ref_path.parts) > 1 else (self._artifact_dir / reference).resolve()
182+
if not file_path.is_relative_to(self._artifact_dir.resolve()):
183+
file_path = (self._artifact_dir / reference).resolve()
174184
if not file_path.is_relative_to(self._artifact_dir.resolve()):
175185
raise KeyError(f"Reference not found: {reference}")
176186
if not file_path.is_file():
177187
raise KeyError(f"Reference not found: {reference}")
178-
content_type = self._content_types.get(reference, "application/octet-stream")
188+
filename = file_path.name
189+
content_type = self._content_types.get(filename, "application/octet-stream")
179190
return file_path.read_bytes(), content_type
180191

181192
def _load_metadata(self) -> dict[str, str]:
@@ -320,15 +331,15 @@ def __init__(
320331
self._lock = threading.Lock()
321332

322333
def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str:
323-
"""Store content as an S3 object and return the object key as reference.
334+
"""Store content as an S3 object and return an ``s3://`` URI as reference.
324335
325336
Args:
326337
key: A unique key for this content block.
327338
content: The raw content bytes to store.
328339
content_type: MIME type of the content.
329340
330341
Returns:
331-
The S3 object key used as the reference.
342+
An S3 URI (e.g., ``s3://bucket/prefix/1234_1_key``).
332343
333344
Raises:
334345
botocore.exceptions.ClientError: If the S3 operation fails (e.g., bucket
@@ -348,22 +359,31 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s
348359
ContentType=content_type,
349360
)
350361

351-
return s3_key
362+
return f"s3://{self._bucket}/{s3_key}"
352363

353364
def retrieve(self, reference: str) -> tuple[bytes, str]:
354365
"""Retrieve content from an S3 object.
355366
367+
Accepts both ``s3://`` URIs (as returned by ``store()``) and raw
368+
S3 keys for backward compatibility.
369+
356370
Args:
357-
reference: The S3 object key returned by store().
371+
reference: The S3 URI or object key returned by store().
358372
359373
Returns:
360374
A tuple of (content bytes, content type).
361375
362376
Raises:
363377
KeyError: If the object does not exist.
364378
"""
379+
s3_key = reference
380+
if reference.startswith("s3://"):
381+
expected_prefix = f"s3://{self._bucket}/"
382+
if not reference.startswith(expected_prefix):
383+
raise KeyError(f"Reference not found: {reference}")
384+
s3_key = reference[len(expected_prefix):]
365385
try:
366-
response = self._client.get_object(Bucket=self._bucket, Key=reference)
386+
response = self._client.get_object(Bucket=self._bucket, Key=s3_key)
367387
content: bytes = response["Body"].read()
368388
content_type: str = response.get("ContentType", "application/octet-stream")
369389
return content, content_type

tests/strands/vended_plugins/context_offloader/test_plugin.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from strands.types.tools import ToolContext, ToolUse
1212
from strands.vended_plugins.context_offloader import (
1313
ContextOffloader,
14+
FileStorage,
1415
InMemoryStorage,
1516
)
1617

@@ -537,3 +538,45 @@ async def test_guidance_does_not_mention_retrieval_tool_when_disabled(self, stor
537538
result_text = event.result["content"][0]["text"]
538539
assert "retrieve_offloaded_content" not in result_text
539540
assert "available tools" in result_text
541+
542+
543+
class TestActionableReferences:
544+
"""Tests that storage-specific references appear in the offloaded preview."""
545+
546+
@pytest.mark.asyncio
547+
async def test_file_storage_path_in_preview(self, tmp_path, mock_agent):
548+
storage = FileStorage(artifact_dir=str(tmp_path / "artifacts"))
549+
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
550+
event = _make_event(mock_agent, "a" * 200)
551+
552+
await plugin._handle_tool_result(event)
553+
554+
result_text = event.result["content"][0]["text"]
555+
assert str(tmp_path / "artifacts") in result_text
556+
557+
@pytest.mark.asyncio
558+
async def test_file_storage_image_placeholder_has_path(self, tmp_path, mock_agent):
559+
storage = FileStorage(artifact_dir=str(tmp_path / "artifacts"))
560+
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
561+
img_bytes = b"\x89PNG" + b"\x00" * 100
562+
content = [
563+
{"text": "x" * 200},
564+
{"image": {"format": "png", "source": {"bytes": img_bytes}}},
565+
]
566+
event = _make_event(mock_agent, content)
567+
568+
await plugin._handle_tool_result(event)
569+
570+
placeholder = event.result["content"][1]["text"]
571+
assert str(tmp_path / "artifacts") in placeholder
572+
573+
@pytest.mark.asyncio
574+
async def test_inmemory_storage_opaque_reference_in_preview(self, mock_agent):
575+
storage = InMemoryStorage()
576+
plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10)
577+
event = _make_event(mock_agent, "a" * 200)
578+
579+
await plugin._handle_tool_result(event)
580+
581+
result_text = event.result["content"][0]["text"]
582+
assert "mem_" in result_text

tests/strands/vended_plugins/context_offloader/test_storage.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for offload storage backends."""
22

33
import threading
4+
from pathlib import Path
45
from unittest.mock import MagicMock, patch
56

67
import pytest
@@ -147,7 +148,30 @@ def test_sanitizes_path_traversal(self, tmp_path):
147148
storage = FileStorage(artifact_dir=str(tmp_path))
148149
ref = storage.store("../../etc/passwd", b"content")
149150
assert ".." not in ref
150-
assert "/" not in ref
151+
assert "/" not in Path(ref).name
152+
153+
def test_reference_includes_artifact_dir(self, tmp_path):
154+
artifact_dir = str(tmp_path / "artifacts")
155+
storage = FileStorage(artifact_dir=artifact_dir)
156+
ref = storage.store("key_1", b"content")
157+
assert ref.startswith(artifact_dir + "/")
158+
159+
def test_relative_artifact_dir_gives_relative_reference(self, tmp_path, monkeypatch):
160+
monkeypatch.chdir(tmp_path)
161+
storage = FileStorage(artifact_dir="./artifacts")
162+
ref = storage.store("key_1", b"content")
163+
assert ref.startswith("artifacts/")
164+
content, content_type = storage.retrieve(ref)
165+
assert content == b"content"
166+
assert content_type == "text/plain"
167+
168+
def test_retrieve_accepts_bare_filename(self, tmp_path):
169+
storage = FileStorage(artifact_dir=str(tmp_path))
170+
ref = storage.store("key_1", b"hello world")
171+
filename = Path(ref).name
172+
content, content_type = storage.retrieve(filename)
173+
assert content == b"hello world"
174+
assert content_type == "text/plain"
151175

152176
def test_metadata_survives_across_instances(self, tmp_path):
153177
artifact_dir = str(tmp_path / "artifacts")
@@ -233,9 +257,9 @@ def test_unique_references(self, storage):
233257
assert storage.retrieve(ref1)[0] == b"content a"
234258
assert storage.retrieve(ref2)[0] == b"content b"
235259

236-
def test_reference_includes_prefix(self, storage):
260+
def test_reference_is_s3_uri(self, storage):
237261
ref = storage.store("tool_abc", b"content")
238-
assert ref.startswith("artifacts/")
262+
assert ref.startswith("s3://test-bucket/artifacts/")
239263

240264
def test_empty_prefix(self, mock_s3_client):
241265
with patch("boto3.Session") as mock_session_cls:
@@ -245,9 +269,20 @@ def test_empty_prefix(self, mock_s3_client):
245269
storage = S3Storage(bucket="test-bucket", prefix="")
246270

247271
ref = storage.store("tool_abc", b"content")
248-
assert not ref.startswith("/")
272+
assert ref.startswith("s3://test-bucket/")
249273
assert storage.retrieve(ref)[0] == b"content"
250274

275+
def test_retrieve_accepts_raw_key(self, storage, mock_s3_client):
276+
ref = storage.store("key_1", b"hello world")
277+
raw_key = ref.removeprefix("s3://test-bucket/")
278+
content, content_type = storage.retrieve(raw_key)
279+
assert content == b"hello world"
280+
assert content_type == "text/plain"
281+
282+
def test_retrieve_rejects_wrong_bucket_uri(self, storage):
283+
with pytest.raises(KeyError, match="Reference not found"):
284+
storage.retrieve("s3://wrong-bucket/artifacts/some_key")
285+
251286
def test_put_object_called_with_correct_params(self, storage, mock_s3_client):
252287
storage.store("key_1", b"test content", "application/json")
253288

0 commit comments

Comments
 (0)