-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
40 lines (32 loc) · 1.42 KB
/
train.py
File metadata and controls
40 lines (32 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from dataset import MyTrainDataset, MyValandTestDataset, MyValandTestBlockDataset, MyDDataset
from model import RFNet,VRCNN,VRCNN_H,DenseNet,MyNet
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from tensorboardX import SummaryWriter
import math
def train(args, model, device, train_loader, optimizer, epoch, writer, fold = 0):
model.train()
train_loss = 0
for batch_idx, batch_data in enumerate(train_loader):
data1 = batch_data['image1'].float()
target = batch_data['target'].float() # [batch-size,channels,height,width] dtype=torch.float32
data1, target = data1.to(device), target.to(device)
optimizer.zero_grad()
output = model(data1)
loss = F.mse_loss(output,target)
train_loss += loss.item()
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
args.epochs * fold + epoch, batch_idx * len(data1), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if args.dry_run:
break
train_loss /= len(train_loader)
writer.add_scalar('Train/Loss', train_loss, args.epochs * fold + epoch)