-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmemory.py
More file actions
213 lines (177 loc) · 7.55 KB
/
memory.py
File metadata and controls
213 lines (177 loc) · 7.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""Memory write functions for AgentCore Memory.
Best-effort (fail-open): all write operations are wrapped in try/except
so a Memory API outage never blocks the agent pipeline. Infrastructure
errors (network, auth, throttling) are caught and logged at WARN level;
programming errors (TypeError, ValueError, AttributeError) are logged at
ERROR level to surface bugs quickly.
"""
import hashlib
import os
import re
import time
from sanitization import sanitize_external_content
_client = None
# Validates "owner/repo" format — must match the TypeScript-side isValidRepo pattern.
_REPO_PATTERN = re.compile(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$")
# Current event schema version:
# v1 = repos/ prefix
# v2 = namespace templates (/{actorId}/...)
# v3 = adds source_type provenance + content_sha256 integrity hash
_SCHEMA_VERSION = "3"
# Valid source_type values for provenance tracking (schema v3).
# Must stay in sync with MemorySourceType in cdk/src/handlers/shared/memory.ts.
MEMORY_SOURCE_TYPES = frozenset({"agent_episode", "agent_learning", "orchestrator_fallback"})
def _get_client():
"""Lazy-init and cache the AgentCore client for memory operations."""
global _client
if _client is not None:
return _client
import boto3
region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
if not region:
raise ValueError("AWS_REGION or AWS_DEFAULT_REGION must be set for memory operations")
_client = boto3.client("bedrock-agentcore", region_name=region)
return _client
def _validate_repo(repo: str) -> None:
"""Raise ValueError if repo does not match expected owner/repo format."""
if not _REPO_PATTERN.match(repo):
raise ValueError(
f"repo '{repo}' does not match expected owner/repo format "
f"(pattern: {_REPO_PATTERN.pattern})"
)
def _log_error(func_name: str, err: Exception, memory_id: str, task_id: str) -> None:
"""Log memory write failure with severity based on exception type."""
is_programming_error = isinstance(err, (TypeError, ValueError, AttributeError, KeyError))
level = "ERROR" if is_programming_error else "WARN"
label = "unexpected error" if is_programming_error else "infra failure"
print(
f"[memory] [{level}] {func_name} {label}: {type(err).__name__}: {err}"
f" (memory_id={memory_id}, task_id={task_id})",
flush=True,
)
def write_task_episode(
memory_id: str,
repo: str,
task_id: str,
status: str,
pr_url: str | None = None,
cost_usd: float | None = None,
duration_s: float | None = None,
self_feedback: str | None = None,
) -> bool:
"""Write a task episode to AgentCore Memory as a short-term event.
The event captures the outcome of one task execution, including
status, PR URL, cost, duration, and any self-feedback from the
agent's "## Agent notes" section in the PR body.
Uses actorId=repo and sessionId=task_id so the extraction strategy
namespace templates (/{actorId}/episodes/{sessionId}/) place records
into the correct per-repo, per-task namespace.
Metadata includes source_type='agent_episode' for provenance tracking
and content_sha256 for integrity auditing on read (schema v3).
Returns True on success, False on failure (fail-open).
"""
try:
_validate_repo(repo)
client = _get_client()
parts = [
f"Task {task_id} completed with status: {status}.",
]
if pr_url:
parts.append(f"PR: {pr_url}.")
if duration_s is not None:
parts.append(f"Duration: {duration_s}s.")
if cost_usd is not None:
parts.append(f"Cost: ${cost_usd:.4f}.")
if self_feedback:
parts.append(f"Agent notes: {self_feedback}")
episode_text = " ".join(parts)
# Hash the sanitized form; store the original. The read path re-sanitizes
# and checks against this hash: sanitize(original) at write == sanitize(stored) at read.
sanitized_text = sanitize_external_content(episode_text)
content_hash = hashlib.sha256(sanitized_text.encode("utf-8")).hexdigest()
metadata = {
"task_id": {"stringValue": task_id},
"type": {"stringValue": "task_episode"},
"source_type": {"stringValue": "agent_episode"},
"content_sha256": {"stringValue": content_hash},
"schema_version": {"stringValue": _SCHEMA_VERSION},
}
if pr_url:
metadata["pr_url"] = {"stringValue": pr_url}
client.create_event(
memoryId=memory_id,
actorId=repo,
sessionId=task_id,
eventTimestamp=_iso_now(),
payload=[
{
"conversational": {
"content": {"text": episode_text},
"role": "OTHER",
}
}
],
metadata=metadata,
)
print("[memory] Task episode written", flush=True)
return True
except Exception as e:
_log_error("write_task_episode", e, memory_id, task_id)
return False
def write_repo_learnings(
memory_id: str,
repo: str,
task_id: str,
learnings: str,
) -> bool:
"""Write repository learnings to AgentCore Memory.
Captures patterns, conventions, and insights discovered about the
repository during task execution. Stored as a separate event so
the semantic extraction strategy can surface them in future tasks.
Uses actorId=repo and sessionId=task_id so the extraction strategy
namespace templates (/{actorId}/knowledge/) place records into
the correct per-repo namespace.
Metadata includes source_type='agent_learning' for provenance tracking
and content_sha256 for integrity auditing on read (schema v3).
Note: hash auditing only happens on the TS orchestrator read path
(loadMemoryContext in memory.ts) where mismatches are logged but
records are kept — the Python side does not independently check hashes.
Returns True on success, False on failure (fail-open).
"""
try:
_validate_repo(repo)
client = _get_client()
learnings_text = f"Repository learnings: {learnings}"
# Hash the sanitized form; store the original. The read path re-sanitizes
# and checks against this hash: sanitize(original) at write == sanitize(stored) at read.
sanitized_text = sanitize_external_content(learnings_text)
content_hash = hashlib.sha256(sanitized_text.encode("utf-8")).hexdigest()
client.create_event(
memoryId=memory_id,
actorId=repo,
sessionId=task_id,
eventTimestamp=_iso_now(),
payload=[
{
"conversational": {
"content": {"text": learnings_text},
"role": "OTHER",
}
}
],
metadata={
"task_id": {"stringValue": task_id},
"type": {"stringValue": "repo_learnings"},
"source_type": {"stringValue": "agent_learning"},
"content_sha256": {"stringValue": content_hash},
"schema_version": {"stringValue": _SCHEMA_VERSION},
},
)
print("[memory] Repo learnings written", flush=True)
return True
except Exception as e:
_log_error("write_repo_learnings", e, memory_id, task_id)
return False
def _iso_now() -> str:
"""Return current time as ISO 8601 string."""
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())