Skip to content

Commit 387a672

Browse files
authored
fix(aggregation): Make non-differentiability explicit (#334)
* Add _non_differentiable.py with NonDifferentiableError and raise_non_differentiable_error * Register raise_non_differentiable_error as a full backward pre-hook of CAGrad, ConFIG, DualProj, GradDrop, IMTLG, NashMTL, PCGrad and UPGrad * Add NonDifferentiableProperty tester * Give NonDifferentiableProperty to CAGrad, ConFIG, DualProj, GradDrop, IMTLG, NashMTL, PCGrad and UPGrad * Add changelog entry
1 parent b6a0a2d commit 387a672

19 files changed

+100
-13
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ changes that do not affect the user.
1515
torchjd[nash_mtl]` or `pip install torchjd[full]` to install TorchJD alongside those dependencies.
1616
This should make TorchJD more lightweight.
1717

18+
### Fixed
19+
20+
- Made some aggregators (`CAGrad`, `ConFIG`, `DualProj`, `GradDrop`, `IMTLG`, `NashMTL`, `PCGrad`
21+
and `UPGrad`) raise a `NonDifferentiableError` whenever one tries to differentiate through them.
22+
Before this change, trying to differentiate through them leaded to wrong gradients or unclear
23+
errors.
24+
1825
## [0.6.0] - 2025-04-19
1926

2027
### Added
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from torch import Tensor, nn
2+
3+
4+
class NonDifferentiableError(RuntimeError):
5+
def __init__(self, module: nn.Module):
6+
super().__init__(f"Trying to differentiate through {module}, which is not differentiable.")
7+
8+
9+
def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...]) -> None:
10+
raise NonDifferentiableError(module)

src/torchjd/aggregation/cagrad.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch import Tensor
99

1010
from ._gramian_utils import compute_gramian, normalize
11+
from ._non_differentiable import raise_non_differentiable_error
1112
from .bases import _WeightedAggregator, _Weighting
1213

1314

@@ -44,6 +45,9 @@ class CAGrad(_WeightedAggregator):
4445
def __init__(self, c: float, norm_eps: float = 0.0001):
4546
super().__init__(weighting=_CAGradWeighting(c=c, norm_eps=norm_eps))
4647

48+
# This prevents considering the computed weights as constant w.r.t. the matrix.
49+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
50+
4751
def __repr__(self) -> str:
4852
return (
4953
f"{self.__class__.__name__}(c={self.weighting.c}, norm_eps={self.weighting.norm_eps})"

src/torchjd/aggregation/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
from torch import Tensor
3030

31+
from ._non_differentiable import raise_non_differentiable_error
3132
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
3233
from .bases import Aggregator
3334
from .sum import _SumWeighting
@@ -65,6 +66,9 @@ def __init__(self, pref_vector: Tensor | None = None):
6566
self.weighting = pref_vector_to_weighting(pref_vector, default=_SumWeighting())
6667
self._pref_vector = pref_vector
6768

69+
# This prevents computing gradients that can be very wrong.
70+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
71+
6872
def forward(self, matrix: Tensor) -> Tensor:
6973
weights = self.weighting(matrix)
7074
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)

src/torchjd/aggregation/dualproj.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from ._dual_cone_utils import project_weights
66
from ._gramian_utils import compute_gramian, normalize, regularize
7+
from ._non_differentiable import raise_non_differentiable_error
78
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
89
from .bases import _WeightedAggregator, _Weighting
910
from .mean import _MeanWeighting
@@ -56,6 +57,9 @@ def __init__(
5657
)
5758
)
5859

60+
# This prevents considering the computed weights as constant w.r.t. the matrix.
61+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
62+
5963
def __repr__(self) -> str:
6064
return (
6165
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="

src/torchjd/aggregation/graddrop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from torch import Tensor
55

6+
from ._non_differentiable import raise_non_differentiable_error
67
from .bases import Aggregator
78

89

@@ -47,6 +48,9 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
4748
self.f = f
4849
self.leak = leak
4950

51+
# This prevents computing gradients that can be very wrong.
52+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
53+
5054
def forward(self, matrix: Tensor) -> Tensor:
5155
self._check_is_matrix(matrix)
5256
self._check_matrix_has_enough_rows(matrix)

src/torchjd/aggregation/imtl_g.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch import Tensor
33

44
from ._gramian_utils import compute_gramian
5+
from ._non_differentiable import raise_non_differentiable_error
56
from .bases import _WeightedAggregator, _Weighting
67

78

@@ -30,6 +31,9 @@ class IMTLG(_WeightedAggregator):
3031
def __init__(self):
3132
super().__init__(weighting=_IMTLGWeighting())
3233

34+
# This prevents computing gradients that can be very wrong.
35+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
36+
3337

3438
class _IMTLGWeighting(_Weighting):
3539
"""

src/torchjd/aggregation/nash_mtl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from cvxpy import Expression
3434
from torch import Tensor
3535

36+
from ._non_differentiable import raise_non_differentiable_error
3637
from .bases import _WeightedAggregator, _Weighting
3738

3839

@@ -95,6 +96,9 @@ def __init__(
9596
)
9697
)
9798

99+
# This prevents considering the computed weights as constant w.r.t. the matrix.
100+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
101+
98102
def reset(self) -> None:
99103
"""Resets the internal state of the algorithm."""
100104
self.weighting.reset()

src/torchjd/aggregation/pcgrad.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch import Tensor
33

44
from ._gramian_utils import compute_gramian
5+
from ._non_differentiable import raise_non_differentiable_error
56
from .bases import _WeightedAggregator, _Weighting
67

78

@@ -28,6 +29,9 @@ class PCGrad(_WeightedAggregator):
2829
def __init__(self):
2930
super().__init__(weighting=_PCGradWeighting())
3031

32+
# This prevents running into a RuntimeError due to modifying stored tensors in place.
33+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
34+
3135

3236
class _PCGradWeighting(_Weighting):
3337
"""

src/torchjd/aggregation/upgrad.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ._dual_cone_utils import project_weights
77
from ._gramian_utils import compute_gramian, normalize, regularize
8+
from ._non_differentiable import raise_non_differentiable_error
89
from ._pref_vector_utils import pref_vector_to_str_suffix, pref_vector_to_weighting
910
from .bases import _WeightedAggregator, _Weighting
1011
from .mean import _MeanWeighting
@@ -56,6 +57,9 @@ def __init__(
5657
)
5758
)
5859

60+
# This prevents considering the computed weights as constant w.r.t. the matrix.
61+
self.register_full_backward_pre_hook(raise_non_differentiable_error)
62+
5963
def __repr__(self) -> str:
6064
return (
6165
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="

0 commit comments

Comments
 (0)