Skip to content

Commit b604f77

Browse files
committed
improve processing speed. use xz and bzip subproc when available
1 parent 63d9a64 commit b604f77

4 files changed

Lines changed: 197 additions & 98 deletions

File tree

mtdata/data.py

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,32 @@ def prepare(cls, langs, out_dir: Path, dataset_ids=Dict[str, List[DatasetId]],
146146
if dataset_ids.get('dev'):
147147
dev_entries = cls.resolve_entries(dataset_ids['dev'])
148148
dataset.add_dev_entries(dev_entries)
149-
if dataset_ids.get('train'): # this might take some time
149+
# Phase 2: Process train parts and mono entries concurrently
150+
# Bitext train parts and mono entries are independent — they write to different dirs
151+
all_tasks = [] # list of (func, task_dict) pairs
152+
train_entries = []
153+
if dataset_ids.get('train'):
150154
train_entries = cls.resolve_entries(dataset_ids['train'])
155+
for ent in train_entries:
156+
all_tasks.append((dataset.add_part, dict(
157+
dir_path=dataset.train_parts_dir, entry=ent,
158+
drop_noise=dataset.drop_train_noise, compress=compress)))
159+
mono_groups = []
160+
for key, dirpath in [('mono_train', dataset.mono_train_parts_dir),
161+
('mono_dev', dataset.mono_tests_dir),
162+
('mono_test', dataset.mono_tests_dir)]:
163+
if dataset_ids.get(key):
164+
dirpath.mkdir(exist_ok=True)
165+
entries = cls.resolve_entries(dataset_ids[key])
166+
mono_groups.append((key, entries))
167+
for ent in entries:
168+
all_tasks.append((dataset.add_mono_entry, dict(
169+
dirpath=dirpath, entry=ent, compress=compress)))
170+
if all_tasks:
171+
dataset._run_entries_multi(all_tasks, desc='Processing entries',
172+
fail_on_error=fail_on_error)
173+
# Phase 3: Merge train if requested (must happen after all train parts are written)
174+
if train_entries and merge_train:
151175
drop_hashes = None
152176
if drop_tests:
153177
pair_files = []
@@ -157,22 +181,8 @@ def prepare(cls, langs, out_dir: Path, dataset_ids=Dict[str, List[DatasetId]],
157181
p1, p2 = p2, p1 # swap
158182
pair_files.append((p1, p2))
159183
test_pair_hash, test_seg_hash = dataset.hash_all_bitexts(pair_files)
160-
drop_hashes = test_pair_hash | test_seg_hash # set union
161-
dataset.add_train_entries(train_entries, merge_train=merge_train, compress=compress,
162-
drop_hashes=drop_hashes)
163-
for key, dirpath in [('mono_train', dataset.mono_train_parts_dir),
164-
('mono_dev', dataset.mono_tests_dir),
165-
('mono_test', dataset.mono_tests_dir)]:
166-
if dataset_ids.get(key):
167-
dirpath.mkdir(exist_ok=True)
168-
entries = cls.resolve_entries(dataset_ids[key])
169-
dataset.add_mono_entries(
170-
dirpath,
171-
entries,
172-
compress=compress,
173-
desc=key.replace('_', ' '),
174-
fail_on_error=fail_on_error,
175-
)
184+
drop_hashes = test_pair_hash | test_seg_hash
185+
dataset._merge_train(train_entries, compress=compress, drop_hashes=drop_hashes)
176186

177187
# citations
178188
refs_file = out_dir / 'references.bib'
@@ -204,11 +214,8 @@ def hash_all_bitexts(self, paired_files):
204214
seg_hashes.add(hash(seg2))
205215
return paired_hashes, seg_hashes
206216

207-
def add_train_entries(self, entries, merge_train=False, compress=False, drop_hashes=None):
208-
self.add_parts(self.train_parts_dir, entries, drop_noise=self.drop_train_noise,
209-
compress=compress, desc='Training sets', fail_on_error=self.fail_on_error)
210-
if not merge_train:
211-
return
217+
def _merge_train(self, entries, compress=False, drop_hashes=None):
218+
"""Merge already-written train parts into single train files."""
212219
lang1, lang2 = self.langs
213220
# paired_files = self.find_bitext_pairs(self.train_parts_dir, lang1, lang2)
214221
paired_files = {}
@@ -274,41 +281,34 @@ def add_mono_entry(self, dirpath, entry: Entry, compress=False):
274281
cache_path = self.cache.get_entry(entry)
275282
parser = Parser(cache_path, ext=entry.in_ext or None, ent=entry)
276283
out_path, meta_file = self.get_paths(dirpath, entry, compress=compress)
277-
log.info("Writing %s to %s", entry.did, out_path)
284+
pbar_man.emit_log(log.INFO, f"Writing {entry.did} to {out_path}")
278285
io_args = dict(encoding='utf-8', errors='ignore')
279-
has_meta = None # None -> True/False on first row ;then ensure it is consistent
280286
with pbar_man.counter(unit='line', desc=f"Processing {entry.did}") as pbar, \
281287
IO.writer(out_path, **io_args) as out, IO.writer(meta_file, **io_args) as out_meta:
282-
count, skips, read_count = 0, 0, 0
288+
count, skips = 0, 0
289+
has_meta = None
283290
for row in parser.read_segs(show_pbar=False):
284-
read_count += 1
285-
if has_meta is None: # first row only
291+
if has_meta is None:
286292
has_meta = bool(isinstance(row, (list, tuple)) and len(row) > 1)
287-
sentence = row
288-
if isinstance(row, (list, tuple)):
289-
sentence = row[0] # flatten list
290-
assert isinstance(sentence, str), f'str sentence expected. found: {type(sentence)}; entry: {entry.did}'
291-
sentence = sentence and sentence.strip()
293+
sentence = row[0] if isinstance(row, (list, tuple)) else row
294+
sentence = sentence.strip().replace('\t', ' ').replace('\r', ' ') if sentence else ''
292295
if not sentence:
293296
skips += 1
294-
pbar.update(incr=1, write_count=count)
297+
pbar.update()
295298
continue
296-
sentence = sentence.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ')
297-
out.write(f'{sentence}\n')
299+
out.write(sentence + '\n')
298300
if has_meta:
299-
assert len(row) >= 2, f'Expected 2 fields, found {len(row)}: {row}'
300-
meta = json.dumps(row[1], ensure_ascii=False, indent=None)
301-
meta = meta.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ')
302-
out_meta.write(f'{meta}\n')
301+
meta = json.dumps(row[1], ensure_ascii=False, indent=None).replace('\t', ' ').replace('\r', ' ')
302+
out_meta.write(meta + '\n')
303303
count += 1
304-
pbar.update(incr=1, write_count=count)
304+
pbar.update(write_count=count)
305305
msg = f'Looks like an error. {count} segs are valid {skips} are invalid: {entry}'
306306
assert count > 0, msg
307307
if skips > count:
308308
pbar_man.emit_log(log.WARNING, msg)
309309
pbar_man.emit_log(log.INFO, f"{entry}: Skips : {skips:,}/{count:,} => {100 * skips / count:.4f}%")
310310
if not has_meta and meta_file.exists():
311-
meta_file.unlink() # remove empty meta file; file may not be empty if its .gz compressed
311+
meta_file.unlink()
312312
flag_file.touch()
313313
return count, skips
314314

@@ -471,6 +471,37 @@ def _run_entries(self, func, tasks, entries, desc=None, fail_on_error=False):
471471
finally:
472472
pbar.update(force=True)
473473

474+
def _run_entries_multi(self, func_task_pairs, desc=None, fail_on_error=False):
475+
"""Run heterogeneous (func, task_dict) pairs through a single worker pool."""
476+
for _, t in func_task_pairs:
477+
t['fail_on_error'] = fail_on_error
478+
total = len(func_task_pairs)
479+
if self.n_jobs == 1:
480+
with pbar_man.counter(total=total, unit='it', desc=desc) as pbar:
481+
for func, task in func_task_pairs:
482+
self._entry_worker(func, task)
483+
pbar.update(force=True)
484+
return
485+
import multiprocessing as mp
486+
progress_queue = mp.Queue()
487+
with pbar_man.counter(total=total, unit='it', desc=desc) as pbar:
488+
with pbar_man.consume_remote(progress_queue):
489+
with concurrent.futures.ProcessPoolExecutor(
490+
max_workers=self.n_jobs,
491+
initializer=_worker_init,
492+
initargs=(progress_queue, log.WARNING)) as executor:
493+
futures = [executor.submit(self._entry_worker, func, task)
494+
for func, task in func_task_pairs]
495+
for future in concurrent.futures.as_completed(futures):
496+
try:
497+
future.result()
498+
except Exception as e:
499+
pbar_man.emit_log(log.ERROR, f"Error in worker: {e}")
500+
if fail_on_error:
501+
raise e
502+
finally:
503+
pbar.update(force=True)
504+
474505
@classmethod
475506
def get_paths(cls, dir_path: Path, entry: Entry, compress=False) -> Union[Tuple[Path, Path], Tuple[Path, Path, Path]]:
476507
"""
@@ -497,43 +528,36 @@ def add_part(self, dir_path: Path, entry: Entry, drop_noise=False, compress=Fals
497528
pbar_man.emit_log(log.INFO, f"{flag_file} exists. Skipping")
498529
return -1, -1
499530
path = self.cache.get_entry(entry)
500-
# swap = entry.is_swap(self.langs)
501531
parser = Parser(path, ext=entry.in_ext or None, ent=entry)
502-
# langs = '_'.join(str(lang) for lang in self.langs)
503-
# Check that files are written in correct order
504532
l1, l2, meta_file = self.get_paths(dir_path, entry, compress=compress)
505533
io_args = dict(encoding='utf-8', errors='ignore')
506-
has_meta = None
507534
with pbar_man.counter(unit='line', desc=f"Processing {entry.did}") as pbar, \
508535
IO.writer(l1, **io_args) as f1, IO.writer(l2, **io_args) as f2, IO.writer(meta_file, **io_args) as f3:
509536
count, skips, noise = 0, 0, 0
537+
has_meta = None
510538
for rec in parser.read_segs(show_pbar=False):
511539
if has_meta is None:
512540
has_meta = bool(len(rec) > 2)
513541
if len(rec) < 2:
514542
skips += 1
515-
pbar.update(incr=1, write_count=count)
543+
pbar.update()
516544
continue
517545
if drop_noise and entry.is_noisy(seg1=rec[0], seg2=rec[1]):
518546
skips += 1
519547
noise += 1
520-
pbar.update(incr=1, write_count=count)
548+
pbar.update()
521549
continue
522-
sent1, sent2 = [s.strip() for s in rec[:2]]
550+
sent1, sent2 = rec[0].strip(), rec[1].strip()
523551
if not sent1 or not sent2:
524552
skips += 1
525-
pbar.update(incr=1, write_count=count)
553+
pbar.update()
526554
continue
527-
sent1 = sent1.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ')
528-
sent2 = sent2.replace('\n', ' ').replace('\t', ' ').replace('\r', ' ')
529-
f1.write(f'{sent1}\n')
530-
f2.write(f'{sent2}\n')
555+
f1.write(sent1.replace('\t', ' ').replace('\r', ' ') + '\n')
556+
f2.write(sent2.replace('\t', ' ').replace('\r', ' ') + '\n')
531557
if has_meta:
532-
assert len(rec) >= 3, f'Expected 3 fields, found {len(rec)}: {rec}'
533-
meta = json.dumps(rec[2], ensure_ascii=False, indent=None)
534-
f3.write(f'{meta}\n')
558+
f3.write(json.dumps(rec[2], ensure_ascii=False, indent=None) + '\n')
535559
count += 1
536-
pbar.update(incr=1, write_count=count)
560+
pbar.update(write_count=count)
537561
msg = f'Looks like an error. {count} segs are valid {skips} are invalid: {entry}'
538562
assert count > 0, msg
539563
if skips > count:

mtdata/pbar.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ def counter(self, desc='', total=None, unit='it'):
123123
return
124124
self._start()
125125
task_id = self._add_task(desc, total=total, unit=unit)
126+
pbar = _RichPbar(self, task_id)
126127
try:
127-
yield _RichPbar(self, task_id)
128+
yield pbar
128129
finally:
130+
pbar.flush()
129131
self._stop(task_id)
130132

131133
@contextmanager
@@ -169,19 +171,35 @@ def _consume():
169171

170172

171173
class _RichPbar:
174+
FLUSH_INTERVAL = 1.0 # seconds; configurable via Defaults if needed
175+
172176
def __init__(self, manager, task_id):
173177
self._manager = manager
174178
self._task_id = task_id
179+
self._pending = 0
180+
self._fields = {}
181+
self._last_flush = _time.monotonic()
182+
183+
def update(self, incr=1, force=False, **kwargs):
184+
self._pending += incr
185+
if kwargs:
186+
self._fields.update(kwargs)
187+
if force or _time.monotonic() - self._last_flush >= self.FLUSH_INTERVAL:
188+
self.flush()
175189

176-
def update(self, incr=1, **kwargs):
177-
self._manager._advance(self._task_id, incr=incr, **kwargs)
190+
def flush(self):
191+
if self._pending > 0:
192+
self._manager._advance(self._task_id, incr=self._pending, **self._fields)
193+
self._pending = 0
194+
self._fields.clear()
195+
self._last_flush = _time.monotonic()
178196

179197

180198
class _RemotePbar:
181199
"""Batches progress updates and sends them via queue at most once per FLUSH_INTERVAL."""
182200
_counter = 0
183201
_lock = threading.Lock()
184-
FLUSH_INTERVAL = 0.5 # seconds
202+
FLUSH_INTERVAL = 1.0 # seconds; matches _RichPbar
185203

186204
def __init__(self, queue):
187205
with _RemotePbar._lock:
@@ -192,11 +210,11 @@ def __init__(self, queue):
192210
self._fields = {}
193211
self._last_flush = _time.monotonic()
194212

195-
def update(self, incr=1, **kwargs):
213+
def update(self, incr=1, force=False, **kwargs):
196214
self._pending += incr
197215
if kwargs:
198216
self._fields.update(kwargs)
199-
if _time.monotonic() - self._last_flush >= self.FLUSH_INTERVAL:
217+
if force or _time.monotonic() - self._last_flush >= self.FLUSH_INTERVAL:
200218
self.flush()
201219

202220
def flush(self):
@@ -208,7 +226,7 @@ def flush(self):
208226

209227

210228
class _NoopPbar:
211-
def update(self, incr=1, **kwargs):
229+
def update(self, incr=1, force=False, **kwargs):
212230
pass
213231

214232

0 commit comments

Comments
 (0)