-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_utils.py
More file actions
49 lines (37 loc) · 1.53 KB
/
_utils.py
File metadata and controls
49 lines (37 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Hashable, Iterable, Sequence, TypeVar
import torch
from torch import Tensor
from .tensor_dict import EmptyTensorDict, TensorDict, _least_common_ancestor
_KeyType = TypeVar("_KeyType", bound=Hashable)
_ValueType = TypeVar("_ValueType")
_A = TypeVar("_A", bound=TensorDict)
_B = TypeVar("_B", bound=TensorDict)
_C = TypeVar("_C", bound=TensorDict)
def dicts_union(dicts: Iterable[dict[_KeyType, _ValueType]]) -> dict[_KeyType, _ValueType]:
result = {}
for d in dicts:
result |= d
return result
def _materialize(
optional_tensors: Sequence[Tensor | None], inputs: Sequence[Tensor]
) -> tuple[Tensor, ...]:
"""
Transforms a sequence of optional tensors by changing each None by a tensor of zeros of the same
shape as the corresponding input. Returns the obtained sequence as a tuple.
Note that the name "materialize" comes from the flag `materialize_grads` from
`torch.autograd.grad`, which will be available in future torch releases.
"""
tensors = []
for optional_tensor, input in zip(optional_tensors, inputs):
if optional_tensor is None:
tensors.append(torch.zeros_like(input))
else:
tensors.append(optional_tensor)
return tuple(tensors)
def _union(tensor_dicts: Iterable[_A]) -> _A:
output_type: type[_A] = EmptyTensorDict
output: _A = EmptyTensorDict()
for tensor_dict in tensor_dicts:
output_type = _least_common_ancestor(output_type, type(tensor_dict))
output |= tensor_dict
return output_type(output)