We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5ec9ada commit 7fb2c75Copy full SHA for 7fb2c75
1 file changed
src/torchjd/_linalg/_gramian.py
@@ -52,10 +52,8 @@ def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
52
therefore `G` divided by the sum of its diagonal elements.
53
"""
54
squared_frobenius_norm = gramian.diagonal().sum()
55
- if squared_frobenius_norm < eps:
56
- output = torch.zeros_like(gramian)
57
- else:
58
- output = gramian / squared_frobenius_norm
+ condition = squared_frobenius_norm < eps
+ output = torch.where(condition, torch.zeros_like(gramian), gramian / squared_frobenius_norm)
59
return cast(PSDMatrix, output)
60
61
0 commit comments