-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
76 lines (52 loc) · 3.06 KB
/
main.py
File metadata and controls
76 lines (52 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import argparse
from datetime import datetime
from src.helpers import load_model, load_sentence_encoder, load_losses, load_optimizer_and_scheduler, train, evaluator, save_ckpt, load_dataloaders, load_summary_writer
if __name__ == '__main__':
parser = argparse.ArgumentParser(fromfile_prefix_chars='£')
# train parameters
parser.add_argument("--unique", action="store_true",
help="Choose whether using a unique encoder or dual encoders")
parser.add_argument("--epochs", default=1, type=int, help='number of epochs')
parser.add_argument("--bs", default=2 , type=int, help='batch size for training')
#Loss choice
parser.add_argument("--loss", default="MSE", type=str, help="loss function to be used for training",
choices=['MSE', 'MAE', 'Hinge', 'contrastive'])
parser.add_argument("--dataset", default="fikr30k", type=str, help="Dataset to be used for the training",
choices=['fikr30k', 'Anime'])
parser.add_argument("--split-size", "--split-size", default=0.3, type=float,
help='split size of the dataset')
parser.add_argument("--sentence-encoder", default="all-MiniLM-L12-v2", type=str, help="name of the pretrained Sentence Encoder",
choices=['all-MiniLM-L12-v2', 'paraphrase-MiniLM-L12-v2'])
parser.add_argument("--model", default="SwinEncoder", type=str, help="Architecture of the model",
choices=['SwinEncoder', 'ResNet'])
parser.add_argument("--optimizer", default="AdamW", type=str, help="name of the optimizer to be used for training",
choices=['AdamW', 'RMSprop'])
parser.add_argument("--lr", "--learning-rate", default=0.000357, type=float,
help='max learning rate')
parser.add_argument("--wd", "--weight-decay", default=0.1, type=float,
help='weight decay')
args = parser.parse_args()
name = f"{args.loss}_{datetime.now().strftime('%m_%d_%H_%M')}"
#Load Summary Writer
writer = load_summary_writer(name)
#Load The models
image_model = load_model(args)
text_model = load_sentence_encoder(args)
#Create The dataloader
dataloader_tr, dataloader_ts = load_dataloaders(args)
#print(len(dataloader_tr))
#Create the Optimizer
optimizer, scheduler = load_optimizer_and_scheduler(args, image_model, steps_per_epoch = len(dataloader_tr))
#Get the Devices ;
device = "cuda" if torch.cuda.is_available() else 'cpu'
#load losses
criterion = load_losses(args)
#loop
for epoch in range(0, args.epochs):
#train
image_model, optimizer, scheduler, writer= train(epoch, dataloader_tr, image_model, text_model, writer, optimizer, scheduler, criterion, device)
#evaluate
value, writer = evaluator(args, image_model, text_model, dataloader_ts, criterion, epoch, device, writer)
#save the ckpts
save_ckpt(args, name, image_model, value, epoch)