Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,211 @@ NIF(as_strided) {
TENSOR(mlx::core::as_strided(*t, to_shape(shape), to_strides(strides), offset, device));
}

// Build a sliding window view of a padded tensor.
// padded: [...] of ndim n; window/strides: per-axis lists of length n.
// Returns a view of shape [o0,...,on-1, w0,...,wn-1] where
// oi = (padded_shape[i] - window[i]) / strides[i] + 1.
static mlx::core::array sliding_window_view_cpp(
const mlx::core::array &padded,
const std::vector<int> &window,
const std::vector<int> &strides,
const mlx::core::Device &device) {
int n = padded.ndim();
auto ps = padded.shape(); // SmallVector<int>

// Doubled element strides: output dims share the same strides as window dims.
auto orig_strides = padded.strides();
std::vector<int64_t> view_strides(orig_strides.begin(), orig_strides.end());
for (auto s : orig_strides) view_strides.push_back(s);

// view_shape = [ps[i]-window[i]+1, ..., w0, ..., wn-1]
std::vector<int> view_shape;
for (int i = 0; i < n; ++i) view_shape.push_back(ps[i] - window[i] + 1);
for (int w : window) view_shape.push_back(w);

auto strided = mlx::core::as_strided(padded, to_shape(view_shape),
to_strides(view_strides), 0, device);

// Slice: strides=[strides..., 1...], stops=view_shape
std::vector<int> starts(2 * n, 0);
std::vector<int> stops = view_shape;
std::vector<int> slstrides = strides;
for (int i = 0; i < n; ++i) slstrides.push_back(1);

return mlx::core::slice(strided, to_shape(starts), to_shape(stops),
to_shape(slstrides), device);
}

// Shared implementation for window_scatter_max/min.
// When scatter_max=true: first-occurrence argmax.
// When scatter_max=false: last-occurrence argmin via mask*arange trick.
static mlx::core::array window_scatter_impl(
const mlx::core::array &tensor_t,
const mlx::core::array &tensor_source,
const mlx::core::array &tensor_init_value,
const std::vector<int> &window,
const std::vector<int> &low_pad,
const std::vector<int> &high_pad,
const std::vector<int> &strides,
bool scatter_max,
const mlx::core::Device &device) {
int n = tensor_t.ndim();

// 1. Cast init_value to the input dtype.
auto init_casted =
mlx::core::astype(tensor_init_value, tensor_t.dtype(), device);

// 2. Pad input with init_value on all axes.
std::vector<int> all_axes(n);
std::iota(all_axes.begin(), all_axes.end(), 0);
auto padded =
mlx::core::pad(tensor_t, all_axes, to_shape(low_pad), to_shape(high_pad),
init_casted, "constant", device);

auto padded_shape = padded.shape();
std::vector<int> padded_shape_vec(padded_shape.begin(), padded_shape.end());

// 3. Sliding window view: [o0,...,on-1, w0,...,wn-1].
auto window_view =
sliding_window_view_cpp(padded, window, strides, device);

// out_shape = first n dims of window_view
std::vector<int> out_shape(window_view.shape().begin(),
window_view.shape().begin() + n);

// K = product of window dims
int K = 1;
for (int w : window) K *= w;

// 4. Flatten window dims: [..., K]
std::vector<int> flat_shape = out_shape;
flat_shape.push_back(K);
auto windows_flat =
mlx::core::reshape(window_view, to_shape(flat_shape), device);

// 5. Find argmax / tie-broken argmin over last axis.
auto arg_idx = [&]() -> mlx::core::array {
if (scatter_max) {
return mlx::core::argmax(windows_flat, n, false, device);
}
// Tie-broken argmin (last-occurrence):
// m = min over last axis (keepdims), mask where equal, argmax(mask*arange).
auto m = mlx::core::min(windows_flat, std::vector<int>{n}, true, device);
auto mask = mlx::core::astype(
mlx::core::equal(windows_flat, m, device), mlx::core::uint32, device);
auto arange_k = mlx::core::astype(
mlx::core::arange(0, K, 1, device), mlx::core::uint32, device);
std::vector<int> arange_shape(n + 1, 1);
arange_shape[n] = K;
auto arange_k_nd =
mlx::core::reshape(arange_k, to_shape(arange_shape), device);
auto weighted = mlx::core::multiply(mask, arange_k_nd, device);
return mlx::core::argmax(weighted, n, false, device);
}();

// 6. Expand arg_idx to [..., 1] for take_along_axis.
std::vector<int> arg_exp_shape = out_shape;
arg_exp_shape.push_back(1);
auto arg_idx_exp =
mlx::core::reshape(arg_idx, to_shape(arg_exp_shape), device);

// 7. For each axis, compute absolute padded-tensor indices.
std::vector<mlx::core::array> abs_indices;
for (int a = 0; a < n; ++a) {
// 1-D iota along axis a of the padded shape.
auto arange_a = mlx::core::astype(
mlx::core::arange(0, (int)padded_shape[a], 1, device),
mlx::core::int32, device);

// Reshape to [1,...,padded_shape[a],...,1] (size pd[a] at axis a).
std::vector<int> iota_shape(n, 1);
iota_shape[a] = (int)padded_shape[a];
auto iota_nd =
mlx::core::reshape(arange_a, to_shape(iota_shape), device);

// Broadcast to full padded shape.
// NOTE: assumes padded is contiguous (as returned by mlx::pad), so its
// element strides are dense. The doubled-strides trick in
// sliding_window_view_cpp works correctly on broadcast_to's zero strides:
// for axis-a iota, the zero stride on non-a dims keeps the value constant,
// which is exactly the intended iota semantics.
auto iota_bc =
mlx::core::broadcast_to(iota_nd, to_shape(padded_shape_vec), device);

// Apply same sliding-window view + flatten.
auto iota_view =
sliding_window_view_cpp(iota_bc, window, strides, device);
auto iota_flat =
mlx::core::reshape(iota_view, to_shape(flat_shape), device);

// Pick the element at arg_idx position.
auto abs_a =
mlx::core::take_along_axis(iota_flat, arg_idx_exp, n, device);
// Squeeze last dim: [..., 1] → [o0,...,on-1]
abs_indices.push_back(
mlx::core::reshape(abs_a, to_shape(out_shape), device));
}

// 8. Scatter source into a buffer filled with init_value.
// MLX scatter_add requires: updates.ndim == array.ndim + indices[0].ndim.
// array.ndim = n (padded), indices[0].ndim = n (out_shape), so we need 2n.
// Reshape source [o0,...,on-1] → [o0,...,on-1, 1,...,1] (n trailing singletons).
auto source_shape_2n = std::vector<int>(tensor_source.shape().begin(),
tensor_source.shape().end());
for (int i = 0; i < n; ++i) source_shape_2n.push_back(1);
auto updates =
mlx::core::reshape(tensor_source, to_shape(source_shape_2n), device);

auto buffer = mlx::core::broadcast_to(
mlx::core::reshape(init_casted, to_shape(std::vector<int>{}), device),
to_shape(padded_shape_vec), device);

std::vector<int> scatter_axes(n);
std::iota(scatter_axes.begin(), scatter_axes.end(), 0);
auto scattered = mlx::core::scatter_add(buffer, abs_indices, updates,
scatter_axes, device);

// 9. Slice back to original shape (strip padding).
auto orig_shape = tensor_t.shape();
std::vector<int> slice_starts = low_pad;
std::vector<int> slice_stops(n);
for (int i = 0; i < n; ++i)
slice_stops[i] = low_pad[i] + (int)orig_shape[i];
std::vector<int> slice_ones(n, 1);

return mlx::core::slice(scattered, to_shape(slice_starts),
to_shape(slice_stops), to_shape(slice_ones), device);
}

NIF(window_scatter_max) {
TENSOR_PARAM(0, tensor_t);
TENSOR_PARAM(1, tensor_source);
TENSOR_PARAM(2, tensor_init_value);
LIST_PARAM(3, std::vector<int>, window);
LIST_PARAM(4, std::vector<int>, low_pad);
LIST_PARAM(5, std::vector<int>, high_pad);
LIST_PARAM(6, std::vector<int>, strides);
DEVICE_PARAM(7, device);

TENSOR(window_scatter_impl(*tensor_t, *tensor_source, *tensor_init_value,
window, low_pad, high_pad, strides, true, device));
}

NIF(window_scatter_min) {
TENSOR_PARAM(0, tensor_t);
TENSOR_PARAM(1, tensor_source);
TENSOR_PARAM(2, tensor_init_value);
LIST_PARAM(3, std::vector<int>, window);
LIST_PARAM(4, std::vector<int>, low_pad);
LIST_PARAM(5, std::vector<int>, high_pad);
LIST_PARAM(6, std::vector<int>, strides);
DEVICE_PARAM(7, device);

TENSOR(window_scatter_impl(*tensor_t, *tensor_source, *tensor_init_value,
window, low_pad, high_pad, strides, false,
device));
}

static ErlNifFunc nif_funcs[] = {
{"strides", 1, strides},
{"as_strided", 5, as_strided},
Expand Down Expand Up @@ -1248,6 +1453,8 @@ static ErlNifFunc nif_funcs[] = {
{"linalg_pinv", 2, linalg_pinv},
{"linalg_solve", 3, linalg_solve},
{"linalg_solve_triangular", 4, linalg_solve_triangular},
{"window_scatter_max", 8, window_scatter_max},
{"window_scatter_min", 8, window_scatter_min},
{"memory_info", 0, memory_info},
{"clear_cache", 0, clear_cache},
{"reset_peak_memory", 0, reset_peak_memory},
Expand Down
21 changes: 21 additions & 0 deletions lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ defmodule EMLX do
deftensor linalg_solve(tensorA, tensorB)
deftensor linalg_solve_triangular(tensorA, tensorB, upper)

## Native pooling (window scatter) ops
deftensor window_scatter_max(
tensor_t,
tensor_source,
tensor_init_value,
window,
low_pad,
high_pad,
strides
)

deftensor window_scatter_min(
tensor_t,
tensor_source,
tensor_init_value,
window,
low_pad,
high_pad,
strides
)

deftensor conv_general(
tensor_input,
tensor_kernel,
Expand Down
108 changes: 26 additions & 82 deletions lib/emlx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1527,93 +1527,37 @@ defmodule EMLX.Backend do
end

@impl true
def window_scatter_min(out, tensor, source, init_value, window_dims_tuple, opts) do
window_scatter_function(
&Nx.argmin(&1, axis: -1, tie_break: :high),
out,
tensor,
source,
init_value,
window_dims_tuple,
opts
def window_scatter_min(out, tensor, source, init_value, window_dims, opts) do
{low_pad, high_pad} = Enum.unzip(opts[:padding])
strides = opts[:strides] || List.duplicate(1, tuple_size(window_dims))

EMLX.window_scatter_min(
from_nx(tensor),
from_nx(source),
init_value |> Nx.backend_transfer(EMLX.Backend) |> from_nx(),
Tuple.to_list(window_dims),
low_pad,
high_pad,
strides
)
|> to_nx(out)
end

@impl true
def window_scatter_max(out, tensor, source, init_value, window_dims_tuple, opts) do
window_scatter_function(
&Nx.argmax(&1, axis: -1),
out,
tensor,
source,
init_value,
window_dims_tuple,
opts
def window_scatter_max(out, tensor, source, init_value, window_dims, opts) do
{low_pad, high_pad} = Enum.unzip(opts[:padding])
strides = opts[:strides] || List.duplicate(1, tuple_size(window_dims))

EMLX.window_scatter_max(
from_nx(tensor),
from_nx(source),
init_value |> Nx.backend_transfer(EMLX.Backend) |> from_nx(),
Tuple.to_list(window_dims),
low_pad,
high_pad,
strides
)
end

defp window_scatter_function(function, out, tensor, source, init_value, window_dims_tuple, opts) do
unfold_flat = fn tensor ->
{device, _} = t_mx = from_nx(tensor)
pad_value_mx = EMLX.scalar_tensor(0, EMLX.scalar_type(t_mx), device)

{low_pad, high_pad} = Enum.unzip(opts[:padding])

padded_mx = EMLX.pad(t_mx, Nx.axes(tensor), low_pad, high_pad, pad_value_mx)

unfolded_mx =
sliding_window_view(
padded_mx,
EMLX.shape(padded_mx),
window_dims_tuple,
opts[:strides]
)

unfolded_shape = EMLX.shape(unfolded_mx)
unfolded = to_nx(unfolded_mx)

{to_keep, to_flatten} =
unfolded_shape
|> Tuple.to_list()
|> Enum.split(-tuple_size(window_dims_tuple))

flat_shape =
to_keep
|> List.to_tuple()
|> then(&Tuple.insert_at(&1, tuple_size(&1), Enum.product(to_flatten)))

Nx.reshape(unfolded, flat_shape)
end

arg_idx =
tensor
|> then(unfold_flat)
|> then(function)

indices_to_flatten =
tensor
|> Nx.axes()
|> Enum.map(fn axis ->
tensor
|> Nx.shape()
|> Nx.iota(axis: axis, backend: EMLX.Backend)
|> then(unfold_flat)
|> Nx.take_along_axis(Nx.new_axis(arg_idx, -1), axis: -1)
end)
|> Nx.concatenate(axis: -1)

num_axes = tuple_size(out.shape)
num_rows = div(Nx.size(indices_to_flatten), num_axes)
indices = Nx.reshape(indices_to_flatten, {num_rows, num_axes})

flat_source = Nx.flatten(source)

init_value
|> Nx.backend_transfer(EMLX.Backend)
|> Nx.broadcast(out.shape)
|> Nx.indexed_add(indices, flat_source)
|> Nx.as_type(out.type)
|> Nx.rename(out.names)
|> to_nx(out)
end

@impl true
Expand Down
Loading
Loading