Skip to content

Commit 5e6db83

Browse files
authored
Merge branch 'main' into fix-diagonalize-key-order
2 parents bb69210 + f8e8a5e commit 5e6db83

4 files changed

Lines changed: 6 additions & 6 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
- id: check-merge-conflict # Check for files that contain merge conflict strings.
1111

1212
- repo: https://github.com/PyCQA/flake8
13-
rev: 7.1.2
13+
rev: 7.2.0
1414
hooks:
1515
- id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually.
1616
args: [

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ changes that do not affect the user.
1919
- Refactored internal verifications in the autojac engine so that they do not run at runtime
2020
anymore. This should minimally improve the performance and reduce the memory usage of `backward`
2121
and `mtl_backward`.
22+
- Improved the implementation of `ConFIG` to be simpler and safer when normalizing vectors. It
23+
should slightly improve the performance of `ConFIG` and minimally affect its behavior.
2224

2325
### Fixed
2426

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TorchJD/torchjd/main.svg)](https://results.pre-commit.ci/latest/github/TorchJD/torchjd/main)
77
[![PyPI - Downloads](https://img.shields.io/pypi/dm/torchjd)](https://pypistats.org/packages/torchjd)
88
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchjd)](https://pypi.org/project/torchjd/)
9+
[![Static Badge](https://img.shields.io/badge/Discord%20-%20community%20-%20%235865F2?logo=discord&logoColor=%23FFFFFF&label=Discord)](https://discord.gg/76KkRnb3nk)
910

1011
TorchJD is a library extending autograd to enable
1112
[Jacobian descent](https://arxiv.org/pdf/2406.16232) with PyTorch. It can be used to train neural

src/torchjd/aggregation/config.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,9 @@ def forward(self, matrix: Tensor) -> Tensor:
7070
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
7171
best_direction = torch.linalg.pinv(units) @ weights
7272

73-
if best_direction.norm() == 0:
74-
unit_target_vector = torch.zeros_like(best_direction)
75-
else:
76-
unit_target_vector = best_direction / best_direction.norm()
73+
unit_target_vector = torch.nn.functional.normalize(best_direction, dim=0)
7774

78-
length = torch.sum(torch.stack([torch.dot(grad, unit_target_vector) for grad in matrix]))
75+
length = torch.sum(matrix @ unit_target_vector)
7976

8077
return length * unit_target_vector
8178

0 commit comments

Comments
 (0)