-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path_differentiate.py
More file actions
43 lines (34 loc) · 1.56 KB
/
_differentiate.py
File metadata and controls
43 lines (34 loc) · 1.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from abc import ABC, abstractmethod
from typing import Iterable, Sequence
from torch import Tensor
from ._utils import ordered_set
from .base import _A, Transform
class _Differentiate(Transform[_A, _A], ABC):
def __init__(
self,
outputs: Iterable[Tensor],
inputs: Iterable[Tensor],
retain_graph: bool,
create_graph: bool,
):
self.outputs = list(outputs)
self.inputs = ordered_set(inputs)
self.retain_graph = retain_graph
self.create_graph = create_graph
def _compute(self, tensors: _A) -> _A:
tensor_outputs = [tensors[output] for output in self.outputs]
differentiated_tuple = self._differentiate(tensor_outputs)
new_differentiations = dict(zip(self.inputs, differentiated_tuple))
return type(tensors)(new_differentiations)
@abstractmethod
def _differentiate(self, tensor_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]:
"""
Abstract method for differentiating the outputs with respect to the inputs, and applying the
linear transformations represented by the tensor_outputs to the results.
The implementation of this method should define what kind of differentiation is performed:
whether gradients, Jacobians, etc. are computed, and what the dimension of the
tensor_outputs should be.
"""
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
# outputs in the forward direction become inputs in the backward direction, and vice-versa
return set(self.outputs), set(self.inputs)