Skip to content

Commit 35e90c6

Browse files
authored
Merge pull request #19 from voidful/codex/refactor-code-with-modem-framework
Add optional HuggingFace dataset support
2 parents 99e886b + 8d5d801 commit 35e90c6

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

tfkit/utility/dataset.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from torch.utils import data
88
from tqdm.contrib.concurrent import process_map
99

10+
try:
11+
from datasets import load_dataset
12+
except Exception: # pragma: no cover - optional dependency
13+
load_dataset = None
14+
1015

1116
def get_dataset(file_path, task_class, tokenizer, parameter):
1217
panel = nlp2.Panel()
@@ -18,6 +23,14 @@ def get_dataset(file_path, task_class, tokenizer, parameter):
1823
# panel.add_element(k=missarg, v=all_arg[missarg], msg=missarg, default=all_arg[missarg])
1924
# filled_arg = panel.get_result_dict()
2025
# parameter.update(filled_arg)
26+
if load_dataset is not None and not os.path.isfile(file_path):
27+
try:
28+
hf_ds = load_dataset(file_path, split=parameter.get('split', 'train'))
29+
return HFDataset(hf_ds, tokenizer=tokenizer,
30+
preprocessor=task_class.Preprocessor,
31+
preprocessing_arg=parameter)
32+
except Exception:
33+
pass
2134
ds = TFKitDataset(fpath=file_path, tokenizer=tokenizer,
2235
preprocessor=task_class.Preprocessor,
2336
preprocessing_arg=parameter)
@@ -76,3 +89,33 @@ def __getitem__(self, idx):
7689
{**{'task_dict': self.task_dict}, **{key: self.sample[key][idx] for key in self.sample.keys()}},
7790
self.tokenizer,
7891
maxlen=self.preprocessor.parameters['maxlen'])
92+
93+
94+
class HFDataset(data.Dataset):
95+
"""Dataset wrapper for the HuggingFace datasets library."""
96+
97+
def __init__(self, hf_dataset, tokenizer, preprocessor, preprocessing_arg=None):
98+
preprocessing_arg = preprocessing_arg or {}
99+
self.task_dict = {}
100+
self.sample = defaultdict(list)
101+
self.preprocessor = preprocessor(tokenizer, kwargs=preprocessing_arg)
102+
self.tokenizer = tokenizer
103+
104+
print("Start preprocessing with HuggingFace dataset...")
105+
length = 0
106+
for raw_item in hf_dataset:
107+
for items in self.preprocessor.preprocess(raw_item):
108+
length += 1
109+
for k, v in items.items():
110+
self.sample[k].append(v)
111+
self.length = length
112+
self.task = self.task_dict
113+
114+
def __len__(self):
115+
return self.length
116+
117+
def __getitem__(self, idx):
118+
return self.preprocessor.postprocess(
119+
{**{'task_dict': self.task_dict}, **{key: self.sample[key][idx] for key in self.sample.keys()}},
120+
self.tokenizer,
121+
maxlen=self.preprocessor.parameters['maxlen'])

0 commit comments

Comments
 (0)