Skip to content

Commit 81648e5

Browse files
authored
refactor(aggregation): Add gramian_weighting field (#533)
* Add gramian_weighting field to GramianWeightedAggregator * Rename weighting to gramian_weighting in GramianWeightedAggregator * Use gramian_weighting in monitoring example * Add changelog entry
1 parent e443fc9 commit 81648e5

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ changelog does not include internal changes that do not affect the user.
1212

1313
- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
1414
between `"min"`, `"median"`, and `"rmse"` scaling.
15+
- Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`.
16+
Usage is still the same, `aggregator.gramian_weighting` is just an alias for the (quite confusing)
17+
`aggregator.weighting.weighting` field.
1518

1619
### Changed
1720

docs/source/examples/monitoring.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ they have a negative inner product).
4949
optimizer = SGD(params, lr=0.1)
5050
aggregator = UPGrad()
5151
52-
aggregator.weighting.weighting.register_forward_hook(print_weights)
52+
aggregator.gramian_weighting.register_forward_hook(print_weights)
5353
aggregator.register_forward_hook(print_gd_similarity)
5454
5555
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ class GramianWeightedAggregator(WeightedAggregator):
7373
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
7474
Weighting to it.
7575
76-
:param weighting: The object responsible for extracting the vector of weights from the gramian.
76+
:param gramian_weighting: The object responsible for extracting the vector of weights from the
77+
gramian.
7778
"""
7879

79-
def __init__(self, weighting: Weighting[PSDMatrix]):
80-
super().__init__(weighting << compute_gramian)
80+
def __init__(self, gramian_weighting: Weighting[PSDMatrix]):
81+
super().__init__(gramian_weighting << compute_gramian)
82+
self.gramian_weighting = gramian_weighting

tests/doc/test_rst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
308308
optimizer = SGD(params, lr=0.1)
309309
aggregator = UPGrad()
310310

311-
aggregator.weighting.weighting.register_forward_hook(print_weights)
311+
aggregator.gramian_weighting.register_forward_hook(print_weights)
312312
aggregator.register_forward_hook(print_gd_similarity)
313313

314314
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10

0 commit comments

Comments
 (0)