Skip to content

Commit b60d75a

Browse files
Fix Nx.gather error message for scalar indices
Nx.gather(tensor, scalar) gave an unhelpful Erlang error because indexed_axes tried to access elem({}, -1) before the shape validation in Nx.Shape.gather could fire. Moved the scalar check earlier. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 11c6cc5 commit b60d75a

2 files changed

Lines changed: 19 additions & 0 deletions

File tree

nx/lib/nx.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14602,6 +14602,10 @@ defmodule Nx do
1460214602
[%T{vectorized_axes: vectorized_axes} = tensor, indices] =
1460314603
broadcast_vectors([tensor, indices], align_ranks: false)
1460414604

14605+
if indices.shape == {} do
14606+
raise ArgumentError, "expected indices rank to be at least 1, got: 0"
14607+
end
14608+
1460514609
axes = indexed_axes(tensor, indices, opts)
1460614610

1460714611
unless Nx.Type.integer?(indices.type) do
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
defmodule Nx.GatherScalarErrorTest do
2+
use ExUnit.Case, async: true
3+
4+
test "gather raises correct error on scalar indices" do
5+
assert_raise ArgumentError, ~r/expected indices rank to be at least 1/, fn ->
6+
Nx.gather(Nx.iota({3}), Nx.tensor(0))
7+
end
8+
end
9+
10+
test "gather with valid indices still works" do
11+
t = Nx.iota({3, 4})
12+
result = Nx.gather(t, Nx.tensor([[0, 0], [2, 3]]))
13+
assert Nx.to_flat_list(result) == [0, 11]
14+
end
15+
end

0 commit comments

Comments
 (0)