-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathaccumulate.py
More file actions
49 lines (36 loc) · 1.79 KB
/
accumulate.py
File metadata and controls
49 lines (36 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Iterable
from torch import Tensor
from .base import Transform
from .tensor_dict import EmptyTensorDict, Gradients
class Accumulate(Transform[Gradients, EmptyTensorDict]):
def __init__(self, required_keys: Iterable[Tensor]):
self._required_keys = set(required_keys)
def _compute(self, gradients: Gradients) -> EmptyTensorDict:
"""
Accumulates gradients with respect to keys in their ``.grad`` field.
"""
for key in gradients.keys():
_check_expects_grad(key)
if hasattr(key, "grad") and key.grad is not None:
key.grad += gradients[key]
else:
# We clone the value because we do not want subsequent accumulations to also affect
# this value (in case it is still used outside). We do not detach from the
# computation graph because the value can have grad_fn that we want to keep track of
# (in case it was obtained via create_graph=True and a differentiable aggregator).
key.grad = gradients[key].clone()
return EmptyTensorDict()
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return self._required_keys, set()
def _check_expects_grad(tensor: Tensor) -> None:
if not _expects_grad(tensor):
raise ValueError(
"Cannot populate the .grad field of a Tensor that does not satisfy:"
"`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`."
)
def _expects_grad(tensor: Tensor) -> bool:
"""
Determines whether a Tensor expects its .grad attribute to be populated.
See https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf for more information.
"""
return tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)