Skip to content

Commit b0adff6

Browse files
refactor(autojac): Use TypeGuard for TensorWithJac (#521)
1 parent b280b10 commit b0adff6

3 files changed

Lines changed: 16 additions & 19 deletions

File tree

src/torchjd/autojac/_accumulation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Iterable
2-
from typing import cast
2+
from typing import TypeGuard
33

44
from torch import Tensor
55

@@ -14,6 +14,10 @@ class TensorWithJac(Tensor):
1414
jac: Tensor
1515

1616

17+
def is_tensor_with_jac(t: Tensor) -> TypeGuard[TensorWithJac]:
18+
return hasattr(t, "jac")
19+
20+
1721
def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
1822
for param, jac in zip(params, jacobians, strict=True):
1923
_check_expects_grad(param, field_name=".jac")
@@ -26,9 +30,8 @@ def accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> No
2630
" jacobian are the same size"
2731
)
2832

29-
if hasattr(param, "jac"): # No check for None because jac cannot be None
30-
param_ = cast(TensorWithJac, param)
31-
param_.jac += jac
33+
if is_tensor_with_jac(param):
34+
param.jac += jac
3235
else:
3336
# We do not clone the value to save memory and time, so subsequent modifications of
3437
# the value of key.jac (subsequent accumulations) will also affect the value of

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from collections.abc import Iterable
2-
from typing import cast
32

43
import torch
54
from torch import Tensor
65

76
from torchjd.aggregation import Aggregator
87

9-
from ._accumulation import TensorWithJac, accumulate_grads
8+
from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
109

1110

1211
def jac_to_grad(
@@ -54,13 +53,12 @@ def jac_to_grad(
5453

5554
tensors_ = list[TensorWithJac]()
5655
for t in tensors:
57-
if not hasattr(t, "jac"):
56+
if not is_tensor_with_jac(t):
5857
raise ValueError(
5958
"Some `jac` fields were not populated. Did you use `autojac.backward` or "
6059
"`autojac.mtl_backward` before calling `jac_to_grad`?"
6160
)
62-
t_ = cast(TensorWithJac, t)
63-
tensors_.append(t_)
61+
tensors_.append(t)
6462

6563
if len(tensors_) == 0:
6664
return

tests/utils/asserts.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
1-
from typing import cast
2-
31
import torch
42
from torch.testing import assert_close
53

64
from torchjd._linalg.matrix import PSDMatrix
7-
from torchjd.autojac._accumulation import TensorWithJac
5+
from torchjd.autojac._accumulation import is_tensor_with_jac
86

97

108
def assert_has_jac(t: torch.Tensor) -> None:
11-
assert hasattr(t, "jac")
12-
t_ = cast(TensorWithJac, t)
13-
assert t_.jac is not None and t_.jac.shape[1:] == t_.shape
9+
assert is_tensor_with_jac(t)
10+
assert t.jac is not None and t.jac.shape[1:] == t.shape
1411

1512

1613
def assert_has_no_jac(t: torch.Tensor) -> None:
17-
assert not hasattr(t, "jac")
14+
assert not is_tensor_with_jac(t)
1815

1916

2017
def assert_jac_close(t: torch.Tensor, expected_jac: torch.Tensor, **kwargs) -> None:
21-
assert hasattr(t, "jac")
22-
t_ = cast(TensorWithJac, t)
23-
assert_close(t_.jac, expected_jac, **kwargs)
18+
assert is_tensor_with_jac(t)
19+
assert_close(t.jac, expected_jac, **kwargs)
2420

2521

2622
def assert_has_grad(t: torch.Tensor) -> None:

0 commit comments

Comments
 (0)