|
8 | 8 | from ._utils import _A, _B, _C, _union |
9 | 9 |
|
10 | 10 |
|
| 11 | +class RequirementError(ValueError): |
| 12 | + pass |
| 13 | + |
| 14 | + |
11 | 15 | class Transform(Generic[_B, _C], ABC): |
12 | 16 | r""" |
13 | 17 | Abstract base class for all transforms. Transforms are elementary building blocks of a jacobian |
@@ -77,7 +81,7 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: |
77 | 81 | outer_required_keys, outer_output_keys = self.outer.check_and_get_keys() |
78 | 82 | inner_required_keys, inner_output_keys = self.inner.check_and_get_keys() |
79 | 83 | if outer_required_keys != inner_output_keys: |
80 | | - raise ValueError( |
| 84 | + raise RequirementError( |
81 | 85 | "The `output_keys` of `inner` must match with the `required_keys` of " |
82 | 86 | f"outer. Found {outer_required_keys} and {inner_output_keys}" |
83 | 87 | ) |
@@ -108,12 +112,12 @@ def check_and_get_keys(self) -> tuple[set[Tensor], set[Tensor]]: |
108 | 112 | required_keys = set(key for required_keys, _ in keys_pairs for key in required_keys) |
109 | 113 | for transform_required_keys, _ in keys_pairs: |
110 | 114 | if transform_required_keys != required_keys: |
111 | | - raise ValueError("All transforms should require the same set of keys.") |
| 115 | + raise RequirementError("All transforms should require the same set of keys.") |
112 | 116 |
|
113 | 117 | output_keys_with_duplicates = [key for _, output_keys in keys_pairs for key in output_keys] |
114 | 118 | output_keys = set(output_keys_with_duplicates) |
115 | 119 |
|
116 | 120 | if len(output_keys) != len(output_keys_with_duplicates): |
117 | | - raise ValueError("The sets of output keys of transforms should be disjoint.") |
| 121 | + raise RequirementError("The sets of output keys of transforms should be disjoint.") |
118 | 122 |
|
119 | 123 | return required_keys, output_keys |
0 commit comments