We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e962c99 commit 526681cCopy full SHA for 526681c
src/torchjd/aggregation/pcgrad.py
@@ -41,11 +41,14 @@ class _PCGradWeighting(_Weighting):
41
42
def forward(self, matrix: Tensor) -> Tensor:
43
# Pre-compute the inner products
44
- inner_products = matrix @ matrix.T
+ gramian = matrix @ matrix.T
45
+ return self._compute_from_gramian(gramian)
46
47
+ @staticmethod
48
+ def _compute_from_gramian(inner_products: Tensor) -> Tensor:
49
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
- device = matrix.device
- dtype = matrix.dtype
50
+ device = inner_products.device
51
+ dtype = inner_products.dtype
52
cpu = torch.device("cpu")
53
inner_products = inner_products.to(device=cpu)
54
0 commit comments