diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index fc5f55dcc3..b8a298c37e 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -13654,6 +13654,16 @@ defmodule Nx do iex> Nx.slice(Nx.tensor([[1, 2, 3], [4, 5, 6]]), [Nx.tensor(1.0), Nx.tensor(0)], [1, 1]) ** (ArgumentError) index must be integer type, got {:f, 32} for axis 0 + + ## Scalars + + Slicing a scalar tensor returns the scalar itself: + + iex> Nx.slice(Nx.tensor(42), [], []) + #Nx.Tensor< + s32 + 42 + > """ @doc type: :indexed def slice(tensor, start_indices, lengths, opts \\ []) @@ -13661,6 +13671,15 @@ defmodule Nx do opts = keyword!(opts, strides: 1) %T{vectorized_axes: vectorized_axes, shape: shape} = tensor = to_tensor(tensor) + # Slicing a scalar tensor is a no-op — return unchanged + if shape == {} and start_indices == [] and lengths == [] do + tensor + else + slice_non_scalar(tensor, start_indices, lengths, opts, vectorized_axes, shape) + end + end + + defp slice_non_scalar(tensor, start_indices, lengths, opts, vectorized_axes, shape) do if Enum.any?(start_indices, &(is_struct(&1, T) and &1.vectorized_axes != [])) do # if any of the indices is vectorized, we instead treat this slice as a gather [%{vectorized_axes: [{first_axis, _} | _] = vectorized_axes} | _] = diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 5d42216132..578f435dbb 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -3574,4 +3574,18 @@ defmodule NxTest do assert 28 = Nx.bit_size(tensor) end end + + describe "slice of scalar tensor" do + test "returns scalar" do + t = Nx.tensor(42) + result = Nx.slice(t, [], []) + assert Nx.to_number(result) == 42 + end + + test "slice of scalar f64 tensor" do + t = Nx.tensor(3.14, type: :f64) + result = Nx.slice(t, [], []) + assert_in_delta Nx.to_number(result), 3.14, 1.0e-10 + end + end end