-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrainParser.py
More file actions
243 lines (202 loc) · 9.14 KB
/
trainParser.py
File metadata and controls
243 lines (202 loc) · 9.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#-*- coding: UTF-8 -*-
import argparse
import os
#import nni
from tqdm import trange
import time
from Model.FModel import FModel
from Model.SModel import SModel
from Model.SNERModel import SNERModel
from utils import MyDataset, BucketDataLoader
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import Adam
from transformers import AutoTokenizer, XLNetModel
import logging
from torch.utils.tensorboard import SummaryWriter
from utils.util import batch_computeF1, get_useful_ones
writer = SummaryWriter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = logging.getLogger('NER')
path_prefix = r'/home/mgliu/work_NER/NewNER/checkpoint'
def get_paras():
parser = argparse.ArgumentParser(description="hyperparameters")
parser.add_argument('--lr', '-l', type=float, help="lr must", default=0.001)
parser.add_argument('--batch_size', '-b', type=int, help="batch_size must", default=4)
parser.add_argument('--epoch', '-e', type=int, help="epoch must", default=200)
parser.add_argument('--dropout', '-d', type=float, help="dropout must", default=0.3)
parser.add_argument('--d_in', '-i', type=int, help="in_size must", default=1024)
parser.add_argument('--d_hid', '-g', type=int, help="g_size must", default=150)
parser.add_argument('--n_layers', '-k', type=int, help="kernel must", default=2)
parser.add_argument('--redo', '-r', type=int, help="reload model", default=0)
args, _ = parser.parse_known_args()
return args
def load_dict(path, model, optimizer):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print("load: epoch ", str(epoch) + " loss " + str(loss))
return model, optimizer
def save_dict(model, path, optimizer, epoch, loss):
torch.save(
{
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, os.path.join(path_prefix, path) # parser版本可根据参数情况来设置ckp文件名
)
print("++++Checkpoint Saved at :", os.path.join(path_prefix, path))
def run(args):
epoch = args['epoch']
batch_size = args['batch_size']
dataset = MyDataset(path="./utils/train/", count=2515)
trainLoader = BucketDataLoader(dataset, batch_size, True, True)
devLoader = BucketDataLoader(dataset, batch_size, True, False)
model = FModel(d_in=args['d_in'], d_hid=args['d_hid'],
d_class=len(dataset.cateDict) + 1, n_layers=args['n_layers'], dropout=args['dropout']).to(device)
for name, param in model.named_parameters():
if "model" in name:
param.requires_grad = True
new_layer = ["model"]
optimizer_grouped_parameters = [
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in new_layer)], "lr": 2e-5},
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in new_layer)], }
]
optimizer = Adam(optimizer_grouped_parameters, lr=1e-4, weight_decay=5e-4)
lossFunc = nn.CrossEntropyLoss(reduction='sum')
#optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], betas=(0.9, 0.999), weight_decay=5e-4)
if args['redo'] == 1:
model, optimizer = load_dict(os.path.join(args["n_layers"] + args["d_hid"] + args["batch_size"]), model,
optimizer)
#model, optimizer = load_dict('./checkpoint/2l-150h-64b-200e-roberta.pt', model, optimizer)
#model = nn.DataParallel(model)
def timeSince(start_time):
sec = time.time() - start_time
min = sec // 60
sec = sec % 60
return "{} min {} sec".format(int(min), int(sec))
def evalTrainer():
epochLoss = 0.0
cycle = 0
Fscore = 0.0
precision = 0.0
recall = 0.0
model.eval()
for passage, mask, label in devLoader:
passage = passage.long()
passage = passage.to(device)
mask = mask.to(device)
label = label.to(device)
if (len(passage.shape) < 2):
passage = passage.unsqueeze(0)
mask = mask.unsqueeze(0)
out = model(passage, mask)
tmp_out, tmp_label = get_useful_ones(out, label, mask)
# loss = lossFunc(out.reshape(out.shape[0] * out.shape[1] * out.shape[2], -1),
# label.reshape(label.shape[0] * label.shape[1] * label.shape[2]))
loss = lossFunc(tmp_out, tmp_label)
# loss = -(c * torch.log(F.softmax(out, dim=-1))).sum()
epochLoss += loss.item()
cycle += 1
Fscore_tmp, precision_tmp, recall_tmp = batch_computeF1(label, out, mask)
Fscore+=Fscore_tmp
precision+=precision_tmp
recall+=recall_tmp
epochLoss = epochLoss / cycle
Fscore = Fscore / cycle
precision = precision / cycle
recall = recall / cycle
# nni.report_intermediate_result(epochLoss)
return epochLoss, Fscore, precision, recall
def trainTrainer(epoch):
evalLoss = 9999
evalFscore = 0.0
evalP = 0.0
evalR = 0.0
precision = 0.0
recall = 0.0
for i in trange(epoch):
# print()
start_time = time.time()
model.train()
epochLoss = 0.0
Fscore = 0.0
# evalLoss = 9999
cycle = 0
for passage, mask, label in trainLoader:
# print("-------------Training--------")
passage = passage.long()
passage = passage.to(device)
mask = mask.to(device)
label = label.to(device)
# print(passage.shape)
if (len(passage.shape) < 2):
passage = passage.unsqueeze(0)
mask = mask.unsqueeze(0)
# print(passage.shape,mask.shape)
# print("-------------embing--------")
out = model(passage, mask)
# print("-------------modeling--------")
optimizer.zero_grad()
tmp_out, tmp_label = get_useful_ones(out, label, mask)
# loss = lossFunc(out.reshape(out.shape[0] * out.shape[1] * out.shape[2], -1),
# label.reshape(label.shape[0] * label.shape[1] * label.shape[2]))
loss = lossFunc(tmp_out, tmp_label)
# loss = -(c * torch.log(F.softmax(out, dim=-1))).sum()
loss.backward()
optimizer.step()
# print("-------------lossing--------")
# print(loss.item())
epochLoss += loss.item()
cycle += 1
Fscore_tmp, precision_tmp, recall_tmp = batch_computeF1(label, out, mask)
Fscore += Fscore_tmp
precision += precision_tmp
recall += recall_tmp
epochLoss = epochLoss / cycle
precision = precision / cycle
recall = recall / cycle
evalTrainerLoss, evalTrainerF1, evalTrainerP, evalTrainerR = evalTrainer()
if evalTrainerF1 > evalFscore:
save_dict(model,
os.path.join(str(args["n_layers"]) +'l-'+ str(args["d_hid"]) +'h-'+ str(args["batch_size"])+'b-'
+"200e-roberta-finetune.pt"),
optimizer, i, evalLoss)
if evalTrainerLoss < evalLoss:
save_dict(model,
os.path.join(str(args["n_layers"]) +'l-'+ str(args["d_hid"]) +'h-'+ str(args["batch_size"])+'b-'
+"200e-minloss-roberta-finetune.pt"),
optimizer, i, evalLoss)
evalLoss = min(evalTrainerLoss, evalLoss)
evalFscore = max(evalFscore, evalTrainerF1)
evalP = max(evalP, evalTrainerP)
evalR = max(evalR, evalTrainerR)
Fscore = Fscore / cycle
writer.add_scalar('Loss/train', epochLoss, i)
writer.add_scalar('Loss/test', evalTrainerLoss, i)
writer.add_scalar('Accuracy/train', Fscore, i)
writer.add_scalar('Accuracy/test', evalTrainerF1, i)
writer.add_scalar("Precision/train", precision, i)
writer.add_scalar("Precision/test", evalTrainerP, i)
writer.add_scalar("Recall/train", recall, i)
writer.add_scalar("Recall/test", evalTrainerR, i)
print("====Epoch: {} epoch_F: {} dev_F: {}".format(i + 1, Fscore, evalTrainerF1))
print(" Time used: {}".format(timeSince(start_time)))
# nni.report_final_result(evalLoss)
trainTrainer(epoch)
# TODO(Yotta): to evaluation tasks
if __name__ == '__main__':
try:
#tuner_params = nni.get_next_parameter()
#logger.debug(tuner_params)
params = vars(get_paras())
#params.update(tuner_params)
print(params)
run(params)
except Exception as exception:
logger.exception(exception)
raise