forked from ZJUCDSYangKaifan/GEVit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_experiment_regularmnist.py
More file actions
141 lines (111 loc) · 4.02 KB
/
Copy pathrun_experiment_regularmnist.py
File metadata and controls
141 lines (111 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import datetime
import os
import sys
import numpy as np
import torch
import wandb
# args
from absl import app, flags
from ml_collections.config_flags import config_flags
import dataset
import tester
import trainer
from model import get_model
from path_handler import model_path
from torchvision.datasets import MNIST
from torchvision import transforms
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", default="config.py")
def main(_):
if "absl.logging" in sys.modules:
import absl.logging
absl.logging.set_verbosity("info")
absl.logging.set_stderrthreshold("info")
config = FLAGS.config
print(config)
# Set the seed
torch.manual_seed(config.seed)
np.random.seed(config.seed)
# Check if in the correct branch
# group_name = config["model"][: config["model"].find("sa")]
# if group_name not in ["z2", "mz2", "p4", "p4m"]:
# raise ValueError(
# "Mlp_encoding is required for rotations finer than 90 degrees. Please change to the mlp_encoding branch."
# )
# initialize weight and bias
# os.environ["WANDB_API_KEY"] = "691777d26bb25439a75be52632da71d865d3a671" # TODO change this if we are doing serious runs
# if not config.train:
# os.environ["WANDB_MODE"] = "dryrun"
# wandb.init(
# project="equivariant-attention",
# config=config,
# group=config["dataset"],
# entity="equivatt_team",
# )
os.environ["WANDB_API_KEY"] = "06019ee01060de1ab2a6e4758fe3f9e945544dff"
wandb.init(
project="wouters_eq_vit",
group="rotmnist",
entity="ge_vit_DL2",
)
# Define the device to be used and move model to that device
config["device"] = (
"cuda:0" if (config.device == "cuda" and torch.cuda.is_available()) else "cpu"
)
model = get_model(config)
# Define transforms and create dataloaders
# dataloaders = dataset.get_dataset(config, num_workers=4, data_fraction=config.data_fraction)
data_mean = (0.1307,)
data_stddev = (0.3081,)
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(data_mean, data_stddev)
])
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(data_mean, data_stddev),
]
)
train_set = MNIST(root="../data/mnistreal", train=True, download=True, transform=transform_train)
test_set = MNIST(root="../data/mnistreal", train=False, download=True, transform=transform_test)
# Define the size of the validation set
validation_size = int(0.2 * len(test_set)) # Adjust as needed
# Define indices for the validation set and the remaining for testing
torch.manual_seed(42)
indices = torch.randperm(len(test_set)).tolist()
validation_indices = indices[:validation_size]
test_indices = indices[validation_size:]
# Create subsets for validation and testing
validation_set = torch.utils.data.Subset(test_set, validation_indices)
test_set = torch.utils.data.Subset(test_set, test_indices)
num_workers=4
training_loader = torch.utils.data.DataLoader(
train_set,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers,
)
test_loader = torch.utils.data.DataLoader(
validation_set,
batch_size=config.batch_size,
shuffle=False,
num_workers=num_workers,
)
dataloaders = {"train": training_loader, "test": test_loader}
# Create model directory and instantiate config.path
model_path(config)
if config.pretrained:
# Load model state dict
model.module.load_state_dict(torch.load(config.path), strict=False)
# Train the model
if config.train:
# Print arguments (Sanity check)
print(config)
print(datetime.datetime.now())
# Train the model
trainer.train(model, dataloaders, config)
# Test model
tester.test(model, dataloaders["test"], config)
if __name__ == "__main__":
app.run(main)