Skip to content

Commit fac9c72

Browse files
committed
Fix creation of SST in autogram
1 parent fc3339c commit fac9c72

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/torchjd/autogram/_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
175175
)
176176

177177
output_dims = list(range(output.ndim))
178-
v_to_ps = [[dim] for dim in output_dims * 2]
179-
jac_output = make_sst(torch.ones_like(output), v_to_ps)
178+
identity = torch.eye(output.ndim, dtype=torch.int64)
179+
strides = torch.concatenate([identity, identity.clone()], dim=0)
180+
jac_output = make_sst(torch.ones_like(output), strides)
180181

181182
vmapped_diff = differentiation
182183
for _ in output_dims:

0 commit comments

Comments
 (0)