-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy path_gramian_accumulator.py
More file actions
36 lines (26 loc) · 1.1 KB
/
_gramian_accumulator.py
File metadata and controls
36 lines (26 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Optional
from torchjd._linalg import PSDMatrix
class GramianAccumulator:
"""
Efficiently accumulates the Gramian of the Jacobian during reverse-mode differentiation.
Jacobians from multiple graph paths to the same parameter are first summed to obtain the full
Jacobian w.r.t. a parameter, then its Gramian is computed and accumulated, over parameters, into
the total Gramian matrix. Intermediate matrices are discarded immediately to save memory.
"""
def __init__(self) -> None:
self._gramian: Optional[PSDMatrix] = None
def reset(self) -> None:
self._gramian = None
def accumulate_gramian(self, gramian: PSDMatrix) -> None:
if self._gramian is not None:
self._gramian.add_(gramian)
else:
self._gramian = gramian
@property
def gramian(self) -> Optional[PSDMatrix]:
"""
Get the Gramian matrix accumulated so far.
:returns: Accumulated Gramian matrix of shape (batch_size, batch_size) or None if nothing
was accumulated yet.
"""
return self._gramian