Skip to content

Commit 2884b27

Browse files
committed
Develop _union into Conjunction.__call__
1 parent 466f943 commit 2884b27

2 files changed

Lines changed: 10 additions & 13 deletions

File tree

src/torchjd/autojac/_transform/_utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch import Tensor
55

6-
from .tensor_dict import EmptyTensorDict, TensorDict, _least_common_ancestor
6+
from .tensor_dict import TensorDict
77

88
_KeyType = TypeVar("_KeyType", bound=Hashable)
99
_ValueType = TypeVar("_ValueType")
@@ -38,12 +38,3 @@ def _materialize(
3838
else:
3939
tensors.append(optional_tensor)
4040
return tuple(tensors)
41-
42-
43-
def _union(tensor_dicts: Iterable[_A]) -> _A:
44-
output_type: type[_A] = EmptyTensorDict
45-
output: _A = EmptyTensorDict()
46-
for tensor_dict in tensor_dicts:
47-
output_type = _least_common_ancestor(output_type, type(tensor_dict))
48-
output |= tensor_dict
49-
return output_type(output)

src/torchjd/autojac/_transform/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from torch import Tensor
77

8-
from ._utils import _A, _B, _C, _union
8+
from ._utils import _A, _B, _C
9+
from .tensor_dict import EmptyTensorDict, _least_common_ancestor
910

1011

1112
class RequirementError(ValueError):
@@ -99,8 +100,13 @@ def __str__(self) -> str:
99100
return "(" + " | ".join(strings) + ")"
100101

101102
def __call__(self, tensor_dict: _A) -> _B:
102-
output = _union([transform(tensor_dict) for transform in self.transforms])
103-
return output
103+
transformed = [transform(tensor_dict) for transform in self.transforms]
104+
output_type: type[_A] = EmptyTensorDict
105+
output: _A = EmptyTensorDict()
106+
for tensor_dict in transformed:
107+
output_type = _least_common_ancestor(output_type, type(tensor_dict))
108+
output |= tensor_dict
109+
return output_type(output)
104110

105111
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
106112
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]

0 commit comments

Comments
 (0)