Skip to content

Commit f8e8a5e

Browse files
refactor(aggregation): Improve ConFIG implementation (#281)
* Improve implementation of the matrix-vector product to simply use a single @ * Change the vector normalization to use torch.nn.functional.normalize with the default epsilon of 1e-12. This should change the output of the aggregator for very uncertain vectors to be zero instead. * Add changelog entry
1 parent 1dcd85a commit f8e8a5e

2 files changed

Lines changed: 4 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ changes that do not affect the user.
1919
- Refactored internal verifications in the autojac engine so that they do not run at runtime
2020
anymore. This should minimally improve the performance and reduce the memory usage of `backward`
2121
and `mtl_backward`.
22+
- Improved the implementation of `ConFIG` to be simpler and safer when normalizing vectors. It
23+
should slightly improve the performance of `ConFIG` and minimally affect its behavior.
2224

2325
### Fixed
2426

src/torchjd/aggregation/config.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,9 @@ def forward(self, matrix: Tensor) -> Tensor:
7070
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
7171
best_direction = torch.linalg.pinv(units) @ weights
7272

73-
if best_direction.norm() == 0:
74-
unit_target_vector = torch.zeros_like(best_direction)
75-
else:
76-
unit_target_vector = best_direction / best_direction.norm()
73+
unit_target_vector = torch.nn.functional.normalize(best_direction, dim=0)
7774

78-
length = torch.sum(torch.stack([torch.dot(grad, unit_target_vector) for grad in matrix]))
75+
length = torch.sum(matrix @ unit_target_vector)
7976

8077
return length * unit_target_vector
8178

0 commit comments

Comments
 (0)