We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents b5ca226 + f30a835 commit 0baa914Copy full SHA for 0baa914
1 file changed
src/torchjd/_linalg/_gramian.py
@@ -62,11 +62,12 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
62
sqrt of the sum of the diagonal elements. The gramian of the (Frobenius) normalization of `A` is
63
therefore `G` divided by the sum of its diagonal elements.
64
"""
65
+
66
squared_frobenius_norm = gramian.diagonal().sum()
- if squared_frobenius_norm < eps:
67
- output = torch.zeros_like(gramian)
68
- else:
69
- output = gramian / squared_frobenius_norm
+ condition = squared_frobenius_norm < eps
+ # Use torch.where rather than a if-else to avoid cuda synchronization.
70
+ output = torch.where(condition, torch.zeros_like(gramian), gramian / squared_frobenius_norm)
71
return cast(PSDMatrix, output)
72
73
0 commit comments