2020 "p" : 0 ,
2121}
2222HF_EXT = "hfds" # huggingface dataset
23+ HF_URL = "https://huggingface.co/datasets/"
24+
25+
26+ def _parse_hf_id (hf_id ):
27+ """Split hf_id into (group, name). Group is title-cased, name is lowercased with hyphens replaced."""
28+ group , name = hf_id .split ("/" )
29+ return group .title (), name .lower ().replace ("-" , "_" )
30+
2331
2432# To refresh the data_file from huggingface:
2533# python -m mtdata.index.huggingface --refresh
@@ -31,35 +39,73 @@ def load_all(index: Index):
3139 with meta_file .open () as lines :
3240 for line in lines :
3341 line = line .strip ()
42+ # skip empty lines and comments starting with # or //
3443 if not line or line .startswith ("#" ) or line .startswith ("//" ):
3544 continue
3645 data = json .loads (line )
37- if data ['id' ] != "google/wmt24pp" :
38- continue # TODO: support other datasets
39-
40- id_parts = data ['id' ].split ('/' )
41- assert len (id_parts ) == 2 , f"Invalid dataset id: { data ['id' ]} "
42- group , name = id_parts
43- group = group .title ()
44- for config in data ['configs' ]:
45- split_name = name
46- if config ["split" ] != "train" :
47- # assume train is the default
48- split_name += f"_{ config ['split' ]} "
49- orig_langs = config ["name" ]
50- langs = tuple (orig_langs .split ("-" ))
51- assert len (langs ) in (1 , 2 ), f"Invalid langs: { langs } "
52- data_id = DatasetId (group = group , name = split_name , version = "1" , langs = langs )
53- if data_id in index :
54- log .warning (f"Duplicate dataset id: { data_id } " )
55- continue
56- url = "https://huggingface.co/datasets/" + data ['id' ]
57- fields = dict (source = "source" , target = "target" , doc_id = "document_id" , domain = "domain" , seg_id = "segment_id" )
58- meta = dict (config = config ['name' ], orig_id = data ['id' ], split = config ["split" ], fields = fields )
59- in_paths = config ["paths" ]
60- cite = None
61- entry = Entry (did = data_id , url = url , in_paths = in_paths , cite = cite , ext = HF_EXT , in_ext = HF_EXT , meta = meta )
62- index .add_entry (entry )
46+ _load_hf_dataset (index , data )
47+
48+
49+ def _load_hf_dataset (index , data ):
50+ hf_id = data ['id' ]
51+ group , base_name = _parse_hf_id (hf_id )
52+ url = f"{ HF_URL } { hf_id } "
53+ default_fields = data .get ('fields' , dict (source = "source" , target = "target" ))
54+
55+ for config in data .get ('configs' , []):
56+ config_name = config ['name' ]
57+ ds_name = config .get ('ds_name' , base_name )
58+
59+ # splits: list of split names to register. Each gets suffixed to ds_name.
60+ # split: single split. If present, suffixed to ds_name.
61+ # If neither is given, no suffix, and split passed as None to HF API (loads default).
62+ splits = config .get ('splits' )
63+ if splits :
64+ split_suffix = {s : f"{ ds_name } _{ s } " for s in splits }
65+ elif 'split' in config :
66+ split_suffix = {config ['split' ]: f"{ ds_name } _{ config ['split' ]} " }
67+ else :
68+ split_suffix = {None : ds_name }
69+
70+ # Languages: explicit or derived from config name
71+ langs = config .get ('langs' )
72+ if langs is None :
73+ if isinstance (config_name , str ):
74+ langs = tuple (config_name .split ("-" ))
75+ elif isinstance (config_name , list ):
76+ langs = tuple (config_name )
77+ else :
78+ log .warning (f"Skipping { hf_id } : cannot derive langs from config { config_name } " )
79+ continue
80+ else :
81+ langs = tuple (langs )
82+
83+ # Fields
84+ if isinstance (config_name , list ):
85+ # Cross-config: source/target are the config names
86+ fields = dict (source = config_name [0 ], target = config_name [1 ])
87+ else :
88+ fields = config .get ('fields' , default_fields )
89+
90+ in_paths = config .get ('paths' , ['default' ])
91+
92+ for split , name in split_suffix .items ():
93+
94+ if isinstance (config_name , list ):
95+ meta = dict (orig_id = hf_id , config = config_name , split = split , fields = fields ,
96+ text_field = config .get ('text_field' , 'text' ),
97+ join_field = config .get ('join_field' , 'id' ))
98+ else :
99+ meta = dict (orig_id = hf_id , config = config_name , split = split , fields = fields )
100+
101+ data_id = DatasetId (group = group , name = name , version = '1' , langs = langs )
102+ if data_id in index :
103+ log .warning (f"Duplicate dataset id: { data_id } " )
104+ continue
105+
106+ entry = Entry (did = data_id , url = url , in_paths = in_paths ,
107+ ext = HF_EXT , in_ext = HF_EXT , meta = meta )
108+ index .add_entry (entry )
63109
64110
65111def query_datasets (page_num = 0 ):
0 commit comments