Skip to content

Commit 2ce5839

Browse files
committed
Highlights vector normalization
1 parent 05ce819 commit 2ce5839

1 file changed

Lines changed: 1 addition & 4 deletions

File tree

src/torchjd/aggregation/config.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ def forward(self, matrix: Tensor) -> Tensor:
7070
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
7171
best_direction = torch.linalg.pinv(units) @ weights
7272

73-
if best_direction.norm() == 0:
74-
unit_target_vector = torch.zeros_like(best_direction)
75-
else:
76-
unit_target_vector = best_direction / best_direction.norm()
73+
unit_target_vector = torch.nn.functional.normalize(best_direction)
7774

7875
length = torch.sum(matrix @ unit_target_vector)
7976

0 commit comments

Comments
 (0)