Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ dependencies = [
"numba",
"triton",
"pre-commit",
"torchjd",
"torchviz"
]
14 changes: 14 additions & 0 deletions src/recursion/dataset/repeat_after_k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from torch import Tensor


def make_sequence(length: int, k: int) -> tuple[Tensor, Tensor]:
seq = torch.randint(low=0, high=2, size=[length + k])
input = seq[k:]

if k == 0:
target = seq
else:
target = seq[:-k]

return input, target
103 changes: 103 additions & 0 deletions src/recursion/models/trivial_memory_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from collections import defaultdict

import torch
from torch import Tensor, nn
from torch.optim import SGD

from recursion.dataset.repeat_after_k import make_sequence


class TrivialMemoryModel(nn.Module):
def __init__(self, memory_dim: int):
super().__init__()

hidden_size = 2 * (1 + memory_dim)
self.fc1 = nn.Linear(1 + memory_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, memory_dim)
# self.fc3 = nn.Linear(memory_dim, 1)
self.relu = nn.ReLU()

def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]:
x = torch.cat([input, memory], dim=-1)
x = self.relu(self.fc1(x))
x = self.fc2(x)

return x


input_sequence, target_sequence = make_sequence(7, 3)

memory_dim = 8

model = TrivialMemoryModel(memory_dim)
head = nn.Linear(memory_dim, 1)
memory = torch.randn(memory_dim)
criterion = nn.BCEWithLogitsLoss()
optimizer = SGD(model.parameters(), lr=1e-2)
memories = []
memories_wrt = []

param_to_gradients = defaultdict(list)
torch.set_printoptions(linewidth=200)
update_every = 6

from torchjd.aggregation import UPGradWeighting

weighting = UPGradWeighting()

for i, (input, target) in enumerate(zip(input_sequence, target_sequence, strict=True)):
memories_wrt.append(memory.detach().requires_grad_(True))
memory = model(input.unsqueeze(0).to(dtype=torch.float32), memories_wrt[-1])
output = head(memory)
loss = criterion(output, target.unsqueeze(0).to(dtype=torch.float32))
memories.append(memory)

print(f"{loss.item():.1e}")

if (i + 1) % update_every == 0:
optimizer.zero_grad()

grad_output = torch.autograd.grad(loss, [memories[-1]])

for j in range(update_every):
print(j)
grads = torch.autograd.grad(
memories[-j - 1],
list(model.parameters()) + [memories_wrt[-j - 1]],
grad_outputs=grad_output,
)
grads_wrt_params = grads[:-1]
grad_output = grads[-1]
Copy link
Copy Markdown

@PierreQuinton PierreQuinton Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be possible to clone the parameters of the memory model at each call, it should not require more memory. But then if we do backward we obtain a grad for each of the copies, we can stack them. Of course this also works and later on we can also make this quite efficient with hooks.

Copy link
Copy Markdown

@PierreQuinton PierreQuinton Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in this code, there is no training at all? (no .grad=...)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be possible to clone the parameters of the memory model at each call, it should not require more memory. But then if we do backward we obtain a grad for each of the copies, we can stack them. Of course this also works and later on we can also make this quite efficient with hooks.

I think the current method is almost maximally efficient. But maybe it's not expressive enough (can't really select paths of length 1, 2, 4, 8, etc, without computing also 3, 5, 6, 7, ..., for now).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe do what you say with a detached view of the parameters (I think cloning duplicates memory + is differentiable so the gradients would flow back to the original params)

Copy link
Copy Markdown

@PierreQuinton PierreQuinton Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Selecting only paths is doable only with residual RNN. But note that if you select only path to level 2 memory, then you don't train interaction between level 1 and level 2, which is not typically what we want to do.


for param, grad in zip(model.parameters(), grads_wrt_params, strict=True):
param_to_gradients[param].append(grad)

param_to_jacobian_matrix = {
param: torch.stack([g.flatten() for g in gradients], dim=0)
for param, gradients in param_to_gradients.items()
}
jacobian_matrix = torch.cat([mat for mat in param_to_jacobian_matrix.values()], dim=1)

gramian = jacobian_matrix @ jacobian_matrix.T
weights = weighting(gramian)
# print(jacobian_matrix.shape)
print(gramian)
print(weights)

# graph = make_dot(loss, params=dict(model.named_parameters()), show_attrs=True, show_saved=True)
# graph.view()

# graph = make_dot(attached_memories[-1], params=dict(model.named_parameters()), show_attrs=True,
# show_saved=True)
# graph.view()

# loss.backward()

# print("fc1 weights: ", model.fc1.weight.grad)
# print("fc1 biases: ", model.fc1.bias.grad)
#
# print("fc2 weights: ", model.fc2.weight.grad)
# print("fc2 biases: ", model.fc2.bias.grad)

optimizer.step()
memory = memory.detach()