-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain_zh.py
More file actions
67 lines (53 loc) · 1.99 KB
/
main_zh.py
File metadata and controls
67 lines (53 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import torch
from utils import generate_data, check_data,word_ids_to_sentence
import torchtext
from torchtext import data
import torch.nn as nn
import torch.nn.functional as F
from train_eval import train,test, test_one_sentence
from torch.autograd import Variable as V
from model import RNNModel
from transformer import TransformerModel
class Config(object):
def __init__(self):
self.model_name="lm_model"
self.data_ori="/mnt/data3/wuchunsheng/data_all/data_mine/lm_data/"
#self.data_ori="E:/data/word_nlp/cnews_data/"
self.train_path="train_0.csv"
self.valid_path="train_0.csv"
self.test_path="test_100.csv"
self.sen_max_length=150
#self.embedding_path = "need_bertembedding"
self.embedding_path = "bert_embedding"
self.embedding_dim=768
self.vocab_maxsize=4000
self.vocab_minfreq=10
self.save_path="lm_ckpt"
self.batch_size = 64
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.hidden_size=200
self.nlayers=1
self.dropout=0.5
self.epoch=100
self.train_len=0
self.test_len = 0
self.valid_len = 0
self.mode="train"
## transformer的参数
self.dropout=0.5
self.max_len=5000
self.nhead=2
#data_path="E:/study_series/2020_3/re_write_classify/data/"
#data_path="/mnt/data3/wuchunsheng/code/nlper/NLP_task/text_classification/my_classification_cnews/2020_3_30/text_classify/data/"
config=Config()
train_iter, valid_iter, test_iter, TEXT=generate_data(config)
#check_data(train_iter,TEXT)
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#model = RNNModel(config, TEXT).to(device)
model=TransformerModel(config, TEXT).to(device)
train(config,model,train_iter, valid_iter, test_iter)
#res=test(config,model,TEXT, test_iter)## 测试的是一个正批量的
#print(res)
res=test_one_sentence(config, model, TEXT, test_iter)
print(res)