Skip to content

Commit a7ec069

Browse files
authored
Merge pull request #353 from donglihe-hub/data_utils
Enhance data_utils
2 parents 290734f + f00da0a commit a7ec069

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

libmultilabel/nn/data_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
159159
This is effective only when is_test=False. Defaults to False.
160160
161161
Returns:
162-
pandas.DataFrame: Data composed of index, label, and tokenized text.
162+
dict: [{(optional: "index": ..., )"label": ..., "text": ...}, ...]
163163
"""
164164
assert isinstance(data, str) or isinstance(data, pd.DataFrame), "Data must be from a file or pandas dataframe."
165165
if isinstance(data, str):
@@ -222,15 +222,12 @@ def load_datasets(
222222
Returns:
223223
dict: A dictionary of datasets.
224224
"""
225-
if isinstance(training_data, str) or isinstance(test_data, str):
226-
assert training_data or test_data, "At least one of `training_data` and `test_data` must be specified."
227-
elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame):
228-
assert (
229-
not training_data.empty or not test_data.empty
230-
), "At least one of `training_data` and `test_data` must be specified."
225+
if training_data is None and test_data is None:
226+
raise ValueError("At least one of `training_data` and `test_data` must be specified.")
231227

232228
datasets = {}
233229
if training_data is not None:
230+
logging.info(f"Loading training data")
234231
datasets["train"] = _load_raw_data(
235232
training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
236233
)
@@ -243,11 +240,12 @@ def load_datasets(
243240
datasets["train"], datasets["val"] = train_test_split(datasets["train"], test_size=val_size, random_state=42)
244241

245242
if test_data is not None:
243+
logging.info(f"Loading test data")
246244
datasets["test"] = _load_raw_data(
247245
test_data, is_test=True, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data
248246
)
249247

250-
if merge_train_val:
248+
if merge_train_val and "val" in datasets:
251249
datasets["train"] = datasets["train"] + datasets["val"]
252250
for i in range(len(datasets["train"])):
253251
datasets["train"][i]["index"] = i

0 commit comments

Comments
 (0)