Skip to content

Commit f0fe529

Browse files
committed
Add kwargs to assert_jac_close and assert_grad_close
These functions really are wrappers around assert_close, so we'd like them to always also take the parameters of assert_close, even if those change in the future, and to have the same default values. So I think kwargs is justified here. Also it's not user facing so the lack of documentation of the expected types will not be visible.
1 parent 430a8a2 commit f0fe529

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

tests/unit/autojac/_asserts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def assert_has_no_jac(t: torch.Tensor) -> None:
1616
assert not hasattr(t, "jac")
1717

1818

19-
def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor) -> None:
19+
def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor, **kwargs) -> None:
2020
assert hasattr(t, "jac")
2121
t_ = cast(TensorWithJac, t)
22-
assert_close(t_.jac, expected_jac)
22+
assert_close(t_.jac, expected_jac, **kwargs)
2323

2424

2525
def assert_has_grad(t: torch.Tensor) -> None:
@@ -30,6 +30,6 @@ def assert_has_no_grad(t: torch.Tensor) -> None:
3030
assert t.grad is None
3131

3232

33-
def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor) -> None:
33+
def assert_grad_close(t: torch.Tensor, expected_grad: torch.Tensor, **kwargs) -> None:
3434
assert t.grad is not None
35-
assert_close(t.grad, expected_grad)
35+
assert_close(t.grad, expected_grad, **kwargs)

0 commit comments

Comments
 (0)