44from torch import Tensor
55from torch .utils ._pytree import PyTree
66
7- from torchjd .autogram ._jacobian_computer import JacobianComputer
8-
97
108class GramianComputer (ABC ):
119 @abstractmethod
@@ -34,24 +32,12 @@ def _to_gramian(jacobian: Tensor) -> Tensor:
3432 return jacobian @ jacobian .T
3533
3634
37- class JacobianBasedGramianComputerWithCrossTerms (JacobianBasedGramianComputer ):
35+ class JacobianBasedGramianComputerWithoutCrossTerms (JacobianBasedGramianComputer ):
3836 """
39- Stateful JacobianBasedGramianComputer that waits for all usages to be counted before returning
40- the gramian .
37+ Stateful JacobianBasedGramianComputer that directly returning the gramian without considering
38+ cross-terms (except intra-module cross-terms) .
4139 """
4240
43- def __init__ (self , jacobian_computer : JacobianComputer ):
44- super ().__init__ (jacobian_computer )
45- self .remaining_counter = 0
46- self .summed_jacobian : Optional [Tensor ] = None
47-
48- def reset (self ) -> None :
49- self .remaining_counter = 0
50- self .summed_jacobian = None
51-
52- def track_forward_call (self ) -> None :
53- self .remaining_counter += 1
54-
5541 def __call__ (
5642 self ,
5743 rg_outputs : tuple [Tensor , ...],
@@ -62,17 +48,4 @@ def __call__(
6248 """Compute what we can for a module and optionally return the gramian if it's ready."""
6349
6450 jacobian_matrix = self .jacobian_computer (rg_outputs , grad_outputs , args , kwargs )
65-
66- if self .summed_jacobian is None :
67- self .summed_jacobian = jacobian_matrix
68- else :
69- self .summed_jacobian += jacobian_matrix
70-
71- self .remaining_counter -= 1
72-
73- if self .remaining_counter == 0 :
74- gramian = self ._to_gramian (self .summed_jacobian )
75- del self .summed_jacobian
76- return gramian
77- else :
78- return None
51+ return self ._to_gramian (jacobian_matrix )
0 commit comments