-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
44 lines (38 loc) · 1.24 KB
/
main.py
File metadata and controls
44 lines (38 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch.nn as nn
import torch.optim as optim
from data import test_data_process, train_val_data_process
from eval import test_model_process
from models import get_model
from trainers import Trainer
from utils import get_device, get_result_dir, save_summary_json, set_seed
def main():
set_seed(42)
dataset_name = "fashionmnist"
model_name = "lenet"
model = get_model(model_name, dataset_name)
train_loader, val_loader = train_val_data_process(model_name, dataset_name)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = get_device()
result_dir = get_result_dir(dataset_name, model_name)
trainer = Trainer(
model = model,
train_loader = train_loader,
val_loader = val_loader,
criterion = criterion,
optimizer = optimizer,
device = device,
result_dir = result_dir,
)
trainer.fit(10)
test_acc = test_model_process(model, test_data_process(model_name, dataset_name))
save_summary_json(
{
"dataset_name": dataset_name,
"model_name": model_name,
"test_acc": test_acc,
},
result_dir / "summary.json",
)
if __name__ == '__main__':
main()