77from torch .utils import data
88from tqdm .contrib .concurrent import process_map
99
10+ try :
11+ from datasets import load_dataset
12+ except Exception : # pragma: no cover - optional dependency
13+ load_dataset = None
14+
1015
1116def get_dataset (file_path , task_class , tokenizer , parameter ):
1217 panel = nlp2 .Panel ()
@@ -18,6 +23,14 @@ def get_dataset(file_path, task_class, tokenizer, parameter):
1823 # panel.add_element(k=missarg, v=all_arg[missarg], msg=missarg, default=all_arg[missarg])
1924 # filled_arg = panel.get_result_dict()
2025 # parameter.update(filled_arg)
26+ if load_dataset is not None and not os .path .isfile (file_path ):
27+ try :
28+ hf_ds = load_dataset (file_path , split = parameter .get ('split' , 'train' ))
29+ return HFDataset (hf_ds , tokenizer = tokenizer ,
30+ preprocessor = task_class .Preprocessor ,
31+ preprocessing_arg = parameter )
32+ except Exception :
33+ pass
2134 ds = TFKitDataset (fpath = file_path , tokenizer = tokenizer ,
2235 preprocessor = task_class .Preprocessor ,
2336 preprocessing_arg = parameter )
@@ -76,3 +89,33 @@ def __getitem__(self, idx):
7689 {** {'task_dict' : self .task_dict }, ** {key : self .sample [key ][idx ] for key in self .sample .keys ()}},
7790 self .tokenizer ,
7891 maxlen = self .preprocessor .parameters ['maxlen' ])
92+
93+
94+ class HFDataset (data .Dataset ):
95+ """Dataset wrapper for the HuggingFace datasets library."""
96+
97+ def __init__ (self , hf_dataset , tokenizer , preprocessor , preprocessing_arg = None ):
98+ preprocessing_arg = preprocessing_arg or {}
99+ self .task_dict = {}
100+ self .sample = defaultdict (list )
101+ self .preprocessor = preprocessor (tokenizer , kwargs = preprocessing_arg )
102+ self .tokenizer = tokenizer
103+
104+ print ("Start preprocessing with HuggingFace dataset..." )
105+ length = 0
106+ for raw_item in hf_dataset :
107+ for items in self .preprocessor .preprocess (raw_item ):
108+ length += 1
109+ for k , v in items .items ():
110+ self .sample [k ].append (v )
111+ self .length = length
112+ self .task = self .task_dict
113+
114+ def __len__ (self ):
115+ return self .length
116+
117+ def __getitem__ (self , idx ):
118+ return self .preprocessor .postprocess (
119+ {** {'task_dict' : self .task_dict }, ** {key : self .sample [key ][idx ] for key in self .sample .keys ()}},
120+ self .tokenizer ,
121+ maxlen = self .preprocessor .parameters ['maxlen' ])
0 commit comments