Skip to content

feat(aggregation): Add GradVac aggregator#638

Open
rkhosrowshahi wants to merge 1 commit intoSimplexLab:mainfrom
rkhosrowshahi:feature/gradvac
Open

feat(aggregation): Add GradVac aggregator#638
rkhosrowshahi wants to merge 1 commit intoSimplexLab:mainfrom
rkhosrowshahi:feature/gradvac

Conversation

@rkhosrowshahi
Copy link
Copy Markdown
Contributor

@rkhosrowshahi rkhosrowshahi commented Apr 9, 2026

Summary

Adds Gradient Vaccine (GradVac) from ICLR 2021 as a stateful Aggregator on the full task Jacobian.

Behavior

  • Per-block cosine statistics and EMA targets \bar{\rho}, with the closed-form vaccine update when \rho < \bar{\rho}.
  • group_type: 0 whole model (single block); 1 all_layer via encoder (leaf modules with parameters); 2 all_matrix via shared_params (one block per tensor, iteration order = Jacobian column order).
  • DEFAULT_GRADVAC_EPS and configurable eps (constructor + mutable attribute).
  • Autogram not supported (needs full rows and per-block inner products). Task shuffle uses torch.randperm; use torch.manual_seed for reproducibility.

Files

  • src/torchjd/aggregation/_gradvac.py, export in __init__.py
  • docs/source/docs/aggregation/gradvac.rst + index toctree
  • tests/unit/aggregation/test_gradvac.py

Verification

  • ruff format / ruff check on touched paths
  • ty check on _gradvac.py
  • pytest tests/unit/aggregation/test_gradvac.py tests/unit/aggregation/test_values.py -W error
  • Sphinx HTML build for docs

Implement Gradient Vaccine (ICLR 2021) as a stateful Jacobian aggregator.
Support group_type 0 (whole model), 1 (all_layer via encoder), and 2
(all_matrix via shared_params), with DEFAULT_GRADVAC_EPS and configurable
eps. Add Sphinx page and unit tests. Autogram is not supported; use
torch.manual_seed for reproducible task shuffle order.

Made-with: Cursor
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels Apr 9, 2026
@rkhosrowshahi
Copy link
Copy Markdown
Contributor Author

Hey @ValerianRey :
Could you please add cc: feat and package: aggregation labels to this PR?

@github-actions github-actions bot changed the title feat(aggregation): add GradVac aggregator feat(aggregation): Add GradVac aggregator Apr 9, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

I think the link check action failed because of bad luck, re-running the job fixed it.

@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks a lot for the PR, I'll look at it tonight or tomorrow!

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan can you please check that the actual implementation (from line 219 to line 251 of _gradvac.py) is equivalent to the implementation in https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/GradVac.py? Please report anything that differs.

@opencode-agent
Copy link
Copy Markdown
Contributor

Let me analyze the differences between the two implementations:

Comparison Summary

After comparing lines 219-251 of _gradvac.py with the LibMTL reference implementation, here are the findings:

✅ Equivalent Behaviors

  1. Core Algorithm: Both implement the same GradVac algorithm:

    • Clone the gradients to pc_grads
    • Iterate over task pairs (i, j) where i ≠ j
    • Random shuffle of j for each i
    • Compute cosine similarity rho_ijk per block
    • Apply vaccine update when rho_ijk < bar (EMA target)
    • Update EMA: bar = (1-beta)*bar + beta*rho_ijk
  2. Weight Formula: Mathematically identical:

    • LibMTL: ||g_i||*(ρ̄*√(1-ρ²)-ρ*√(1-ρ̄²))/(||g_j||*√(1-ρ̄²)+ε)
    • TorchJD: norm_i*(bar*sqrt_1_rho2-rho_ijk*sqrt_1_bar2)/(norm_j*sqrt_1_bar2+eps)
  3. Group Types: Both support the same three grouping strategies (whole_model, all_layer, all_matrix)

🔍 Differences Found

Aspect LibMTL TorchJD (this PR)
RNG for shuffling random.shuffle() (Python stdlib) torch.randperm() (PyTorch)
Epsilon Hardcoded 1e-8 Configurable eps attribute
Numerical safety No clamp before sqrt() .clamp(min=0.0) before sqrt()
Block indexing Runtime sum(self.k_idx[:k]) Precomputed offsets list
Whole model case Special case with k_idx=[-1] Natural via sizes=(n,)
Return value batch_weight (unused) pc_grads.sum(dim=0)
State management self.step counter _state_key tuple

📝 Notes

  1. The clamp operation in TorchJD ((1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt()) is a numerical safety improvement not present in LibMTL - it prevents potential NaN from negative values due to floating-point errors.

  2. The random number generator difference is documented in the TorchJD docstring: "Task-order shuffling uses the global PyTorch RNG".

  3. Both implementations correctly handle the asymmetric nature of the algorithm (updating pc_grads[i] based on comparisons with grads[j]).

Conclusion: The implementations are algorithmically equivalent. TorchJD includes minor improvements (configurable eps, numerical clamp) and follows the codebase's design patterns.

New%20session%20-%202026-04-11T12%3A55%3A53.983Z
opencode session  |  github run

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan Please review this PR in detail.

@opencode-agent

This comment was marked as low quality.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good work, tyvm! This seems almost ready to merge IMO. Just a bunch of nitpicks that should be really easy to fix, and 1 somewhat important implementation question (iterating over groups in outer loop vs inner loop).

I would also argue that it seems that this could be implemented as a Weighting internally (because we actually act on norms and cosine similarities between gradients, which is what the gramian contains). Also, it's possible to keep track of norms and cosine similarities between projected gradients even if we don't have those gradients, just by making some operations on the gramian. This is what we did to implement PCGrad as a Weighting.

For example, imagine you have g1 and g2 be two gradients. From the gramian, you know ||g1||, ||g2|| (the sqrt of the diag elements), and g1 . g2 (an off-diag element), so you can deduce cos(g1, g2) from that.

If you compute g1' = g1 + w * g2, you can also directly deduce the norm of g1':
||g1'||² = ||g1||² + w² ||g2||² + 2w g1 . g2 (all elements of the right handside are known).

Similarly, you can compute g1' . g2 = (g1 + w * g2) . g2 = g1 . g2 + w g1 . g2.

So even after projection, you still know the dot products between all of your gradients, meaning that you still know the "new" gramian.

I didn't think through it entirely but at a first glance it seems possible to adapt this as a weighting, because of that. The implementation may even be faster actually (because we have fewer norms to recompute). But it may be hard to implement, so IMO we should merge this without even trying to implement it as a Weighting, and we can always improve later. @PierreQuinton what do you think about that?

Comment on lines +14 to +15
#: Default small constant added to denominators for numerical stability.
DEFAULT_GRADVAC_EPS = 1e-8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need that to be stored in a constant (we never do that for the default value of the params of the other aggregators).

Comment on lines +8 to +17
The constructor argument ``group_type`` (default ``0``) sets **parameter granularity** for the
per-block cosine statistics in GradVac:

* ``0`` — **whole model** (``whole_model``): one block per task gradient row. Omit ``encoder`` and
``shared_params``.
* ``1`` — **all layer** (``all_layer``): one block per leaf submodule with parameters under
``encoder`` (same traversal as ``encoder.modules()`` in the reference formulation).
* ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in order. Use
the same tensors as for the shared-parameter Jacobian columns (e.g. the parameters you would pass
to a shared-gradient helper).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is already included in the built documentation by .. autoclass:: torchjd.aggregation.GradVac, so it ends up being duplicated. Btw to look at the built documentation, you can run:

uv run make clean -C docs
uv run make html -C docs

and then open docs/build/html/index.html with a web browser.

device = grads.device
dtype = grads.dtype
self._ensure_state(m, n, sizes, device, dtype)
assert self._rho_t is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assert also cannot fail I think, so we can remove it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can return the self._rho_t from self._ensure_state().

     def _ensure_state(
        self,
        m: int,
        n: int,
        sizes: tuple[int, ...],
        device: torch.device,
        dtype: torch.dtype,
    ) -> Tensor:
        key = (m, n, sizes, device, dtype)
        num_groups = len(sizes)
        if self._state_key != key or self._phi_t is None:
            phi = torch.zeros(m, m, num_groups, device=device, dtype=dtype)
            self._phi_t = phi
            self._state_key = key
            return phi
        return self._phi_t
Suggested change
assert self._rho_t is not None
phi_t = self._ensure_state(m, n, sizes, device, dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of changing internal state + returning it. I think the intention of the function becomes a bit harder to understand.

If you just get rid of the assert, does it work or do you get a problem reported by ty? If so we could maybe just cast, or even keep the assert.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then we initialize the state but don't return it:

def _ensure_state(
        self,
        m: int,
        n: int,
        sizes: tuple[int, ...],
        device: torch.device,
        dtype: torch.dtype,
    ) -> None:
        key = (m, n, sizes, device, dtype)
        num_groups = len(sizes)
        if self._state_key != key or self._phi_t is None:
            self._phi_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype)
            self._state_key = key

and in forward:

self._ensure_state(m, n, sizes, device, dtype)
phi_t = cast(Tensor, self._phi_t)

cast here is to make sure the data type is Tensor.

Massively Multilingual Models (ICLR 2021 Spotlight)
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.

The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torchjd we usually denote the Jacobian as J, the number of objectives as (lowercase) m, and the number of parameters as (lowercase) n.

Suggested change
The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task
The input matrix is a Jacobian :math:`J \in \mathbb{R}^{m \times n}` whose rows are per-task

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I will also change the other two lines that used D.

each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets
:math:`\bar{\rho}_{ijk}` are computed **per block** rather than only globally:

* ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* ``0``**whole model** (``whole_model``): the full row of length :math:`D` is a single block.
* ``0``**whole model** (``whole_model``): the full row of length :math:`n` is a single block.



def test_eps_can_be_changed_between_steps() -> None:
j = tensor([[1.0, 0.0], [0.0, 1.0]])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor here is not gonna be affected by DEVICE and DTYPE. Please use tensor_ (from tests/utils/tensors.py) to ensure that the tensor's device and dtype are gonna change when we change the PYTEST_TORCH_DEVICE and PYTEST_TORCH_DTYPE variables.

Suggested change
j = tensor([[1.0, 0.0], [0.0, 1.0]])
J = tensor_([[1.0, 0.0], [0.0, 1.0]])

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Was the capital J intended to indicate a constant variable or a typo?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we always use uppercase J for the Jacobian, because it matches the mathematical notation. Same as G for the Gramian. Also lowercase j is generally just an index in a for loop.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.



def test_group_type_0_rejects_shared_params() -> None:
p = nn.Parameter(tensor([1.0]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
p = nn.Parameter(tensor([1.0]))
p = nn.Parameter(tensor_([1.0]))

Comment on lines +145 to +146
out = GradVac()(tensor([]).reshape(0, 3))
assert_close(out, tensor([0.0, 0.0, 0.0]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out = GradVac()(tensor([]).reshape(0, 3))
assert_close(out, tensor([0.0, 0.0, 0.0]))
out = GradVac()(tensor_([]).reshape(0, 3))
assert_close(out, tensor_([0.0, 0.0, 0.0]))

def test_zero_columns_returns_zero_vector() -> None:
"""Handled inside forward before grouping validation."""

out = GradVac()(tensor([]).reshape(2, 0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out = GradVac()(tensor([]).reshape(2, 0))
out = GradVac()(tensor_([]).reshape(2, 0))

d = sum(p.numel() for p in net.parameters())
agg = GradVac(group_type=1, encoder=net)
with raises(ValueError, match="Jacobian width"):
agg(tensor([[1.0] * (d - 1), [2.0] * (d - 1)]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
agg(tensor([[1.0] * (d - 1), [2.0] * (d - 1)]))
agg(tensor_([[1.0] * (d - 1), [2.0] * (d - 1)]))

@ValerianRey
Copy link
Copy Markdown
Contributor

Opencode's review was quite low quality, but it mentioned something that I missed: we need a test for GradVac in tests/unit/aggregation/test_values.py.

Similarly, i'd like to have GradVac added to tests/plots/interactive_plotter.py.

@rkhosrowshahi
Copy link
Copy Markdown
Contributor Author

Opencode's review was quite low quality, but it mentioned something that I missed: we need a test for GradVac in tests/unit/aggregation/test_values.py.

Similarly, i'd like to have GradVac added to tests/plots/interactive_plotter.py.

Thanks. I added the GradVac to the code and improved the code a bit to be more user-friendly. See the PCGrad and GradVac in the plot, find the same aggregated gradient. If you liked the changes, I can add to the commit as well.
PCGrad vs. GradVac

@ValerianRey
Copy link
Copy Markdown
Contributor

@rkhosrowshahi Thx for all the updates! feel free to commit and push all the code suggestions and other changes you made!

@PierreQuinton
Copy link
Copy Markdown
Contributor

@rkhosrowshahi Very nice, I like all this. I think for me, we need to remove the groupings, we should add a grouping usage example (possibly not with gradvac aggregator), and in the gradvac documentation page explain that this is just the aggregation part of gradvac, and to get the full thing, you need to have grouping, with a link to the usage example, as well as a code example specific to gradvac.

Lastly, as @ValerianRey mentioned, this is most likely a gramian based aggregator, I think the formula he gave were right and the Gramian operation can be deduced (and will probably be more efficient). So I think this should be made into a Gramian weighting and gramian based aggregator pair (see UPGrad for an example).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants