File tree Expand file tree Collapse file tree
src/torchjd/autojac/_transform Expand file tree Collapse file tree Original file line number Diff line number Diff line change 33import torch
44from torch import Tensor
55
6- from .tensor_dict import EmptyTensorDict , TensorDict , _least_common_ancestor
6+ from .tensor_dict import TensorDict
77
88_KeyType = TypeVar ("_KeyType" , bound = Hashable )
99_ValueType = TypeVar ("_ValueType" )
@@ -38,12 +38,3 @@ def _materialize(
3838 else :
3939 tensors .append (optional_tensor )
4040 return tuple (tensors )
41-
42-
43- def _union (tensor_dicts : Iterable [_A ]) -> _A :
44- output_type : type [_A ] = EmptyTensorDict
45- output : _A = EmptyTensorDict ()
46- for tensor_dict in tensor_dicts :
47- output_type = _least_common_ancestor (output_type , type (tensor_dict ))
48- output |= tensor_dict
49- return output_type (output )
Original file line number Diff line number Diff line change 55
66from torch import Tensor
77
8- from ._utils import _A , _B , _C , _union
8+ from ._utils import _A , _B , _C
9+ from .tensor_dict import EmptyTensorDict , _least_common_ancestor
910
1011
1112class RequirementError (ValueError ):
@@ -99,8 +100,13 @@ def __str__(self) -> str:
99100 return "(" + " | " .join (strings ) + ")"
100101
101102 def __call__ (self , tensor_dict : _A ) -> _B :
102- output = _union ([transform (tensor_dict ) for transform in self .transforms ])
103- return output
103+ transformed = [transform (tensor_dict ) for transform in self .transforms ]
104+ output_type : type [_A ] = EmptyTensorDict
105+ output : _A = EmptyTensorDict ()
106+ for tensor_dict in transformed :
107+ output_type = _least_common_ancestor (output_type , type (tensor_dict ))
108+ output |= tensor_dict
109+ return output_type (output )
104110
105111 def check_keys (self , input_keys : set [Tensor ]) -> set [Tensor ]:
106112 output_keys_list = [key for t in self .transforms for key in t .check_keys (input_keys )]
You can’t perform that action at this time.
0 commit comments