Skip to content

Commit 8a0fb0e

Browse files
committed
Make _disunite_gradient use less memory
1 parent 674f6ad commit 8a0fb0e

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,13 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
8888
def _disunite_gradient(
8989
gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac]
9090
) -> list[Tensor]:
91-
gradient_vectors = []
91+
gradients = list[Tensor]()
9292
start = 0
93-
for jacobian in jacobians:
93+
for jacobian, t in zip(jacobians, tensors, strict=True):
9494
end = start + jacobian[0].numel()
9595
current_gradient_vector = gradient_vector[start:end]
96-
gradient_vectors.append(current_gradient_vector)
96+
gradients.append(current_gradient_vector.view(t.shape))
9797
start = end
98-
gradients = [g.view(t.shape) for t, g in zip(tensors, gradient_vectors, strict=True)]
9998
return gradients
10099

101100

0 commit comments

Comments
 (0)