Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13654,13 +13654,32 @@ defmodule Nx do

iex> Nx.slice(Nx.tensor([[1, 2, 3], [4, 5, 6]]), [Nx.tensor(1.0), Nx.tensor(0)], [1, 1])
** (ArgumentError) index must be integer type, got {:f, 32} for axis 0

## Scalars

Slicing a scalar tensor returns the scalar itself:

iex> Nx.slice(Nx.tensor(42), [], [])
#Nx.Tensor<
s32
42
>
"""
@doc type: :indexed
def slice(tensor, start_indices, lengths, opts \\ [])
when is_list(start_indices) and is_list(lengths) and is_list(opts) do
opts = keyword!(opts, strides: 1)
%T{vectorized_axes: vectorized_axes, shape: shape} = tensor = to_tensor(tensor)

# Slicing a scalar tensor is a no-op — return unchanged
if shape == {} and start_indices == [] and lengths == [] do
tensor
else
slice_non_scalar(tensor, start_indices, lengths, opts, vectorized_axes, shape)
end
end

defp slice_non_scalar(tensor, start_indices, lengths, opts, vectorized_axes, shape) do
if Enum.any?(start_indices, &(is_struct(&1, T) and &1.vectorized_axes != [])) do
# if any of the indices is vectorized, we instead treat this slice as a gather
[%{vectorized_axes: [{first_axis, _} | _] = vectorized_axes} | _] =
Expand Down
14 changes: 14 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3574,4 +3574,18 @@ defmodule NxTest do
assert 28 = Nx.bit_size(tensor)
end
end

describe "slice of scalar tensor" do
test "returns scalar" do
t = Nx.tensor(42)
result = Nx.slice(t, [], [])
assert Nx.to_number(result) == 42
end

test "slice of scalar f64 tensor" do
t = Nx.tensor(3.14, type: :f64)
result = Nx.slice(t, [], [])
assert_in_delta Nx.to_number(result), 3.14, 1.0e-10
end
end
end
Loading