We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 05ce819 commit 2ce5839Copy full SHA for 2ce5839
1 file changed
src/torchjd/aggregation/config.py
@@ -70,10 +70,7 @@ def forward(self, matrix: Tensor) -> Tensor:
70
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
71
best_direction = torch.linalg.pinv(units) @ weights
72
73
- if best_direction.norm() == 0:
74
- unit_target_vector = torch.zeros_like(best_direction)
75
- else:
76
- unit_target_vector = best_direction / best_direction.norm()
+ unit_target_vector = torch.nn.functional.normalize(best_direction)
77
78
length = torch.sum(matrix @ unit_target_vector)
79
0 commit comments