File tree Expand file tree Collapse file tree 1 file changed +12
-5
lines changed
Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -30,11 +30,18 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
3030 first dimension).
3131 """
3232
33- contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t .ndim
34- indices_source = list (range (t .ndim - contracted_dims ))
35- indices_dest = list (range (t .ndim - 1 , contracted_dims - 1 , - 1 ))
36- transposed = t .movedim (indices_source , indices_dest )
37- gramian = torch .tensordot (t , transposed , dims = contracted_dims )
33+ # Optimization: it's faster to do that than moving dims and using tensordot, and this case
34+ # happens very often, sometimes hundreds of times for a single jac_to_grad.
35+ if contracted_dims == - 1 :
36+ matrix = t .flatten (start_dim = 1 )
37+ gramian = matrix @ matrix .T
38+
39+ else :
40+ contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t .ndim
41+ indices_source = list (range (t .ndim - contracted_dims ))
42+ indices_dest = list (range (t .ndim - 1 , contracted_dims - 1 , - 1 ))
43+ transposed = t .movedim (indices_source , indices_dest )
44+ gramian = torch .tensordot (t , transposed , dims = contracted_dims )
3845 return cast (PSDTensor , gramian )
3946
4047
You can’t perform that action at this time.
0 commit comments