-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
73 lines (58 loc) · 2.31 KB
/
inference.py
File metadata and controls
73 lines (58 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# -*- encoding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertConfig, BertModel, BertTokenizer
from train import TARGET_NAMES, TextDataset, TextClassifyModel, eval
BATCH_SIZE = 64
MAXLEN = 32
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SAVE_DIR = 'saved_model'
SAVE_CKPT = 'saved_model/pytorch_model.bin'
class InferDataset(Dataset):
"""自定义数据集"""
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def text_2_id(self, text):
return self.tokenizer(text, max_length=MAXLEN, truncation=True,
padding='max_length', return_tensors='pt')
def __getitem__(self, index):
return self.text_2_id(self.data[index])
def infer(model, dataloader):
"""推理函数"""
model.eval()
preds = torch.tensor([], dtype=int, device=DEVICE)
with torch.no_grad():
for text in dataloader:
input_ids = text.get('input_ids').squeeze(1).to(DEVICE)
attention_mask = text.get('attention_mask').squeeze(1).to(DEVICE)
token_type_ids = text.get('token_type_ids').squeeze(1).to(DEVICE)
out = model(input_ids, attention_mask, token_type_ids) # [batch, 10]
max_idx = torch.max(out, 1)[1]
preds = torch.cat((preds, max_idx), dim=0)
return [TARGET_NAMES[i] for i in preds.cpu().numpy()]
if __name__ == '__main__':
texts = [
'今年春晚节目单出来了',
'刘翔夺冠了',
'基金今天涨了5个点',
'国家总统访问美国',
'中小学马上要开学了',
'男子救人落水不幸逝世',
'白酒股最近跌了很多',
'吴亦凡入狱了, 大快人心!',
]
# load data
tokenizer = BertTokenizer.from_pretrained(SAVE_DIR)
dataloader = DataLoader(InferDataset(texts, tokenizer), batch_size=BATCH_SIZE)
# load model
model = TextClassifyModel(pretrained_model=SAVE_DIR).to(DEVICE)
model.load_state_dict(torch.load(SAVE_CKPT))
# infer
res = infer(model, dataloader)
for text, pred in zip(texts, res):
print(f'{text} {pred}')