Skip to content

Commit 2523462

Browse files
authored
typing: Fix typing error of test parametrizations (#611)
1 parent 18c4fd7 commit 2523462

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/unit/aggregation/test_values.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
],
5252
)
5353

54-
AGGREGATOR_PARAMETRIZATIONS = [
54+
AGGREGATOR_PARAMETRIZATIONS: list[tuple] = [
5555
(AlignedMTL(), J_base, tensor([0.2133, 0.9673, 0.9673])),
5656
(ConFIG(), J_base, tensor([0.1588, 2.0706, 2.0706])),
5757
(Constant(tensor([1.0, 2.0])), J_base, tensor([8.0, 3.0, 3.0])),
@@ -71,7 +71,7 @@
7171
G_base = J_base @ J_base.T
7272
G_Krum = J_Krum @ J_Krum.T
7373

74-
WEIGHTING_PARAMETRIZATIONS = [
74+
WEIGHTING_PARAMETRIZATIONS: list[tuple] = [
7575
(AlignedMTLWeighting(), G_base, tensor([0.5591, 0.4083])),
7676
(ConstantWeighting(tensor([1.0, 2.0])), G_base, tensor([1.0, 2.0])),
7777
(DualProjWeighting(), G_base, tensor([0.6109, 0.5000])),

tests/unit/autojac/test_jac_to_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def dummy_backward_pre_hook(_module, _grad_output) -> Tensor:
191191
assert not _has_forward_hook(module)
192192

193193

194-
_PARAMETRIZATIONS = [
194+
_PARAMETRIZATIONS: list[tuple] = [
195195
(AlignedMTL(), True),
196196
(DualProj(), True),
197197
(IMTLG(), True),

0 commit comments

Comments
 (0)