Skip to content

Commit ccdc14f

Browse files
committed
Make Diagonalize check_keys check for set equality
1 parent 5e662f5 commit ccdc14f

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/torchjd/autojac/_transform/diagonalize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def __call__(self, tensors: Gradients) -> Jacobians:
2929

3030
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
3131
considered = set(self.considered)
32-
if not considered.issubset(input_keys):
32+
if not considered == input_keys:
3333
raise RequirementError(
34-
f"The input_keys should be a super set of the considered keys. Found input_keys "
35-
f"{input_keys} and considered keys {considered}."
34+
f"The input_keys must match the considered keys. Found input_keys {input_keys} and"
35+
f"considered keys {considered}."
3636
)
3737
return considered

0 commit comments

Comments
 (0)