Skip to content

Commit e659f9a

Browse files
authored
Merge pull request #38 from airsimonhan/streaming
update streaming output
2 parents 93e6861 + 7d741d1 commit e659f9a

9 files changed

Lines changed: 1331 additions & 3 deletions

File tree

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,25 @@ docker-compose up
181181

182182
For comprehensive Docker deployment instructions, configuration options, troubleshooting, and production deployment guidelines, please refer to our detailed [Docker Deployment Guide](./web-ui/DOCKER.md).
183183

184+
## :bulb: Simple Test Demo
185+
186+
1. Run by steps
187+
```bash
188+
conda activate rag
189+
cd Hyper_RAG/reproduce
190+
python reproduce/Step_0.py
191+
python reproduce/Step_1.py
192+
193+
cd Hyper-RAG
194+
python -m uvicorn service_api:app --app-dir . --host 0.0.0.0 --port 8000
195+
```
196+
2. Open `testHTML_light.html` in your web browser.
197+
3. Selecting the model (`hyper`,`hyper-lite`,`naive`) and whether to output in streaming mode
198+
199+
<div align="center">
200+
<img src="./assets/hyperrag-streaming.gif" alt="Efficiency analysis" width="80%" />
201+
</div>
202+
184203
## :checkered_flag: Evaluation
185204
In this work, we propose two evaluation strategys: the **selection-based** and **scoring-based** evaluation.
186205

assets/hyperrag-streaming.gif

433 KB
Loading

hyperrag/hyperrag.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
convert_response_to_json,
3434
logger,
3535
set_logger,
36+
limit_async_gen_call
3637
)
3738
from .base import (
3839
BaseKVStorage,
@@ -42,6 +43,8 @@
4243
BaseHypergraphStorage,
4344
)
4445

46+
from .operate import hyper_query_stream, hyper_query_lite_stream, naive_query_stream, llm_query_stream
47+
4548

4649
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
4750
try:
@@ -60,7 +63,7 @@ class HyperRAG:
6063
working_dir: str = field(
6164
default_factory=lambda: f"./HyperRAG_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
6265
)
63-
print(working_dir)
66+
# print(working_dir)
6467

6568
current_log_level = logger.level
6669
log_level: str = field(default=current_log_level)
@@ -78,7 +81,7 @@ class HyperRAG:
7881
relation_keywords_to_max_tokens: int = 100
7982

8083
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
81-
embedding_batch_num: int = 32
84+
embedding_batch_num: int = 8
8285
embedding_func_max_async: int = 16
8386

8487
# LLM
@@ -89,6 +92,8 @@ class HyperRAG:
8992
llm_model_max_async: int = 16
9093
llm_model_kwargs: dict = field(default_factory=dict)
9194

95+
llm_model_stream_func: callable = None
96+
9297
# storage
9398
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
9499
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
@@ -166,6 +171,16 @@ def __post_init__(self):
166171
)
167172
)
168173

174+
if getattr(self, "llm_model_stream_func", None) is not None:
175+
# 先把 hashing_kv 注入到 stream func(供 openai_complete_stream_if_cache 使用)
176+
self.llm_model_stream_func = limit_async_gen_call(self.llm_model_max_async)(
177+
partial(
178+
self.llm_model_stream_func,
179+
hashing_kv=self.llm_response_cache,
180+
**self.llm_model_kwargs,
181+
)
182+
)
183+
169184
def insert(self, string_or_strings):
170185
loop = always_get_an_event_loop()
171186
return loop.run_until_complete(self.ainsert(string_or_strings))
@@ -304,6 +319,61 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
304319
await self._query_done()
305320
return response
306321

322+
async def astream_query(self, query: str, param: QueryParam = QueryParam()):
323+
"""
324+
流式查询:返回 async generator(逐 token / 逐块)
325+
依赖 self.llm_model_stream_func,不提供则抛错。
326+
"""
327+
if self.llm_model_stream_func is None:
328+
raise AttributeError("llm_model_stream_func is not set, streaming is unavailable.")
329+
330+
# 把 stream func 放进 global_config
331+
cfg = asdict(self)
332+
cfg["llm_model_stream_func"] = self.llm_model_stream_func
333+
334+
if param.mode == "hyper":
335+
async for tok in hyper_query_stream(
336+
query,
337+
self.chunk_entity_relation_hypergraph,
338+
self.entities_vdb,
339+
self.relationships_vdb,
340+
self.text_chunks,
341+
param,
342+
cfg,
343+
):
344+
yield tok
345+
346+
elif param.mode == "hyper-lite":
347+
async for tok in hyper_query_lite_stream(
348+
query,
349+
self.chunk_entity_relation_hypergraph,
350+
self.entities_vdb,
351+
self.text_chunks,
352+
param,
353+
cfg,
354+
):
355+
yield tok
356+
357+
elif param.mode == "naive":
358+
async for tok in naive_query_stream(
359+
query,
360+
self.chunks_vdb,
361+
self.text_chunks,
362+
param,
363+
cfg,
364+
):
365+
yield tok
366+
367+
elif param.mode == "llm":
368+
async for tok in llm_query_stream(query, param, cfg):
369+
yield tok
370+
371+
else:
372+
raise ValueError(f"Unknown mode {param.mode}")
373+
374+
await self._query_done()
375+
376+
307377
async def _query_done(self):
308378
tasks = []
309379
for storage_inst in [self.llm_response_cache]:

hyperrag/llm.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,68 @@ async def openai_complete_if_cache(
7373
)
7474
return response.choices[0].message.content
7575

76+
async def openai_complete_stream_if_cache(
77+
model,
78+
prompt,
79+
system_prompt=None,
80+
history_messages=[],
81+
base_url=None,
82+
api_key=None,
83+
chunk_size: int = 32,
84+
**kwargs,
85+
):
86+
"""
87+
OpenAI-compatible 流式输出(async generator)
88+
- 命中缓存:按 chunk_size 分块 yield
89+
- 不命中:stream=True 逐 token yield,并在结束后写缓存
90+
"""
91+
if api_key:
92+
os.environ["OPENAI_API_KEY"] = api_key
93+
94+
openai_async_client = (
95+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
96+
)
97+
98+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
99+
100+
messages = []
101+
if system_prompt is not None:
102+
messages.append({"role": "system", "content": system_prompt})
103+
messages.extend(history_messages)
104+
messages.append({"role": "user", "content": prompt})
105+
106+
# 1) cache 命中:直接回放
107+
if hashing_kv is not None:
108+
args_hash = compute_args_hash(model, messages)
109+
if_cache_return = await hashing_kv.get_by_id(args_hash)
110+
if if_cache_return is not None:
111+
cached = if_cache_return["return"] or ""
112+
# 按块 yield,避免一次性返回
113+
for i in range(0, len(cached), chunk_size):
114+
yield cached[i:i + chunk_size]
115+
return
116+
117+
# 2) cache 未命中:真实 stream
118+
full_text = []
119+
stream = await openai_async_client.chat.completions.create(
120+
model=model,
121+
messages=messages,
122+
stream=True,
123+
**kwargs,
124+
)
125+
126+
async for event in stream:
127+
delta = None
128+
if event.choices:
129+
delta = getattr(event.choices[0].delta, "content", None)
130+
if delta:
131+
full_text.append(delta)
132+
yield delta
133+
134+
# 3) 写入 cache
135+
if hashing_kv is not None:
136+
text = "".join(full_text)
137+
await hashing_kv.upsert({args_hash: {"return": text, "model": model}})
76138

77139
@retry(
78140
stop=stop_after_attempt(3),

0 commit comments

Comments
 (0)