Skip to content

Commit 3e9e7d4

Browse files
committed
Fix usage of unsqueeze on SLT to call unsqueeze_default instead
1 parent 2e641c7 commit 3e9e7d4

File tree

1 file changed

+1
-1
lines changed
  • src/torchjd/sparse/_aten_function_overrides

1 file changed

+1
-1
lines changed

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT
223223

224224
# Add as many dimensions as needed at the beginning of the tensor (as torch.expand works)
225225
for _ in range(len(sizes) - t.ndim):
226-
t = t.unsqueeze(0)
226+
t = unsqueeze_default(t, 0)
227227

228228
# Try to expand each dimension to its new size
229229
new_physical = t.physical

0 commit comments

Comments
 (0)