@@ -146,8 +146,32 @@ def prepare(cls, langs, out_dir: Path, dataset_ids=Dict[str, List[DatasetId]],
146146 if dataset_ids .get ('dev' ):
147147 dev_entries = cls .resolve_entries (dataset_ids ['dev' ])
148148 dataset .add_dev_entries (dev_entries )
149- if dataset_ids .get ('train' ): # this might take some time
149+ # Phase 2: Process train parts and mono entries concurrently
150+ # Bitext train parts and mono entries are independent — they write to different dirs
151+ all_tasks = [] # list of (func, task_dict) pairs
152+ train_entries = []
153+ if dataset_ids .get ('train' ):
150154 train_entries = cls .resolve_entries (dataset_ids ['train' ])
155+ for ent in train_entries :
156+ all_tasks .append ((dataset .add_part , dict (
157+ dir_path = dataset .train_parts_dir , entry = ent ,
158+ drop_noise = dataset .drop_train_noise , compress = compress )))
159+ mono_groups = []
160+ for key , dirpath in [('mono_train' , dataset .mono_train_parts_dir ),
161+ ('mono_dev' , dataset .mono_tests_dir ),
162+ ('mono_test' , dataset .mono_tests_dir )]:
163+ if dataset_ids .get (key ):
164+ dirpath .mkdir (exist_ok = True )
165+ entries = cls .resolve_entries (dataset_ids [key ])
166+ mono_groups .append ((key , entries ))
167+ for ent in entries :
168+ all_tasks .append ((dataset .add_mono_entry , dict (
169+ dirpath = dirpath , entry = ent , compress = compress )))
170+ if all_tasks :
171+ dataset ._run_entries_multi (all_tasks , desc = 'Processing entries' ,
172+ fail_on_error = fail_on_error )
173+ # Phase 3: Merge train if requested (must happen after all train parts are written)
174+ if train_entries and merge_train :
151175 drop_hashes = None
152176 if drop_tests :
153177 pair_files = []
@@ -157,22 +181,8 @@ def prepare(cls, langs, out_dir: Path, dataset_ids=Dict[str, List[DatasetId]],
157181 p1 , p2 = p2 , p1 # swap
158182 pair_files .append ((p1 , p2 ))
159183 test_pair_hash , test_seg_hash = dataset .hash_all_bitexts (pair_files )
160- drop_hashes = test_pair_hash | test_seg_hash # set union
161- dataset .add_train_entries (train_entries , merge_train = merge_train , compress = compress ,
162- drop_hashes = drop_hashes )
163- for key , dirpath in [('mono_train' , dataset .mono_train_parts_dir ),
164- ('mono_dev' , dataset .mono_tests_dir ),
165- ('mono_test' , dataset .mono_tests_dir )]:
166- if dataset_ids .get (key ):
167- dirpath .mkdir (exist_ok = True )
168- entries = cls .resolve_entries (dataset_ids [key ])
169- dataset .add_mono_entries (
170- dirpath ,
171- entries ,
172- compress = compress ,
173- desc = key .replace ('_' , ' ' ),
174- fail_on_error = fail_on_error ,
175- )
184+ drop_hashes = test_pair_hash | test_seg_hash
185+ dataset ._merge_train (train_entries , compress = compress , drop_hashes = drop_hashes )
176186
177187 # citations
178188 refs_file = out_dir / 'references.bib'
@@ -204,11 +214,8 @@ def hash_all_bitexts(self, paired_files):
204214 seg_hashes .add (hash (seg2 ))
205215 return paired_hashes , seg_hashes
206216
207- def add_train_entries (self , entries , merge_train = False , compress = False , drop_hashes = None ):
208- self .add_parts (self .train_parts_dir , entries , drop_noise = self .drop_train_noise ,
209- compress = compress , desc = 'Training sets' , fail_on_error = self .fail_on_error )
210- if not merge_train :
211- return
217+ def _merge_train (self , entries , compress = False , drop_hashes = None ):
218+ """Merge already-written train parts into single train files."""
212219 lang1 , lang2 = self .langs
213220 # paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2)
214221 paired_files = {}
@@ -274,41 +281,34 @@ def add_mono_entry(self, dirpath, entry: Entry, compress=False):
274281 cache_path = self .cache .get_entry (entry )
275282 parser = Parser (cache_path , ext = entry .in_ext or None , ent = entry )
276283 out_path , meta_file = self .get_paths (dirpath , entry , compress = compress )
277- log . info ( "Writing %s to %s" , entry .did , out_path )
284+ pbar_man . emit_log ( log . INFO , f "Writing { entry .did } to { out_path } " )
278285 io_args = dict (encoding = 'utf-8' , errors = 'ignore' )
279- has_meta = None # None -> True/False on first row ;then ensure it is consistent
280286 with pbar_man .counter (unit = 'line' , desc = f"Processing { entry .did } " ) as pbar , \
281287 IO .writer (out_path , ** io_args ) as out , IO .writer (meta_file , ** io_args ) as out_meta :
282- count , skips , read_count = 0 , 0 , 0
288+ count , skips = 0 , 0
289+ has_meta = None
283290 for row in parser .read_segs (show_pbar = False ):
284- read_count += 1
285- if has_meta is None : # first row only
291+ if has_meta is None :
286292 has_meta = bool (isinstance (row , (list , tuple )) and len (row ) > 1 )
287- sentence = row
288- if isinstance (row , (list , tuple )):
289- sentence = row [0 ] # flatten list
290- assert isinstance (sentence , str ), f'str sentence expected. found: { type (sentence )} ; entry: { entry .did } '
291- sentence = sentence and sentence .strip ()
293+ sentence = row [0 ] if isinstance (row , (list , tuple )) else row
294+ sentence = sentence .strip ().replace ('\t ' , ' ' ).replace ('\r ' , ' ' ) if sentence else ''
292295 if not sentence :
293296 skips += 1
294- pbar .update (incr = 1 , write_count = count )
297+ pbar .update ()
295298 continue
296- sentence = sentence .replace ('\n ' , ' ' ).replace ('\t ' , ' ' ).replace ('\r ' , ' ' )
297- out .write (f'{ sentence } \n ' )
299+ out .write (sentence + '\n ' )
298300 if has_meta :
299- assert len (row ) >= 2 , f'Expected 2 fields, found { len (row )} : { row } '
300- meta = json .dumps (row [1 ], ensure_ascii = False , indent = None )
301- meta = meta .replace ('\n ' , ' ' ).replace ('\t ' , ' ' ).replace ('\r ' , ' ' )
302- out_meta .write (f'{ meta } \n ' )
301+ meta = json .dumps (row [1 ], ensure_ascii = False , indent = None ).replace ('\t ' , ' ' ).replace ('\r ' , ' ' )
302+ out_meta .write (meta + '\n ' )
303303 count += 1
304- pbar .update (incr = 1 , write_count = count )
304+ pbar .update (write_count = count )
305305 msg = f'Looks like an error. { count } segs are valid { skips } are invalid: { entry } '
306306 assert count > 0 , msg
307307 if skips > count :
308308 pbar_man .emit_log (log .WARNING , msg )
309309 pbar_man .emit_log (log .INFO , f"{ entry } : Skips : { skips :,} /{ count :,} => { 100 * skips / count :.4f} %" )
310310 if not has_meta and meta_file .exists ():
311- meta_file .unlink () # remove empty meta file; file may not be empty if its .gz compressed
311+ meta_file .unlink ()
312312 flag_file .touch ()
313313 return count , skips
314314
@@ -471,6 +471,37 @@ def _run_entries(self, func, tasks, entries, desc=None, fail_on_error=False):
471471 finally :
472472 pbar .update (force = True )
473473
474+ def _run_entries_multi (self , func_task_pairs , desc = None , fail_on_error = False ):
475+ """Run heterogeneous (func, task_dict) pairs through a single worker pool."""
476+ for _ , t in func_task_pairs :
477+ t ['fail_on_error' ] = fail_on_error
478+ total = len (func_task_pairs )
479+ if self .n_jobs == 1 :
480+ with pbar_man .counter (total = total , unit = 'it' , desc = desc ) as pbar :
481+ for func , task in func_task_pairs :
482+ self ._entry_worker (func , task )
483+ pbar .update (force = True )
484+ return
485+ import multiprocessing as mp
486+ progress_queue = mp .Queue ()
487+ with pbar_man .counter (total = total , unit = 'it' , desc = desc ) as pbar :
488+ with pbar_man .consume_remote (progress_queue ):
489+ with concurrent .futures .ProcessPoolExecutor (
490+ max_workers = self .n_jobs ,
491+ initializer = _worker_init ,
492+ initargs = (progress_queue , log .WARNING )) as executor :
493+ futures = [executor .submit (self ._entry_worker , func , task )
494+ for func , task in func_task_pairs ]
495+ for future in concurrent .futures .as_completed (futures ):
496+ try :
497+ future .result ()
498+ except Exception as e :
499+ pbar_man .emit_log (log .ERROR , f"Error in worker: { e } " )
500+ if fail_on_error :
501+ raise e
502+ finally :
503+ pbar .update (force = True )
504+
474505 @classmethod
475506 def get_paths (cls , dir_path : Path , entry : Entry , compress = False ) -> Union [Tuple [Path , Path ], Tuple [Path , Path , Path ]]:
476507 """
@@ -497,43 +528,36 @@ def add_part(self, dir_path: Path, entry: Entry, drop_noise=False, compress=Fals
497528 pbar_man .emit_log (log .INFO , f"{ flag_file } exists. Skipping" )
498529 return - 1 , - 1
499530 path = self .cache .get_entry (entry )
500- # swap = entry.is_swap(self.langs)
501531 parser = Parser (path , ext = entry .in_ext or None , ent = entry )
502- # langs = '_'.join(str(lang) for lang in self.langs)
503- # Check that files are written in correct order
504532 l1 , l2 , meta_file = self .get_paths (dir_path , entry , compress = compress )
505533 io_args = dict (encoding = 'utf-8' , errors = 'ignore' )
506- has_meta = None
507534 with pbar_man .counter (unit = 'line' , desc = f"Processing { entry .did } " ) as pbar , \
508535 IO .writer (l1 , ** io_args ) as f1 , IO .writer (l2 , ** io_args ) as f2 , IO .writer (meta_file , ** io_args ) as f3 :
509536 count , skips , noise = 0 , 0 , 0
537+ has_meta = None
510538 for rec in parser .read_segs (show_pbar = False ):
511539 if has_meta is None :
512540 has_meta = bool (len (rec ) > 2 )
513541 if len (rec ) < 2 :
514542 skips += 1
515- pbar .update (incr = 1 , write_count = count )
543+ pbar .update ()
516544 continue
517545 if drop_noise and entry .is_noisy (seg1 = rec [0 ], seg2 = rec [1 ]):
518546 skips += 1
519547 noise += 1
520- pbar .update (incr = 1 , write_count = count )
548+ pbar .update ()
521549 continue
522- sent1 , sent2 = [ s .strip () for s in rec [: 2 ]]
550+ sent1 , sent2 = rec [ 0 ] .strip (), rec [1 ]. strip ()
523551 if not sent1 or not sent2 :
524552 skips += 1
525- pbar .update (incr = 1 , write_count = count )
553+ pbar .update ()
526554 continue
527- sent1 = sent1 .replace ('\n ' , ' ' ).replace ('\t ' , ' ' ).replace ('\r ' , ' ' )
528- sent2 = sent2 .replace ('\n ' , ' ' ).replace ('\t ' , ' ' ).replace ('\r ' , ' ' )
529- f1 .write (f'{ sent1 } \n ' )
530- f2 .write (f'{ sent2 } \n ' )
555+ f1 .write (sent1 .replace ('\t ' , ' ' ).replace ('\r ' , ' ' ) + '\n ' )
556+ f2 .write (sent2 .replace ('\t ' , ' ' ).replace ('\r ' , ' ' ) + '\n ' )
531557 if has_meta :
532- assert len (rec ) >= 3 , f'Expected 3 fields, found { len (rec )} : { rec } '
533- meta = json .dumps (rec [2 ], ensure_ascii = False , indent = None )
534- f3 .write (f'{ meta } \n ' )
558+ f3 .write (json .dumps (rec [2 ], ensure_ascii = False , indent = None ) + '\n ' )
535559 count += 1
536- pbar .update (incr = 1 , write_count = count )
560+ pbar .update (write_count = count )
537561 msg = f'Looks like an error. { count } segs are valid { skips } are invalid: { entry } '
538562 assert count > 0 , msg
539563 if skips > count :
0 commit comments