-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpc_variants.py
More file actions
61 lines (43 loc) · 1.8 KB
/
pc_variants.py
File metadata and controls
61 lines (43 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from types import MethodType
import torch.nn.functional as F
from pc_e import PCE
# Define CELoss version of PCE
def use_CrossEntropyLoss(pc_module):
"""CELoss to avoid vanishing grads with state optim..."""
# Define the new loss method using CrossEntropyLoss
def class_loss(self, y_pred, y):
return F.cross_entropy(y_pred, y, reduction="sum")
# Override pc_module.class_loss with the new method
pc_module.class_loss = MethodType(class_loss, pc_module)
return pc_module
# Define state optim version of PCE
class PC_States(PCE):
def minimize_error_energy(self, x, y):
# Recycle iters and e_lr for state optimization, and store final states
self.states = super().minimize_state_energy(x, y, self.iters, self.e_lr)
def E_local(self, x, y):
return super().E_states_only(x, y, self.states)
# No need to redefine forward or y_pred:
# For prediction, they set all errors to zero and simply to the correct prediction.
# Therefore, we only need to adapt the training procedure.
# Define backprop version of PCE
class BackpropMSE(PCE):
def training_step(self, batch, batch_idx):
x, y = batch["img"], batch["y"]
self.forward(x) # sets all errors to 0
return self.class_loss(self.y_pred(x), y) / self.batch_size
def get_pc_variant(algorithm: str, USE_CROSSENTROPY_INSTEAD_OF_MSE: bool):
if algorithm == "EO":
pctype = PCE
elif algorithm == "SO":
pctype = PC_States
elif algorithm == "BP":
pctype = BackpropMSE
else:
raise NotImplementedError("Choose one of these options: EO | SO | BP")
def pc_maker(*args, **kwargs):
pc = pctype(*args, **kwargs)
if USE_CROSSENTROPY_INSTEAD_OF_MSE:
pc = use_CrossEntropyLoss(pc)
return pc
return pc_maker