Skip to content

Commit 4bb425a

Browse files
committed
doc: Fix weighting hook registration
1 parent 73c1965 commit 4bb425a

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

docs/source/examples/monitoring.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ they have a negative inner product).
4949
optimizer = SGD(params, lr=0.1)
5050
aggregator = UPGrad()
5151
52-
aggregator.weighting.register_forward_hook(print_weights)
52+
aggregator.weighting.weighting.register_forward_hook(print_weights)
5353
aggregator.register_forward_hook(print_gd_similarity)
5454
5555
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10

tests/doc/test_rst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
300300
optimizer = SGD(params, lr=0.1)
301301
aggregator = UPGrad()
302302

303-
aggregator.weighting.register_forward_hook(print_weights)
303+
aggregator.weighting.weighting.register_forward_hook(print_weights)
304304
aggregator.register_forward_hook(print_gd_similarity)
305305

306306
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10

0 commit comments

Comments
 (0)