Skip to content

Commit 62bda0c

Browse files
committed
Add and use JacobianBasedGramianComputerWithoutCrossTerms
1 parent 4676f11 commit 62bda0c

File tree

2 files changed

+7
-33
lines changed

2 files changed

+7
-33
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ._edge_registry import EdgeRegistry
88
from ._gramian_accumulator import GramianAccumulator
9-
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms
9+
from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithoutCrossTerms
1010
from ._gramian_utils import movedim_gramian, reshape_gramian
1111
from ._jacobian_computer import (
1212
AutogradJacobianComputer,
@@ -205,9 +205,10 @@ def _make_gramian_computer(self, module: nn.Module) -> GramianComputer:
205205
jacobian_computer: JacobianComputer
206206
if self._batch_dim is not None:
207207
jacobian_computer = FunctionalJacobianComputer(module)
208+
gramian_computer = JacobianBasedGramianComputerWithoutCrossTerms(jacobian_computer)
208209
else:
209210
jacobian_computer = AutogradJacobianComputer(module)
210-
gramian_computer = JacobianBasedGramianComputerWithCrossTerms(jacobian_computer)
211+
gramian_computer = JacobianBasedGramianComputerWithoutCrossTerms(jacobian_computer)
211212

212213
return gramian_computer
213214

src/torchjd/autogram/_gramian_computer.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from torch import Tensor
55
from torch.utils._pytree import PyTree
66

7-
from torchjd.autogram._jacobian_computer import JacobianComputer
8-
97

108
class GramianComputer(ABC):
119
@abstractmethod
@@ -34,24 +32,12 @@ def _to_gramian(jacobian: Tensor) -> Tensor:
3432
return jacobian @ jacobian.T
3533

3634

37-
class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
35+
class JacobianBasedGramianComputerWithoutCrossTerms(JacobianBasedGramianComputer):
3836
"""
39-
Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning
40-
the gramian.
37+
Stateful JacobianBasedGramianComputer that directly returning the gramian without considering
38+
cross-terms (except intra-module cross-terms).
4139
"""
4240

43-
def __init__(self, jacobian_computer: JacobianComputer):
44-
super().__init__(jacobian_computer)
45-
self.remaining_counter = 0
46-
self.summed_jacobian: Optional[Tensor] = None
47-
48-
def reset(self) -> None:
49-
self.remaining_counter = 0
50-
self.summed_jacobian = None
51-
52-
def track_forward_call(self) -> None:
53-
self.remaining_counter += 1
54-
5541
def __call__(
5642
self,
5743
rg_outputs: tuple[Tensor, ...],
@@ -62,17 +48,4 @@ def __call__(
6248
"""Compute what we can for a module and optionally return the gramian if it's ready."""
6349

6450
jacobian_matrix = self.jacobian_computer(rg_outputs, grad_outputs, args, kwargs)
65-
66-
if self.summed_jacobian is None:
67-
self.summed_jacobian = jacobian_matrix
68-
else:
69-
self.summed_jacobian += jacobian_matrix
70-
71-
self.remaining_counter -= 1
72-
73-
if self.remaining_counter == 0:
74-
gramian = self._to_gramian(self.summed_jacobian)
75-
del self.summed_jacobian
76-
return gramian
77-
else:
78-
return None
51+
return self._to_gramian(jacobian_matrix)

0 commit comments

Comments
 (0)