Skip to content

Commit 1d319b4

Browse files
Fix Nx.gather error message for scalar indices
Nx.gather(tensor, scalar) gave an unhelpful Erlang error because indexed_axes tried to access elem({}, -1) before the shape validation in Nx.Shape.gather could fire. Moved the scalar check earlier. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 11c6cc5 commit 1d319b4

2 files changed

Lines changed: 13 additions & 0 deletions

File tree

nx/lib/nx.ex

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14594,6 +14594,9 @@ defmodule Nx do
1459414594
1459514595
iex> Nx.gather(Nx.tensor([[1, 2], [3, 4]]), Nx.tensor([[0, 0]], type: :f32))
1459614596
** (ArgumentError) indices must be an integer tensor, got {:f, 32}
14597+
14598+
iex> Nx.gather(Nx.iota({3}), Nx.tensor(0))
14599+
** (ArgumentError) expected indices rank to be at least 1, got: 0
1459714600
"""
1459814601
@doc type: :indexed
1459914602
def gather(tensor, indices, opts \\ []) do
@@ -14602,6 +14605,10 @@ defmodule Nx do
1460214605
[%T{vectorized_axes: vectorized_axes} = tensor, indices] =
1460314606
broadcast_vectors([tensor, indices], align_ranks: false)
1460414607

14608+
if indices.shape == {} do
14609+
raise ArgumentError, "expected indices rank to be at least 1, got: 0"
14610+
end
14611+
1460514612
axes = indexed_axes(tensor, indices, opts)
1460614613

1460714614
unless Nx.Type.integer?(indices.type) do

nx/test/nx_test.exs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,6 +2863,12 @@ defmodule NxTest do
28632863
Nx.gather(t, Nx.tensor([[0, -1]]))
28642864
end
28652865
end
2866+
2867+
test "raises correct error on scalar indices" do
2868+
assert_raise ArgumentError, ~r/expected indices rank to be at least 1/, fn ->
2869+
Nx.gather(Nx.iota({3}), Nx.tensor(0))
2870+
end
2871+
end
28662872
end
28672873

28682874
describe "variance/1" do

0 commit comments

Comments
 (0)