Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1631,11 +1631,7 @@ defmodule Nx.BinaryBackend do

# Compute absolute index in padded space, then adjust back to
# original (unpadded) coordinates by subtracting low-padding
padded_absolute_index =
anchor
|> Enum.zip(offset_from_anchor)
|> Enum.map(fn {x, y} -> x + y end)

padded_absolute_index = Enum.zip_with(anchor, offset_from_anchor, &+/2)
absolute_index = Enum.zip_with(padded_absolute_index, low_pads, &-/2)

source_consumed = i * source_size
Expand Down Expand Up @@ -1664,7 +1660,8 @@ defmodule Nx.BinaryBackend do
|> Enum.group_by(&elem(&1, 1), &elem(&1, 0))
|> Enum.map(fn {index, value} ->
offset = weighted_offset(output_weighted_shape, index)
{offset, Enum.reduce(value, init_value, scatter_fn)}
tensor = Enum.reduce(value, init_value, scatter_fn)
{offset, scalar_to_number(tensor)}
end)
|> Enum.sort_by(&elem(&1, 0))

Expand All @@ -1675,7 +1672,12 @@ defmodule Nx.BinaryBackend do
{acc_offset, acc_binary} ->
num_vals_before = div(offset - acc_offset, output_size)
vals_before = List.duplicate(init_binary, num_vals_before)
source_val = to_binary(value)

source_val =
match_types [output_type] do
<<write!(value, 0)>>
end

new_binary = :erlang.list_to_bitstring([vals_before, source_val])

{offset + output_size, <<acc_binary::bitstring, new_binary::bitstring>>}
Expand Down
18 changes: 18 additions & 0 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,24 @@ defmodule NxTest do
])
end

test "computes window scatter max with f64" do
t = Nx.iota({6}, type: :f64)
s = Nx.iota({3}, type: :f64)
init = Nx.tensor(0.0, type: :f64)
result = Nx.window_scatter_max(t, s, init, {2}, strides: [2], padding: :valid)
assert Nx.type(result) == {:f, 64}
assert Nx.shape(result) == {6}
end

test "computes window scatter min with f64" do
t = Nx.iota({6}, type: :f64)
s = Nx.iota({3}, type: :f64)
init = Nx.tensor(0.0, type: :f64)
result = Nx.window_scatter_min(t, s, init, {2}, strides: [2], padding: :valid)
assert Nx.type(result) == {:f, 64}
assert Nx.shape(result) == {6}
end

test "computes window reduce (sum of squares) with same padding" do
t = Nx.iota({4, 4}, type: {:f, 32})

Expand Down
Loading