Skip to content

Commit fe1686b

Browse files
authored
refactor(autojac): Improve _transform package structure (#298)
* Move _A, _B and _C to tensor_dict.py * Move _union to Conjunction.__call__ * Move dicts_union to _stack * Move _KeyType to ordered_set.py * Rename _utils.py to _materialize.py * Rename _materialize to materialize * Rename _Differentiate to Differentiate
1 parent 466f943 commit fe1686b

File tree

11 files changed

+61
-71
lines changed

11 files changed

+61
-71
lines changed

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from torch import Tensor
55

6-
from .base import _A, RequirementError, Transform
6+
from .base import RequirementError, Transform
77
from .ordered_set import OrderedSet
8+
from .tensor_dict import _A
89

910

10-
class _Differentiate(Transform[_A, _A], ABC):
11+
class Differentiate(Transform[_A, _A], ABC):
1112
def __init__(
1213
self,
1314
outputs: Iterable[Tensor],
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Sequence
2+
3+
import torch
4+
from torch import Tensor
5+
6+
7+
def materialize(
8+
optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
9+
) -> tuple[Tensor, ...]:
10+
"""
11+
Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
12+
shape as the corresponding input. Returns the obtained sequence as a tuple.
13+
14+
Note that the name "materialize" comes from the flag `materialize_grads` from
15+
`torch.autograd.grad`, which will be available in future torch releases.
16+
"""
17+
18+
tensors = []
19+
for optional_tensor, input in zip(optional_tensors, inputs):
20+
if optional_tensor is None:
21+
tensors.append(torch.zeros_like(input))
22+
else:
23+
tensors.append(optional_tensor)
24+
return tuple(tensors)

src/torchjd/autojac/_transform/_utils.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

src/torchjd/autojac/_transform/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torch import Tensor
77

8-
from ._utils import _A, _B, _C, _union
8+
from .tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor
99

1010

1111
class RequirementError(ValueError):
@@ -99,8 +99,13 @@ def __str__(self) -> str:
9999
return "(" + " | ".join(strings) + ")"
100100

101101
def __call__(self, tensor_dict: _A) -> _B:
102-
output = _union([transform(tensor_dict) for transform in self.transforms])
103-
return output
102+
tensor_dicts = [transform(tensor_dict) for transform in self.transforms]
103+
output_type: type[_A] = EmptyTensorDict
104+
output: _A = EmptyTensorDict()
105+
for tensor_dict in tensor_dicts:
106+
output_type = _least_common_ancestor(output_type, type(tensor_dict))
107+
output |= tensor_dict
108+
return output_type(output)
104109

105110
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
106111
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]

src/torchjd/autojac/_transform/grad.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import torch
44
from torch import Tensor
55

6-
from ._differentiate import _Differentiate
7-
from ._utils import _materialize
6+
from ._differentiate import Differentiate
7+
from ._materialize import materialize
88
from .tensor_dict import Gradients
99

1010

11-
class Grad(_Differentiate[Gradients]):
11+
class Grad(Differentiate[Gradients]):
1212
def __init__(
1313
self,
1414
outputs: Iterable[Tensor],
@@ -47,5 +47,5 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
4747
create_graph=self.create_graph,
4848
allow_unused=True,
4949
)
50-
grads = _materialize(optional_grads, inputs)
50+
grads = materialize(optional_grads, inputs)
5151
return grads

src/torchjd/autojac/_transform/jac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import torch
77
from torch import Size, Tensor
88

9-
from ._differentiate import _Differentiate
10-
from ._utils import _materialize
9+
from ._differentiate import Differentiate
10+
from ._materialize import materialize
1111
from .tensor_dict import Jacobians
1212

1313

14-
class Jac(_Differentiate[Jacobians]):
14+
class Jac(Differentiate[Jacobians]):
1515
def __init__(
1616
self,
1717
outputs: Iterable[Tensor],
@@ -60,7 +60,7 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
6060
create_graph=self.create_graph,
6161
allow_unused=True,
6262
)
63-
grads = _materialize(optional_grads, inputs=inputs)
63+
grads = materialize(optional_grads, inputs=inputs)
6464
return torch.concatenate([grad.reshape([-1]) for grad in grads])
6565

6666
# By the Jacobians constraint, this value should be the same for all jac_outputs.

src/torchjd/autojac/_transform/ordered_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import OrderedDict
2-
from typing import Iterable
2+
from typing import Hashable, Iterable, TypeVar
33

4-
from torchjd.autojac._transform._utils import _KeyType
4+
_KeyType = TypeVar("_KeyType", bound=Hashable)
55

66

77
class OrderedSet(OrderedDict[_KeyType, None]):

src/torchjd/autojac/_transform/select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from torch import Tensor
44

5-
from ._utils import _A
65
from .base import RequirementError, Transform
6+
from .tensor_dict import _A
77

88

99
class Select(Transform[_A, _A]):

src/torchjd/autojac/_transform/stack.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import torch
44
from torch import Tensor
55

6-
from ._utils import _A, _materialize, dicts_union
6+
from ._materialize import materialize
77
from .base import Transform
8-
from .tensor_dict import Gradients, Jacobians
8+
from .tensor_dict import _A, Gradients, Jacobians
99

1010

1111
class Stack(Transform[_A, Jacobians]):
@@ -32,7 +32,10 @@ def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
3232
# It is important to first remove duplicate keys before computing their associated
3333
# stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute
3434
# unique_keys, and only then, we compute the stacked tensors.
35-
unique_keys = dicts_union(gradient_dicts).keys()
35+
union = {}
36+
for d in gradient_dicts:
37+
union |= d
38+
unique_keys = union.keys()
3639
result = Jacobians({key: _stack_one_key(gradient_dicts, key) for key in unique_keys})
3740
return result
3841

@@ -43,6 +46,6 @@ def _stack_one_key(gradient_dicts: list[Gradients], input: Tensor) -> Tensor:
4346
"""
4447

4548
optional_gradients = [gradients.get(input, None) for gradients in gradient_dicts]
46-
gradients = _materialize(optional_gradients, [input] * len(optional_gradients))
49+
gradients = materialize(optional_gradients, [input] * len(optional_gradients))
4750
jacobian = torch.stack(gradients, dim=0)
4851
return jacobian

src/torchjd/autojac/_transform/tensor_dict.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import TypeVar
2+
13
from torch import Tensor
24

35

@@ -177,3 +179,8 @@ def _check_corresponding_numel(key: Tensor, value: Tensor, dim: int) -> None:
177179
"the number of elements in the corresponding key tensor. Found pair with key shape "
178180
f"{key.shape} and value shape {value.shape}."
179181
)
182+
183+
184+
_A = TypeVar("_A", bound=TensorDict)
185+
_B = TypeVar("_B", bound=TensorDict)
186+
_C = TypeVar("_C", bound=TensorDict)

0 commit comments

Comments
 (0)