Skip to content

Commit 1e5532f

Browse files
blasphemetheusclaudepolvalente
authored
Fix window_scatter_max/min crash on f64 tensors (#1711)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent 674d3bb commit 1e5532f

2 files changed

Lines changed: 27 additions & 7 deletions

File tree

nx/lib/nx/binary_backend.ex

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,11 +1631,7 @@ defmodule Nx.BinaryBackend do
16311631

16321632
# Compute absolute index in padded space, then adjust back to
16331633
# original (unpadded) coordinates by subtracting low-padding
1634-
padded_absolute_index =
1635-
anchor
1636-
|> Enum.zip(offset_from_anchor)
1637-
|> Enum.map(fn {x, y} -> x + y end)
1638-
1634+
padded_absolute_index = Enum.zip_with(anchor, offset_from_anchor, &+/2)
16391635
absolute_index = Enum.zip_with(padded_absolute_index, low_pads, &-/2)
16401636

16411637
source_consumed = i * source_size
@@ -1664,7 +1660,8 @@ defmodule Nx.BinaryBackend do
16641660
|> Enum.group_by(&elem(&1, 1), &elem(&1, 0))
16651661
|> Enum.map(fn {index, value} ->
16661662
offset = weighted_offset(output_weighted_shape, index)
1667-
{offset, Enum.reduce(value, init_value, scatter_fn)}
1663+
tensor = Enum.reduce(value, init_value, scatter_fn)
1664+
{offset, scalar_to_number(tensor)}
16681665
end)
16691666
|> Enum.sort_by(&elem(&1, 0))
16701667

@@ -1675,7 +1672,12 @@ defmodule Nx.BinaryBackend do
16751672
{acc_offset, acc_binary} ->
16761673
num_vals_before = div(offset - acc_offset, output_size)
16771674
vals_before = List.duplicate(init_binary, num_vals_before)
1678-
source_val = to_binary(value)
1675+
1676+
source_val =
1677+
match_types [output_type] do
1678+
<<write!(value, 0)>>
1679+
end
1680+
16791681
new_binary = :erlang.list_to_bitstring([vals_before, source_val])
16801682

16811683
{offset + output_size, <<acc_binary::bitstring, new_binary::bitstring>>}

nx/test/nx_test.exs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,24 @@ defmodule NxTest do
10591059
])
10601060
end
10611061

1062+
test "computes window scatter max with f64" do
1063+
t = Nx.iota({6}, type: :f64)
1064+
s = Nx.iota({3}, type: :f64)
1065+
init = Nx.tensor(0.0, type: :f64)
1066+
result = Nx.window_scatter_max(t, s, init, {2}, strides: [2], padding: :valid)
1067+
assert Nx.type(result) == {:f, 64}
1068+
assert Nx.shape(result) == {6}
1069+
end
1070+
1071+
test "computes window scatter min with f64" do
1072+
t = Nx.iota({6}, type: :f64)
1073+
s = Nx.iota({3}, type: :f64)
1074+
init = Nx.tensor(0.0, type: :f64)
1075+
result = Nx.window_scatter_min(t, s, init, {2}, strides: [2], padding: :valid)
1076+
assert Nx.type(result) == {:f, 64}
1077+
assert Nx.shape(result) == {6}
1078+
end
1079+
10621080
test "computes window reduce (sum of squares) with same padding" do
10631081
t = Nx.iota({4, 4}, type: {:f, 32})
10641082

0 commit comments

Comments
 (0)