Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 226 additions & 2 deletions c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
#include <numeric>
#include <string>
#include <cstring>
#include <random>
#include <optional>
#include <cstdio>
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>

using namespace mlx::core;

Expand Down Expand Up @@ -785,6 +791,219 @@ NIF(cumulative_min) {
TENSOR(mlx::core::NATIVE_OP(*a, *b, device)); \
}

// Generate a unique POSIX shared-memory name of the form /emlx_<16hex>.
// Names must begin with '/' and be short enough for shm_open on all platforms.
// Uses thread_local state so concurrent NIF calls don't race on the RNG.
static std::string generate_shm_name() {
thread_local std::mt19937_64 gen(std::random_device{}());
thread_local std::uniform_int_distribution<uint64_t> dist;
char buf[32];
snprintf(buf, sizeof(buf), "/emlx_%016llx", (unsigned long long)dist(gen));
return std::string(buf);
}

// ── IPC shared-memory interop ─────────────────────────────────────────────────
//
// These two NIFs implement the :ipc mode of Nx.Backend.to_pointer / from_pointer
// using POSIX shared memory (shm_open + mmap). MLX arrays are immutable, so
// the sender *copies* tensor data into the shm segment — there is no zero-copy
// path here (documented as copy semantics in the Elixir layer).
//
// Lifecycle:
// Sender (tensor_to_shm): creates shm, memcpy, munmap+close fd. Name persists
// Receiver (array_from_shm): shm_open, mmap, shm_unlink immediately (keeps object
// alive via fd), creates mlx::array with a deleter that munmap+closes on GC.

// Creates a POSIX shm segment containing a contiguous copy of the tensor's data.
// argv[0]: tensor_ref (must already be eval'd — Elixir calls eval before this NIF)
// argv[1]: permissions (mode_t expressed as uint64, e.g. 0o400 = 256)
// Returns: {:ok, {name_binary, byte_size}} on success.
NIF(tensor_to_shm) {
TENSOR_PARAM(0, t);
PARAM(1, size_t, permissions);

if (t->data<void>() == nullptr) {
return nx::nif::error(env,
"Tensor not evaluated; call EMLX.eval/1 before to_pointer with mode: :ipc");
}

// Ensure contiguous layout before exposing to shared memory.
size_t byte_size;
void *src_ptr;

// Use optional to avoid default-constructing mlx::core::array.
std::optional<mlx::core::array> ct_opt;
if (t->flags().row_contiguous) {
byte_size = t->nbytes();
src_ptr = t->data<void>();
} else {
ct_opt.emplace(mlx::core::contiguous(*t));
mlx::core::eval(*ct_opt);
byte_size = ct_opt->nbytes();
src_ptr = ct_opt->data<void>();
}

// O_EXCL ensures we create a fresh segment; retry on the rare collision.
int fd = -1;
std::string shm_name;
for (int attempt = 0; attempt < 10; ++attempt) {
shm_name = generate_shm_name();
fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR, (mode_t)permissions);
if (fd != -1 || errno != EEXIST) break;
}
if (fd == -1) {
return nx::nif::error(env, "shm_open failed in tensor_to_shm");
}

if (ftruncate(fd, (off_t)byte_size) == -1) {
close(fd);
shm_unlink(shm_name.c_str());
return nx::nif::error(env, "ftruncate failed in tensor_to_shm");
}

void *ptr = mmap(NULL, byte_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
if (ptr == MAP_FAILED) {
close(fd);
shm_unlink(shm_name.c_str());
return nx::nif::error(env, "mmap failed in tensor_to_shm");
}

std::memcpy(ptr, src_ptr, byte_size);

munmap(ptr, byte_size);
close(fd);
// shm object persists under shm_name until the receiver calls shm_unlink.

// Return the shm name as a binary (not a charlist) so %Nx.Pointer{handle: ...}
// holds a conventional Elixir binary that other backends can consume.
// ERTS API: enif_make_new_binary(env, size, &term) returns the writable ptr.
ERL_NIF_TERM name_term;
unsigned char *bin_data = enif_make_new_binary(env, shm_name.size(), &name_term);
std::memcpy(bin_data, shm_name.data(), shm_name.size());
ERL_NIF_TERM size_term = nx::nif::make(env, byte_size);
return nx::nif::ok(env, enif_make_tuple2(env, name_term, size_term));
}

// Opens an existing POSIX shm segment and wraps it as an MLX array.
// The shm is unlinked immediately after mmap so cleanup is automatic:
// when the returned MLX array is GC'd, its deleter calls munmap + close.
// argv[0]: name_binary (POSIX shm name string)
// argv[1]: shape (tuple of ints)
// argv[2]: dtype (atom, e.g. :float32)
// argv[3]: byte_size (uint64, validated against computed size)
// Returns: {:ok, tensor_ref} on success.
NIF(array_from_shm) {
std::string shm_name;
if (!nx::nif::get(env, argv[0], shm_name))
return nx::nif::error(env, "Unable to get shm name param");

SHAPE_PARAM(1, shape);
TYPE_PARAM(2, dtype);
PARAM(3, size_t, byte_size);

if (shm_name.empty()) {
return nx::nif::error(env, "Empty shm name");
}

// Try read-write first; fall back to read-only if permission denied.
int writable = 1;
int fd = shm_open(shm_name.c_str(), O_RDWR, 0);
if (fd == -1 && errno == EACCES) {
fd = shm_open(shm_name.c_str(), O_RDONLY, 0);
writable = 0;
}
if (fd == -1) {
return nx::nif::error(env, "shm_open failed in array_from_shm");
}

int prot = writable ? (PROT_READ | PROT_WRITE) : PROT_READ;
void *ptr = mmap(NULL, byte_size, prot, MAP_SHARED, fd, 0);
if (ptr == MAP_FAILED) {
close(fd);
return nx::nif::error(env, "mmap failed in array_from_shm");
}

// Unlink immediately: the name is removed, but the object lives as long as
// this fd (and thus this mmap) is open. The deleter owns cleanup.
shm_unlink(shm_name.c_str());

try {
// Capture fd and byte_size in the deleter; MLX calls it exactly once.
auto deleter = [fd, byte_size](void *p) {
munmap(p, byte_size);
close(fd);
};
auto arr = mlx::core::array(ptr, to_shape(shape), dtype, deleter);
return nx::nif::ok(env, create_tensor_resource(env, std::move(arr)));
}
CATCH()
}

// Unlinks a POSIX shm segment by name. Call this if the receiver never opens
// the pointer returned by tensor_to_shm — otherwise the shm name persists in
// /dev/shm until the next reboot.
// argv[0]: name binary (the handle from %Nx.Pointer{kind: :ipc, handle: name})
NIF(shm_unlink_handle) {
std::string shm_name;
if (!nx::nif::get(env, argv[0], shm_name))
return nx::nif::error(env, "Unable to get shm name param");

if (shm_unlink(shm_name.c_str()) == -1 && errno != ENOENT) {
return nx::nif::error(env, "shm_unlink failed");
}
return nx::nif::ok(env);
}

// Returns the raw data pointer of an evaluated tensor as a {address, byte_size}
// tuple of uint64 values. The Elixir caller must call EMLX.eval/1 first so
// that data<void>() is non-null and stable (MLX arrays are immutable once
// materialised). On Apple Silicon the pointer is accessible from both CPU and
// GPU due to unified memory. Primary use case: sharing tensors with Python MLX
// via Pythonx using the Nx.Backend.to_pointer/from_pointer protocol.
NIF(tensor_data_ptr) {
TENSOR_PARAM(0, t);

if (t->data<void>() == nullptr) {
return nx::nif::error(
env, "Tensor has not been evaluated; call EMLX.eval/1 before to_pointer");
}

size_t addr = reinterpret_cast<size_t>(t->data<void>());
size_t byte_size = t->nbytes();

ERL_NIF_TERM addr_term = nx::nif::make(env, addr);
ERL_NIF_TERM size_term = nx::nif::make(env, byte_size);
return nx::nif::ok(env, enif_make_tuple2(env, addr_term, size_term));
}

// Wraps an external raw pointer as an MLX array with a no-op deleter.
// The caller is responsible for keeping the backing buffer alive for the
// duration of use (see include/emlx.h for the lifetime contract).
// argv[0]: address (uint64 / size_t)
// argv[1]: shape (tuple of ints)
// argv[2]: dtype (atom, e.g. :float32)
// argv[3]: byte_size (uint64, validated but not used by the MLX ctor)
// argv[4]: deleter (reserved / ignored; pass nil)
NIF(array_from_ptr) {
PARAM(0, size_t, raw_addr);
SHAPE_PARAM(1, shape);
TYPE_PARAM(2, dtype);
// argv[3] byte_size and argv[4] deleter are accepted but deferred.

if (raw_addr == 0) {
return nx::nif::error(env, "Null pointer passed to array_from_ptr");
}

try {
void *ptr = reinterpret_cast<void *>(raw_addr);
// No-op deleter: the caller owns the buffer.
auto arr =
mlx::core::array(ptr, to_shape(shape), dtype, [](void *) {});
return nx::nif::ok(env, create_tensor_resource(env, std::move(arr)));
}
CATCH()
}

static int open_resources(ErlNifEnv *env) {
const char *mod = "EMLX";
if (!open_resource<mlx::core::array>(env, mod, "MLXArray")) {
Expand Down Expand Up @@ -1459,8 +1678,13 @@ static ErlNifFunc nif_funcs[] = {
{"clear_cache", 0, clear_cache},
{"reset_peak_memory", 0, reset_peak_memory},
{"set_memory_limit", 1, set_memory_limit},
{"set_cache_limit", 1, set_cache_limit}
{"set_cache_limit", 1, set_cache_limit},
{"tensor_data_ptr", 1, tensor_data_ptr, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"array_from_ptr", 5, array_from_ptr, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"tensor_to_shm", 2, tensor_to_shm, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"array_from_shm", 4, array_from_shm, ERL_NIF_DIRTY_JOB_CPU_BOUND},
{"shm_unlink_handle", 1, shm_unlink_handle}
};

// Update the NIF initialization
ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL)

40 changes: 40 additions & 0 deletions lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,46 @@ defmodule EMLX do
EMLX.NIF.to_blob(ref, limit) |> unwrap!()
end

@doc """
Returns `{address, byte_size}` for the tensor's raw GPU buffer.

Evals the tensor first (same pattern as `to_blob/1`). The pointer is valid
as long as no further MLX evaluation is triggered on the array and the
Elixir tensor term is kept alive. On Apple Silicon the address is accessible
from both CPU and GPU due to unified memory.
"""
def tensor_data_ptr({device, ref} = tensor) when is_tensor(device, ref) do
eval(tensor)
EMLX.NIF.tensor_data_ptr(ref) |> unwrap!()
end

@doc """
Copies tensor data into a new POSIX shared-memory segment and returns
`{shm_name, byte_size}`.

Note: this involves a **memcpy** — MLX arrays are immutable so zero-copy
cross-process sharing is not possible. `permissions` is a Unix mode integer
(e.g. `0o400` for owner-read-only).

The shm name persists until the receiver opens and unlinks it (which
`EMLX.NIF.array_from_shm/4` does automatically).
"""
def tensor_to_shm({device, ref} = tensor, permissions) when is_tensor(device, ref) do
eval(tensor)
EMLX.NIF.tensor_to_shm(ref, permissions) |> unwrap!()
end

@doc """
Unlinks a POSIX shared-memory segment by its handle name.

Call this if the receiver never opens the `%Nx.Pointer{kind: :ipc}` returned
by `Nx.to_pointer/2` — otherwise the shm name persists until the next reboot.
Safe to call even if the segment has already been unlinked (ENOENT is ignored).
"""
def shm_unlink(name) when is_binary(name) do
EMLX.NIF.shm_unlink_handle(name) |> unwrap!()
end

defp unwrap!(:ok), do: :ok
defp unwrap!({:ok, result}), do: result
defp unwrap!({:error, error}), do: raise(EMLX.NIFError, List.to_string(error))
Expand Down
Loading
Loading