This repository was archived by the owner on Jan 1, 2025. It is now read-only.
forked from ASUS-AICS/LibMultiLabel-Old-Archive
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathplot_KimCNN_quickstart.py
More file actions
120 lines (100 loc) · 4.32 KB
/
plot_KimCNN_quickstart.py
File metadata and controls
120 lines (100 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
KimCNN Model for Multi-label Classification
===========================================
This step-by-step example shows how to train and test a KimCNN model via LibMultiLabel.
Import the libraries
----------------------------
Please add the following code to your python3 script.
"""
from libmultilabel.nn.data_utils import *
from libmultilabel.nn.nn_utils import *
######################################################################
# Setup device
# --------------------
# If you need to reproduce the results, please use the function ``set_seed``.
# For example, you will get the same result as you always use the seed ``1337``.
#
# For initial a hardware device, please use ``init_device`` to assign the hardware device that you want to use.
set_seed(1337)
device = init_device() # use gpu by default
######################################################################
# Load and tokenize data
# ------------------------------------------
#
# To run KimCNN, LibMultiLabel tokenizes documents and uses an embedding vector for each word.
# Thus, ``tokenize_text=True`` is set.
#
# We choose ``glove.6B.300d`` from torchtext as embedding vectors.
datasets = load_datasets("data/rcv1/train.txt", "data/rcv1/test.txt", tokenize_text=True)
classes = load_or_build_label(datasets)
word_dict, embed_vecs = load_or_build_text_dict(dataset=datasets["train"], embed_file="glove.6B.300d")
######################################################################
# Initialize a model
# --------------------------
#
# We consider the following settings for the KimCNN model.
model_name = "KimCNN"
network_config = {"embed_dropout": 0.2, "post_encoder_dropout": 0.2, "filter_sizes": [2, 4, 8], "num_filter_per_size": 128}
learning_rate = 0.0003
model = init_model(
model_name=model_name,
network_config=network_config,
classes=classes,
word_dict=word_dict,
embed_vecs=embed_vecs,
learning_rate=learning_rate,
monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"],
)
######################################################################
# * ``model_name`` leads ``init_model`` function to find a network model.
# * ``network_config`` contains the configurations of a network model.
# * ``classes`` is the label set of the data.
# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them.
# * ``moniter_metrics`` includes metrics you would like to track.
#
#
# Initialize a trainer
# ----------------------------
#
# We use the function ``init_trainer`` to initialize a trainer.
trainer = init_trainer(checkpoint_dir="runs/NN-example", epochs=15, val_metric="P@5")
######################################################################
# In this example, ``checkpoint_dir`` is the place we save the best and the last models during the training. Furthermore, we set the number of training loops by ``epochs=15``, and the validation metric by ``val_metric='P@5'``.
#
# Create data loaders
# ---------------------------
#
# In most cases, we do not load a full set due to the hardware limitation.
# Therefore, a data loader can load a batch of samples each time.
loaders = dict()
for split in ["train", "val", "test"]:
loaders[split] = get_dataset_loader(
data=datasets[split],
classes=classes,
device=device,
max_seq_length=512,
batch_size=8,
shuffle=True if split == "train" else False,
word_dict=word_dict,
)
######################################################################
# This example loads three loaders, and the batch size is set by ``batch_size=8``. Other variables can be checked in `here <../api/nn.html#libmultilabel.nn.data_utils.get_dataset_loader>`_.
#
# Train and test a model
# ------------------------------
#
# The bert model training process can be started via
trainer.fit(model, loaders["train"], loaders["val"])
######################################################################
# After the training process is finished, we can then run the test process by
trainer.test(model, dataloaders=loaders["test"])
######################################################################
# The test results should be similar to::
#
# {
# 'Macro-F1': 0.48948464335831743,
# 'Micro-F1': 0.7769773602485657,
# 'P@1': 0.9471677541732788,
# 'P@3': 0.7772253751754761,
# 'P@5': 0.5449321269989014,
# }