-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
129 lines (96 loc) · 3.39 KB
/
main.py
File metadata and controls
129 lines (96 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from dataset import create_wall_dataloader
from evaluator import ProbingEvaluator
import torch
from models import MockModel, JEPA
import glob
def get_device():
"""Check for GPU availability."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
return device
def load_data(device):
data_path = "/scratch/DL24FA"
probe_train_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_normal/train",
probing=True,
device=device,
train=True,
)
probe_val_normal_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_normal/val",
probing=True,
device=device,
train=False,
)
probe_val_wall_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_wall/val",
probing=True,
device=device,
train=False,
)
probe_val_wall_other_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_wall_other/val",
probing=True,
device=device,
train=False,
)
probe_val_ds = {
"normal": probe_val_normal_ds,
"wall": probe_val_wall_ds,
"wall_other": probe_val_wall_other_ds,
}
return probe_train_ds, probe_val_ds
def load_expert_data(device):
data_path = "/scratch/DL24FA"
probe_train_expert_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_expert/train",
probing=True,
device=device,
train=True,
)
probe_val_expert_ds = {
"expert": create_wall_dataloader(
data_path=f"{data_path}/probe_expert/val",
probing=True,
device=device,
train=False,
)
}
return probe_train_expert_ds, probe_val_expert_ds
def load_model():
"""Load or initialize the JEPA model."""
from models import ActionRegularizationJEPA2Dv0, FinalModel
from configs import JEPAConfig
config = JEPAConfig.parse_from_file("config/areg_jepa_2d_config_v0.yaml")
config.steps_per_epoch = 999
# model = ActionRegularizationJEPA2Dv0(config)
# model.load_state_dict(torch.load('weights/best_expert_model_epoch_4_train_iter_76_normal_loss_21.21573_wall_loss_19.79489_expert_loss_85.14044.pt'))
# model.cuda()
model = FinalModel(config)
model.load_state_dict(torch.load('model_weights.pth'))
model.cuda()
return model
def evaluate_model(device, model, probe_train_ds, probe_val_ds):
evaluator = ProbingEvaluator(
device=device,
model=model,
probe_train_ds=probe_train_ds,
probe_val_ds=probe_val_ds,
quick_debug=False,
)
prober = evaluator.train_pred_prober()
avg_losses = evaluator.evaluate_all(prober=prober)
for probe_attr, loss in avg_losses.items():
print(f"{probe_attr} loss: {loss}")
return avg_losses
if __name__ == "__main__":
device = get_device()
model = load_model()
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {total_params:,}")
probe_train_ds, probe_val_ds = load_data(device)
model.set_repr_dim(probe_train_ds)
evaluate_model(device, model, probe_train_ds, probe_val_ds)
probe_train_expert_ds, probe_val_expert_ds = load_expert_data(device)
model.set_repr_dim(probe_train_expert_ds)
evaluate_model(device, model, probe_train_expert_ds, probe_val_expert_ds)