33import torch
44from torch import Tensor
55
6- from ._utils import _A , _materialize , dicts_union
6+ from ._materialize import materialize
77from .base import Transform
8- from .tensor_dict import Gradients , Jacobians
8+ from .tensor_dict import _A , Gradients , Jacobians
99
1010
1111class Stack (Transform [_A , Jacobians ]):
@@ -32,7 +32,10 @@ def _stack(gradient_dicts: list[Gradients]) -> Jacobians:
3232 # It is important to first remove duplicate keys before computing their associated
3333 # stacked tensor. Otherwise, some computations would be duplicated. Therefore, we first compute
3434 # unique_keys, and only then, we compute the stacked tensors.
35- unique_keys = dicts_union (gradient_dicts ).keys ()
35+ union = {}
36+ for d in gradient_dicts :
37+ union |= d
38+ unique_keys = union .keys ()
3639 result = Jacobians ({key : _stack_one_key (gradient_dicts , key ) for key in unique_keys })
3740 return result
3841
@@ -43,6 +46,6 @@ def _stack_one_key(gradient_dicts: list[Gradients], input: Tensor) -> Tensor:
4346 """
4447
4548 optional_gradients = [gradients .get (input , None ) for gradients in gradient_dicts ]
46- gradients = _materialize (optional_gradients , [input ] * len (optional_gradients ))
49+ gradients = materialize (optional_gradients , [input ] * len (optional_gradients ))
4750 jacobian = torch .stack (gradients , dim = 0 )
4851 return jacobian
0 commit comments