Skip to content

Commit 0b85205

Browse files
authored
fix(autojac): Fix non-deterministic column order (#283)
* Replace the ordered_set function and the _OrderedSet TypeVar in _utils.py by the OrderedSet class in ordered_set.py * Add difference_update and add to OrderedSet * Change _get_descendant_accumulate_grads and _get_leaf_tensors to work with OrderedSets * Make backward and mtl_backward use OrderedSet instead of set or list for tensors to differentiate. This should make the _AggregateMatrices transform use a deterministic key_order, which should fix the column order before aggregation. * Add changelog entry Note: we did not verify that there was some non-determinism in the column ordering, and we think that with mtl_backward it could only happen when the parameters were not specified by the user. Still, this should make things much safer.
1 parent 33fc0c9 commit 0b85205

10 files changed

Lines changed: 69 additions & 39 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ changes that do not affect the user.
2121
and `mtl_backward`.
2222

2323
### Fixed
24+
2425
- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they
2526
appear several times in a list of tensors provided as argument). Instead of raising an exception
2627
in these cases, we are now aligned with the behavior of `torch.autograd.backward`. Repeated
2728
tensors that we differentiate lead to repeated rows in the Jacobian, prior to aggregation, and
2829
repeated tensors with respect to which we differentiate count only once.
30+
- Fixed an issue with `backward` and `mtl_backward` that could make the ordering of the columns of
31+
the Jacobians non-deterministic, and that could thus lead to slightly non-deterministic results
32+
with some aggregators.
2933
- Removed arbitrary exception handling in `IMTLG` and `AlignedMTL` when the computation fails. In
3034
practice, this fix should only affect some matrices with extremely large values, which should
3135
not usually happen.

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from torch import Tensor
55

6-
from ._utils import ordered_set
76
from .base import _A, Transform
7+
from .ordered_set import OrderedSet
88

99

1010
class _Differentiate(Transform[_A, _A], ABC):
@@ -16,7 +16,7 @@ def __init__(
1616
create_graph: bool,
1717
):
1818
self.outputs = list(outputs)
19-
self.inputs = ordered_set(inputs)
19+
self.inputs = OrderedSet(inputs)
2020
self.retain_graph = retain_graph
2121
self.create_graph = create_graph
2222

src/torchjd/autojac/_transform/_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import OrderedDict
2-
from typing import Hashable, Iterable, Sequence, TypeAlias, TypeVar
1+
from typing import Hashable, Iterable, Sequence, TypeVar
32

43
import torch
54
from torch import Tensor
@@ -8,17 +7,12 @@
87

98
_KeyType = TypeVar("_KeyType", bound=Hashable)
109
_ValueType = TypeVar("_ValueType")
11-
_OrderedSet: TypeAlias = OrderedDict[_KeyType, None]
1210

1311
_A = TypeVar("_A", bound=TensorDict)
1412
_B = TypeVar("_B", bound=TensorDict)
1513
_C = TypeVar("_C", bound=TensorDict)
1614

1715

18-
def ordered_set(elements: Iterable[_KeyType]) -> _OrderedSet[_KeyType]:
19-
return OrderedDict.fromkeys(elements, None)
20-
21-
2216
def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
2317
result = {}
2418
for d in dicts:

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from torchjd.aggregation import Aggregator
88

9-
from ._utils import _OrderedSet, ordered_set
109
from .base import Transform
10+
from .ordered_set import OrderedSet
1111
from .tensor_dict import EmptyTensorDict, Gradients, GradientVectors, JacobianMatrices, Jacobians
1212

1313
_KeyType = TypeVar("_KeyType", bound=Hashable)
@@ -32,7 +32,7 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
3232

3333
class _AggregateMatrices(Transform[JacobianMatrices, GradientVectors]):
3434
def __init__(self, aggregator: Aggregator, key_order: Iterable[Tensor]):
35-
self.key_order = ordered_set(key_order)
35+
self.key_order = OrderedSet(key_order)
3636
self.aggregator = aggregator
3737

3838
def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
@@ -54,7 +54,7 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
5454

5555
@staticmethod
5656
def _select_ordered_subdict(
57-
dictionary: dict[_KeyType, _ValueType], ordered_keys: _OrderedSet[_KeyType]
57+
dictionary: dict[_KeyType, _ValueType], ordered_keys: OrderedSet[_KeyType]
5858
) -> OrderedDict[_KeyType, _ValueType]:
5959
"""
6060
Selects a subset of a dictionary corresponding to the keys given by ``ordered_keys``.

src/torchjd/autojac/_transform/diagonalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import torch
44
from torch import Tensor
55

6-
from ._utils import ordered_set
76
from .base import Transform
7+
from .ordered_set import OrderedSet
88
from .tensor_dict import Gradients, Jacobians
99

1010

1111
class Diagonalize(Transform[Gradients, Jacobians]):
1212
def __init__(self, considered: Iterable[Tensor]):
13-
self.considered = ordered_set(considered)
13+
self.considered = OrderedSet(considered)
1414
self.indices: list[tuple[int, int]] = []
1515
begin = 0
1616
for tensor in self.considered:
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from collections import OrderedDict
2+
from typing import Iterable
3+
4+
from torchjd.autojac._transform._utils import _KeyType
5+
6+
7+
class OrderedSet(OrderedDict[_KeyType, None]):
8+
"""Ordered collection of distinct elements."""
9+
10+
def __init__(self, elements: Iterable[_KeyType]):
11+
super().__init__([(element, None) for element in elements])
12+
13+
def difference_update(self, elements: set[_KeyType]) -> None:
14+
"""Removes all specified elements from the OrderedSet."""
15+
16+
for element in elements:
17+
if element in self:
18+
del self[element]
19+
20+
def add(self, element: _KeyType) -> None:
21+
"""Adds the specified element to the OrderedSet."""
22+
23+
self[element] = None

src/torchjd/autojac/_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch import Tensor
55
from torch.autograd.graph import Node
66

7+
from ._transform.ordered_set import OrderedSet
8+
79

810
def _check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
911
if not (parallel_chunk_size is None or parallel_chunk_size > 0):
@@ -21,7 +23,7 @@ def _as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]:
2123
return output
2224

2325

24-
def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> set[Tensor]:
26+
def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]:
2527
"""
2628
Gets the leaves of the autograd graph of all specified ``tensors``.
2729
@@ -39,15 +41,17 @@ def _get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) ->
3941
raise ValueError("All `excluded` tensors should have a `grad_fn`.")
4042

4143
accumulate_grads = _get_descendant_accumulate_grads(
42-
roots={tensor.grad_fn for tensor in tensors},
44+
roots=OrderedSet([tensor.grad_fn for tensor in tensors]),
4345
excluded_nodes={tensor.grad_fn for tensor in excluded},
4446
)
45-
leaves = {g.variable for g in accumulate_grads}
47+
leaves = OrderedSet([g.variable for g in accumulate_grads])
4648

4749
return leaves
4850

4951

50-
def _get_descendant_accumulate_grads(roots: set[Node], excluded_nodes: set[Node]) -> set[Node]:
52+
def _get_descendant_accumulate_grads(
53+
roots: OrderedSet[Node], excluded_nodes: set[Node]
54+
) -> OrderedSet[Node]:
5155
"""
5256
Gets the AccumulateGrad descendants of the specified nodes.
5357
@@ -56,8 +60,9 @@ def _get_descendant_accumulate_grads(roots: set[Node], excluded_nodes: set[Node]
5660
"""
5761

5862
excluded_nodes = set(excluded_nodes) # Re-instantiate set to avoid modifying input
59-
result = set()
60-
nodes_to_traverse = deque(roots - excluded_nodes)
63+
result = OrderedSet([])
64+
roots.difference_update(excluded_nodes)
65+
nodes_to_traverse = deque(roots)
6166

6267
# This implementation more or less follows what is advised in
6368
# https://discuss.pytorch.org/t/autograd-graph-traversal/213658 and what was suggested in

src/torchjd/autojac/backward.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchjd.aggregation import Aggregator
66

77
from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform
8+
from ._transform.ordered_set import OrderedSet
89
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
910

1011

@@ -76,7 +77,7 @@ def backward(
7677
if inputs is None:
7778
inputs = _get_leaf_tensors(tensors=tensors, excluded=set())
7879
else:
79-
inputs = set(inputs)
80+
inputs = OrderedSet(inputs)
8081

8182
backward_transform = _create_transform(
8283
tensors=tensors,
@@ -92,7 +93,7 @@ def backward(
9293
def _create_transform(
9394
tensors: list[Tensor],
9495
aggregator: Aggregator,
95-
inputs: set[Tensor],
96+
inputs: OrderedSet[Tensor],
9697
retain_graph: bool,
9798
parallel_chunk_size: int | None,
9899
) -> Transform[EmptyTensorDict, EmptyTensorDict]:

src/torchjd/autojac/mtl_backward.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Stack,
1717
Transform,
1818
)
19+
from ._transform.ordered_set import OrderedSet
1920
from ._utils import _as_tensor_list, _check_optional_positive_chunk_size, _get_leaf_tensors
2021

2122

@@ -84,8 +85,12 @@ def mtl_backward(
8485

8586
if shared_params is None:
8687
shared_params = _get_leaf_tensors(tensors=features, excluded=[])
88+
else:
89+
shared_params = OrderedSet(shared_params)
8790
if tasks_params is None:
8891
tasks_params = [_get_leaf_tensors(tensors=[loss], excluded=features) for loss in losses]
92+
else:
93+
tasks_params = [OrderedSet(task_params) for task_params in tasks_params]
8994

9095
if len(features) == 0:
9196
raise ValueError("`features` cannot be empty.")
@@ -115,8 +120,8 @@ def _create_transform(
115120
losses: Sequence[Tensor],
116121
features: list[Tensor],
117122
aggregator: Aggregator,
118-
tasks_params: list[Iterable[Tensor]],
119-
shared_params: set[Tensor],
123+
tasks_params: list[OrderedSet[Tensor]],
124+
shared_params: OrderedSet[Tensor],
120125
retain_graph: bool,
121126
parallel_chunk_size: int | None,
122127
) -> Transform[EmptyTensorDict, EmptyTensorDict]:
@@ -126,9 +131,6 @@ def _create_transform(
126131
task-specific parameters).
127132
"""
128133

129-
shared_params = list(shared_params)
130-
tasks_params = [list(task_params) for task_params in tasks_params]
131-
132134
# Task-specific transforms. Each of them computes and accumulates the gradient of the task's
133135
# loss w.r.t. the task's specific parameters, and computes and backpropagates the gradient of
134136
# the losses w.r.t. the shared representations.
@@ -160,12 +162,13 @@ def _create_transform(
160162

161163
def _create_task_transform(
162164
features: list[Tensor],
163-
task_params: list[Tensor],
165+
task_params: OrderedSet[Tensor],
164166
loss: Tensor,
165167
retain_graph: bool,
166168
) -> Transform[EmptyTensorDict, Gradients]:
167169
# Tensors with respect to which we compute the gradients.
168-
to_differentiate = task_params + features
170+
to_differentiate = OrderedSet(task_params) # Re-instantiate set to avoid modifying input
171+
to_differentiate.update(OrderedSet(features))
169172

170173
# Transform that initializes the gradient output to 1.
171174
init = Init([loss])

tests/unit/autojac/test_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_simple_get_leaf_tensors():
1515
y2 = (a1**2).sum() + a2.norm()
1616

1717
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set())
18-
assert leaves == {a1, a2}
18+
assert set(leaves) == {a1, a2}
1919

2020

2121
def test_get_leaf_tensors_excluded_1():
@@ -36,7 +36,7 @@ def test_get_leaf_tensors_excluded_1():
3636
y2 = b1
3737

3838
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
39-
assert leaves == {a1}
39+
assert set(leaves) == {a1}
4040

4141

4242
def test_get_leaf_tensors_excluded_2():
@@ -57,7 +57,7 @@ def test_get_leaf_tensors_excluded_2():
5757
y2 = b1
5858

5959
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={b1, b2})
60-
assert leaves == {a1, a2}
60+
assert set(leaves) == {a1, a2}
6161

6262

6363
def test_get_leaf_tensors_leaf_not_requiring_grad():
@@ -72,7 +72,7 @@ def test_get_leaf_tensors_leaf_not_requiring_grad():
7272
y2 = (a1**2).sum() + a2.norm()
7373

7474
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded=set())
75-
assert leaves == {a1}
75+
assert set(leaves) == {a1}
7676

7777

7878
def test_get_leaf_tensors_model():
@@ -91,7 +91,7 @@ def test_get_leaf_tensors_model():
9191
losses = loss_fn(y_hat, y)
9292

9393
leaves = _get_leaf_tensors(tensors=[losses], excluded=set())
94-
assert leaves == set(model.parameters())
94+
assert set(leaves) == set(model.parameters())
9595

9696

9797
def test_get_leaf_tensors_model_excluded_2():
@@ -112,7 +112,7 @@ def test_get_leaf_tensors_model_excluded_2():
112112
losses = loss_fn(z_hat, z)
113113

114114
leaves = _get_leaf_tensors(tensors=[losses], excluded={y})
115-
assert leaves == set(model2.parameters())
115+
assert set(leaves) == set(model2.parameters())
116116

117117

118118
def test_get_leaf_tensors_single_root():
@@ -122,14 +122,14 @@ def test_get_leaf_tensors_single_root():
122122
y = p * 2
123123

124124
leaves = _get_leaf_tensors(tensors=[y], excluded=set())
125-
assert leaves == {p}
125+
assert set(leaves) == {p}
126126

127127

128128
def test_get_leaf_tensors_empty_roots():
129129
"""Tests that _get_leaf_tensors returns no leaves when roots is the empty set."""
130130

131131
leaves = _get_leaf_tensors(tensors=[], excluded=set())
132-
assert leaves == set()
132+
assert set(leaves) == set({})
133133

134134

135135
def test_get_leaf_tensors_excluded_root():
@@ -142,7 +142,7 @@ def test_get_leaf_tensors_excluded_root():
142142
y2 = (a1**2).sum()
143143

144144
leaves = _get_leaf_tensors(tensors=[y1, y2], excluded={y1})
145-
assert leaves == {a1}
145+
assert set(leaves) == {a1}
146146

147147

148148
@mark.parametrize("depth", [100, 1000, 10000])
@@ -155,7 +155,7 @@ def test_get_leaf_tensors_deep(depth: int):
155155
sum_ = sum_ + one
156156

157157
leaves = _get_leaf_tensors(tensors=[sum_], excluded=set())
158-
assert leaves == {one}
158+
assert set(leaves) == {one}
159159

160160

161161
def test_get_leaf_tensors_leaf():

0 commit comments

Comments
 (0)