Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions src/torchjd/autojac/_transform/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor

from .tensor_dict import EmptyTensorDict, TensorDict, _least_common_ancestor
from .tensor_dict import TensorDict

_KeyType = TypeVar("_KeyType", bound=Hashable)
_ValueType = TypeVar("_ValueType")
Expand Down Expand Up @@ -38,12 +38,3 @@ def _materialize(
else:
tensors.append(optional_tensor)
return tuple(tensors)


def _union(tensor_dicts: Iterable[_A]) -> _A:
output_type: type[_A] = EmptyTensorDict
output: _A = EmptyTensorDict()
for tensor_dict in tensor_dicts:
output_type = _least_common_ancestor(output_type, type(tensor_dict))
output |= tensor_dict
return output_type(output)
12 changes: 9 additions & 3 deletions src/torchjd/autojac/_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from torch import Tensor

from ._utils import _A, _B, _C, _union
from ._utils import _A, _B, _C
from .tensor_dict import EmptyTensorDict, _least_common_ancestor


class RequirementError(ValueError):
Expand Down Expand Up @@ -99,8 +100,13 @@ def __str__(self) -> str:
return "(" + " | ".join(strings) + ")"

def __call__(self, tensor_dict: _A) -> _B:
output = _union([transform(tensor_dict) for transform in self.transforms])
return output
transformed = [transform(tensor_dict) for transform in self.transforms]
output_type: type[_A] = EmptyTensorDict
output: _A = EmptyTensorDict()
for tensor_dict in transformed:
output_type = _least_common_ancestor(output_type, type(tensor_dict))
output |= tensor_dict
return output_type(output)

def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]
Expand Down