-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathmain.py
More file actions
86 lines (74 loc) · 2.93 KB
/
main.py
File metadata and controls
86 lines (74 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import numpy as np
import logging
import sys
import torch
import torch.optim as optim
import random
import shutil
# from tensorboardX import SummaryWriter
from utils.general import init_logging
from model.x2x import Tree2TreeModel
from config import parser
from trainer import Trainer
from model.utils import device_map_location
if __name__ == '__main__':
args = parser.parse_args()
# arguments validation
args.cuda = args.cuda and torch.cuda.is_available()
# random seed
np.random.seed(args.random_seed)
torch.manual_seed(args.random_seed)
random.seed(args.random_seed)
if args.cuda:
torch.cuda.manual_seed(args.random_seed)
torch.backends.cudnn.benchmark = True
# prepare dirs
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# start logging
init_logging(os.path.join(args.output_dir, 'parser.log'))
logging.info('command line: %s', ' '.join(sys.argv))
logging.info('current config: %s', args)
logging.info('loading dataset [%s]', args.dataset)
# load data
load_dataset = None
if args.dataset == 'hs':
from datasets.hs import load_dataset
elif args.dataset == 'django':
from datasets.django import load_dataset
else:
raise Exception('Dataset {} is not prepared yet'.format(args.dataset))
train_data, dev_data, test_data = load_dataset(args)
# configure more variables
args.source_vocab_size = train_data.vocab.size()
args.target_vocab_size = train_data.terminal_vocab.size()
args.rule_num = len(train_data.grammar.rules)
args.node_num = len(train_data.grammar.node_type_to_id)
# load model
if args.model:
logging.info('Loading model: {}'.format(args.model))
# device map location allows to load model trained on GPU on CPU env and vice versa
model = torch.load(args.model, device_map_location(args.cuda))
else:
logging.info('Creating new model'.format(args.model))
emb_file = os.path.join(args.data_dir, 'word_embeddings.pth')
emb = torch.load(emb_file)
model = Tree2TreeModel(args, emb, train_data.terminal_vocab, train_data.grammar)
if args.cuda:
model = model.cuda()
# create learner
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=args.lr)
trainer = Trainer(model, args, optimizer)
if args.mode == 'train':
trainer.train_all(train_data, dev_data, test_data, args.output_dir)
elif args.mode == 'validate':
tmp_epoch_dir = os.path.join(args.output_dir, 'tmp')
if os.path.exists(tmp_epoch_dir):
shutil.rmtree(tmp_epoch_dir)
os.mkdir(tmp_epoch_dir)
trainer.validate(test_data, 0, tmp_epoch_dir)
elif args.mode == 'start_batch':
trainer.train(train_data, 0, st_batch=59)
else:
raise Exception("Unknown mode!")