Skip to content

Commit f44da13

Browse files
authored
Merge branch 'main' into revamp-config
2 parents 083a6b4 + 1dcd85a commit f44da13

30 files changed

Lines changed: 337 additions & 281 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
- id: check-merge-conflict # Check for files that contain merge conflict strings.
1111

1212
- repo: https://github.com/PyCQA/flake8
13-
rev: 7.1.2
13+
rev: 7.2.0
1414
hooks:
1515
- id: flake8 # Check style and syntax. Does not modify code, issues have to be solved manually.
1616
args: [

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ changes that do not affect the user.
2121
and `mtl_backward`.
2222

2323
### Fixed
24+
2425
- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they
2526
appear several times in a list of tensors provided as argument). Instead of raising an exception
2627
in these cases, we are now aligned with the behavior of `torch.autograd.backward`. Repeated
2728
tensors that we differentiate lead to repeated rows in the Jacobian, prior to aggregation, and
2829
repeated tensors with respect to which we differentiate count only once.
30+
- Fixed an issue with `backward` and `mtl_backward` that could make the ordering of the columns of
31+
the Jacobians non-deterministic, and that could thus lead to slightly non-deterministic results
32+
with some aggregators.
2933
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
3034
practice, this fix should only affect some matrices with extremely large values, which should
3135
not usually happen.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/TorchJD/torchjd/main.svg)](https://results.pre-commit.ci/latest/github/TorchJD/torchjd/main)
77
[![PyPI - Downloads](https://img.shields.io/pypi/dm/torchjd)](https://pypistats.org/packages/torchjd)
88
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchjd)](https://pypi.org/project/torchjd/)
9+
[![Static Badge](https://img.shields.io/badge/Discord%20-%20community%20-%20%235865F2?logo=discord&logoColor=%23FFFFFF&label=Discord)](https://discord.gg/76KkRnb3nk)
910

1011
TorchJD is a library extending autograd to enable
1112
[Jacobian descent](https://arxiv.org/pdf/2406.16232) with PyTorch. It can be used to train neural

src/torchjd/autojac/_transform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .accumulate import Accumulate
22
from .aggregate import Aggregate
3-
from .base import Composition, Conjunction, Transform
3+
from .base import Composition, Conjunction, RequirementError, Transform
44
from .diagonalize import Diagonalize
55
from .grad import Grad
66
from .init import Init

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from torch import Tensor
55

6-
from ._utils import ordered_set
7-
from .base import _A, Transform
6+
from .base import _A, RequirementError, Transform
7+
from .ordered_set import OrderedSet
88

99

1010
class _Differentiate(Transform[_A, _A], ABC):
@@ -16,7 +16,7 @@ def __init__(
1616
create_graph: bool,
1717
):
1818
self.outputs = list(outputs)
19-
self.inputs = ordered_set(inputs)
19+
self.inputs = OrderedSet(inputs)
2020
self.retain_graph = retain_graph
2121
self.create_graph = create_graph
2222

@@ -38,6 +38,11 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]
3838
tensor_outputs should be.
3939
"""
4040

41-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
42-
# outputs in the forward direction become inputs in the backward direction, and vice-versa
43-
return set(self.outputs), set(self.inputs)
41+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
42+
outputs = set(self.outputs)
43+
if not outputs == input_keys:
44+
raise RequirementError(
45+
f"The input_keys must match the expected outputs. Found input_keys {input_keys} and"
46+
f"outputs {outputs}."
47+
)
48+
return set(self.inputs)

src/torchjd/autojac/_transform/_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import OrderedDict
2-
from typing import Hashable, Iterable, Sequence, TypeAlias, TypeVar
1+
from typing import Hashable, Iterable, Sequence, TypeVar
32

43
import torch
54
from torch import Tensor
@@ -8,17 +7,12 @@
87

98
_KeyType = TypeVar("_KeyType", bound=Hashable)
109
_ValueType = TypeVar("_ValueType")
11-
_OrderedSet: TypeAlias = OrderedDict[_KeyType, None]
1210

1311
_A = TypeVar("_A", bound=TensorDict)
1412
_B = TypeVar("_B", bound=TensorDict)
1513
_C = TypeVar("_C", bound=TensorDict)
1614

1715

18-
def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
19-
return OrderedDict.fromkeys(elements, None)
20-
21-
2216
def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
2317
result = {}
2418
for d in dicts:

src/torchjd/autojac/_transform/accumulate.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1-
from typing import Iterable
2-
31
from torch import Tensor
42

53
from .base import Transform
64
from .tensor_dict import EmptyTensorDict, Gradients
75

86

97
class Accumulate(Transform[Gradients, EmptyTensorDict]):
10-
def __init__(self, required_keys: Iterable[Tensor]):
11-
self._required_keys = set(required_keys)
12-
138
def __call__(self, gradients: Gradients) -> EmptyTensorDict:
149
"""
1510
Accumulates gradients with respect to keys in their ``.grad`` field.
@@ -28,8 +23,8 @@ def __call__(self, gradients: Gradients) -> EmptyTensorDict:
2823

2924
return EmptyTensorDict()
3025

31-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
32-
return self._required_keys, set()
26+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
27+
return set()
3328

3429

3530
def _check_expects_grad(tensor: Tensor) -> None:

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
11
from collections import OrderedDict
2-
from typing import Hashable, Iterable, TypeVar
2+
from typing import Hashable, TypeVar
33

44
import torch
55
from torch import Tensor
66

77
from torchjd.aggregation import Aggregator
88

9-
from ._utils import _OrderedSet, ordered_set
10-
from .base import Transform
9+
from .base import RequirementError, Transform
10+
from .ordered_set import OrderedSet
1111
from .tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians
1212

1313
_KeyType = TypeVar("_KeyType", bound=Hashable)
1414
_ValueType = TypeVar("_ValueType")
1515

1616

1717
class Aggregate(Transform[Jacobians, Gradients]):
18-
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
19-
matrixify = _Matrixify(key_order)
18+
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
19+
matrixify = _Matrixify()
2020
aggregate_matrices = _AggregateMatrices(aggregator, key_order)
21-
reshape = _Reshape(key_order)
21+
reshape = _Reshape()
2222

2323
self._aggregator_str = str(aggregator)
2424
self.transform = reshape << aggregate_matrices << matrixify
2525

2626
def __call__(self, input: Jacobians) -> Gradients:
2727
return self.transform(input)
2828

29-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
30-
return self.transform.check_and_get_keys()
29+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
30+
return self.transform.check_keys(input_keys)
3131

3232

3333
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
34-
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
35-
self.key_order = ordered_set(key_order)
34+
def __init__(self, aggregator: Aggregator, key_order: OrderedSet[Tensor]):
35+
self.key_order = OrderedSet(key_order)
3636
self.aggregator = aggregator
3737

3838
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
@@ -48,13 +48,17 @@ def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
4848
ordered_matrices = self._select_ordered_subdict(jacobian_matrices, self.key_order)
4949
return self._aggregate_group(ordered_matrices, self.aggregator)
5050

51-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
52-
keys = set(self.key_order)
53-
return keys, keys
51+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
52+
if not set(self.key_order) == input_keys:
53+
raise RequirementError(
54+
f"The input_keys must match the key_order. Found input_keys {input_keys} and"
55+
f"key_order {self.key_order}."
56+
)
57+
return input_keys
5458

5559
@staticmethod
5660
def _select_ordered_subdict(
57-
dictionary: dict[_KeyType, _ValueType], ordered_keys: _OrderedSet[_KeyType]
61+
dictionary: dict[_KeyType, _ValueType], ordered_keys: OrderedSet[_KeyType]
5862
) -> OrderedDict[_KeyType, _ValueType]:
5963
"""
6064
Selects a subset of a dictionary corresponding to the keys given by ``ordered_keys``.
@@ -108,29 +112,23 @@ def _disunite(
108112

109113

110114
class _Matrixify(Transform[Jacobians, JacobianMatrices]):
111-
def __init__(self, required_keys: Iterable[Tensor]):
112-
self._required_keys = set(required_keys)
113-
114115
def __call__(self, jacobians: Jacobians) -> JacobianMatrices:
115116
jacobian_matrices = {
116117
key: jacobian.view(jacobian.shape[0], -1) for key, jacobian in jacobians.items()
117118
}
118119
return JacobianMatrices(jacobian_matrices)
119120

120-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
121-
return self._required_keys, self._required_keys
121+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
122+
return input_keys
122123

123124

124125
class _Reshape(Transform[GradientVectors, Gradients]):
125-
def __init__(self, required_keys: Iterable[Tensor]):
126-
self._required_keys = set(required_keys)
127-
128126
def __call__(self, gradient_vectors: GradientVectors) -> Gradients:
129127
gradients = {
130128
key: gradient_vector.view(key.shape)
131129
for key, gradient_vector in gradient_vectors.items()
132130
}
133131
return Gradients(gradients)
134132

135-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
136-
return self._required_keys, self._required_keys
133+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
134+
return input_keys

src/torchjd/autojac/_transform/base.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from ._utils import _A, _B, _C, _union
99

1010

11+
class RequirementError(ValueError):
12+
"""Inappropriate set of inputs keys."""
13+
14+
pass
15+
16+
1117
class Transform(Generic[_B, _C], ABC):
1218
r"""
1319
Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
@@ -44,17 +50,16 @@ def __call__(self, input: _B) -> _C:
4450
"""Applies the transform to the input."""
4551

4652
@abstractmethod
47-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
53+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
4854
"""
49-
Returns a pair containing (in order) the required keys and the output keys of the Transform
50-
and recursively checks that the transform is valid.
55+
Checks that the provided input_keys satisfy the transform's requirements and returns the
56+
corresponding output keys for recursion.
5157
52-
The required keys are the set of keys that the transform requires to be present in its input
53-
TensorDicts. The output keys are the set of keys that will be present in the output
54-
TensorDicts of the transform.
58+
If the provided input_keys do not satisfy the transform's requirements, raises a
59+
RequirementError.
5560
56-
Since the computation of the required and output keys and the verification that the
57-
transform is valid are sometimes intertwined operations, we do them in a single method.
61+
The output keys are the set of keys of the output TensorDict of the transform when the input
62+
TensorDict's keys are input_keys.
5863
"""
5964

6065
__lshift__ = compose
@@ -73,15 +78,10 @@ def __call__(self, input: _A) -> _C:
7378
intermediate = self.inner(input)
7479
return self.outer(intermediate)
7580

76-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
77-
outer_required_keys, outer_output_keys = self.outer.check_and_get_keys()
78-
inner_required_keys, inner_output_keys = self.inner.check_and_get_keys()
79-
if outer_required_keys != inner_output_keys:
80-
raise ValueError(
81-
"The `output_keys` of `inner` must match with the `required_keys` of "
82-
f"outer. Found {outer_required_keys} and {inner_output_keys}"
83-
)
84-
return inner_required_keys, outer_output_keys
81+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
82+
intermediate_keys = self.inner.check_keys(input_keys)
83+
output_keys = self.outer.check_keys(intermediate_keys)
84+
return output_keys
8585

8686

8787
class Conjunction(Transform[_A, _B]):
@@ -102,18 +102,11 @@ def __call__(self, tensor_dict: _A) -> _B:
102102
output = _union([transform(tensor_dict) for transform in self.transforms])
103103
return output
104104

105-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
106-
keys_pairs = [transform.check_and_get_keys() for transform in self.transforms]
107-
108-
required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
109-
for transform_required_keys, _ in keys_pairs:
110-
if transform_required_keys != required_keys:
111-
raise ValueError("All transforms should require the same set of keys.")
112-
113-
output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys]
114-
output_keys = set(output_keys_with_duplicates)
105+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
106+
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]
107+
output_keys = set(output_keys_list)
115108

116-
if len(output_keys) != len(output_keys_with_duplicates):
117-
raise ValueError("The sets of output keys of transforms should be disjoint.")
109+
if len(output_keys) != len(output_keys_list):
110+
raise RequirementError("The sets of output keys of transforms should be disjoint.")
118111

119-
return required_keys, output_keys
112+
return output_keys

src/torchjd/autojac/_transform/diagonalize.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import torch
44
from torch import Tensor
55

6-
from ._utils import ordered_set
7-
from .base import Transform
6+
from .base import RequirementError, Transform
7+
from .ordered_set import OrderedSet
88
from .tensor_dict import Gradients, Jacobians
99

1010

1111
class Diagonalize(Transform[Gradients, Jacobians]):
1212
def __init__(self, considered: Iterable[Tensor]):
13-
self.considered = ordered_set(considered)
13+
self.considered = OrderedSet(considered)
1414
self.indices: list[tuple[int, int]] = []
1515
begin = 0
1616
for tensor in self.considered:
@@ -27,6 +27,11 @@ def __call__(self, tensors: Gradients) -> Jacobians:
2727
}
2828
return Jacobians(diagonalized_tensors)
2929

30-
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
31-
keys = set(self.considered)
32-
return keys, keys
30+
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
31+
considered = set(self.considered)
32+
if not considered == input_keys:
33+
raise RequirementError(
34+
f"The input_keys must match the considered keys. Found input_keys {input_keys} and"
35+
f"considered keys {considered}."
36+
)
37+
return considered

0 commit comments

Comments
 (0)