-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathgraph.py
More file actions
192 lines (154 loc) · 5.79 KB
/
graph.py
File metadata and controls
192 lines (154 loc) · 5.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""LangGraph state machine for agentic RAG with authorization."""
from langgraph.graph import StateGraph, END
from .state import AgenticRAGState
from .nodes import (
retrieval_node,
authorization_node,
reasoning_node,
generation_node,
)
from .validation import validate_query, validate_subject_id, ValidationError
def should_retry_or_generate(state: AgenticRAGState) -> str:
"""Decide whether to retry retrieval or generate answer.
After reasoning about authorization failures, decide:
- If attempts remain and no authorized docs: retry retrieval
- Otherwise: generate answer (possibly explaining access denial)
"""
if (
state["retrieval_attempt"] < state["max_attempts"]
and len(state["authorized_documents"]) == 0
):
return "retrieve"
return "generate"
def should_reason_or_generate(state: AgenticRAGState) -> str:
"""Decide whether to reason about failures or generate answer.
After authorization:
- If we have authorized documents: generate answer
- If no authorized documents AND attempts remain (max_attempts > 1): reason about what to do
- If no authorized documents AND no attempts remain: generate answer with explanation
Note: With max_attempts=1 (default), reasoning is skipped and we go directly to generation.
This is more efficient for simple use cases where retry isn't needed.
"""
if state["authorization_passed"]:
return "generate"
# Only reason if we have attempts remaining (max_attempts > 1 and not exhausted)
if (
state["max_attempts"] > 1
and state["retrieval_attempt"] < state["max_attempts"]
):
return "reason"
# Otherwise, generate answer (possibly explaining access denial)
return "generate"
def build_agentic_rag_graph():
"""Build the agentic RAG graph with deterministic authorization.
Simplified Flow:
1. Retrieval: Fetch documents from Milvus
2. Authorization: Deterministic permission check (security boundary)
3. Conditional:
- If authorized docs exist: Generate answer
- If no authorized docs AND max_attempts > 1: Reason about retry strategy
- If no authorized docs AND max_attempts == 1: Generate answer (with explanation)
4. After reasoning (only with max_attempts > 1):
- If attempts remain: Retry retrieval
- Otherwise: Generate answer explaining constraints
Note: With max_attempts=1 (default), the flow is just 3 nodes:
Retrieve → Authorize → Generate
"""
workflow = StateGraph(AgenticRAGState)
# Add nodes
workflow.add_node("retrieve", retrieval_node)
workflow.add_node("authorize", authorization_node) # ALWAYS runs
workflow.add_node("reason", reasoning_node)
workflow.add_node("generate", generation_node)
# Define flow - start directly at retrieval
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "authorize") # Deterministic auth
# Conditional: after auth, either reason or generate
workflow.add_conditional_edges(
"authorize",
should_reason_or_generate,
{
"reason": "reason",
"generate": "generate",
},
)
# Conditional: after reasoning, retry retrieval or generate
workflow.add_conditional_edges(
"reason",
should_retry_or_generate,
{
"retrieve": "retrieve",
"generate": "generate",
},
)
workflow.add_edge("generate", END)
return workflow.compile()
def run_agentic_rag(query: str, subject_id: str, max_attempts: int = 1) -> dict:
"""
Run the agentic RAG graph with input validation (synchronous).
This is the main entry point for running the agentic RAG system.
It validates inputs before processing to ensure security and stability.
Args:
query: User query string
subject_id: User/subject identifier for authorization
max_attempts: Maximum number of retrieval attempts (default 1)
Returns:
Final state dict with answer and metadata
Raises:
ValidationError: If inputs are invalid
"""
# Validate inputs
query = validate_query(query)
subject_id = validate_subject_id(subject_id)
# Build graph
graph = build_agentic_rag_graph()
# Run graph
initial_state = {
"query": query,
"subject_id": subject_id,
"max_attempts": max_attempts,
"retrieved_documents": [],
"authorized_documents": [],
"denied_count": 0,
"reasoning": [],
"retrieval_attempt": 0,
"authorization_passed": False,
"messages": [],
"answer": None,
}
result = graph.invoke(initial_state)
return result
async def run_agentic_rag_async(query: str, subject_id: str, max_attempts: int = 1) -> dict:
"""
Run the agentic RAG graph with input validation (asynchronous).
Async version of run_agentic_rag for use in async contexts.
Args:
query: User query string
subject_id: User/subject identifier for authorization
max_attempts: Maximum number of retrieval attempts (default 1)
Returns:
Final state dict with answer and metadata
Raises:
ValidationError: If inputs are invalid
"""
# Validate inputs
query = validate_query(query)
subject_id = validate_subject_id(subject_id)
# Build graph
graph = build_agentic_rag_graph()
# Run graph
initial_state = {
"query": query,
"subject_id": subject_id,
"max_attempts": max_attempts,
"retrieved_documents": [],
"authorized_documents": [],
"denied_count": 0,
"reasoning": [],
"retrieval_attempt": 0,
"authorization_passed": False,
"messages": [],
"answer": None,
}
result = await graph.ainvoke(initial_state)
return result