forked from SamsungSAILMontreal/TinyRecursiveModels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrivial_memory_model.py
More file actions
103 lines (76 loc) · 3.18 KB
/
trivial_memory_model.py
File metadata and controls
103 lines (76 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from collections import defaultdict
import torch
from torch import Tensor, nn
from torch.optim import SGD
from recursion.dataset.repeat_after_k import make_sequence
class TrivialMemoryModel(nn.Module):
def __init__(self, memory_dim: int):
super().__init__()
hidden_size = 2 * (1 + memory_dim)
self.fc1 = nn.Linear(1 + memory_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, memory_dim)
# self.fc3 = nn.Linear(memory_dim, 1)
self.relu = nn.ReLU()
def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]:
x = torch.cat([input, memory], dim=-1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
input_sequence, target_sequence = make_sequence(7, 3)
memory_dim = 8
model = TrivialMemoryModel(memory_dim)
head = nn.Linear(memory_dim, 1)
memory = torch.randn(memory_dim)
criterion = nn.BCEWithLogitsLoss()
optimizer = SGD(model.parameters(), lr=1e-2)
memories = []
memories_wrt = []
param_to_gradients = defaultdict(list)
torch.set_printoptions(linewidth=200)
update_every = 6
from torchjd.aggregation import UPGradWeighting
weighting = UPGradWeighting()
for i, (input, target) in enumerate(zip(input_sequence, target_sequence, strict=True)):
memories_wrt.append(memory.detach().requires_grad_(True))
memory = model(input.unsqueeze(0).to(dtype=torch.float32), memories_wrt[-1])
output = head(memory)
loss = criterion(output, target.unsqueeze(0).to(dtype=torch.float32))
memories.append(memory)
print(f"{loss.item():.1e}")
if (i + 1) % update_every == 0:
optimizer.zero_grad()
grad_output = torch.autograd.grad(loss, [memories[-1]])
for j in range(update_every):
print(j)
grads = torch.autograd.grad(
memories[-j - 1],
list(model.parameters()) + [memories_wrt[-j - 1]],
grad_outputs=grad_output,
)
grads_wrt_params = grads[:-1]
grad_output = grads[-1]
for param, grad in zip(model.parameters(), grads_wrt_params, strict=True):
param_to_gradients[param].append(grad)
param_to_jacobian_matrix = {
param: torch.stack([g.flatten() for g in gradients], dim=0)
for param, gradients in param_to_gradients.items()
}
jacobian_matrix = torch.cat([mat for mat in param_to_jacobian_matrix.values()], dim=1)
gramian = jacobian_matrix @ jacobian_matrix.T
weights = weighting(gramian)
# print(jacobian_matrix.shape)
print(gramian)
print(weights)
# graph = make_dot(loss, params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
# graph.view()
# graph = make_dot(attached_memories[-1], params=dict(model.named_parameters()), show_attrs=True,
# show_saved=True)
# graph.view()
# loss.backward()
# print("fc1 weights: ", model.fc1.weight.grad)
# print("fc1 biases: ", model.fc1.bias.grad)
#
# print("fc2 weights: ", model.fc2.weight.grad)
# print("fc2 biases: ", model.fc2.bias.grad)
optimizer.step()
memory = memory.detach()