Skip to content

Commit 69b3c1a

Browse files
authored
Merge pull request #179 from thammegowda/tg/new-data
new datasets and wmt26 recipes
2 parents 146332c + 8ecb2c3 commit 69b3c1a

9 files changed

Lines changed: 61208 additions & 97 deletions

mtdata/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
from mtdata.pbar import get_log_handler # noqa: E402
2020
log.basicConfig(level=log.INFO, datefmt='%Y%m%d %H:%M:%S',
2121
handlers=[get_log_handler()])
22+
23+
_THIRD_PARTY_LOGGERS = ('httpx', 'datasets', 'huggingface_hub', 'fsspec', 'urllib3')
24+
25+
def set_third_party_log_level(level=log.WARNING):
26+
for name in _THIRD_PARTY_LOGGERS:
27+
log.getLogger(name).setLevel(level)
28+
29+
set_third_party_log_level(log.WARNING)
2230
cache_dir = Path(os.environ.get('MTDATA', '~/.mtdata')).expanduser()
2331
recipes_dir = Path(os.getenv('MTDATA_RECIPES', '.')).resolve()
2432
cached_index_file = cache_dir / f'mtdata.index.{__version__}.pkl'

mtdata/cache.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def get_hf_dataset(self, url: str, entry=None):
162162
config = entry.meta.get("config", None)
163163
split = entry.meta.get("split", None)
164164
cache_dir = self.root / 'huggingface' / 'datasets'
165+
166+
if isinstance(config, list):
167+
# Cross-config alignment: load two configs, join by a shared field
168+
return self._get_hf_cross_config(entry)
169+
165170
args = dict(
166171
name=config,
167172
split=split,
@@ -171,8 +176,48 @@ def get_hf_dataset(self, url: str, entry=None):
171176
)
172177
log.debug(f"Loading dataset {hf_id} with args: {args}")
173178
ds = load_dataset(hf_id, **args)
179+
if split is None and hasattr(ds, 'keys'):
180+
# load_dataset returns DatasetDict when split=None
181+
keys = list(ds.keys())
182+
assert len(keys) == 1, (f"Multiple splits found in {hf_id}: {keys}."
183+
f" Specify 'split' in the resource file.")
184+
ds = ds[keys[0]]
174185
return ds
175186

187+
def _get_hf_cross_config(self, entry):
188+
"""Load two HF configs and align rows by a join field, yielding combined dicts."""
189+
from datasets import load_dataset
190+
hf_id = entry.meta["orig_id"]
191+
configs = entry.meta["config"]
192+
split = entry.meta.get("split", None)
193+
assert len(configs) == 2, f"Expected 2 configs for cross-config, got {configs}"
194+
join_field = entry.meta.get("join_field", "id")
195+
src_config, tgt_config = configs
196+
197+
cache_dir = self.root / 'huggingface' / 'datasets'
198+
common_args = dict(cache_dir=cache_dir, streaming=False, trust_remote_code=False)
199+
log.debug(f"Loading cross-config: {hf_id} [{src_config}] + [{tgt_config}]")
200+
ds1 = load_dataset(hf_id, name=src_config, split=split, **common_args)
201+
ds2 = load_dataset(hf_id, name=tgt_config, split=split, **common_args)
202+
203+
# Build lookup from second config, keyed by join field
204+
tgt_lookup = {}
205+
text_field = entry.meta.get("text_field", "text")
206+
for row in ds2:
207+
key = row[join_field]
208+
tgt_lookup[key] = row[text_field]
209+
210+
# Yield aligned rows as dicts with config names as keys
211+
class CrossConfigDataset:
212+
"""Iterable wrapper that yields aligned rows from two HF configs."""
213+
def __iter__(self_inner):
214+
for row in ds1:
215+
key = row[join_field]
216+
if key in tgt_lookup:
217+
yield {src_config: row[text_field], tgt_config: tgt_lookup[key]}
218+
219+
return CrossConfigDataset()
220+
176221
@classmethod
177222
def match_globs(cls, names, globs, meta=''):
178223
result = []

mtdata/index/huggingface.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
"p": 0,
2121
}
2222
HF_EXT = "hfds" # huggingface dataset
23+
HF_URL = "https://huggingface.co/datasets/"
24+
25+
26+
def _parse_hf_id(hf_id):
27+
"""Split hf_id into (group, name). Group is title-cased, name is lowercased with hyphens replaced."""
28+
group, name = hf_id.split("/")
29+
return group.title(), name.lower().replace("-", "_")
30+
2331

2432
# To refresh the data_file from huggingface:
2533
# python -m mtdata.index.huggingface --refresh
@@ -31,35 +39,73 @@ def load_all(index: Index):
3139
with meta_file.open() as lines:
3240
for line in lines:
3341
line = line.strip()
42+
# skip empty lines and comments starting with # or //
3443
if not line or line.startswith("#") or line.startswith("//"):
3544
continue
3645
data = json.loads(line)
37-
if data['id'] != "google/wmt24pp":
38-
continue # TODO: support other datasets
39-
40-
id_parts = data['id'].split('/')
41-
assert len(id_parts) == 2, f"Invalid dataset id: {data['id']}"
42-
group, name = id_parts
43-
group = group.title()
44-
for config in data['configs']:
45-
split_name = name
46-
if config["split"] != "train":
47-
# assume train is the default
48-
split_name += f"_{config['split']}"
49-
orig_langs = config["name"]
50-
langs = tuple(orig_langs.split("-"))
51-
assert len(langs) in (1, 2), f"Invalid langs: {langs}"
52-
data_id = DatasetId(group=group, name=split_name, version="1", langs=langs)
53-
if data_id in index:
54-
log.warning(f"Duplicate dataset id: {data_id}")
55-
continue
56-
url = "https://huggingface.co/datasets/" + data['id']
57-
fields = dict(source="source", target="target", doc_id="document_id", domain="domain", seg_id="segment_id")
58-
meta = dict(config=config['name'], orig_id=data['id'], split=config["split"], fields=fields)
59-
in_paths = config["paths"]
60-
cite = None
61-
entry = Entry(did=data_id, url=url, in_paths=in_paths, cite=cite, ext=HF_EXT, in_ext=HF_EXT, meta=meta)
62-
index.add_entry(entry)
46+
_load_hf_dataset(index, data)
47+
48+
49+
def _load_hf_dataset(index, data):
50+
hf_id = data['id']
51+
group, base_name = _parse_hf_id(hf_id)
52+
url = f"{HF_URL}{hf_id}"
53+
default_fields = data.get('fields', dict(source="source", target="target"))
54+
55+
for config in data.get('configs', []):
56+
config_name = config['name']
57+
ds_name = config.get('ds_name', base_name)
58+
59+
# splits: list of split names to register. Each gets suffixed to ds_name.
60+
# split: single split. If present, suffixed to ds_name.
61+
# If neither is given, no suffix, and split passed as None to HF API (loads default).
62+
splits = config.get('splits')
63+
if splits:
64+
split_suffix = {s: f"{ds_name}_{s}" for s in splits}
65+
elif 'split' in config:
66+
split_suffix = {config['split']: f"{ds_name}_{config['split']}"}
67+
else:
68+
split_suffix = {None: ds_name}
69+
70+
# Languages: explicit or derived from config name
71+
langs = config.get('langs')
72+
if langs is None:
73+
if isinstance(config_name, str):
74+
langs = tuple(config_name.split("-"))
75+
elif isinstance(config_name, list):
76+
langs = tuple(config_name)
77+
else:
78+
log.warning(f"Skipping {hf_id}: cannot derive langs from config {config_name}")
79+
continue
80+
else:
81+
langs = tuple(langs)
82+
83+
# Fields
84+
if isinstance(config_name, list):
85+
# Cross-config: source/target are the config names
86+
fields = dict(source=config_name[0], target=config_name[1])
87+
else:
88+
fields = config.get('fields', default_fields)
89+
90+
in_paths = config.get('paths', ['default'])
91+
92+
for split, name in split_suffix.items():
93+
94+
if isinstance(config_name, list):
95+
meta = dict(orig_id=hf_id, config=config_name, split=split, fields=fields,
96+
text_field=config.get('text_field', 'text'),
97+
join_field=config.get('join_field', 'id'))
98+
else:
99+
meta = dict(orig_id=hf_id, config=config_name, split=split, fields=fields)
100+
101+
data_id = DatasetId(group=group, name=name, version='1', langs=langs)
102+
if data_id in index:
103+
log.warning(f"Duplicate dataset id: {data_id}")
104+
continue
105+
106+
entry = Entry(did=data_id, url=url, in_paths=in_paths,
107+
ext=HF_EXT, in_ext=HF_EXT, meta=meta)
108+
index.add_entry(entry)
63109

64110

65111
def query_datasets(page_num=0):

mtdata/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def add_getter_args(parser):
370370

371371
if args.verbose:
372372
mtdata.debug_mode = True
373+
mtdata.set_third_party_log_level(log.DEBUG)
373374
return args
374375

375376

mtdata/parser.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ def read_tsv(self, path, delim='\t', cols=None, skipheader=False, meta_fields=No
150150
out_row.append(metadata)
151151
yield out_row
152152

153+
@staticmethod
154+
def _nested_get(row, field):
155+
"""Get a value from a dict using dot-separated path for nested access.
156+
e.g. _nested_get(row, "translation.ita") == row["translation"]["ita"]
157+
"""
158+
parts = field.split('.')
159+
val = row
160+
for part in parts:
161+
val = val[part]
162+
return val
163+
153164
def read_hfds(self, ds):
154165
""" Read data from huggingface Dataset
155166
:param ds: huggingface dataset
@@ -163,10 +174,11 @@ def read_hfds(self, ds):
163174
# in the current version, I am going to retain all fields to see what all fields exist,
164175
# and map the subset of fields as per the dict; so, created rev_map.get(orig,orig)
165176
for row in ds:
166-
out_row = [row.pop(src_field)]
177+
out_row = [self._nested_get(row, src_field)]
167178
if tgt_field is not None:
168-
out_row.append(row.pop(tgt_field))
179+
out_row.append(self._nested_get(row, tgt_field))
169180
# remap meta fields if necessary
170-
metadata = {rev_map.get(k, k): v for k, v in row.items() if k not in (src_field, tgt_field)}
181+
top_keys = {f.split('.')[0] for f in [src_field] + ([tgt_field] if tgt_field else [])}
182+
metadata = {rev_map.get(k, k): v for k, v in row.items() if k not in top_keys}
171183
out_row.append(metadata)
172184
yield out_row

0 commit comments

Comments
 (0)