Skip to content

Commit 9050d00

Browse files
committed
Fix RetrievalWorkflow: Use Context to send events in workflow steps
1 parent 9aaa5dc commit 9050d00

1 file changed

Lines changed: 13 additions & 11 deletions

File tree

src/pipeline/retrieval.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
StartEvent,
88
StopEvent,
99
step,
10+
Context,
1011
)
1112
from llama_index.core.postprocessor import LongContextReorder
1213
from llama_index.llms.openai import OpenAI
@@ -67,18 +68,18 @@ async def _call_llm_with_retry(self, prompt: str):
6768
return response
6869

6970
@step
70-
async def process_start(self, ev: StartEvent) -> Union[QueryTransformedEvent, StopEvent, StreamingStatusEvent]:
71+
async def process_start(self, ctx: Context, ev: StartEvent) -> Union[QueryTransformedEvent, StopEvent, StreamingStatusEvent]:
7172
query_str = ev.get("query")
7273
if not query_str:
7374
raise RetrievalException("query must be provided in StartEvent", status_code=400)
7475

7576
# 1. Cache Check
7677
cached_answer = self.cache.get_cache(query_str)
7778
if cached_answer:
78-
self.send_event(StreamingStatusEvent(status="Cache Hit! Returning cached response."))
79+
ctx.send_event(StreamingStatusEvent(status="Cache Hit! Returning cached response."))
7980
return StopEvent(result={"answer": cached_answer, "source_nodes": [], "from_cache": True})
8081

81-
self.send_event(StreamingStatusEvent(status="Transforming query..."))
82+
ctx.send_event(StreamingStatusEvent(status="Transforming query..."))
8283

8384
# Decompose & HyDE (Simplified)
8485
hyde_prompt = f"Write a hypothetical document that would answer the following question: {query_str}"
@@ -93,8 +94,8 @@ async def process_start(self, ev: StartEvent) -> Union[QueryTransformedEvent, St
9394
return QueryTransformedEvent(query_bundle=query_bundle, loops=0)
9495

9596
@step
96-
async def retrieve_context(self, ev: QueryTransformedEvent) -> Union[ContextRetrievedEvent, StreamingStatusEvent]:
97-
self.send_event(StreamingStatusEvent(status="Retrieving context from Chroma Cloud..."))
97+
async def retrieve_context(self, ctx: Context, ev: QueryTransformedEvent) -> Union[ContextRetrievedEvent, StreamingStatusEvent]:
98+
ctx.send_event(StreamingStatusEvent(status="Retrieving context from Chroma Cloud..."))
9899

99100
# Using Chroma Cloud Hybrid Search
100101
results = await self.chroma_service.hybrid_search(ev.query_bundle.query_str, n_results=20)
@@ -115,11 +116,11 @@ async def retrieve_context(self, ev: QueryTransformedEvent) -> Union[ContextRetr
115116
return ContextRetrievedEvent(nodes=nodes, query_bundle=ev.query_bundle, loops=ev.loops)
116117

117118
@step
118-
async def judge_relevance(self, ev: ContextRetrievedEvent) -> Union[RelevanceJudgedEvent, QueryTransformedEvent, StreamingStatusEvent]:
119+
async def judge_relevance(self, ctx: Context, ev: ContextRetrievedEvent) -> Union[RelevanceJudgedEvent, QueryTransformedEvent, StreamingStatusEvent]:
119120
if ev.loops >= 1 or not ev.nodes:
120121
return RelevanceJudgedEvent(is_relevant=True, nodes=ev.nodes, query_bundle=ev.query_bundle)
121122

122-
self.send_event(StreamingStatusEvent(status="Judging context relevance..."))
123+
ctx.send_event(StreamingStatusEvent(status="Judging context relevance..."))
123124
context_text = "\n".join([n.get_content() for n in ev.nodes[:3]])
124125

125126
judge_prompt = (
@@ -136,7 +137,7 @@ async def judge_relevance(self, ev: ContextRetrievedEvent) -> Union[RelevanceJud
136137
is_relevant = True
137138

138139
if not is_relevant:
139-
self.send_event(StreamingStatusEvent(status="Refining query..."))
140+
ctx.send_event(StreamingStatusEvent(status="Refining query..."))
140141
refine_prompt = f"Rewrite the query '{ev.query_bundle.query_str}' to be more specific for better search results."
141142
try:
142143
new_query_resp = await self._call_llm_with_retry(refine_prompt)
@@ -148,13 +149,13 @@ async def judge_relevance(self, ev: ContextRetrievedEvent) -> Union[RelevanceJud
148149
return RelevanceJudgedEvent(is_relevant=True, nodes=ev.nodes, query_bundle=ev.query_bundle)
149150

150151
@step
151-
async def post_process(self, ev: RelevanceJudgedEvent) -> Union[StopEvent, StreamingStatusEvent]:
152+
async def post_process(self, ctx: Context, ev: RelevanceJudgedEvent) -> Union[StopEvent, StreamingStatusEvent]:
152153
if not ev.nodes:
153154
return StopEvent(result={"answer": "No relevant context found.", "source_nodes": [], "from_cache": False})
154155

155156
try:
156157
if self.reranker:
157-
self.send_event(StreamingStatusEvent(status="Reranking results..."))
158+
ctx.send_event(StreamingStatusEvent(status="Reranking results..."))
158159
reranked_nodes = self.reranker.postprocess_nodes(ev.nodes, query_bundle=ev.query_bundle)
159160
final_nodes = self.reorder.postprocess_nodes(reranked_nodes)
160161
else:
@@ -163,7 +164,7 @@ async def post_process(self, ev: RelevanceJudgedEvent) -> Union[StopEvent, Strea
163164
logger.error(f"[RETRIEVAL] Post-processing error: {e}")
164165
final_nodes = ev.nodes
165166

166-
self.send_event(StreamingStatusEvent(status="Generating answer..."))
167+
ctx.send_event(StreamingStatusEvent(status="Generating answer..."))
167168
context_str = "\n".join([n.get_content() for n in final_nodes])
168169
final_prompt = f"Context:\n{context_str}\n\nQuestion: {ev.query_bundle.query_str}\n\nAnswer:"
169170

@@ -175,3 +176,4 @@ async def post_process(self, ev: RelevanceJudgedEvent) -> Union[StopEvent, Strea
175176
except Exception as e:
176177
logger.error(f"[RETRIEVAL] Answer generation failed: {e}")
177178
raise RetrievalException(f"Failed to generate answer: {e}")
179+

0 commit comments

Comments
 (0)