Skip to content

Commit 443ee29

Browse files
Merge pull request InternScience#42 from open-sciencelab/community
feat: add cot data generation pipeline
2 parents f08558a + 5a978ed commit 443ee29

52 files changed

Lines changed: 2432 additions & 562 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Furthermore, GraphGen incorporates multi-hop neighborhood sampling to capture co
5656

5757
## 📌 Latest Updates
5858

59-
- **2025.07.31**: We have added Google, Bing, Wikipedia, and UniProt as search back-ends, perfect for closing data gaps.
59+
- **2025.07.31**: We have added Google, Bing, Wikipedia, and UniProt as search back-ends.
6060
- **2025.04.21**: We have released the initial version of GraphGen.
6161

6262
## 🚀 Quick Start
@@ -136,18 +136,31 @@ For any questions, please check [FAQ](https://github.com/open-sciencelab/GraphGe
136136
TRAINEE_BASE_URL=your_base_url_for_trainee_model
137137
TRAINEE_API_KEY=your_api_key_for_trainee_model
138138
```
139-
2. (Optional) If you want to modify the default generated configuration, you can edit the content of the configs/graphgen_config.yaml file.
139+
2. (Optional) Customize generation parameters in `graphgen/configs/` folder.
140+
141+
Edit the corresponding YAML file, e.g.:
142+
140143
```yaml
141-
# configs/graphgen_config.yaml
142-
# Example configuration
143-
data_type: "raw"
144-
input_file: "resources/examples/raw_demo.jsonl"
145-
# more configurations...
144+
# configs/cot_config.yaml
145+
input_data_type: raw
146+
input_file: resources/input_examples/raw_demo.jsonl
147+
output_data_type: cot
148+
tokenizer: cl100k_base
149+
# additional settings...
146150
```
147-
3. Run the generation script
148-
```bash
149-
bash scripts/generate.sh
150-
```
151+
152+
3. Generate data
153+
154+
Pick the desired format and run the matching script:
155+
156+
| Format | Script to run | Notes |
157+
| ------------ | ---------------------------------------------- |-------------------------------------------------------------------|
158+
| `cot` | `bash scripts/generate/generate_cot.sh` | Chain-of-Thought Q\&A pairs |
159+
| `atomic` | `bash scripts/generate/generate_atomic.sh` | Atomic Q\&A pairs covering basic knowledge |
160+
| `aggregated` | `bash scripts/generate/generate_aggregated.sh` | Aggregated Q\&A pairs incorporating complex, integrated knowledge |
161+
| `multi-hop` | `bash scripts/generate/generate_multihop.sh` | Multi-hop reasoning Q\&A pairs |
162+
163+
151164
4. Get the generated data
152165
```bash
153166
ls cache/data/graphgen
@@ -176,7 +189,8 @@ See [analysis](https://deepwiki.com/open-sciencelab/GraphGen) by deepwiki for a
176189
## 🍀 Acknowledgements
177190
- [SiliconFlow](https://siliconflow.cn) Abundant LLM API, some models are free
178191
- [LightRAG](https://github.com/HKUDS/LightRAG) Simple and efficient graph retrieval solution
179-
- [ROGRAG](https://github.com/tpoisonooo/ROGRAG) ROGRAG: A Robustly Optimized GraphRAG Framework
192+
- [ROGRAG](https://github.com/tpoisonooo/ROGRAG) A robustly optimized GraphRAG framework
193+
- [DB-GPT](https://github.com/eosphoros-ai/DB-GPT) An AI native data app development framework
180194

181195

182196
## 📚 Citation

baselines/EntiGraph/entigraph.py

Lines changed: 90 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# https://arxiv.org/abs/2409.07431
22
# https://github.com/zitongyang/synthetic_continued_pretraining
33

4-
import os
4+
import argparse
5+
import asyncio
56
import json
7+
import os
68
import random
7-
import asyncio
8-
import argparse
99
from hashlib import md5
1010

1111
from tqdm.asyncio import tqdm as tqdm_async
@@ -18,9 +18,9 @@ def compute_content_hash(content, prefix: str = ""):
1818
return prefix + md5(content.encode()).hexdigest()
1919

2020

21-
async def generate_entities(document_content: str,
22-
system_message: str,
23-
openai_model: str):
21+
async def generate_entities(
22+
document_content: str, system_message: str, openai_model: str
23+
):
2424
prompt = f"""
2525
### Document Content:
2626
{document_content}
@@ -30,41 +30,44 @@ async def generate_entities(document_content: str,
3030
max_tries = 5
3131
while not can_read_entities and max_tries > 0:
3232
try:
33-
completion = await gptqa(prompt,
34-
openai_model,
35-
system_message,
36-
json_format=False)
37-
completion = completion[completion.find("{"): completion.rfind("}") + 1]
33+
completion = await gptqa(
34+
prompt, openai_model, system_message, json_format=False
35+
)
36+
completion = completion[completion.find("{") : completion.rfind("}") + 1]
3837
response = json.loads(completion)
39-
can_read_entities = response['entities']
38+
can_read_entities = response["entities"]
4039
return response
41-
except Exception as e: # pylint: disable=broad-except
40+
except Exception as e: # pylint: disable=broad-except
4241
print(f"Failed to generate entities: {str(e)}")
4342
max_tries -= 1
4443

45-
async def generate_two_entity_relations(document_content: str,
46-
entity1: str,
47-
entity2: str,
48-
system_message: str,
49-
openai_model: str):
44+
45+
async def generate_two_entity_relations(
46+
document_content: str,
47+
entity1: str,
48+
entity2: str,
49+
system_message: str,
50+
openai_model: str,
51+
):
5052
prompt = f"""
5153
### Document Content:
5254
{document_content}
5355
### Entities:
5456
- {entity1}
5557
- {entity2}
5658
"""
57-
completion = await gptqa(prompt,
58-
openai_model,
59-
system_message)
59+
completion = await gptqa(prompt, openai_model, system_message)
6060
return completion
6161

62-
async def generate_three_entity_relations(document_content: str,
63-
entity1: str,
64-
entity2: str,
65-
entity3: str,
66-
system_message: str,
67-
openai_model: str):
62+
63+
async def generate_three_entity_relations(
64+
document_content: str,
65+
entity1: str,
66+
entity2: str,
67+
entity3: str,
68+
system_message: str,
69+
openai_model: str,
70+
):
6871
prompt = f"""
6972
### Document Content:
7073
{document_content}
@@ -73,11 +76,10 @@ async def generate_three_entity_relations(document_content: str,
7376
- {entity2}
7477
- {entity3}
7578
"""
76-
completion = await gptqa(prompt,
77-
openai_model,
78-
system_message)
79+
completion = await gptqa(prompt, openai_model, system_message)
7980
return completion
8081

82+
8183
def _post_process_synthetic_data(data):
8284
block = data.split("\n\n")
8385
qas = {}
@@ -87,7 +89,7 @@ def _post_process_synthetic_data(data):
8789
answer = line.split("Answer: ")[1]
8890
qas[compute_content_hash(question)] = {
8991
"question": question,
90-
"answer": answer
92+
"answer": answer,
9193
}
9294
break
9395
return qas
@@ -105,25 +107,26 @@ async def generate_document_entities(doc):
105107
async with semaphore:
106108
try:
107109
entities = await generate_entities(
108-
doc.text,
109-
task.openai_system_generate_entities,
110-
model_name)
110+
doc.text, task.openai_system_generate_entities, model_name
111+
)
111112
if not entities:
112113
return None
113114
return {
114-
'document': doc.text,
115-
'entities': entities['entities'],
116-
'summary': entities['summary']
115+
"document": doc.text,
116+
"entities": entities["entities"],
117+
"summary": entities["summary"],
117118
}
118-
except Exception as e: # pylint: disable=broad-except
119+
except Exception as e: # pylint: disable=broad-except
119120
print(f"Error: {e}")
120121
return None
121122

122123
entities_list = []
123124
for result in tqdm_async(
124-
asyncio.as_completed([generate_document_entities(doc) for doc in task.documents]),
125-
total=len(task.documents),
126-
desc="Generating entities"
125+
asyncio.as_completed(
126+
[generate_document_entities(doc) for doc in task.documents]
127+
),
128+
total=len(task.documents),
129+
desc="Generating entities",
127130
):
128131
result = await result
129132
if result:
@@ -132,38 +135,42 @@ async def generate_document_entities(doc):
132135
# iterate over triples of entities and generate relations
133136
pair_list = []
134137
for doc in entities_list:
135-
entities = doc['entities']
138+
entities = doc["entities"]
136139
temp = []
137140
for i, entity_i in enumerate(entities):
138141
if i == len(entities) - 1:
139142
break
140143
for j in range(i + 1, len(entities)):
141144
entity_j = entities[j]
142-
pair = (doc['document'], entity_i, entity_j)
145+
pair = (doc["document"], entity_i, entity_j)
143146
temp.append(pair)
144147

145148
# Compute all possible combinations of entities is impractical, so we randomly sample 10 pairs
146149
pair_list.extend(random.sample(temp, min(len(temp), 10)))
147150

148-
149151
async def process_two_entity_relations(pair):
150152
async with semaphore:
151153
try:
152154
document, entity1, entity2 = pair
153155
response = await generate_two_entity_relations(
154-
document, entity1, entity2,
156+
document,
157+
entity1,
158+
entity2,
155159
task.openai_system_generate_two_entity_relations,
156-
model_name)
160+
model_name,
161+
)
157162
return response
158-
except Exception as e: # pylint: disable=broad-except
163+
except Exception as e: # pylint: disable=broad-except
159164
print(f"Error: {e}")
160165
return None
161166

162-
corpus= []
167+
corpus = []
163168
for result in tqdm_async(
164-
asyncio.as_completed([process_two_entity_relations(pair) for pair in pair_list]),
165-
total=len(pair_list),
166-
desc="Generating two entity relations"
169+
asyncio.as_completed(
170+
[process_two_entity_relations(pair) for pair in pair_list]
171+
),
172+
total=len(pair_list),
173+
desc="Generating two entity relations",
167174
):
168175
result = await result
169176
if result:
@@ -194,51 +201,60 @@ async def process_two_entity_relations(pair):
194201
# ):
195202
# corpus.append(await result)
196203

197-
corpus = [doc['summary'] for doc in entities_list] + corpus
204+
corpus = [doc["summary"] for doc in entities_list] + corpus
198205

199206
qa_sft_results = {}
200207

201208
async def generate_qa_sft(content):
202209
async with semaphore:
203-
completion = await gptqa(content, model_name, task.openai_system_quality_qa_sft)
210+
completion = await gptqa(
211+
content, model_name, task.openai_system_quality_qa_sft
212+
)
204213
return completion
205214

206-
207215
for result in tqdm_async(
208-
asyncio.as_completed([generate_qa_sft(content) for content in corpus]),
209-
total=len(corpus),
210-
desc="Generating QA SFT"
216+
asyncio.as_completed([generate_qa_sft(content) for content in corpus]),
217+
total=len(corpus),
218+
desc="Generating QA SFT",
211219
):
212220
try:
213221
result = await result
214222
if result:
215223
qa_sft_results.update(_post_process_synthetic_data(result))
216-
except Exception as e: # pylint: disable=broad-except
224+
except Exception as e: # pylint: disable=broad-except
217225
print(f"Error: {e}")
218226

219227
return qa_sft_results
220228

221229

222-
if __name__ == '__main__':
230+
if __name__ == "__main__":
223231
parser = argparse.ArgumentParser()
224-
parser.add_argument('--input_file',
225-
help='Raw context jsonl path.',
226-
default='resources/examples/chunked_demo.json',
227-
type=str)
228-
parser.add_argument('--data_type',
229-
help='Data type of input file. (Raw context or chunked context)',
230-
choices=['raw', 'chunked'],
231-
default='raw',
232-
type=str)
233-
parser.add_argument('--output_file',
234-
help='Output file path.',
235-
default='cache/data/entigraph.json',
236-
type=str)
232+
parser.add_argument(
233+
"--input_file",
234+
help="Raw context jsonl path.",
235+
default="resources/input_examples/chunked_demo.json",
236+
type=str,
237+
)
238+
parser.add_argument(
239+
"--data_type",
240+
help="Data type of input file. (Raw context or chunked context)",
241+
choices=["raw", "chunked"],
242+
default="raw",
243+
type=str,
244+
)
245+
parser.add_argument(
246+
"--output_file",
247+
help="Output file path.",
248+
default="cache/data/entigraph.json",
249+
type=str,
250+
)
237251

238252
args = parser.parse_args()
239253

240-
results = asyncio.run(generate_synthetic_data_for_document(args.input_file, args.data_type))
254+
results = asyncio.run(
255+
generate_synthetic_data_for_document(args.input_file, args.data_type)
256+
)
241257

242258
# Save results
243-
with open(args.output_file, "w", encoding='utf-8') as f:
259+
with open(args.output_file, "w", encoding="utf-8") as f:
244260
json.dump(results, f, indent=4, ensure_ascii=False)

0 commit comments

Comments
 (0)