-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathinit.py
More file actions
25 lines (18 loc) · 871 Bytes
/
init.py
File metadata and controls
25 lines (18 loc) · 871 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from typing import Iterable
import torch
from torch import Tensor
from .base import Transform
from .tensor_dict import EmptyTensorDict, Gradients
class Init(Transform[EmptyTensorDict, Gradients]):
def __init__(self, values: Iterable[Tensor]):
self.values = set(values)
def __call__(self, input: EmptyTensorDict) -> Gradients:
r"""
Computes the gradients of the ``value`` with respect to itself. Returns the result as a
dictionary. The only key of the dictionary is ``value``. The corresponding gradient is a
tensor of 1s of identical shape, because :math:`\frac{\partial v}{\partial v} = 1` for any
:math:`v`.
"""
return Gradients({value: torch.ones_like(value) for value in self.values})
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
return set(), self.values