diff --git a/src/torchjd/autojac/_transform/grad.py b/src/torchjd/autojac/_transform/grad.py index f6275cad0..c084a33df 100644 --- a/src/torchjd/autojac/_transform/grad.py +++ b/src/torchjd/autojac/_transform/grad.py @@ -37,12 +37,7 @@ def _differentiate(self, grad_outputs: Sequence[Tensor]) -> tuple[Tensor, ...]: return tuple() if len(outputs) == 0: - return tuple( - [ - torch.empty(input.shape, device=input.device, dtype=input.dtype) - for input in inputs - ] - ) + return tuple([torch.zeros_like(input) for input in inputs]) optional_grads = torch.autograd.grad( outputs,