Skip to content

Commit aacf671

Browse files
committed
Add missing type annotation for chunk_size
1 parent e11de01 commit aacf671

3 files changed

Lines changed: 3 additions & 3 deletions

File tree

tests/unit/autojac/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def test_multiple_tensors() -> None:
268268

269269

270270
@mark.parametrize("chunk_size", [None, 1, 2, 4])
271-
def test_various_valid_chunk_sizes(chunk_size) -> None:
271+
def test_various_valid_chunk_sizes(chunk_size: int | None) -> None:
272272
"""Tests that backward works for various valid values of parallel_chunk_size."""
273273

274274
a1 = tensor_([1.0, 2.0], requires_grad=True)

tests/unit/autojac/test_jac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def test_multiple_tensors() -> None:
268268

269269

270270
@mark.parametrize("chunk_size", [None, 1, 2, 4])
271-
def test_various_valid_chunk_sizes(chunk_size) -> None:
271+
def test_various_valid_chunk_sizes(chunk_size: int | None) -> None:
272272
"""Tests that jac works for various valid values of parallel_chunk_size."""
273273

274274
a1 = tensor_([1.0, 2.0], requires_grad=True)

tests/unit/autojac/test_mtl_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def test_non_scalar_loss_fails() -> None:
378378

379379

380380
@mark.parametrize("chunk_size", [None, 1, 2, 4])
381-
def test_various_valid_chunk_sizes(chunk_size) -> None:
381+
def test_various_valid_chunk_sizes(chunk_size: int | None) -> None:
382382
"""Tests that mtl_backward works for various valid values of parallel_chunk_size."""
383383

384384
p0 = tensor_([1.0, 2.0], requires_grad=True)

0 commit comments

Comments
 (0)