Skip to content

Commit f693e99

Browse files
committed
Improve error message for cat_default
1 parent 59bcf06 commit f693e99

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

  • src/torchjd/sparse/_aten_function_overrides

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,11 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
175175
ref_tensor = tensors_[0]
176176
ref_strides = ref_tensor.strides
177177
if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]):
178-
raise NotImplementedError()
178+
raise NotImplementedError(
179+
"Override for aten.cat.default does not support SSTs that do not all have the same "
180+
f"strides. Found the following strides:\n{[t.strides for t in tensors_]} and the "
181+
f"following dim: {dim}."
182+
)
179183

180184
# We need to try to find the (pretty sure it either does not exist or is unique) physical
181185
# dimension that makes us only move on virtual dimension dim. It also needs to be such that

0 commit comments

Comments
 (0)