Skip to content

Commit 526681c

Browse files
committed
Make dependence on gramian explicit in PCGrad
1 parent e962c99 commit 526681c

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/torchjd/aggregation/pcgrad.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ class _PCGradWeighting(_Weighting):
4141

4242
def forward(self, matrix: Tensor) -> Tensor:
4343
# Pre-compute the inner products
44-
inner_products = matrix @ matrix.T
44+
gramian = matrix @ matrix.T
45+
return self._compute_from_gramian(gramian)
4546

47+
@staticmethod
48+
def _compute_from_gramian(inner_products: Tensor) -> Tensor:
4649
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
47-
device = matrix.device
48-
dtype = matrix.dtype
50+
device = inner_products.device
51+
dtype = inner_products.dtype
4952
cpu = torch.device("cpu")
5053
inner_products = inner_products.to(device=cpu)
5154

0 commit comments

Comments
 (0)