Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/torchjd/autojac/_transform/_differentiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from torch import Tensor

from .base import _A, RequirementError, Transform
from .base import RequirementError, Transform
from .ordered_set import OrderedSet
from .tensor_dict import _A


class _Differentiate(Transform[_A, _A], ABC):
class Differentiate(Transform[_A, _A], ABC):
def __init__(
self,
outputs: Iterable[Tensor],
Expand Down
24 changes: 24 additions & 0 deletions src/torchjd/autojac/_transform/_materialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Sequence

import torch
from torch import Tensor


def materialize(
optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
) -> tuple[Tensor, ...]:
"""
Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
shape as the corresponding input. Returns the obtained sequence as a tuple.

Note that the name "materialize" comes from the flag `materialize_grads` from
`torch.autograd.grad`, which will be available in future torch releases.
"""

tensors = []
for optional_tensor, input in zip(optional_tensors, inputs):
if optional_tensor is None:
tensors.append(torch.zeros_like(input))
else:
tensors.append(optional_tensor)
return tuple(tensors)
49 changes: 0 additions & 49 deletions src/torchjd/autojac/_transform/_utils.py

This file was deleted.

11 changes: 8 additions & 3 deletions src/torchjd/autojac/_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch import Tensor

from ._utils import _A, _B, _C, _union
from .tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor


class RequirementError(ValueError):
Expand Down Expand Up @@ -99,8 +99,13 @@ def __str__(self) -> str:
return "(" + " | ".join(strings) + ")"

def __call__(self, tensor_dict: _A) -> _B:
output = _union([transform(tensor_dict) for transform in self.transforms])
return output
tensor_dicts = [transform(tensor_dict) for transform in self.transforms]
output_type: type[_A] = EmptyTensorDict
output: _A = EmptyTensorDict()
for tensor_dict in tensor_dicts:
output_type = _least_common_ancestor(output_type, type(tensor_dict))
output |= tensor_dict
return output_type(output)

def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]
Expand Down
8 changes: 4 additions & 4 deletions src/torchjd/autojac/_transform/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import torch
from torch import Tensor

from ._differentiate import _Differentiate
from ._utils import _materialize
from ._differentiate import Differentiate
from ._materialize import materialize
from .tensor_dict import Gradients


class Grad(_Differentiate[Gradients]):
class Grad(Differentiate[Gradients]):
def __init__(
self,
outputs: Iterable[Tensor],
Expand Down Expand Up @@ -47,5 +47,5 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
create_graph=self.create_graph,
allow_unused=True,
)
grads = _materialize(optional_grads, inputs)
grads = materialize(optional_grads, inputs)
return grads
8 changes: 4 additions & 4 deletions src/torchjd/autojac/_transform/jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import torch
from torch import Size, Tensor

from ._differentiate import _Differentiate
from ._utils import _materialize
from ._differentiate import Differentiate
from ._materialize import materialize
from .tensor_dict import Jacobians


class Jac(_Differentiate[Jacobians]):
class Jac(Differentiate[Jacobians]):
def __init__(
self,
outputs: Iterable[Tensor],
Expand Down Expand Up @@ -60,7 +60,7 @@ def _get_vjp(grad_outputs: Sequence[Tensor], retain_graph: bool) -> Tensor:
create_graph=self.create_graph,
allow_unused=True,
)
grads = _materialize(optional_grads, inputs=inputs)
grads = materialize(optional_grads, inputs=inputs)
return torch.concatenate([grad.reshape([-1]) for grad in grads])

# By the Jacobians constraint, this value should be the same for all jac_outputs.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/autojac/_transform/ordered_set.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import OrderedDict
from typing import Iterable
from typing import Hashable, Iterable, TypeVar

from torchjd.autojac._transform._utils import _KeyType
_KeyType = TypeVar("_KeyType", bound=Hashable)


class OrderedSet(OrderedDict[_KeyType, None]):
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_transform/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from torch import Tensor

from ._utils import _A
from .base import RequirementError, Transform
from .tensor_dict import _A


class Select(Transform[_A, _A]):
Expand Down
11 changes: 7 additions & 4 deletions src/torchjd/autojac/_transform/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch
from torch import Tensor

from ._utils import _A, _materialize, dicts_union
from ._materialize import materialize
from .base import Transform
from .tensor_dict import Gradients, Jacobians
from .tensor_dict import _A, Gradients, Jacobians


class Stack(Transform[_A, Jacobians]):
Expand All @@ -32,7 +32,10 @@ def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
# It is important to first remove duplicate keys before computing their associated
# stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute
# unique_keys, and only then, we compute the stacked tensors.
unique_keys = dicts_union(gradient_dicts).keys()
union = {}
for d in gradient_dicts:
union |= d
unique_keys = union.keys()
result = Jacobians({key: _stack_one_key(gradient_dicts, key) for key in unique_keys})
return result

Expand All @@ -43,6 +46,6 @@ def _stack_one_key(gradient_dicts: list[Gradients], input: Tensor) -> Tensor:
"""

optional_gradients = [gradients.get(input, None) for gradients in gradient_dicts]
gradients = _materialize(optional_gradients, [input] * len(optional_gradients))
gradients = materialize(optional_gradients, [input] * len(optional_gradients))
jacobian = torch.stack(gradients, dim=0)
return jacobian
7 changes: 7 additions & 0 deletions src/torchjd/autojac/_transform/tensor_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TypeVar

from torch import Tensor


Expand Down Expand Up @@ -177,3 +179,8 @@ def _check_corresponding_numel(key: Tensor, value: Tensor, dim: int) -> None:
"the number of elements in the corresponding key tensor. Found pair with key shape "
f"{key.shape} and value shape {value.shape}."
)


_A = TypeVar("_A", bound=TensorDict)
_B = TypeVar("_B", bound=TensorDict)
_C = TypeVar("_C", bound=TensorDict)
3 changes: 1 addition & 2 deletions tests/unit/autojac/_transform/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from pytest import raises
from torch import Tensor

from torchjd.autojac._transform._utils import _B, _C
from torchjd.autojac._transform.base import Conjunction, RequirementError, Transform
from torchjd.autojac._transform.tensor_dict import TensorDict
from torchjd.autojac._transform.tensor_dict import _B, _C, TensorDict


class FakeTransform(Transform[_B, _C]):
Expand Down