forked from SamsungSAILMontreal/TinyRecursiveModels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrivial_memory_model.py
More file actions
135 lines (101 loc) · 4.23 KB
/
trivial_memory_model.py
File metadata and controls
135 lines (101 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from collections import defaultdict
import torch
from torch import Tensor, nn
from torch.nn.functional import cosine_similarity
from torch.optim import SGD
from torchjd.aggregation import UPGrad
from torchjd.autojac._transform import (
Accumulate,
Aggregate,
Diagonalize,
Init,
Jac,
OrderedSet,
Select,
)
from recursion.dataset.repeat_after_k import make_sequences
class ResidualMemoryModel(nn.Module):
def __init__(self, memory_dim: int):
super().__init__()
hidden_size = 2 * (1 + memory_dim)
self.fc1 = nn.Linear(1 + memory_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, memory_dim)
self.relu = nn.ReLU()
def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]:
x = torch.cat([input, memory], dim=-1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x + memory
batch_size = 16
k = 2
input_sequences, target_sequences = make_sequences(50000, k, batch_size=batch_size)
memory_dim = 8
model = ResidualMemoryModel(memory_dim)
head = nn.Linear(memory_dim, 1)
memory = torch.zeros(batch_size, memory_dim)
criterion = nn.BCEWithLogitsLoss(reduction="none")
optimizer = SGD(model.parameters(), lr=5e-03)
head_optimizer = SGD(head.parameters(), lr=5e-03)
memories = []
memories_wrt = []
param_to_jacobians = defaultdict(list)
torch.set_printoptions(linewidth=200)
update_every = 3
aggregator = UPGrad()
def hook(_, args: tuple[Tensor], __) -> None:
jacobian = args[0]
gramian = jacobian @ jacobian.T
print(gramian[0, 0] / gramian[k * batch_size, k * batch_size])
def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
"""Prints the cosine similarity between the aggregation and the average gradient."""
matrix = inputs[0]
gd_output = matrix.mean(dim=0)
similarity = cosine_similarity(aggregation, gd_output, dim=0)
print(f"Cosine similarity: {similarity.item():.4f}")
aggregator.register_forward_hook(hook)
aggregator.register_forward_hook(print_gd_similarity)
for i, (input, target) in enumerate(zip(input_sequences.T, target_sequences.T, strict=True)):
memories_wrt.append(memory.detach().requires_grad_(True))
memory = model(input.unsqueeze(1).to(dtype=torch.float32), memories_wrt[-1])
output = head(memory)
losses = criterion(output, target.unsqueeze(1).to(dtype=torch.float32))
loss = losses.mean()
memories.append(memory)
transform = Accumulate() << Aggregate(aggregator, OrderedSet(list(model.parameters())))
print(f"{loss.item():.1e}")
if (i + 1) % update_every == 0:
# grad_output = torch.autograd.grad(loss, [memories[-1]], retain_graph=True)
ordered_set = OrderedSet(losses)
init = Init(ordered_set)
diag = Diagonalize(ordered_set)
jac = Jac(ordered_set, OrderedSet([memories[-1]]), chunk_size=None, retain_graph=True)
trans = jac << diag << init
trans.check_keys(set())
jac_output = trans({})
for j in range(update_every):
new_jac = Jac(
OrderedSet([memories[-j - 1]]),
OrderedSet(list(model.parameters()) + [memories_wrt[-j - 1]]),
chunk_size=None,
)
select_jac_wrt_model = Select(OrderedSet(list(model.parameters())))
select_jac_wrt_memory = Select(OrderedSet([memories_wrt[-j - 1]]))
jacobians = new_jac(jac_output)
jac_output = select_jac_wrt_memory(jacobians)
if j < update_every - 1:
jac_output = {memories[-j - 2]: jac_output[memories_wrt[-j - 1]]}
jac_wrt_params = select_jac_wrt_model(jacobians)
for param, jacob in jac_wrt_params.items():
param_to_jacobians[param].append(jacob)
param_to_jacobian = {
param: torch.cat(jacobs, dim=0) for param, jacobs in param_to_jacobians.items()
}
optimizer.zero_grad()
transform(param_to_jacobian) # This stores the aggregated Jacobian in the .grad fields
optimizer.step()
memories = []
memories_wrt = []
param_to_jacobians = defaultdict(list)
head_optimizer.zero_grad()
torch.autograd.backward(loss, inputs=list(head.parameters()))
head_optimizer.step()