@@ -1631,11 +1631,7 @@ defmodule Nx.BinaryBackend do
16311631
16321632 # Compute absolute index in padded space, then adjust back to
16331633 # original (unpadded) coordinates by subtracting low-padding
1634- padded_absolute_index =
1635- anchor
1636- |> Enum . zip ( offset_from_anchor )
1637- |> Enum . map ( fn { x , y } -> x + y end )
1638-
1634+ padded_absolute_index = Enum . zip_with ( anchor , offset_from_anchor , & + / 2 )
16391635 absolute_index = Enum . zip_with ( padded_absolute_index , low_pads , & - / 2 )
16401636
16411637 source_consumed = i * source_size
@@ -1664,7 +1660,8 @@ defmodule Nx.BinaryBackend do
16641660 |> Enum . group_by ( & elem ( & 1 , 1 ) , & elem ( & 1 , 0 ) )
16651661 |> Enum . map ( fn { index , value } ->
16661662 offset = weighted_offset ( output_weighted_shape , index )
1667- { offset , Enum . reduce ( value , init_value , scatter_fn ) }
1663+ tensor = Enum . reduce ( value , init_value , scatter_fn )
1664+ { offset , scalar_to_number ( tensor ) }
16681665 end )
16691666 |> Enum . sort_by ( & elem ( & 1 , 0 ) )
16701667
@@ -1675,7 +1672,12 @@ defmodule Nx.BinaryBackend do
16751672 { acc_offset , acc_binary } ->
16761673 num_vals_before = div ( offset - acc_offset , output_size )
16771674 vals_before = List . duplicate ( init_binary , num_vals_before )
1678- source_val = to_binary ( value )
1675+
1676+ source_val =
1677+ match_types [ output_type ] do
1678+ << write! ( value , 0 ) >>
1679+ end
1680+
16791681 new_binary = :erlang . list_to_bitstring ( [ vals_before , source_val ] )
16801682
16811683 { offset + output_size , << acc_binary :: bitstring , new_binary :: bitstring >> }
0 commit comments