-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathagent.py
More file actions
197 lines (166 loc) · 5.87 KB
/
Copy pathagent.py
File metadata and controls
197 lines (166 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Awaitable, Callable, TypedDict
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
load_dotenv(Path(__file__).with_name(".env"))
EmitCallback = Callable[[str], Awaitable[None]]
MODEL_NAME = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
MODEL_TEMPERATURE = float(os.getenv("OPENAI_TEMPERATURE", "0"))
def _approval_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {"confirm": {"type": "boolean"}},
"required": ["confirm"],
}
def _requires_human_approval(message: str) -> bool:
lowered = message.lower()
destructive = any(
term in lowered
for term in (
"delete",
"drop",
"truncate",
"wipe",
"remove",
"purge",
"clear",
"invalidate",
)
)
sensitive_target = any(
term in lowered
for term in (
"database",
"db",
"table",
"tables",
"row",
"rows",
"record",
"records",
"cache",
"redis",
"key",
"keys",
"production",
"prod",
)
)
return destructive and sensitive_target
class AgentState(TypedDict):
message: str
context: str
response: str
async def _emit_optional(callback: EmitCallback | None, payload: str) -> None:
if callback is not None and payload:
await callback(payload)
def _chunk_to_text(chunk: Any) -> str:
content = getattr(chunk, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(item.get("text", ""))
return "".join(parts)
return str(content) if content else ""
def _build_messages(state: AgentState) -> list[dict[str, str]]:
context = state["context"].strip()
system_lines = [
"You are a concise, helpful assistant in a Redis Agent Kit demo.",
"Answer in 2-4 short paragraphs or bullets.",
"If extra context is supplied, use it. If it is empty, answer from general knowledge.",
"When relevant, explain concepts clearly for engineers evaluating agent infrastructure.",
]
if context:
system_lines.extend(["", "Additional context:", context])
return [
{"role": "system", "content": "\n".join(system_lines)},
{"role": "user", "content": state["message"]},
]
def _create_llm() -> ChatOpenAI:
return ChatOpenAI(
model=MODEL_NAME,
temperature=MODEL_TEMPERATURE,
streaming=True,
)
def _build_graph(
*,
emit_update: EmitCallback | None = None,
emit_token: EmitCallback | None = None,
):
llm = _create_llm()
async def generate(state: AgentState) -> AgentState:
await _emit_optional(emit_update, "Running LangGraph model node...")
messages = _build_messages(state)
chunks: list[str] = []
async for chunk in llm.astream(messages):
text = _chunk_to_text(chunk)
if text:
chunks.append(text)
await _emit_optional(emit_token, text)
response = "".join(chunks).strip()
if not response:
final = await llm.ainvoke(messages)
response = str(getattr(final, "content", "")).strip()
return {
"message": state["message"],
"context": state["context"],
"response": response,
}
graph = StateGraph(AgentState)
graph.add_node("generate", generate)
graph.set_entry_point("generate")
graph.add_edge("generate", END)
return graph.compile()
async def run_langgraph_agent(
message: str,
*,
rag_context: str = "",
emit_update: EmitCallback | None = None,
emit_token: EmitCallback | None = None,
) -> dict[str, Any]:
graph = _build_graph(emit_update=emit_update, emit_token=emit_token)
result = await graph.ainvoke(
{
"message": message,
"context": rag_context,
"response": "",
}
)
return {"response": result["response"]}
async def run_task(ctx) -> dict[str, Any]:
await ctx.emitter.emit("Worker picked up the task.")
task = await ctx.kit.task_manager.get_task(ctx.task_id)
input_response = getattr(task, "input_response", None) if task else None
if _requires_human_approval(ctx.message):
if input_response and input_response.get("confirm") is not None:
if not bool(input_response.get("confirm")):
await ctx.emitter.emit("Human approval declined. Task cancelled.")
return {
"response": "Cancelled. Database-destructive actions require explicit human approval."
}
await ctx.emitter.emit("Human approval received. Continuing task.")
else:
await ctx.emitter.emit(
"This request needs human approval before it can continue."
)
await ctx.kit.task_manager.request_input(
task_id=ctx.task_id,
prompt="This request asks to perform a destructive action on production data or cache. Approve before the agent continues?",
json_schema=_approval_schema(),
metadata={"kind": "approval", "reason": "destructive_operation"},
)
return {"response": "Awaiting human approval."}
return await run_langgraph_agent(
ctx.message,
rag_context=getattr(ctx, "rag_context", ""),
emit_update=ctx.emitter.emit,
emit_token=ctx.emitter.emit_token,
)