We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 59bcf06 commit f693e99Copy full SHA for f693e99
1 file changed
src/torchjd/sparse/_aten_function_overrides/shape.py
@@ -175,7 +175,11 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
175
ref_tensor = tensors_[0]
176
ref_strides = ref_tensor.strides
177
if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]):
178
- raise NotImplementedError()
+ raise NotImplementedError(
179
+ "Override for aten.cat.default does not support SSTs that do not all have the same "
180
+ f"strides. Found the following strides:\n{[t.strides for t in tensors_]} and the "
181
+ f"following dim: {dim}."
182
+ )
183
184
# We need to try to find the (pretty sure it either does not exist or is unique) physical
185
# dimension that makes us only move on virtual dimension dim. It also needs to be such that
0 commit comments