Skip to content

Commit 1443dcd

Browse files
authored
docs(aggregation): Fix hook example type hint (#342)
1 parent 24a5c54 commit 1443dcd

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

docs/source/examples/monitoring.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ they have a negative inner product).
2929
"""Prints the extracted weights."""
3030
print(f"Weights: {weights}")
3131
32-
def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None:
32+
def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None:
3333
"""Prints the cosine similarity between the aggregation and the average gradient."""
3434
matrix = inputs[0]
3535
gd_output = matrix.mean(dim=0)

tests/doc/test_rst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def print_weights(_, __, weights: torch.Tensor) -> None:
222222
"""Prints the extracted weights."""
223223
print(f"Weights: {weights}")
224224

225-
def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None:
225+
def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None:
226226
"""Prints the cosine similarity between the aggregation and the average gradient."""
227227
matrix = inputs[0]
228228
gd_output = matrix.mean(dim=0)

0 commit comments

Comments
 (0)