-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaudio_cnn_pytorch.py
More file actions
72 lines (57 loc) · 2.02 KB
/
Copy pathaudio_cnn_pytorch.py
File metadata and controls
72 lines (57 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from models import CNN, CNN2
from dataloader import prepare_dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'CNN'
# 在预测之前重新加载模型
if model_name == 'CNN2':
model = CNN().to(device)
else:
model = CNN2().to(device)
train_loader, test_loader, val_loader = prepare_dataset(0.25, 0.2)
learning_rate = 0.0001
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
epochs = 500
print("Training Start")
for epoch in range(epochs):
running_loss = 0
running_corrects = 0
if epoch % 20 == 0:
loop = tqdm(enumerate(train_loader), total=len(train_loader)) # To get progress bar
for i, (inputs, labels) in loop:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# loop.set_description(f"Epoch [{epoch}/{epochs}]")
# loop.set_postfix(loss = running_loss/len(train_loader))
if epoch % 20 == 0:
loop.set_description(f"Epoch [{epoch}/{epochs}]")
loop.set_postfix(loss = running_loss/len(train_loader))
# Evaluating on validation set
with torch.no_grad():
# print("Evaluate")
model.eval()
for inputs, labels in val_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
running_corrects += torch.sum(preds == labels)
num_samples = labels.size(0)
val_acc = running_corrects.double() / num_samples
if epoch % 20 == 0:
print(f"epoch:{epoch}, Validation Accuracy {val_acc :.4f}")
model.train()
torch.save(model.state_dict(), 'model.pth')
print("Model saved")