Skip to content

Commit 3cba2e1

Browse files
Fix window_scatter_max/min crash on f64 tensors
The scatter result was not cast to the output type before to_binary, causing a binary size mismatch for f64 (8-byte) tensors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 11c6cc5 commit 3cba2e1

3 files changed

Lines changed: 29 additions & 1 deletion

File tree

nx/lib/nx.ex

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7430,6 +7430,16 @@ defmodule Nx do
74307430
]
74317431
]
74327432
>
7433+
7434+
It also works with f64 tensors:
7435+
7436+
iex> t = Nx.iota({6}, type: :f64)
7437+
iex> s = Nx.iota({3}, type: :f64)
7438+
iex> Nx.window_scatter_max(t, s, 0.0, {2}, strides: [2], padding: :valid)
7439+
#Nx.Tensor<
7440+
f64[6]
7441+
[0.0, 0.0, 0.0, 1.0, 0.0, 2.0]
7442+
>
74337443
"""
74347444
@doc type: :window
74357445
def window_scatter_max(tensor, source, init_value, window_dimensions, opts \\ []) do

nx/lib/nx/binary_backend.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1675,7 +1675,7 @@ defmodule Nx.BinaryBackend do
16751675
{acc_offset, acc_binary} ->
16761676
num_vals_before = div(offset - acc_offset, output_size)
16771677
vals_before = List.duplicate(init_binary, num_vals_before)
1678-
source_val = to_binary(value)
1678+
source_val = value |> Nx.as_type(output_type) |> to_binary()
16791679
new_binary = :erlang.list_to_bitstring([vals_before, source_val])
16801680

16811681
{offset + output_size, <<acc_binary::bitstring, new_binary::bitstring>>}

nx/test/nx_test.exs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,24 @@ defmodule NxTest do
10591059
])
10601060
end
10611061

1062+
test "computes window scatter max with f64" do
1063+
t = Nx.iota({6}, type: :f64)
1064+
s = Nx.iota({3}, type: :f64)
1065+
init = Nx.tensor(0.0, type: :f64)
1066+
result = Nx.window_scatter_max(t, s, init, {2}, strides: [2], padding: :valid)
1067+
assert Nx.type(result) == {:f, 64}
1068+
assert Nx.shape(result) == {6}
1069+
end
1070+
1071+
test "computes window scatter min with f64" do
1072+
t = Nx.iota({6}, type: :f64)
1073+
s = Nx.iota({3}, type: :f64)
1074+
init = Nx.tensor(0.0, type: :f64)
1075+
result = Nx.window_scatter_min(t, s, init, {2}, strides: [2], padding: :valid)
1076+
assert Nx.type(result) == {:f, 64}
1077+
assert Nx.shape(result) == {6}
1078+
end
1079+
10621080
test "computes window reduce (sum of squares) with same padding" do
10631081
t = Nx.iota({4, 4}, type: {:f, 32})
10641082

0 commit comments

Comments
 (0)