Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.aggregation._mixins import ResettableMixin

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class GradVac(GramianWeightedAggregator):
class GradVac(GramianWeightedAggregator, ResettableMixin):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Expand Down Expand Up @@ -71,7 +72,7 @@ def __repr__(self) -> str:
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"


class GradVacWeighting(Weighting[PSDMatrix]):
class GradVacWeighting(Weighting[PSDMatrix], ResettableMixin):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.
Expand Down
9 changes: 9 additions & 0 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod


class ResettableMixin(ABC):
"""Class implementing a reset method."""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""
9 changes: 5 additions & 4 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
# See NOTICES for the full license text.

from typing import cast

from torchjd._linalg import Matrix
from torchjd.aggregation._mixins import ResettableMixin

from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import Weighting

check_dependencies_are_installed(["cvxpy", "ecos"])

from typing import cast

import cvxpy as cp
import numpy as np
import torch
Expand All @@ -20,7 +21,7 @@
from ._utils.non_differentiable import raise_non_differentiable_error


class NashMTL(WeightedAggregator):
class NashMTL(WeightedAggregator, ResettableMixin):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of
`Multi-Task Learning as a Bargaining Game <https://arxiv.org/pdf/2202.01017.pdf>`_.
Expand Down Expand Up @@ -83,7 +84,7 @@ def __repr__(self) -> str:
)


class _NashMTLWeighting(Weighting[Matrix]):
class _NashMTLWeighting(Weighting[Matrix], ResettableMixin):
"""
:class:`~torchjd.aggregation.Weighting` that extracts weights using the step decision
of Algorithm 1 of `Multi-Task Learning as a Bargaining Game
Expand Down
Loading