Skip to content

refactor(aggregation): Make Stateful an iff and add Stochastic#656

Open
PierreQuinton wants to merge 11 commits intomainfrom
refactor-stateful
Open

refactor(aggregation): Make Stateful an iff and add Stochastic#656
PierreQuinton wants to merge 11 commits intomainfrom
refactor-stateful

Conversation

@PierreQuinton
Copy link
Copy Markdown
Contributor

This is failing, not sure why. @ValerianRey if you have any ideas I would take it. Also I think RNG is handled differently on CPU and GPU, so maybe we need to be careful about that.

@PierreQuinton PierreQuinton requested review from a team and ValerianRey as code owners April 17, 2026 08:12
@PierreQuinton PierreQuinton added package: aggregation cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements labels Apr 17, 2026
@github-actions github-actions bot changed the title Make Stateful an iff and add Stochastic refactor(aggregation): Make Stateful an iff and add Stochastic Apr 17, 2026
@github-actions github-actions bot changed the title Make Stateful an iff and add Stochastic refactor(aggregation): Make Stateful an iff and add Stochastic Apr 17, 2026
Comment thread src/torchjd/aggregation/_gradvac.py Outdated
Comment on lines +16 to +18
class GradVac(GramianWeightedAggregator, Stochastic):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._mixins.Stochastic`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like this to inherit from both Stochastic and Stateful, if that's possible. Or at least to say in the docstring: "Stateful Stochastic Aggregator ..."

Comment thread src/torchjd/aggregation/_gradvac.py Outdated
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed)
GramianWeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is very weird. We should have a single generator, in the weighting, instead of 1 in the weighting and 1 in the aggregator. Similarly, the reset method of the aggregator should not call Stochastic.reset because it is already called when resetting the weighting.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the aggregator should not be Stochastic?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be a distinction between being stochastic (what the aggregator is) and directly owning a generator (what the aggregator shouldn't do).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how we can do that nicely.

Comment thread src/torchjd/aggregation/_pcgrad.py Outdated
def __init__(self, seed: int | None = None) -> None:
weighting = PCGradWeighting(seed=seed)
GramianWeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in GradVac.

Comment thread src/torchjd/aggregation/_random.py Outdated
def __init__(self, seed: int | None = None) -> None:
weighting = RandomWeighting(seed=seed)
WeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in GradVac

@ValerianRey

This comment was marked as resolved.

@ValerianRey
Copy link
Copy Markdown
Contributor

if you have any ideas I would take it.

Could be because we have 2 generators per aggregator (see my other comments). Or maybe the Stochastic.reset method doesn't behave as expected.

if generator is not None:
self.generator = generator
else:
self.generator = torch.Generator()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generator requires a device, so this wont work for cuda I think. And we don't know the device at that point, so I don't think it's easy to fix.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

Need a different implementation for Stochastic I think.

Comment thread tests/unit/aggregation/_asserts.py Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants