Skip to content

Commit 99ae2d9

Browse files
committed
Improve RequirementError messages
1 parent 3b8a1b5 commit 99ae2d9

4 files changed

Lines changed: 7 additions & 6 deletions

File tree

src/torchjd/autojac/_transform/aggregate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __call__(self, jacobian_matrices: JacobianMatrices) -> GradientVectors:
5151
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
5252
if not set(self.key_order) == input_keys:
5353
raise RequirementError(
54-
f"The input_keys must match the key_order. Found {input_keys} and {self.key_order}"
54+
f"The input_keys must match the key_order. Found input_keys {input_keys} and"
55+
f"key_order {self.key_order}."
5556
)
5657
return input_keys
5758

src/torchjd/autojac/_transform/diagonalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
3131
considered = set(self.considered)
3232
if not considered.issubset(input_keys):
3333
raise RequirementError(
34-
f"The input_keys needs to be a super set of the considered keys. Found {input_keys} "
35-
f"and {considered}"
34+
f"The input_keys should be a super set of the considered keys. Found input_keys "
35+
f"{input_keys} and considered keys {considered}."
3636
)
3737
return considered

src/torchjd/autojac/_transform/init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ def __call__(self, input: EmptyTensorDict) -> Gradients:
2323

2424
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
2525
if input_keys == set():
26-
raise RequirementError(f"Init expects an empty set of input_keys. Found {input_keys}")
26+
raise RequirementError(f"The input_keys should be the empty set. Found {input_keys}.")
2727
return self.values

src/torchjd/autojac/_transform/select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __call__(self, tensor_dict: _A) -> _A:
1717
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
1818
if not self.keys.issubset(input_keys):
1919
raise RequirementError(
20-
f"The input_keys needs to be a super set of the keys to select. Found {input_keys} "
21-
f"and {self.keys}"
20+
f"The input_keys should be a super set of the keys to select. Found input_keys "
21+
f"{input_keys} and keys to select {self.keys}."
2222
)
2323
return self.keys

0 commit comments

Comments
 (0)