Skip to content

Commit a48b0dd

Browse files
committed
Optimize compute_gramian for when contracted_dims=-1
1 parent 453971a commit a48b0dd

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/torchjd/_linalg/_gramian.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,18 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
3030
first dimension).
3131
"""
3232

33-
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
34-
indices_source = list(range(t.ndim - contracted_dims))
35-
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
36-
transposed = t.movedim(indices_source, indices_dest)
37-
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
33+
# Optimization: it's faster to do that than moving dims and using tensordot, and this case
34+
# happens very often, sometimes hundreds of times for a single jac_to_grad.
35+
if contracted_dims == -1:
36+
matrix = t.flatten(start_dim=1)
37+
gramian = matrix @ matrix.T
38+
39+
else:
40+
contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim
41+
indices_source = list(range(t.ndim - contracted_dims))
42+
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
43+
transposed = t.movedim(indices_source, indices_dest)
44+
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
3845
return cast(PSDTensor, gramian)
3946

4047

0 commit comments

Comments
 (0)