Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

from torch import Tensor, nn

Expand Down Expand Up @@ -68,7 +69,10 @@ def forward(self, matrix: Matrix, /) -> Tensor:
return vector


class GramianWeightedAggregator(WeightedAggregator):
_T = TypeVar("_T", bound=Weighting[PSDMatrix])


class GramianWeightedAggregator(WeightedAggregator, Generic[_T]):
"""
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
Weighting to it.
Expand All @@ -77,6 +81,6 @@ class GramianWeightedAggregator(WeightedAggregator):
gramian.
"""

def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
def __init__(self, gramian_weighting: _T) -> None:
super().__init__(gramian_weighting << compute_gramian)
self.gramian_weighting = gramian_weighting
90 changes: 54 additions & 36 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,6 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class CAGrad(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.

:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[cagrad]"``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self._c = c
self._norm_eps = norm_eps

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"

def __str__(self) -> str:
c_str = str(self._c).rstrip("0")
return f"CAGrad{c_str}"


class CAGradWeighting(Weighting[PSDMatrix]):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
Expand All @@ -69,13 +37,22 @@

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__()

if c < 0.0:
raise ValueError(f"Parameter `c` should be a non-negative float. Found `c = {c}`.")

self.c = c
self.norm_eps = norm_eps

@property
def c(self) -> float:
return self._c

@c.setter
def c(self, value: float) -> None:
if value < 0.0:
raise ValueError(
f"Parameter `value` should be a non-negative float. Found `value = {value}`."
)

self._c = value

def forward(self, gramian: PSDMatrix, /) -> Tensor:
U, S, _ = torch.svd(normalize(gramian, self.norm_eps))

Expand Down Expand Up @@ -104,3 +81,44 @@
weights = torch.from_numpy(weight_array).to(device=gramian.device, dtype=gramian.dtype)

return weights


class CAGrad(GramianWeightedAggregator[CAGradWeighting]):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.

:param c: The scale of the radius of the ball constraint.
:param norm_eps: A small value to avoid division by zero when normalizing.

.. note::
This aggregator is not installed by default. When not installed, trying to import it should
result in the following error:
``ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'``.
To install it, use ``pip install "torchjd[cagrad]"``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self.c = c
self._norm_eps = norm_eps

# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def c(self) -> float:
return self._c

Check warning on line 112 in src/torchjd/aggregation/_cagrad.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_cagrad.py#L112

Added line #L112 was not covered by tests

@c.setter
def c(self, value: float) -> None:
self.gramian_weighting.c = value
self._c = value

def __repr__(self) -> str:
return f"{self.__class__.__name__}(c={self._c}, norm_eps={self._norm_eps})"

def __str__(self) -> str:
c_str = str(self._c).rstrip("0")
return f"CAGrad{c_str}"
Loading