Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/examples/monitoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ they have a negative inner product).
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()

aggregator.weighting.weighting.register_forward_hook(print_weights)
aggregator.gramian_weighting.register_forward_hook(print_weights)
aggregator.register_forward_hook(print_gd_similarity)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
Expand Down
8 changes: 5 additions & 3 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ class GramianWeightedAggregator(WeightedAggregator):
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
Weighting to it.

:param weighting: The object responsible for extracting the vector of weights from the gramian.
:param gramian_weighting: The object responsible for extracting the vector of weights from the
gramian.
"""

def __init__(self, weighting: Weighting[PSDMatrix]):
super().__init__(weighting << compute_gramian)
def __init__(self, gramian_weighting: Weighting[PSDMatrix]):
super().__init__(gramian_weighting << compute_gramian)
self.gramian_weighting = gramian_weighting
2 changes: 1 addition & 1 deletion tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()

aggregator.weighting.weighting.register_forward_hook(print_weights)
aggregator.gramian_weighting.register_forward_hook(print_weights)
aggregator.register_forward_hook(print_gd_similarity)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
Expand Down