Skip to content

Commit e08e45f

Browse files
committed
[WIP] Fix jac undefined errors
1 parent 7dda66f commit e08e45f

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

src/torchjd/utils/_accumulation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
from collections.abc import Iterable
2+
from typing import cast
23

34
from torch import Tensor
45

6+
from torchjd.utils._tensor_with_jac import TensorWithJac
7+
58

69
def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
710
for param, jac in zip(params, jacobians, strict=True):
811
_check_expects_grad(param)
9-
if hasattr(param, "jac"):
10-
param.jac += jac
12+
if hasattr(param, "jac"): # No check for None because jac cannot be None
13+
param_ = cast(TensorWithJac, param)
14+
param_.jac += jac
1115
else:
1216
# TODO: this could be a serious memory issue
1317
# We clone the value because we do not want subsequent accumulations to also affect
1418
# this value (in case it is still used outside). We do not detach from the
1519
# computation graph because the value can have grad_fn that we want to keep track of
1620
# (in case it was obtained via create_graph=True and a differentiable aggregator).
17-
param.jac = jac.clone()
21+
param.__setattr__("jac", jac.clone())
1822

1923

2024
def _accumulate_grads(params: Iterable[Tensor], gradients: Iterable[Tensor]) -> None:

src/torchjd/utils/_jac_to_grad.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from collections.abc import Iterable
2+
from typing import cast
23

34
import torch
45
from torch import Tensor
56

67
from torchjd.aggregation import Aggregator
78
from torchjd.utils._accumulation import _accumulate_grads
9+
from torchjd.utils._tensor_with_jac import TensorWithJac
810

911

1012
def jac_to_grad(
@@ -20,17 +22,19 @@ def jac_to_grad(
2022
:param retain_jacs: Whether to preserve the ``.jac`` fields of the parameters.
2123
"""
2224

23-
params_ = list(params)
25+
params_ = list[TensorWithJac]()
26+
for p in params:
27+
if not hasattr(p, "jac"):
28+
raise ValueError(
29+
"Some `jac` fields were not populated. Did you use `autojac.backward` before"
30+
"calling `jac_to_grad`?"
31+
)
32+
p_ = cast(TensorWithJac, p)
33+
params_.append(p_)
2434

2535
if len(params_) == 0:
2636
return
2737

28-
if not all([hasattr(p, "jac") and p.jac is not None for p in params_]):
29-
raise ValueError(
30-
"Some `jac` fields were not populated. Did you use `autojac.backward` before calling "
31-
"`jac_to_grad`?"
32-
)
33-
3438
jacobians = [p.jac for p in params_]
3539

3640
# TODO: check that the Jacobian shapes match
@@ -51,7 +55,7 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
5155

5256

5357
def _disunite_gradient(
54-
gradient_vector: Tensor, jacobians: list[Tensor], params: list[Tensor]
58+
gradient_vector: Tensor, jacobians: list[Tensor], params: list[TensorWithJac]
5559
) -> list[Tensor]:
5660
gradient_vectors = []
5761
start = 0
@@ -64,7 +68,7 @@ def _disunite_gradient(
6468
return gradients
6569

6670

67-
def _free_jacs(params: Iterable[Tensor]) -> None:
71+
def _free_jacs(params: Iterable[TensorWithJac]) -> None:
6872
"""
6973
Deletes the ``.jac`` field of the provided parameters.
7074
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torch import Tensor
2+
3+
4+
class TensorWithJac(Tensor):
5+
"""
6+
Tensor known to have a populated jac field.
7+
8+
Should not be directly instantiated, but can be used as a type hint and can be casted to.
9+
"""
10+
11+
jac: Tensor

0 commit comments

Comments
 (0)