Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added `pref_vector`, `norm_eps`, and `reg_eps` getters and setters to `UPGrad` and
`UPGradWeighting`. The setters for `norm_eps` and `reg_eps` validate that the assigned value is
non-negative.

## [0.10.0] - 2026-04-16

### Added
Expand Down
71 changes: 62 additions & 9 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,50 @@ def __init__(
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
self.pref_vector = pref_vector
self.norm_eps = norm_eps
self.reg_eps = reg_eps
self.solver: SUPPORTED_SOLVER = solver

def forward(self, gramian: PSDMatrix, /) -> Tensor:
U = torch.diag(self.weighting(gramian))
G = regularize(normalize(gramian, self.norm_eps), self.reg_eps)
G = regularize(normalize(gramian, self._norm_eps), self._reg_eps)
Comment thread
ValerianRey marked this conversation as resolved.
Outdated
W = project_weights(U, G, self.solver)
return torch.sum(W, dim=0)

@property
def pref_vector(self) -> Tensor | None:
return self._pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self._pref_vector = value
self.weighting = pref_vector_to_weighting(value, default=MeanWeighting())
Comment thread
ValerianRey marked this conversation as resolved.

@property
def norm_eps(self) -> float:
return self._norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:

if value < 0:
raise ValueError(f"norm_eps must be non-negative, but got {value}.")

self._norm_eps = value

@property
def reg_eps(self) -> float:
return self._reg_eps

@reg_eps.setter
def reg_eps(self, value: float) -> None:

if value < 0:
raise ValueError(f"reg_eps must be non-negative, but got {value}.")

self._reg_eps = value


class UPGrad(GramianWeightedAggregator):
r"""
Expand Down Expand Up @@ -73,9 +105,6 @@ def __init__(
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
) -> None:
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
self._solver: SUPPORTED_SOLVER = solver

super().__init__(
Expand All @@ -85,11 +114,35 @@ def __init__(
# This prevents considering the computed weights as constant w.r.t. the matrix.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector

@pref_vector.setter
def pref_vector(self, value: Tensor | None) -> None:
self.gramian_weighting.pref_vector = value

@property
def norm_eps(self) -> float:
return self.gramian_weighting.norm_eps

@norm_eps.setter
def norm_eps(self, value: float) -> None:
self.gramian_weighting.norm_eps = value

@property
def reg_eps(self) -> float:
return self.gramian_weighting.reg_eps

@reg_eps.setter
def reg_eps(self, value: float) -> None:
self.gramian_weighting.reg_eps = value

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(pref_vector={repr(self._pref_vector)}, norm_eps="
f"{self._norm_eps}, reg_eps={self._reg_eps}, solver={repr(self._solver)})"
f"{self.__class__.__name__}(pref_vector={repr(self.pref_vector)}, norm_eps="
f"{self.norm_eps}, reg_eps={self.reg_eps}, solver={repr(self._solver)})"
)

def __str__(self) -> str:
return f"UPGrad{pref_vector_to_str_suffix(self._pref_vector)}"
return f"UPGrad{pref_vector_to_str_suffix(self.pref_vector)}"
46 changes: 45 additions & 1 deletion tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from pytest import mark
from pytest import mark, raises
from torch import Tensor
from utils.tensors import ones_

from torchjd.aggregation import UPGrad
from torchjd.aggregation._upgrad import UPGradWeighting

from ._asserts import (
assert_expected_structure,
Expand Down Expand Up @@ -67,3 +68,46 @@ def test_representations() -> None:
"solver='quadprog')"
)
assert str(A) == "UPGrad([1., 2., 3.])"


def test_pref_vector_setter_updates_value() -> None:
A = UPGrad()
new_pref = torch.tensor([1.0, 2.0, 3.0])
A.pref_vector = new_pref
assert A.pref_vector is new_pref
Comment thread
ValerianRey marked this conversation as resolved.


def test_norm_eps_setter_updates_value() -> None:
A = UPGrad()
A.norm_eps = 0.25
assert A.norm_eps == 0.25


def test_reg_eps_setter_updates_value() -> None:
A = UPGrad()
A.reg_eps = 0.25
assert A.reg_eps == 0.25


def test_norm_eps_setter_rejects_negative() -> None:
A = UPGrad()
with raises(ValueError, match="norm_eps"):
A.norm_eps = -1e-9


def test_reg_eps_setter_rejects_negative() -> None:
A = UPGrad()
with raises(ValueError, match="reg_eps"):
A.reg_eps = -1e-9


def test_weighting_norm_eps_setter_rejects_negative() -> None:
W = UPGradWeighting()
with raises(ValueError, match="norm_eps"):
W.norm_eps = -1e-9


def test_weighting_reg_eps_setter_rejects_negative() -> None:
W = UPGradWeighting()
with raises(ValueError, match="reg_eps"):
W.reg_eps = -1e-9
Loading