forked from igeti/ner_task
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
74 lines (66 loc) · 2.81 KB
/
model.py
File metadata and controls
74 lines (66 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel
class NERModel(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
config = AutoConfig.from_pretrained(args.model_type, num_labels=args.num_class)
self.model = AutoModel.from_pretrained(args.model_type, )
self.dropout = nn.Dropout(args.dropout_prob)
self.classifier = nn.Linear(config.hidden_size, args.num_class)
self.loss_fnt = nn.CrossEntropyLoss(ignore_index=-1)
def forward(self, input_ids, attention_mask, labels=None):
h, *_ = self.model(input_ids, attention_mask, return_dict=False)
h = self.dropout(h)
c = self.args.num_class
logits = self.classifier(h)
logits = logits.view(-1, c)
outputs = (logits,)
if labels is not None:
labels = labels.view(-1)
loss = self.loss_fnt(logits, labels)
outputs = (loss,) + outputs
return outputs
def kl_div(p, q):
t = (p * ((p + 1e-5).log() - (q + 1e-5).log())).sum(-1)
return t
class ModelConstruct(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.models = nn.ModuleList()
self.device = args.device
self.loss_fnt = nn.CrossEntropyLoss()
for i in range(2):
model = NERModel(args)
model.to(self.device)
self.models.append(model)
def forward(self, input_ids, attention_mask, labels=None):
if labels is None:
return self.models[0](input_ids=input_ids,
attention_mask=attention_mask,
)
else:
num_models = len(self.models)
outputs = []
for i in range(num_models):
output = self.models[i](
input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
labels=labels.to(self.device) if labels is not None else None,
)
output = tuple([o.to(self.device) for o in output])
outputs.append(output)
model_output = outputs[0]
loss = sum([output[0] for output in outputs]) / num_models
logits = [output[1] for output in outputs]
probs = [F.softmax(logit, dim=-1) for logit in logits]
avg_prob = torch.stack(probs, dim=0).mean(0)
mask = (labels.view(-1) != -1).to(logits[0])
reg_loss = sum([kl_div(avg_prob, prob) * mask for prob in probs]) / num_models
reg_loss = reg_loss.sum() / (mask.sum() + 1e-3)
loss = loss + self.args.alpha_t * reg_loss
model_output = (loss,) + model_output[1:]
return model_output