Skip to content

Commit 1e69b0e

Browse files
feat: add search config
1 parent 8eebad1 commit 1e69b0e

2 files changed

Lines changed: 24 additions & 15 deletions

File tree

graphgen/configs/search_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pipeline:
22
- name: read
33
params:
4-
input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
4+
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55

66
- name: search
77
params:

graphgen/graphgen.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(
6868
self.working_dir, namespace="graph"
6969
)
7070
self.search_storage: JsonKVStorage = JsonKVStorage(
71-
self.working_dir, namespace="search"
71+
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
72+
namespace="search",
7273
)
7374
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7475
self.working_dir, namespace="rephrase"
@@ -206,15 +207,23 @@ async def build_kg(self, inputs: List):
206207
# Step 3: store the new entities and relations
207208
await self.graph_storage.index_done_callback()
208209

209-
@op("search", deps=["read"], op_type=OpType.STREAMING)
210+
@op("search", deps=["read"], op_type=OpType.BATCH, batch_size=64)
210211
@async_to_sync_method
211-
async def search(self, search_config: Dict, input_stream: Iterator):
212+
async def search(self, search_config: Dict, inputs: List):
213+
"""
214+
search new documents from full_docs_storage
215+
input_stream: document IDs from full_docs_storage
216+
return: None
217+
"""
212218
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
213219

214-
seeds = await self.meta_storage.get_new_data(self.full_docs_storage)
215-
if len(seeds) == 0:
216-
logger.warning("All documents are already been searched")
217-
return
220+
# Step 1: get documents
221+
seeds = {}
222+
for doc_id in inputs:
223+
doc = await self.full_docs_storage.get_by_id(doc_id)
224+
if doc:
225+
seeds[doc_id] = doc
226+
218227
search_results = await search_all(
219228
seed_data=seeds,
220229
search_config=search_config,
@@ -223,16 +232,15 @@ async def search(self, search_config: Dict, input_stream: Iterator):
223232
_add_search_keys = await self.search_storage.filter_keys(
224233
list(search_results.keys())
225234
)
235+
226236
search_results = {
227237
k: v for k, v in search_results.items() if k in _add_search_keys
228238
}
229239
if len(search_results) == 0:
230-
logger.warning("All search results are already in the storage")
231-
return
240+
logger.warning("[Search] No new search results to add to storage")
241+
232242
await self.search_storage.upsert(search_results)
233243
await self.search_storage.index_done_callback()
234-
await self.meta_storage.mark_done(self.full_docs_storage)
235-
await self.meta_storage.index_done_callback()
236244

237245
@op("quiz_and_judge", deps=["build_kg"], op_type=OpType.BARRIER)
238246
@async_to_sync_method
@@ -276,6 +284,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
276284
@op("partition", deps=["build_kg"], op_type=OpType.BARRIER)
277285
@async_to_sync_method
278286
async def partition(self, partition_config: Dict):
287+
# TODO: partition 可以yield batches
279288
batches = await partition_kg(
280289
self.graph_storage,
281290
self.chunks_storage,
@@ -308,10 +317,10 @@ async def extract(self, extract_config: Dict, input_stream: Iterator):
308317
await self.meta_storage.mark_done(self.chunks_storage)
309318
await self.meta_storage.index_done_callback()
310319

311-
@op("generate", deps=["partition"], op_type=OpType.BARRIER)
320+
@op("generate", deps=["partition"], op_type=OpType.STREAMING)
312321
@async_to_sync_method
313-
async def generate(self, generate_config: Dict, inputs: None):
314-
322+
async def generate(self, generate_config: Dict, input_stream: Iterator):
323+
# TODO:
315324
batches = self.partition_storage.data
316325
if not batches:
317326
logger.warning("No partitions found for QA generation")

0 commit comments

Comments
 (0)