-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathbase.py
More file actions
115 lines (85 loc) · 4.25 KB
/
base.py
File metadata and controls
115 lines (85 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generic, Sequence
from torch import Tensor
from ._utils import _A, _B, _C, _union
class Transform(Generic[_B, _C], ABC):
r"""
Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
descent backward phase. A transform maps a :class:`~torchjd.transform.tensor_dict.TensorDict` to
another. The input :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `required_keys`
and the output :class:`~torchjd.transform.tensor_dict.TensorDict` has keys `output_keys`.
Formally a transform is a function:
.. math::
f:\mathbb R^{n_1+\dots+n_p}\to \mathbb R^{m_1+\dots+m_q}
where we have ``p`` `required_keys`, ``q`` `output_keys`, ``n_i`` is the number of elements in
the value associated to the ``i`` th `required_key` of the input
:class:`~torchjd.transform.tensor_dict.TensorDict` and ``m_j`` is the number of elements in the
value associated to the ``j`` th `output_key` of the output
:class:`~torchjd.transform.tensor_dict.TensorDict`.
As they are mathematical functions, transforms can be composed together as long as their
domains and range meaningfully match.
"""
def compose(self, other: Transform[_A, _B]) -> Transform[_A, _C]:
return Composition(self, other)
def conjunct(self, other: Transform[_B, _C]) -> Transform[_B, _C]:
return Conjunction([self, other])
def __str__(self) -> str:
return type(self).__name__
@abstractmethod
def _compute(self, input: _B) -> _C:
"""Applies the transform to the input."""
def __call__(self, input: _B) -> _C:
return self._compute(input)
@abstractmethod
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
"""
Returns a pair containing (in order) the required keys and the output keys of the Transform.
Checks that the transform is valid.
"""
__lshift__ = compose
__or__ = conjunct
class Composition(Transform[_A, _C]):
def __init__(self, outer: Transform[_B, _C], inner: Transform[_A, _B]):
self.outer = outer
self.inner = inner
def __str__(self) -> str:
return str(self.outer) + " ∘ " + str(self.inner)
def _compute(self, input: _A) -> _C:
intermediate = self.inner(input)
return self.outer(intermediate)
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
outer_required_keys, outer_output_keys = self.outer.check_keys()
inner_required_keys, inner_output_keys = self.inner.check_keys()
if outer_required_keys != inner_output_keys:
raise ValueError(
"The `output_keys` of `inner` must match with the `required_keys` of "
f"outer. Found {outer_required_keys} and {inner_output_keys}"
)
return inner_required_keys, outer_output_keys
class Conjunction(Transform[_A, _B]):
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
self.transforms = transforms
def __str__(self) -> str:
strings = []
for t in self.transforms:
s = str(t)
if isinstance(t, Conjunction):
strings.append(s[1:-1]) # Remove parentheses
else:
strings.append(s)
return "(" + " | ".join(strings) + ")"
def _compute(self, tensor_dict: _A) -> _B:
output = _union([transform(tensor_dict) for transform in self.transforms])
return output
def check_keys(self) -> tuple[set[Tensor], set[Tensor]]:
keys_pairs = [transform.check_keys() for transform in self.transforms]
required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
for transform_required_keys, _ in keys_pairs:
if transform_required_keys != required_keys:
raise ValueError("All transforms should require the same set of keys.")
output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys]
output_keys = set(output_keys_with_duplicates)
if len(output_keys) != len(output_keys_with_duplicates):
raise ValueError("The sets of output keys of transforms should be disjoint.")
return required_keys, output_keys