Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mtdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from mtdata.pbar import get_log_handler # noqa: E402
log.basicConfig(level=log.INFO, datefmt='%Y%m%d %H:%M:%S',
handlers=[get_log_handler()])

_THIRD_PARTY_LOGGERS = ('httpx', 'datasets', 'huggingface_hub', 'fsspec', 'urllib3')

def set_third_party_log_level(level=log.WARNING):
for name in _THIRD_PARTY_LOGGERS:
log.getLogger(name).setLevel(level)

set_third_party_log_level(log.WARNING)
cache_dir = Path(os.environ.get('MTDATA', '~/.mtdata')).expanduser()
recipes_dir = Path(os.getenv('MTDATA_RECIPES', '.')).resolve()
cached_index_file = cache_dir / f'mtdata.index.{__version__}.pkl'
Expand Down
45 changes: 45 additions & 0 deletions mtdata/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def get_hf_dataset(self, url: str, entry=None):
config = entry.meta.get("config", None)
split = entry.meta.get("split", None)
cache_dir = self.root / 'huggingface' / 'datasets'

if isinstance(config, list):
# Cross-config alignment: load two configs, join by a shared field
return self._get_hf_cross_config(entry)

args = dict(
name=config,
split=split,
Expand All @@ -171,8 +176,48 @@ def get_hf_dataset(self, url: str, entry=None):
)
log.debug(f"Loading dataset {hf_id} with args: {args}")
ds = load_dataset(hf_id, **args)
if split is None and hasattr(ds, 'keys'):
# load_dataset returns DatasetDict when split=None
keys = list(ds.keys())
assert len(keys) == 1, (f"Multiple splits found in {hf_id}: {keys}."
f" Specify 'split' in the resource file.")
ds = ds[keys[0]]
return ds

def _get_hf_cross_config(self, entry):
"""Load two HF configs and align rows by a join field, yielding combined dicts."""
from datasets import load_dataset
hf_id = entry.meta["orig_id"]
configs = entry.meta["config"]
split = entry.meta.get("split", None)
assert len(configs) == 2, f"Expected 2 configs for cross-config, got {configs}"
join_field = entry.meta.get("join_field", "id")
src_config, tgt_config = configs

cache_dir = self.root / 'huggingface' / 'datasets'
common_args = dict(cache_dir=cache_dir, streaming=False, trust_remote_code=False)
log.debug(f"Loading cross-config: {hf_id} [{src_config}] + [{tgt_config}]")
ds1 = load_dataset(hf_id, name=src_config, split=split, **common_args)
ds2 = load_dataset(hf_id, name=tgt_config, split=split, **common_args)

# Build lookup from second config, keyed by join field
tgt_lookup = {}
text_field = entry.meta.get("text_field", "text")
for row in ds2:
key = row[join_field]
tgt_lookup[key] = row[text_field]

# Yield aligned rows as dicts with config names as keys
class CrossConfigDataset:
"""Iterable wrapper that yields aligned rows from two HF configs."""
def __iter__(self_inner):
for row in ds1:
key = row[join_field]
if key in tgt_lookup:
yield {src_config: row[text_field], tgt_config: tgt_lookup[key]}

return CrossConfigDataset()

@classmethod
def match_globs(cls, names, globs, meta=''):
result = []
Expand Down
98 changes: 72 additions & 26 deletions mtdata/index/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
"p": 0,
}
HF_EXT = "hfds" # huggingface dataset
HF_URL = "https://huggingface.co/datasets/"


def _parse_hf_id(hf_id):
"""Split hf_id into (group, name). Group is title-cased, name is lowercased with hyphens replaced."""
group, name = hf_id.split("/")
return group.title(), name.lower().replace("-", "_")


# To refresh the data_file from huggingface:
# python -m mtdata.index.huggingface --refresh
Expand All @@ -31,35 +39,73 @@ def load_all(index: Index):
with meta_file.open() as lines:
for line in lines:
line = line.strip()
# skip empty lines and comments starting with # or //
if not line or line.startswith("#") or line.startswith("//"):
continue
data = json.loads(line)
if data['id'] != "google/wmt24pp":
continue # TODO: support other datasets

id_parts = data['id'].split('/')
assert len(id_parts) == 2, f"Invalid dataset id: {data['id']}"
group, name = id_parts
group = group.title()
for config in data['configs']:
split_name = name
if config["split"] != "train":
# assume train is the default
split_name += f"_{config['split']}"
orig_langs = config["name"]
langs = tuple(orig_langs.split("-"))
assert len(langs) in (1, 2), f"Invalid langs: {langs}"
data_id = DatasetId(group=group, name=split_name, version="1", langs=langs)
if data_id in index:
log.warning(f"Duplicate dataset id: {data_id}")
continue
url = "https://huggingface.co/datasets/" + data['id']
fields = dict(source="source", target="target", doc_id="document_id", domain="domain", seg_id="segment_id")
meta = dict(config=config['name'], orig_id=data['id'], split=config["split"], fields=fields)
in_paths = config["paths"]
cite = None
entry = Entry(did=data_id, url=url, in_paths=in_paths, cite=cite, ext=HF_EXT, in_ext=HF_EXT, meta=meta)
index.add_entry(entry)
_load_hf_dataset(index, data)


def _load_hf_dataset(index, data):
hf_id = data['id']
group, base_name = _parse_hf_id(hf_id)
url = f"{HF_URL}{hf_id}"
default_fields = data.get('fields', dict(source="source", target="target"))

for config in data.get('configs', []):
config_name = config['name']
ds_name = config.get('ds_name', base_name)

# splits: list of split names to register. Each gets suffixed to ds_name.
# split: single split. If present, suffixed to ds_name.
# If neither is given, no suffix, and split passed as None to HF API (loads default).
splits = config.get('splits')
if splits:
split_suffix = {s: f"{ds_name}_{s}" for s in splits}
elif 'split' in config:
split_suffix = {config['split']: f"{ds_name}_{config['split']}"}
else:
split_suffix = {None: ds_name}

# Languages: explicit or derived from config name
langs = config.get('langs')
if langs is None:
if isinstance(config_name, str):
langs = tuple(config_name.split("-"))
elif isinstance(config_name, list):
langs = tuple(config_name)
else:
log.warning(f"Skipping {hf_id}: cannot derive langs from config {config_name}")
continue
else:
langs = tuple(langs)

# Fields
if isinstance(config_name, list):
# Cross-config: source/target are the config names
fields = dict(source=config_name[0], target=config_name[1])
else:
fields = config.get('fields', default_fields)

in_paths = config.get('paths', ['default'])

for split, name in split_suffix.items():

if isinstance(config_name, list):
meta = dict(orig_id=hf_id, config=config_name, split=split, fields=fields,
text_field=config.get('text_field', 'text'),
join_field=config.get('join_field', 'id'))
else:
meta = dict(orig_id=hf_id, config=config_name, split=split, fields=fields)

data_id = DatasetId(group=group, name=name, version='1', langs=langs)
if data_id in index:
log.warning(f"Duplicate dataset id: {data_id}")
continue

entry = Entry(did=data_id, url=url, in_paths=in_paths,
ext=HF_EXT, in_ext=HF_EXT, meta=meta)
index.add_entry(entry)


def query_datasets(page_num=0):
Expand Down
1 change: 1 addition & 0 deletions mtdata/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def add_getter_args(parser):

if args.verbose:
mtdata.debug_mode = True
mtdata.set_third_party_log_level(log.DEBUG)
return args


Expand Down
18 changes: 15 additions & 3 deletions mtdata/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def read_tsv(self, path, delim='\t', cols=None, skipheader=False, meta_fields=No
out_row.append(metadata)
yield out_row

@staticmethod
def _nested_get(row, field):
"""Get a value from a dict using dot-separated path for nested access.
e.g. _nested_get(row, "translation.ita") == row["translation"]["ita"]
"""
parts = field.split('.')
val = row
for part in parts:
val = val[part]
return val

def read_hfds(self, ds):
""" Read data from huggingface Dataset
:param ds: huggingface dataset
Expand All @@ -163,10 +174,11 @@ def read_hfds(self, ds):
# in the current version, I am going to retain all fields to see what all fields exist,
# and map the subset of fields as per the dict; so, created rev_map.get(orig,orig)
for row in ds:
out_row = [row.pop(src_field)]
out_row = [self._nested_get(row, src_field)]
if tgt_field is not None:
out_row.append(row.pop(tgt_field))
out_row.append(self._nested_get(row, tgt_field))
# remap meta fields if necessary
metadata = {rev_map.get(k, k): v for k, v in row.items() if k not in (src_field, tgt_field)}
top_keys = {f.split('.')[0] for f in [src_field] + ([tgt_field] if tgt_field else [])}
metadata = {rev_map.get(k, k): v for k, v in row.items() if k not in top_keys}
out_row.append(metadata)
yield out_row
Loading
Loading