Skip to content

Commit 05ce819

Browse files
committed
Highlights matrix multiplication
1 parent 1f0afca commit 05ce819

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/torchjd/aggregation/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(self, matrix: Tensor) -> Tensor:
7575
else:
7676
unit_target_vector = best_direction / best_direction.norm()
7777

78-
length = torch.sum(torch.stack([torch.dot(grad, unit_target_vector) for grad in matrix]))
78+
length = torch.sum(matrix @ unit_target_vector)
7979

8080
return length * unit_target_vector
8181

0 commit comments

Comments
 (0)