11# https://arxiv.org/abs/2409.07431
22# https://github.com/zitongyang/synthetic_continued_pretraining
33
4- import os
4+ import argparse
5+ import asyncio
56import json
7+ import os
68import random
7- import asyncio
8- import argparse
99from hashlib import md5
1010
1111from tqdm .asyncio import tqdm as tqdm_async
@@ -18,9 +18,9 @@ def compute_content_hash(content, prefix: str = ""):
1818 return prefix + md5 (content .encode ()).hexdigest ()
1919
2020
21- async def generate_entities (document_content : str ,
22- system_message : str ,
23- openai_model : str ):
21+ async def generate_entities (
22+ document_content : str , system_message : str , openai_model : str
23+ ):
2424 prompt = f"""
2525 ### Document Content:
2626 { document_content }
@@ -30,41 +30,44 @@ async def generate_entities(document_content: str,
3030 max_tries = 5
3131 while not can_read_entities and max_tries > 0 :
3232 try :
33- completion = await gptqa (prompt ,
34- openai_model ,
35- system_message ,
36- json_format = False )
37- completion = completion [completion .find ("{" ): completion .rfind ("}" ) + 1 ]
33+ completion = await gptqa (
34+ prompt , openai_model , system_message , json_format = False
35+ )
36+ completion = completion [completion .find ("{" ) : completion .rfind ("}" ) + 1 ]
3837 response = json .loads (completion )
39- can_read_entities = response [' entities' ]
38+ can_read_entities = response [" entities" ]
4039 return response
41- except Exception as e : # pylint: disable=broad-except
40+ except Exception as e : # pylint: disable=broad-except
4241 print (f"Failed to generate entities: { str (e )} " )
4342 max_tries -= 1
4443
45- async def generate_two_entity_relations (document_content : str ,
46- entity1 : str ,
47- entity2 : str ,
48- system_message : str ,
49- openai_model : str ):
44+
45+ async def generate_two_entity_relations (
46+ document_content : str ,
47+ entity1 : str ,
48+ entity2 : str ,
49+ system_message : str ,
50+ openai_model : str ,
51+ ):
5052 prompt = f"""
5153 ### Document Content:
5254 { document_content }
5355 ### Entities:
5456 - { entity1 }
5557 - { entity2 }
5658 """
57- completion = await gptqa (prompt ,
58- openai_model ,
59- system_message )
59+ completion = await gptqa (prompt , openai_model , system_message )
6060 return completion
6161
62- async def generate_three_entity_relations (document_content : str ,
63- entity1 : str ,
64- entity2 : str ,
65- entity3 : str ,
66- system_message : str ,
67- openai_model : str ):
62+
63+ async def generate_three_entity_relations (
64+ document_content : str ,
65+ entity1 : str ,
66+ entity2 : str ,
67+ entity3 : str ,
68+ system_message : str ,
69+ openai_model : str ,
70+ ):
6871 prompt = f"""
6972 ### Document Content:
7073 { document_content }
@@ -73,11 +76,10 @@ async def generate_three_entity_relations(document_content: str,
7376 - { entity2 }
7477 - { entity3 }
7578 """
76- completion = await gptqa (prompt ,
77- openai_model ,
78- system_message )
79+ completion = await gptqa (prompt , openai_model , system_message )
7980 return completion
8081
82+
8183def _post_process_synthetic_data (data ):
8284 block = data .split ("\n \n " )
8385 qas = {}
@@ -87,7 +89,7 @@ def _post_process_synthetic_data(data):
8789 answer = line .split ("Answer: " )[1 ]
8890 qas [compute_content_hash (question )] = {
8991 "question" : question ,
90- "answer" : answer
92+ "answer" : answer ,
9193 }
9294 break
9395 return qas
@@ -105,25 +107,26 @@ async def generate_document_entities(doc):
105107 async with semaphore :
106108 try :
107109 entities = await generate_entities (
108- doc .text ,
109- task .openai_system_generate_entities ,
110- model_name )
110+ doc .text , task .openai_system_generate_entities , model_name
111+ )
111112 if not entities :
112113 return None
113114 return {
114- ' document' : doc .text ,
115- ' entities' : entities [' entities' ],
116- ' summary' : entities [' summary' ]
115+ " document" : doc .text ,
116+ " entities" : entities [" entities" ],
117+ " summary" : entities [" summary" ],
117118 }
118- except Exception as e : # pylint: disable=broad-except
119+ except Exception as e : # pylint: disable=broad-except
119120 print (f"Error: { e } " )
120121 return None
121122
122123 entities_list = []
123124 for result in tqdm_async (
124- asyncio .as_completed ([generate_document_entities (doc ) for doc in task .documents ]),
125- total = len (task .documents ),
126- desc = "Generating entities"
125+ asyncio .as_completed (
126+ [generate_document_entities (doc ) for doc in task .documents ]
127+ ),
128+ total = len (task .documents ),
129+ desc = "Generating entities" ,
127130 ):
128131 result = await result
129132 if result :
@@ -132,38 +135,42 @@ async def generate_document_entities(doc):
132135 # iterate over triples of entities and generate relations
133136 pair_list = []
134137 for doc in entities_list :
135- entities = doc [' entities' ]
138+ entities = doc [" entities" ]
136139 temp = []
137140 for i , entity_i in enumerate (entities ):
138141 if i == len (entities ) - 1 :
139142 break
140143 for j in range (i + 1 , len (entities )):
141144 entity_j = entities [j ]
142- pair = (doc [' document' ], entity_i , entity_j )
145+ pair = (doc [" document" ], entity_i , entity_j )
143146 temp .append (pair )
144147
145148 # Compute all possible combinations of entities is impractical, so we randomly sample 10 pairs
146149 pair_list .extend (random .sample (temp , min (len (temp ), 10 )))
147150
148-
149151 async def process_two_entity_relations (pair ):
150152 async with semaphore :
151153 try :
152154 document , entity1 , entity2 = pair
153155 response = await generate_two_entity_relations (
154- document , entity1 , entity2 ,
156+ document ,
157+ entity1 ,
158+ entity2 ,
155159 task .openai_system_generate_two_entity_relations ,
156- model_name )
160+ model_name ,
161+ )
157162 return response
158- except Exception as e : # pylint: disable=broad-except
163+ except Exception as e : # pylint: disable=broad-except
159164 print (f"Error: { e } " )
160165 return None
161166
162- corpus = []
167+ corpus = []
163168 for result in tqdm_async (
164- asyncio .as_completed ([process_two_entity_relations (pair ) for pair in pair_list ]),
165- total = len (pair_list ),
166- desc = "Generating two entity relations"
169+ asyncio .as_completed (
170+ [process_two_entity_relations (pair ) for pair in pair_list ]
171+ ),
172+ total = len (pair_list ),
173+ desc = "Generating two entity relations" ,
167174 ):
168175 result = await result
169176 if result :
@@ -194,51 +201,60 @@ async def process_two_entity_relations(pair):
194201 # ):
195202 # corpus.append(await result)
196203
197- corpus = [doc [' summary' ] for doc in entities_list ] + corpus
204+ corpus = [doc [" summary" ] for doc in entities_list ] + corpus
198205
199206 qa_sft_results = {}
200207
201208 async def generate_qa_sft (content ):
202209 async with semaphore :
203- completion = await gptqa (content , model_name , task .openai_system_quality_qa_sft )
210+ completion = await gptqa (
211+ content , model_name , task .openai_system_quality_qa_sft
212+ )
204213 return completion
205214
206-
207215 for result in tqdm_async (
208- asyncio .as_completed ([generate_qa_sft (content ) for content in corpus ]),
209- total = len (corpus ),
210- desc = "Generating QA SFT"
216+ asyncio .as_completed ([generate_qa_sft (content ) for content in corpus ]),
217+ total = len (corpus ),
218+ desc = "Generating QA SFT" ,
211219 ):
212220 try :
213221 result = await result
214222 if result :
215223 qa_sft_results .update (_post_process_synthetic_data (result ))
216- except Exception as e : # pylint: disable=broad-except
224+ except Exception as e : # pylint: disable=broad-except
217225 print (f"Error: { e } " )
218226
219227 return qa_sft_results
220228
221229
222- if __name__ == ' __main__' :
230+ if __name__ == " __main__" :
223231 parser = argparse .ArgumentParser ()
224- parser .add_argument ('--input_file' ,
225- help = 'Raw context jsonl path.' ,
226- default = 'resources/examples/chunked_demo.json' ,
227- type = str )
228- parser .add_argument ('--data_type' ,
229- help = 'Data type of input file. (Raw context or chunked context)' ,
230- choices = ['raw' , 'chunked' ],
231- default = 'raw' ,
232- type = str )
233- parser .add_argument ('--output_file' ,
234- help = 'Output file path.' ,
235- default = 'cache/data/entigraph.json' ,
236- type = str )
232+ parser .add_argument (
233+ "--input_file" ,
234+ help = "Raw context jsonl path." ,
235+ default = "resources/input_examples/chunked_demo.json" ,
236+ type = str ,
237+ )
238+ parser .add_argument (
239+ "--data_type" ,
240+ help = "Data type of input file. (Raw context or chunked context)" ,
241+ choices = ["raw" , "chunked" ],
242+ default = "raw" ,
243+ type = str ,
244+ )
245+ parser .add_argument (
246+ "--output_file" ,
247+ help = "Output file path." ,
248+ default = "cache/data/entigraph.json" ,
249+ type = str ,
250+ )
237251
238252 args = parser .parse_args ()
239253
240- results = asyncio .run (generate_synthetic_data_for_document (args .input_file , args .data_type ))
254+ results = asyncio .run (
255+ generate_synthetic_data_for_document (args .input_file , args .data_type )
256+ )
241257
242258 # Save results
243- with open (args .output_file , "w" , encoding = ' utf-8' ) as f :
259+ with open (args .output_file , "w" , encoding = " utf-8" ) as f :
244260 json .dump (results , f , indent = 4 , ensure_ascii = False )
0 commit comments