Skip to content

Commit 4188441

Browse files
refactor: use async_to_sync_method
1 parent c42cd24 commit 4188441

16 files changed

Lines changed: 112 additions & 120 deletions

File tree

graphgen/graphgen.py

Lines changed: 42 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
from dataclasses import dataclass, field
5-
from typing import Dict, List, Union, cast
5+
from typing import Dict, cast
66

77
import gradio as gr
88
from tqdm.asyncio import tqdm as tqdm_async
@@ -16,23 +16,22 @@
1616
OpenAIModel,
1717
Tokenizer,
1818
TraverseStrategy,
19-
read_file,
20-
split_chunks,
2119
)
22-
23-
from .operators import (
20+
from graphgen.operators import (
2421
extract_kg,
2522
generate_cot,
2623
judge_statement,
2724
quiz,
25+
read_files,
2826
search_all,
27+
split_chunks,
2928
traverse_graph_for_aggregated,
3029
traverse_graph_for_atomic,
3130
traverse_graph_for_multi_hop,
3231
)
33-
from .utils import (
32+
from graphgen.utils import (
33+
async_to_sync_method,
3434
compute_content_hash,
35-
create_event_loop,
3635
detect_main_language,
3736
format_generation_results,
3837
logger,
@@ -106,15 +105,25 @@ def __post_init__(self):
106105
namespace=f"qa-{self.unique_id}",
107106
)
108107

109-
async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict:
110-
# TODO: configurable whether to use coreference resolution
108+
@async_to_sync_method
109+
async def insert(self):
110+
"""
111+
insert chunks into the graph
112+
"""
113+
114+
input_file = self.config["read"]["input_file"]
115+
116+
# Step 1: Read files
117+
data = read_files(input_file)
111118
if len(data) == 0:
112-
return {}
119+
logger.warning("No data to process")
120+
return
121+
122+
# TODO: configurable whether to use coreference resolution
113123

124+
# Step 2: Split chunks and filter existing ones
114125
inserting_chunks = {}
115126
assert isinstance(data, list) and isinstance(data[0], dict)
116-
117-
# compute hash for each document
118127
new_docs = {
119128
compute_content_hash(doc["content"], prefix="doc-"): {
120129
"content": doc["content"]
@@ -123,9 +132,10 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict:
123132
}
124133
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
125134
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
135+
126136
if len(new_docs) == 0:
127137
logger.warning("All docs are already in the storage")
128-
return {}
138+
return
129139
logger.info("[New Docs] inserting %d docs", len(new_docs))
130140

131141
cur_index = 1
@@ -162,29 +172,16 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict:
162172
inserting_chunks = {
163173
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
164174
}
165-
await self.full_docs_storage.upsert(new_docs)
166-
await self.text_chunks_storage.upsert(inserting_chunks)
167-
168-
return inserting_chunks
169-
170-
def insert(self):
171-
loop = create_event_loop()
172-
loop.run_until_complete(self.async_insert())
173-
174-
async def async_insert(self):
175-
"""
176-
insert chunks into the graph
177-
"""
178-
179-
input_file = self.config["read"]["input_file"]
180-
data = read_file(input_file)
181-
inserting_chunks = await self.async_split_chunks(data)
182175

183176
if len(inserting_chunks) == 0:
184177
logger.warning("All chunks are already in the storage")
185178
return
179+
186180
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
181+
await self.full_docs_storage.upsert(new_docs)
182+
await self.text_chunks_storage.upsert(inserting_chunks)
187183

184+
# Step 3: Extract entities and relations from chunks
188185
logger.info("[Entity and Relation Extraction]...")
189186
_add_entities_and_relations = await extract_kg(
190187
llm_client=self.synthesizer_llm_client,
@@ -214,11 +211,8 @@ async def _insert_done(self):
214211
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
215212
await asyncio.gather(*tasks)
216213

217-
def search(self):
218-
loop = create_event_loop()
219-
loop.run_until_complete(self.async_search())
220-
221-
async def async_search(self):
214+
@async_to_sync_method
215+
async def search(self):
222216
logger.info(
223217
"Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
224218
)
@@ -254,11 +248,8 @@ async def async_search(self):
254248
# TODO: fix insert after search
255249
await self.async_insert()
256250

257-
def quiz(self):
258-
loop = create_event_loop()
259-
loop.run_until_complete(self.async_quiz())
260-
261-
async def async_quiz(self):
251+
@async_to_sync_method
252+
async def quiz(self):
262253
max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
263254
await quiz(
264255
self.synthesizer_llm_client,
@@ -268,11 +259,8 @@ async def async_quiz(self):
268259
)
269260
await self.rephrase_storage.index_done_callback()
270261

271-
def judge(self):
272-
loop = create_event_loop()
273-
loop.run_until_complete(self.async_judge())
274-
275-
async def async_judge(self):
262+
@async_to_sync_method
263+
async def judge(self):
276264
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
277265
_update_relations = await judge_statement(
278266
self.trainee_llm_client,
@@ -282,11 +270,8 @@ async def async_judge(self):
282270
)
283271
await _update_relations.index_done_callback()
284272

285-
def traverse(self):
286-
loop = create_event_loop()
287-
loop.run_until_complete(self.async_traverse())
288-
289-
async def async_traverse(self):
273+
@async_to_sync_method
274+
async def traverse(self):
290275
output_data_type = self.config["output_data_type"]
291276

292277
if output_data_type == "atomic":
@@ -326,11 +311,12 @@ async def async_traverse(self):
326311
await self.qa_storage.upsert(results)
327312
await self.qa_storage.index_done_callback()
328313

329-
def generate_reasoning(self, method_params):
330-
loop = create_event_loop()
331-
loop.run_until_complete(self.async_generate_reasoning(method_params))
314+
# def generate_reasoning(self, method_params):
315+
# loop = create_event_loop()
316+
# loop.run_until_complete(self.async_generate_reasoning(method_params))
332317

333-
async def async_generate_reasoning(self, method_params):
318+
@async_to_sync_method
319+
async def generate_reasoning(self, method_params):
334320
results = await generate_cot(
335321
self.graph_storage,
336322
self.synthesizer_llm_client,
@@ -344,11 +330,8 @@ async def async_generate_reasoning(self, method_params):
344330
await self.qa_storage.upsert(results)
345331
await self.qa_storage.index_done_callback()
346332

347-
def clear(self):
348-
loop = create_event_loop()
349-
loop.run_until_complete(self.async_clear())
350-
351-
async def async_clear(self):
333+
@async_to_sync_method
334+
async def clear(self):
352335
await self.full_docs_storage.drop()
353336
await self.text_chunks_storage.drop()
354337
await self.search_storage.drop()

graphgen/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from .llm.openai_model import OpenAIModel
77
from .llm.tokenizer import Tokenizer
88
from .llm.topk_token_model import Token, TopkTokenModel
9-
from .reader import read_file
9+
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
1010
from .search.db.uniprot_search import UniProtSearch
1111
from .search.kg.wiki_search import WikiSearch
1212
from .search.web.bing_search import BingSearch
1313
from .search.web.google_search import GoogleSearch
14-
from .splitter import split_chunks
14+
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
1515
from .storage.json_storage import JsonKVStorage, JsonListStorage
1616
from .storage.networkx_storage import NetworkXStorage
1717
from .strategy.travserse_strategy import TraverseStrategy

graphgen/models/reader/__init__.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,3 @@
22
from .json_reader import JsonReader
33
from .jsonl_reader import JsonlReader
44
from .txt_reader import TxtReader
5-
6-
_MAPPING = {
7-
"jsonl": JsonlReader,
8-
"json": JsonReader,
9-
"txt": TxtReader,
10-
"csv": CsvReader,
11-
}
12-
13-
14-
def read_file(file_path: str):
15-
suffix = file_path.split(".")[-1]
16-
if suffix in _MAPPING:
17-
reader = _MAPPING[suffix]()
18-
else:
19-
raise ValueError(
20-
f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}"
21-
)
22-
return reader.read(file_path)
Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,4 @@
1-
from functools import lru_cache
2-
from typing import Union
3-
41
from .recursive_character_splitter import (
52
ChineseRecursiveTextSplitter,
63
RecursiveCharacterSplitter,
74
)
8-
9-
_MAPPING = {
10-
"en": RecursiveCharacterSplitter,
11-
"zh": ChineseRecursiveTextSplitter,
12-
}
13-
14-
SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]
15-
16-
17-
@lru_cache(maxsize=None)
18-
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
19-
cls = _MAPPING[language]
20-
kwargs = dict(frozen_kwargs)
21-
return cls(**kwargs)
22-
23-
24-
def split_chunks(text: str, language: str = "en", **kwargs) -> list:
25-
if language not in _MAPPING:
26-
raise ValueError(
27-
f"Unsupported language: {language}. "
28-
f"Supported languages are: {list(_MAPPING.keys())}"
29-
)
30-
splitter = _get_splitter(language, frozenset(kwargs.items()))
31-
return splitter.split_text(text)

graphgen/operators/__init__.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
1+
from graphgen.operators.build_kg.extract_kg import extract_kg
12
from graphgen.operators.generate.generate_cot import generate_cot
2-
from graphgen.operators.kg.extract_kg import extract_kg
33
from graphgen.operators.search.search_all import search_all
44

55
from .judge import judge_statement
66
from .quiz import quiz
7+
from .read import read_files
8+
from .split import split_chunks
79
from .traverse_graph import (
810
traverse_graph_for_aggregated,
911
traverse_graph_for_atomic,
1012
traverse_graph_for_multi_hop,
1113
)
12-
13-
__all__ = [
14-
"extract_kg",
15-
"quiz",
16-
"judge_statement",
17-
"search_all",
18-
"traverse_graph_for_aggregated",
19-
"traverse_graph_for_atomic",
20-
"traverse_graph_for_multi_hop",
21-
"generate_cot",
22-
]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from graphgen.bases.base_storage import BaseGraphStorage
1010
from graphgen.bases.datatypes import Chunk
1111
from graphgen.models import OpenAIModel, Tokenizer
12-
from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
12+
from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes
1313
from graphgen.templates import KG_EXTRACTION_PROMPT
1414
from graphgen.utils import (
1515
detect_if_chinese,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .read_files import read_files

0 commit comments

Comments
 (0)