diff --git a/src/torchjd/aggregation/graddrop.py b/src/torchjd/aggregation/graddrop.py index acf0cd70c..97d671e76 100644 --- a/src/torchjd/aggregation/graddrop.py +++ b/src/torchjd/aggregation/graddrop.py @@ -72,8 +72,8 @@ def _check_matrix_has_enough_rows(self, matrix: Tensor) -> None: n_rows = matrix.shape[0] if self.leak is not None and n_rows != len(self.leak): raise ValueError( - f"Parameter `matrix` should be a matrix of at least {len(self.leak)} rows " - f"(i.e. the number of leak scalars). Found `matrix` of shape `{matrix.shape}`." + f"Parameter `matrix` should be a matrix of exactly {len(self.leak)} rows (i.e. the " + f"number of leak scalars). Found `matrix` of shape `{matrix.shape}`." ) def __repr__(self) -> str: