3333 convert_response_to_json ,
3434 logger ,
3535 set_logger ,
36+ limit_async_gen_call
3637)
3738from .base import (
3839 BaseKVStorage ,
4243 BaseHypergraphStorage ,
4344)
4445
46+ from .operate import hyper_query_stream , hyper_query_lite_stream , naive_query_stream , llm_query_stream
47+
4548
4649def always_get_an_event_loop () -> asyncio .AbstractEventLoop :
4750 try :
@@ -60,7 +63,7 @@ class HyperRAG:
6063 working_dir : str = field (
6164 default_factory = lambda : f"./HyperRAG_cache_{ datetime .now ().strftime ('%Y-%m-%d-%H:%M:%S' )} "
6265 )
63- print (working_dir )
66+ # print(working_dir)
6467
6568 current_log_level = logger .level
6669 log_level : str = field (default = current_log_level )
@@ -78,7 +81,7 @@ class HyperRAG:
7881 relation_keywords_to_max_tokens : int = 100
7982
8083 embedding_func : EmbeddingFunc = field (default_factory = lambda : openai_embedding )
81- embedding_batch_num : int = 32
84+ embedding_batch_num : int = 8
8285 embedding_func_max_async : int = 16
8386
8487 # LLM
@@ -89,6 +92,8 @@ class HyperRAG:
8992 llm_model_max_async : int = 16
9093 llm_model_kwargs : dict = field (default_factory = dict )
9194
95+ llm_model_stream_func : callable = None
96+
9297 # storage
9398 key_string_value_json_storage_cls : Type [BaseKVStorage ] = JsonKVStorage
9499 vector_db_storage_cls : Type [BaseVectorStorage ] = NanoVectorDBStorage
@@ -166,6 +171,16 @@ def __post_init__(self):
166171 )
167172 )
168173
174+ if getattr (self , "llm_model_stream_func" , None ) is not None :
175+ # 先把 hashing_kv 注入到 stream func(供 openai_complete_stream_if_cache 使用)
176+ self .llm_model_stream_func = limit_async_gen_call (self .llm_model_max_async )(
177+ partial (
178+ self .llm_model_stream_func ,
179+ hashing_kv = self .llm_response_cache ,
180+ ** self .llm_model_kwargs ,
181+ )
182+ )
183+
169184 def insert (self , string_or_strings ):
170185 loop = always_get_an_event_loop ()
171186 return loop .run_until_complete (self .ainsert (string_or_strings ))
@@ -304,6 +319,61 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()):
304319 await self ._query_done ()
305320 return response
306321
322+ async def astream_query (self , query : str , param : QueryParam = QueryParam ()):
323+ """
324+ 流式查询:返回 async generator(逐 token / 逐块)
325+ 依赖 self.llm_model_stream_func,不提供则抛错。
326+ """
327+ if self .llm_model_stream_func is None :
328+ raise AttributeError ("llm_model_stream_func is not set, streaming is unavailable." )
329+
330+ # 把 stream func 放进 global_config
331+ cfg = asdict (self )
332+ cfg ["llm_model_stream_func" ] = self .llm_model_stream_func
333+
334+ if param .mode == "hyper" :
335+ async for tok in hyper_query_stream (
336+ query ,
337+ self .chunk_entity_relation_hypergraph ,
338+ self .entities_vdb ,
339+ self .relationships_vdb ,
340+ self .text_chunks ,
341+ param ,
342+ cfg ,
343+ ):
344+ yield tok
345+
346+ elif param .mode == "hyper-lite" :
347+ async for tok in hyper_query_lite_stream (
348+ query ,
349+ self .chunk_entity_relation_hypergraph ,
350+ self .entities_vdb ,
351+ self .text_chunks ,
352+ param ,
353+ cfg ,
354+ ):
355+ yield tok
356+
357+ elif param .mode == "naive" :
358+ async for tok in naive_query_stream (
359+ query ,
360+ self .chunks_vdb ,
361+ self .text_chunks ,
362+ param ,
363+ cfg ,
364+ ):
365+ yield tok
366+
367+ elif param .mode == "llm" :
368+ async for tok in llm_query_stream (query , param , cfg ):
369+ yield tok
370+
371+ else :
372+ raise ValueError (f"Unknown mode { param .mode } " )
373+
374+ await self ._query_done ()
375+
376+
307377 async def _query_done (self ):
308378 tasks = []
309379 for storage_inst in [self .llm_response_cache ]:
0 commit comments