diff --git a/libmultilabel/nn/data_utils.py b/libmultilabel/nn/data_utils.py index beed3100e..e710bb60c 100644 --- a/libmultilabel/nn/data_utils.py +++ b/libmultilabel/nn/data_utils.py @@ -1,8 +1,10 @@ import csv import gc import logging +import multiprocessing as mp +import os import warnings -from concurrent.futures import ProcessPoolExecutor +from math import sqrt, floor import pandas as pd import torch @@ -176,9 +178,17 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data data["label"] = data["label"].astype(str).map(lambda s: s.split()) if tokenize_text: - # multiprocessing requires serializable objects - with ProcessPoolExecutor() as executor: - data["text"] = pd.Series(tqdm(executor.map(tokenize, data["text"]), total=len(data["text"]))) + # fork is the fastest start method + start_method = "fork" + if start_method in mp.get_all_start_methods(): + cpu_count = len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count() + processes = floor(sqrt(cpu_count)) + with mp.get_context(start_method).Pool(processes=processes) as p: + # imap has worse performance compared to map + # tqdm should not be used as map blocks the main process + data["text"] = pd.Series(p.map(tokenize, data["text"])) + else: + data["text"] = data["text"].map(tokenize) data = data.to_dict("records") if not is_test: num_no_label_data = sum(1 for d in data if len(d["label"]) == 0)