Skip to content

Commit 0baa914

Browse files
authored
Merge branch 'main' into optimize_jac_to_grad
2 parents b5ca226 + f30a835 commit 0baa914

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/torchjd/_linalg/_gramian.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,12 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
6262
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
6363
therefore `G` divided by the sum of its diagonal elements.
6464
"""
65+
6566
squared_frobenius_norm = gramian.diagonal().sum()
66-
if squared_frobenius_norm < eps:
67-
output = torch.zeros_like(gramian)
68-
else:
69-
output = gramian / squared_frobenius_norm
67+
condition = squared_frobenius_norm < eps
68+
69+
# Use torch.where rather than a if-else to avoid cuda synchronization.
70+
output = torch.where(condition, torch.zeros_like(gramian), gramian / squared_frobenius_norm)
7071
return cast(PSDMatrix, output)
7172

7273

0 commit comments

Comments
 (0)