Skip to content

Commit 7dda66f

Browse files
committed
Delete jac field instead of setting to None
1 parent 99b4260 commit 7dda66f

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/torchjd/utils/_accumulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def _accumulate_jacs(params: Iterable[Tensor], jacobians: Iterable[Tensor]) -> None:
77
for param, jac in zip(params, jacobians, strict=True):
88
_check_expects_grad(param)
9-
if hasattr(param, "jac") and param.jac is not None:
9+
if hasattr(param, "jac"):
1010
param.jac += jac
1111
else:
1212
# TODO: this could be a serious memory issue

src/torchjd/utils/_jac_to_grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def _disunite_gradient(
6666

6767
def _free_jacs(params: Iterable[Tensor]) -> None:
6868
"""
69-
Clears the ``.jac`` fields of the provided parameters by setting them to ``None``.
69+
Deletes the ``.jac`` field of the provided parameters.
7070
7171
:param params: The parameters whose ``.jac`` fields should be cleared.
7272
"""
7373

7474
for p in params:
75-
p.jac = None
75+
del p.jac

0 commit comments

Comments
 (0)