Skip to content

Commit 7fb2c75

Browse files
committed
perf(aggregation): Prevent cuda sync in normalize
1 parent 5ec9ada commit 7fb2c75

1 file changed

Lines changed: 2 additions & 4 deletions

File tree

src/torchjd/_linalg/_gramian.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
5252
therefore `G` divided by the sum of its diagonal elements.
5353
"""
5454
squared_frobenius_norm = gramian.diagonal().sum()
55-
if squared_frobenius_norm < eps:
56-
output = torch.zeros_like(gramian)
57-
else:
58-
output = gramian / squared_frobenius_norm
55+
condition = squared_frobenius_norm < eps
56+
output = torch.where(condition, torch.zeros_like(gramian), gramian / squared_frobenius_norm)
5957
return cast(PSDMatrix, output)
6058

6159

0 commit comments

Comments
 (0)