-
Notifications
You must be signed in to change notification settings - Fork 28
Expand file tree
/
Copy pathlinear_trainer.py
More file actions
114 lines (102 loc) · 4.67 KB
/
linear_trainer.py
File metadata and controls
114 lines (102 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import logging
from math import ceil
import numpy as np
from tqdm import tqdm
import libmultilabel.linear as linear
from libmultilabel.common_utils import dump_log, is_multiclass_dataset
from libmultilabel.linear.utils import LINEAR_TECHNIQUES
def linear_test(config, model, datasets, label_mapping):
metrics = linear.get_metrics(
config.monitor_metrics,
datasets["test"]["y"].shape[1],
multiclass=not model.is_multilabel,
)
num_instance = datasets["test"]["x"].shape[0]
k = config.save_k_predictions
if k > 0:
labels = np.zeros((num_instance, k), dtype=object)
scores = np.zeros((num_instance, k), dtype="d")
else:
labels = []
scores = []
for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
slice = np.s_[i * config.eval_batch_size : (i + 1) * config.eval_batch_size]
preds = linear.predict_values(model, datasets["test"]["x"][slice])
target = datasets["test"]["y"][slice].toarray()
metrics.update(preds, target)
if k > 0:
labels[slice], scores[slice] = linear.get_topk_labels(preds, label_mapping, config.save_k_predictions)
elif config.save_positive_predictions:
res = linear.get_positive_labels(preds, label_mapping)
labels.append(res[0])
scores.append(res[1])
metric_dict = metrics.compute()
return metric_dict, labels, scores
def linear_train(datasets, config):
# detect task type
is_multilabel = config.get("is_multilabel", "auto")
if is_multilabel == "auto":
is_multilabel = not is_multiclass_dataset(datasets["train"], "y")
elif not isinstance(is_multilabel, bool):
raise ValueError(
f'"is_multilabel" is expected to be either "auto", "True", or "False". But got "{is_multilabel}".'
)
# train
if config.linear_technique == "tree":
if not is_multilabel:
raise ValueError("Tree model should only be used with multilabel datasets.")
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
config.liblinear_options,
config.tree_degree,
config.tree_max_depth,
)
else:
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
is_multilabel=is_multilabel,
options=config.liblinear_options,
)
return model
def linear_run(config):
if config.seed is not None:
np.random.seed(config.seed)
if config.eval:
preprocessor, model = linear.load_pipeline(config.checkpoint_path)
datasets = linear.load_dataset(config.data_format, config.training_file, config.test_file)
datasets = preprocessor.transform(datasets)
else:
preprocessor = linear.Preprocessor(config.include_test_labels, config.remove_no_label_data)
datasets = linear.load_dataset(
config.data_format,
config.training_file,
config.test_file,
config.label_file,
)
datasets = preprocessor.fit_transform(datasets)
model = linear_train(datasets, config)
linear.save_pipeline(config.checkpoint_dir, preprocessor, model)
if config.test_file is not None:
assert not (
config.save_positive_predictions and config.save_k_predictions > 0
), """
If save_k_predictions is larger than 0, only top k labels are saved.
Save all labels with decision value larger than 0 by using save_positive_predictions and save_k_predictions=0."""
metric_dict, labels, scores = linear_test(config, model, datasets, preprocessor.label_mapping)
dump_log(config=config, metrics=metric_dict, split="test", log_path=config.log_path)
print(linear.tabulate_metrics(metric_dict, "test"))
if config.save_k_predictions > 0:
with open(config.predict_out_path, "w") as fp:
for label, score in zip(labels, scores):
out_str = " ".join([f"{i}:{s:.4}" for i, s in zip(label, score)])
fp.write(out_str + "\n")
logging.info(f"Saved predictions to: {config.predict_out_path}")
elif config.save_positive_predictions:
with open(config.predict_out_path, "w") as fp:
for batch_labels, batch_scores in zip(labels, scores):
for label, score in zip(batch_labels, batch_scores):
out_str = " ".join([f"{i}:{s:.4}" for i, s in zip(label, score)])
fp.write(out_str + "\n")
logging.info(f"Saved predictions to: {config.predict_out_path}")