Skip to content

Commit b1f27e9

Browse files
Fix Nx.slice crash on scalar tensor
Slicing a scalar tensor is a valid no-op — return the tensor unchanged when shape is {} and start_indices/lengths are empty. The check is done in Nx.slice itself (not in BinaryBackend) so all backends get the fix without needing separate implementations. NumPy does the same: np.array(5)[()] returns 5. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c9406de commit b1f27e9

3 files changed

Lines changed: 10 additions & 3 deletions

File tree

nx/lib/nx.ex

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13670,6 +13670,15 @@ defmodule Nx do
1367013670
opts = keyword!(opts, strides: 1)
1367113671
%T{vectorized_axes: vectorized_axes, shape: shape} = tensor = to_tensor(tensor)
1367213672

13673+
# Slicing a scalar tensor is a no-op — return unchanged
13674+
if shape == {} and start_indices == [] and lengths == [] do
13675+
tensor
13676+
else
13677+
slice_non_scalar(tensor, start_indices, lengths, opts, vectorized_axes, shape)
13678+
end
13679+
end
13680+
13681+
defp slice_non_scalar(tensor, start_indices, lengths, opts, vectorized_axes, shape) do
1367313682
if Enum.any?(start_indices, &(is_struct(&1, T) and &1.vectorized_axes != [])) do
1367413683
# if any of the indices is vectorized, we instead treat this slice as a gather
1367513684
[%{vectorized_axes: [{first_axis, _} | _] = vectorized_axes} | _] =

nx/lib/nx/binary_backend.ex

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,8 +1849,6 @@ defmodule Nx.BinaryBackend do
18491849
|> then(&from_binary(out, &1))
18501850
end
18511851

1852-
defp bin_slice(data, _shape, _size, [], [], [], _output_shape), do: data
1853-
18541852
defp bin_slice(data, shape, size, start_indices, lengths, strides, output_shape) do
18551853
start_indices = clamp_indices(start_indices, shape, lengths)
18561854

nx/test/nx_test.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3585,7 +3585,7 @@ defmodule NxTest do
35853585
test "slice of scalar f64 tensor" do
35863586
t = Nx.tensor(3.14, type: :f64)
35873587
result = Nx.slice(t, [], [])
3588-
assert_in_delta Nx.to_number(result), 3.15, 1.0e-10
3588+
assert_in_delta Nx.to_number(result), 3.14, 1.0e-10
35893589
end
35903590
end
35913591
end

0 commit comments

Comments
 (0)