Skip to content

Commit 878bf9b

Browse files
authored
refactor(autogram): Rename JacobianAccumulator (#471)
* Rename JacobianAccumulator to AutogramNode * Update its docstring
1 parent 165c11b commit 878bf9b

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def hook_module(self, module: nn.Module, gramian_computer: GramianComputer) -> N
5353
"""
5454
Add a module hook used to insert Jacobian accumulation nodes into the backward graph.
5555
56-
The hook injects a JacobianAccumulator function into the computation graph after the module,
56+
The hook injects a AutogramNode function into the computation graph after the module,
5757
enabling Gramian computation.
5858
"""
5959

@@ -125,14 +125,14 @@ def __call__(
125125

126126
self.gramian_computer.track_forward_call()
127127

128-
# We only care about running the JacobianAccumulator node, so we need one of its child
128+
# We only care about running the AutogramNode, so we need one of its child
129129
# edges (the edges of the original outputs of the model) as target. For memory
130130
# efficiency, we select the smallest one (that requires grad).
131131
preference = torch.tensor([t.numel() for t in rg_outputs])
132132
index = cast(int, preference.argmin().item())
133133
self.target_edges.register(get_gradient_edge(rg_outputs[index]))
134134

135-
autograd_fn_rg_outputs = JacobianAccumulator.apply(
135+
autograd_fn_rg_outputs = AutogramNode.apply(
136136
self.gramian_accumulation_phase,
137137
self.gramian_computer,
138138
args,
@@ -147,13 +147,10 @@ def __call__(
147147
return tree_unflatten(flat_outputs, output_spec)
148148

149149

150-
class JacobianAccumulator(torch.autograd.Function):
150+
class AutogramNode(torch.autograd.Function):
151151
"""
152-
Autograd function that accumulates Jacobian Gramians during the first backward pass.
153-
154-
Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
155-
of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
156-
toggle mechanism to activate only during the Gramian accumulation phase.
152+
Autograd function that is identity on forward and that launches the computation and accumulation
153+
of the gramian on backward.
157154
"""
158155

159156
generate_vmap_rule = True

0 commit comments

Comments
 (0)