Skip to content

Commit 922e746

Browse files
committed
Add RequirementError
1 parent 7252056 commit 922e746

5 files changed

Lines changed: 16 additions & 10 deletions

File tree

src/torchjd/autojac/_transform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .accumulate import Accumulate
22
from .aggregate import Aggregate
3-
from .base import Composition, Conjunction, Transform
3+
from .base import Composition, Conjunction, RequirementError, Transform
44
from .diagonalize import Diagonalize
55
from .grad import Grad
66
from .init import Init

src/torchjd/autojac/_transform/accumulate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from torch import Tensor
44

5-
from .base import Transform
5+
from .base import RequirementError, Transform
66
from .tensor_dict import EmptyTensorDict, Gradients
77

88

@@ -34,7 +34,7 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
3434

3535
def _check_expects_grad(tensor: Tensor) -> None:
3636
if not _expects_grad(tensor):
37-
raise ValueError(
37+
raise RequirementError(
3838
"Cannot populate the .grad field of a Tensor that does not satisfy:"
3939
"`tensor.requires_grad and (tensor.is_leaf or tensor.retains_grad)`."
4040
)

src/torchjd/autojac/_transform/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from ._utils import _A, _B, _C, _union
99

1010

11+
class RequirementError(ValueError):
12+
pass
13+
14+
1115
class Transform(Generic[_B, _C], ABC):
1216
r"""
1317
Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian
@@ -77,7 +81,7 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
7781
outer_required_keys, outer_output_keys = self.outer.check_and_get_keys()
7882
inner_required_keys, inner_output_keys = self.inner.check_and_get_keys()
7983
if outer_required_keys != inner_output_keys:
80-
raise ValueError(
84+
raise RequirementError(
8185
"The `output_keys` of `inner` must match with the `required_keys` of "
8286
f"outer. Found {outer_required_keys} and {inner_output_keys}"
8387
)
@@ -108,12 +112,12 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
108112
required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys)
109113
for transform_required_keys, _ in keys_pairs:
110114
if transform_required_keys != required_keys:
111-
raise ValueError("All transforms should require the same set of keys.")
115+
raise RequirementError("All transforms should require the same set of keys.")
112116

113117
output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys]
114118
output_keys = set(output_keys_with_duplicates)
115119

116120
if len(output_keys) != len(output_keys_with_duplicates):
117-
raise ValueError("The sets of output keys of transforms should be disjoint.")
121+
raise RequirementError("The sets of output keys of transforms should be disjoint.")
118122

119123
return required_keys, output_keys

src/torchjd/autojac/_transform/select.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import Tensor
44

55
from ._utils import _A
6-
from .base import Transform
6+
from .base import RequirementError, Transform
77

88

99
class Select(Transform[_A, _A]):
@@ -18,6 +18,8 @@ def __call__(self, tensor_dict: _A) -> _A:
1818
def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
1919
required_keys = self._required_keys
2020
if not self.keys.issubset(required_keys):
21-
raise ValueError("Parameter `keys` should be a subset of parameter `required_keys`")
21+
raise RequirementError(
22+
"Parameter `keys` should be a subset of parameter `required_keys`"
23+
)
2224

2325
return required_keys, self.keys

src/torchjd/autojac/_transform/stack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66
from ._utils import _A, _materialize, dicts_union
7-
from .base import Transform
7+
from .base import RequirementError, Transform
88
from .tensor_dict import Gradients, Jacobians
99

1010

@@ -25,7 +25,7 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]:
2525

2626
for transform_required_keys, _ in keys_pairs:
2727
if transform_required_keys != required_keys:
28-
raise ValueError("All transforms should require the same set of keys.")
28+
raise RequirementError("All transforms should require the same set of keys.")
2929

3030
return required_keys, output_keys
3131

0 commit comments

Comments
 (0)