refactor(aggregation): Make Stateful an iff and add Stochastic#656
refactor(aggregation): Make Stateful an iff and add Stochastic#656PierreQuinton wants to merge 11 commits intomainfrom
Conversation
| class GradVac(GramianWeightedAggregator, Stochastic): | ||
| r""" | ||
| :class:`~torchjd.aggregation._mixins.Stateful` | ||
| :class:`~torchjd.aggregation._mixins.Stochastic` |
There was a problem hiding this comment.
I'd like this to inherit from both Stochastic and Stateful, if that's possible. Or at least to say in the docstring: "Stateful Stochastic Aggregator ..."
| def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None: | ||
| weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed) | ||
| GramianWeightedAggregator.__init__(self, weighting) | ||
| Stochastic.__init__(self, generator=weighting.generator) |
There was a problem hiding this comment.
I think this is very weird. We should have a single generator, in the weighting, instead of 1 in the weighting and 1 in the aggregator. Similarly, the reset method of the aggregator should not call Stochastic.reset because it is already called when resetting the weighting.
There was a problem hiding this comment.
So the aggregator should not be Stochastic?
There was a problem hiding this comment.
I think there should be a distinction between being stochastic (what the aggregator is) and directly owning a generator (what the aggregator shouldn't do).
There was a problem hiding this comment.
Not sure how we can do that nicely.
| def __init__(self, seed: int | None = None) -> None: | ||
| weighting = PCGradWeighting(seed=seed) | ||
| GramianWeightedAggregator.__init__(self, weighting) | ||
| Stochastic.__init__(self, generator=weighting.generator) |
There was a problem hiding this comment.
Same comment as in GradVac.
| def __init__(self, seed: int | None = None) -> None: | ||
| weighting = RandomWeighting(seed=seed) | ||
| WeightedAggregator.__init__(self, weighting) | ||
| Stochastic.__init__(self, generator=weighting.generator) |
There was a problem hiding this comment.
Same comment as in GradVac
This comment was marked as resolved.
This comment was marked as resolved.
Could be because we have 2 generators per aggregator (see my other comments). Or maybe the Stochastic.reset method doesn't behave as expected. |
| if generator is not None: | ||
| self.generator = generator | ||
| else: | ||
| self.generator = torch.Generator() |
There was a problem hiding this comment.
Generator requires a device, so this wont work for cuda I think. And we don't know the device at that point, so I don't think it's easy to fix.
There was a problem hiding this comment.
This leads to RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'
Need a different implementation for Stochastic I think.
This is failing, not sure why. @ValerianRey if you have any ideas I would take it. Also I think RNG is handled differently on CPU and GPU, so maybe we need to be careful about that.