Skip to content

Commit 430a8a2

Browse files
committed
Use Tensor.split in _disunit_gradient
1 parent 0e8add2 commit 430a8a2

1 file changed

Lines changed: 2 additions & 7 deletions

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,8 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
8888
def _disunite_gradient(
8989
gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac]
9090
) -> list[Tensor]:
91-
gradients = list[Tensor]()
92-
start = 0
93-
for jacobian, t in zip(jacobians, tensors, strict=True):
94-
end = start + jacobian[0].numel()
95-
current_gradient_vector = gradient_vector[start:end]
96-
gradients.append(current_gradient_vector.view(t.shape))
97-
start = end
91+
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])
92+
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors)]
9893
return gradients
9994

10095

0 commit comments

Comments
 (0)