Skip to content

Commit 2b26100

Browse files
refactor(autojac): Improve typing (#312)
* Make OrderedSet inherit from collections.abc.Set * Make Init take Set[Tensor] keys * Make Select take Set[Tensor] keys * Make Differentiate, Jac and Grad take OrderedSet[Tensor] outputs and inputs * Move the cast to list of outputs and inputs from Jac._differentiate and Grad._differentiate to Differentiate.__init__ * Change as_tensor_list to as_checked_ordered_set and make it check that the provided tensors are unique * Stop allowing repeated tensors to differentiate in backward and mtl_backward * Remove changelog entry saying that we allow repeated outputs * Change the implementation backward and mtl_backward to reduce the number of casts and use only OrderedSets * Add changelog entry saying that we refactored internal types and that this should improve performance --------- Co-authored-by: Pierre Quinton <pierre.quinton@gmail.com>
1 parent 739757d commit 2b26100

File tree

17 files changed

+205
-190
lines changed

17 files changed

+205
-190
lines changed

CHANGELOG.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,13 @@ changes that do not affect the user.
1919
- Refactored internal verifications in the autojac engine so that they do not run at runtime
2020
anymore. This should minimally improve the performance and reduce the memory usage of `backward`
2121
and `mtl_backward`.
22+
- Refactored internal typing in the autojac engine so that fewer casts are made and so that code is
23+
simplified. This should slightly improve the performance of `backward` and `mtl_backward`.
2224
- Improved the implementation of `ConFIG` to be simpler and safer when normalizing vectors. It
2325
should slightly improve the performance of `ConFIG` and minimally affect its behavior.
2426

2527
### Fixed
2628

27-
- Fixed the behavior of `backward` and `mtl_backward` when some tensors are repeated (i.e. when they
28-
appear several times in a list of tensors provided as argument). Instead of raising an exception
29-
in these cases, we are now aligned with the behavior of `torch.autograd.backward`. Repeated
30-
tensors that we differentiate lead to repeated rows in the Jacobian, prior to aggregation, and
31-
repeated tensors with respect to which we differentiate count only once.
3229
- Fixed an issue with `backward` and `mtl_backward` that could make the ordering of the columns of
3330
the Jacobians non-deterministic, and that could thus lead to slightly non-deterministic results
3431
with some aggregators.

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Iterable, Sequence
2+
from typing import Sequence
33

44
from torch import Tensor
55

@@ -11,13 +11,17 @@
1111
class Differentiate(Transform[_A, _A], ABC):
1212
def __init__(
1313
self,
14-
outputs: Iterable[Tensor],
15-
inputs: Iterable[Tensor],
14+
outputs: OrderedSet[Tensor],
15+
inputs: OrderedSet[Tensor],
1616
retain_graph: bool,
1717
create_graph: bool,
1818
):
19+
# The order of outputs and inputs only matters because we have no guarantee that
20+
# torch.autograd.grad is *exactly* equivariant to input permutations and invariant to
21+
# output (with their corresponding grad_output) permutations.
22+
1923
self.outputs = list(outputs)
20-
self.inputs = OrderedSet(inputs)
24+
self.inputs = list(inputs)
2125
self.retain_graph = retain_graph
2226
self.create_graph = create_graph
2327

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from typing import Iterable, Sequence
1+
from typing import Sequence
22

33
import torch
44
from torch import Tensor
55

66
from ._differentiate import Differentiate
77
from ._materialize import materialize
8+
from .ordered_set import OrderedSet
89
from .tensor_dict import Gradients
910

1011

1112
class Grad(Differentiate[Gradients]):
1213
def __init__(
1314
self,
14-
outputs: Iterable[Tensor],
15-
inputs: Iterable[Tensor],
15+
outputs: OrderedSet[Tensor],
16+
inputs: OrderedSet[Tensor],
1617
retain_graph: bool = False,
1718
create_graph: bool = False,
1819
):
@@ -30,22 +31,19 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
3031
the same shape as the corresponding output.
3132
"""
3233

33-
outputs = list(self.outputs)
34-
inputs = list(self.inputs)
35-
36-
if len(inputs) == 0:
34+
if len(self.inputs) == 0:
3735
return tuple()
3836

39-
if len(outputs) == 0:
40-
return tuple([torch.zeros_like(input) for input in inputs])
37+
if len(self.outputs) == 0:
38+
return tuple([torch.zeros_like(input) for input in self.inputs])
4139

4240
optional_grads = torch.autograd.grad(
43-
outputs,
44-
inputs,
41+
self.outputs,
42+
self.inputs,
4543
grad_outputs=grad_outputs,
4644
retain_graph=self.retain_graph,
4745
create_graph=self.create_graph,
4846
allow_unused=True,
4947
)
50-
grads = materialize(optional_grads, inputs)
48+
grads = materialize(optional_grads, self.inputs)
5149
return grads

src/torchjd/autojac/_transform/init.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable
1+
from collections.abc import Set
22

33
import torch
44
from torch import Tensor
@@ -8,8 +8,8 @@
88

99

1010
class Init(Transform[EmptyTensorDict, Gradients]):
11-
def __init__(self, values: Iterable[Tensor]):
12-
self.values = set(values)
11+
def __init__(self, values: Set[Tensor]):
12+
self.values = values
1313

1414
def __call__(self, input: EmptyTensorDict) -> Gradients:
1515
r"""
@@ -26,4 +26,4 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
2626
raise RequirementError(
2727
f"The input_keys should be the empty set. Found input_keys {input_keys}."
2828
)
29-
return self.values
29+
return set(self.values)

src/torchjd/autojac/_transform/jac.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import math
22
from functools import partial
33
from itertools import accumulate
4-
from typing import Callable, Iterable, Sequence
4+
from typing import Callable, Sequence
55

66
import torch
77
from torch import Size, Tensor
88

99
from ._differentiate import Differentiate
1010
from ._materialize import materialize
11+
from .ordered_set import OrderedSet
1112
from .tensor_dict import Jacobians
1213

1314

1415
class Jac(Differentiate[Jacobians]):
1516
def __init__(
1617
self,
17-
outputs: Iterable[Tensor],
18-
inputs: Iterable[Tensor],
18+
outputs: OrderedSet[Tensor],
19+
inputs: OrderedSet[Tensor],
1920
chunk_size: int | None,
2021
retain_graph: bool = False,
2122
create_graph: bool = False,
@@ -37,30 +38,27 @@ def _differentiate(self, jac_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
3738
jac_outputs.
3839
"""
3940

40-
outputs = list(self.outputs)
41-
inputs = list(self.inputs)
42-
43-
if len(inputs) == 0:
41+
if len(self.inputs) == 0:
4442
return tuple()
4543

46-
if len(outputs) == 0:
44+
if len(self.outputs) == 0:
4745
return tuple(
4846
[
4947
torch.empty((0,) + input.shape, device=input.device, dtype=input.dtype)
50-
for input in inputs
48+
for input in self.inputs
5149
]
5250
)
5351

5452
def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
5553
optional_grads = torch.autograd.grad(
56-
outputs,
57-
inputs,
54+
self.outputs,
55+
self.inputs,
5856
grad_outputs=grad_outputs,
5957
retain_graph=retain_graph,
6058
create_graph=self.create_graph,
6159
allow_unused=True,
6260
)
63-
grads = materialize(optional_grads, inputs=inputs)
61+
grads = materialize(optional_grads, inputs=self.inputs)
6462
return torch.concatenate([grad.reshape([-1]) for grad in grads])
6563

6664
# By the Jacobians constraint, this value should be the same for all jac_outputs.
@@ -86,10 +84,10 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
8684
jac_matrix_chunks.append(_get_jac_matrix_chunk(jac_outputs_chunk, get_vjp_last))
8785

8886
jac_matrix = torch.vstack(jac_matrix_chunks)
89-
lengths = [input.numel() for input in inputs]
87+
lengths = [input.numel() for input in self.inputs]
9088
jac_matrices = _extract_sub_matrices(jac_matrix, lengths)
9189

92-
shapes = [input.shape for input in inputs]
90+
shapes = [input.shape for input in self.inputs]
9391
jacs = _reshape_matrices(jac_matrices, shapes)
9492

9593
return tuple(jacs)

src/torchjd/autojac/_transform/ordered_set.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from collections import OrderedDict
2+
from collections.abc import Set
23
from typing import Hashable, Iterable, TypeVar
34

45
_KeyType = TypeVar("_KeyType", bound=Hashable)
56

67

7-
class OrderedSet(OrderedDict[_KeyType, None]):
8+
class OrderedSet(OrderedDict[_KeyType, None], Set[_KeyType]):
89
"""Ordered collection of distinct elements."""
910

1011
def __init__(self, elements: Iterable[_KeyType]):
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable
1+
from collections.abc import Set
22

33
from torch import Tensor
44

@@ -7,17 +7,18 @@
77

88

99
class Select(Transform[_A, _A]):
10-
def __init__(self, keys: Iterable[Tensor]):
11-
self.keys = set(keys)
10+
def __init__(self, keys: Set[Tensor]):
11+
self.keys = keys
1212

1313
def __call__(self, tensor_dict: _A) -> _A:
1414
output = {key: tensor_dict[key] for key in self.keys}
1515
return type(tensor_dict)(output)
1616

1717
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
18-
if not self.keys.issubset(input_keys):
18+
keys = set(self.keys)
19+
if not keys.issubset(input_keys):
1920
raise RequirementError(
2021
f"The input_keys should be a super set of the keys to select. Found input_keys "
21-
f"{input_keys} and keys to select {self.keys}."
22+
f"{input_keys} and keys to select {keys}."
2223
)
23-
return self.keys
24+
return keys

src/torchjd/autojac/_utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:
1515
)
1616

1717

18-
def as_tensor_list(tensors: Sequence[Tensor] | Tensor) -> list[Tensor]:
18+
def as_checked_ordered_set(
19+
tensors: Sequence[Tensor] | Tensor, variable_name: str
20+
) -> OrderedSet[Tensor]:
1921
if isinstance(tensors, Tensor):
20-
output = [tensors]
21-
else:
22-
output = list(tensors)
23-
return output
22+
tensors = [tensors]
23+
24+
original_length = len(tensors)
25+
output = OrderedSet(tensors)
26+
27+
if len(output) != original_length:
28+
raise ValueError(f"`{variable_name}` should contain unique elements.")
29+
30+
return OrderedSet(tensors)
2431

2532

2633
def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]:

src/torchjd/autojac/backward.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform
88
from ._transform.ordered_set import OrderedSet
9-
from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors
9+
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
1010

1111

1212
def backward(
@@ -69,7 +69,7 @@ def backward(
6969
"""
7070
check_optional_positive_chunk_size(parallel_chunk_size)
7171

72-
tensors = as_tensor_list(tensors)
72+
tensors = as_checked_ordered_set(tensors, "tensors")
7373

7474
if len(tensors) == 0:
7575
raise ValueError("`tensors` cannot be empty")
@@ -91,7 +91,7 @@ def backward(
9191

9292

9393
def _create_transform(
94-
tensors: list[Tensor],
94+
tensors: OrderedSet[Tensor],
9595
aggregator: Aggregator,
9696
inputs: OrderedSet[Tensor],
9797
retain_graph: bool,
@@ -103,7 +103,7 @@ def _create_transform(
103103
init = Init(tensors)
104104

105105
# Transform that turns the gradients into Jacobians.
106-
diag = Diagonalize(OrderedSet(tensors))
106+
diag = Diagonalize(tensors)
107107

108108
# Transform that computes the required Jacobians.
109109
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)

src/torchjd/autojac/mtl_backward.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Transform,
1818
)
1919
from ._transform.ordered_set import OrderedSet
20-
from ._utils import as_tensor_list, check_optional_positive_chunk_size, get_leaf_tensors
20+
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
2121

2222

2323
def mtl_backward(
@@ -81,7 +81,8 @@ def mtl_backward(
8181

8282
check_optional_positive_chunk_size(parallel_chunk_size)
8383

84-
features = as_tensor_list(features)
84+
losses = as_checked_ordered_set(losses, "losses")
85+
features = as_checked_ordered_set(features, "features")
8586

8687
if shared_params is None:
8788
shared_params = get_leaf_tensors(tensors=features, excluded=[])
@@ -117,8 +118,8 @@ def mtl_backward(
117118

118119

119120
def _create_transform(
120-
losses: Sequence[Tensor],
121-
features: list[Tensor],
121+
losses: OrderedSet[Tensor],
122+
features: OrderedSet[Tensor],
122123
aggregator: Aggregator,
123124
tasks_params: list[OrderedSet[Tensor]],
124125
shared_params: OrderedSet[Tensor],
@@ -138,7 +139,7 @@ def _create_transform(
138139
_create_task_transform(
139140
features,
140141
task_params,
141-
loss,
142+
OrderedSet([loss]),
142143
retain_graph,
143144
)
144145
for task_params, loss in zip(tasks_params, losses)
@@ -161,21 +162,21 @@ def _create_transform(
161162

162163

163164
def _create_task_transform(
164-
features: list[Tensor],
165+
features: OrderedSet[Tensor],
165166
task_params: OrderedSet[Tensor],
166-
loss: Tensor,
167+
loss: OrderedSet[Tensor], # contains a single scalar loss
167168
retain_graph: bool,
168169
) -> Transform[EmptyTensorDict, Gradients]:
169170
# Tensors with respect to which we compute the gradients.
170171
to_differentiate = OrderedSet(task_params) # Re-instantiate set to avoid modifying input
171-
to_differentiate.update(OrderedSet(features))
172+
to_differentiate.update(features)
172173

173174
# Transform that initializes the gradient output to 1.
174-
init = Init([loss])
175+
init = Init(loss)
175176

176177
# Transform that computes the gradients of the loss w.r.t. the task-specific parameters and
177178
# the features.
178-
grad = Grad([loss], to_differentiate, retain_graph)
179+
grad = Grad(loss, to_differentiate, retain_graph)
179180

180181
# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
181182
# .grad fields.

0 commit comments

Comments
 (0)