diff --git a/Makefile b/Makefile index 017b232..4ed2020 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ MAKE_JOBS ?= $(MAKE_DEFAULT_JOBS) # Source files SOURCES = c_src/emlx_nif.cpp -HEADERS = c_src/nx_nif_utils.hpp +HEADERS = c_src/nx_nif_utils.hpp c_src/emlx_worker.hpp c_src/emlx_async.hpp OBJECTS = $(patsubst c_src/%.cpp,$(BUILD_DIR)/%.o,$(SOURCES)) # Main targets diff --git a/c_src/emlx_async.hpp b/c_src/emlx_async.hpp new file mode 100644 index 0000000..7e7ec16 --- /dev/null +++ b/c_src/emlx_async.hpp @@ -0,0 +1,177 @@ +// Async NIF dispatch built on top of emlx::Worker. +// +// MLX 0.31.2 makes both Metal CommandEncoders (mlx/backend/metal/device.cpp: +// `static thread_local std::unordered_map encoders;`) +// and the per-device default Stream (mlx/stream.cpp: `static thread_local +// auto default_streams = ...`) thread-local. Because `mlx::core::eval` walks +// the tape and calls `gpu::eval(arr)` *directly* on the calling thread (it +// is NOT trampolined to a `scheduler::StreamThread`; see +// mlx/transforms.cpp:eval_impl), every op for a given GPU stream — both +// graph construction AND eval — must happen on the OS thread that called +// `mlx::core::new_stream(d)` for that stream. Otherwise the eval thread's +// thread-local encoder map will not contain an entry for the stream's +// index, producing the runtime error +// "There is no Stream(gpu, N) in current thread." +// +// Consequence for EMLX: every NIF that touches the MLX graph must run on +// the worker thread that owns the stream. We achieve this without +// rewriting each NIF body by: +// +// 1. Defining each "sync" NIF (e.g. `add`, `reshape`, `eval`, ...) as +// a normal C++ function with the ERTS NIF signature. +// 2. Registering an `_async` wrapper in `nif_funcs[]` whose arity is +// `original_arity + 1` (worker is `argv[0]`). +// 3. The wrapper extracts the worker, copies `argv[1..]` into a +// process-independent `msg_env`, captures the caller pid + a fresh +// ref, and posts a lambda to `worker->post(...)`. +// 4. The worker thread runs the original sync NIF body against +// `msg_env` + the shifted argv, takes its `{:ok, _}` / `{:error, _}` +// tagged tuple result, wraps it as `{ref, payload}`, and +// `enif_send`s it back to the caller. +// +// The worker's `thread_main` calls `mlx::core::set_default_stream(stream)` +// before signalling ready, so any sync NIF body that resolves a +// `StreamOrDevice` from a `:cpu` / `:gpu` device atom (via `DEVICE_PARAM`) +// picks up the worker's stream automatically through MLX's +// `to_stream(s, default_) -> default_stream(default_)` lookup. No +// per-NIF code change is required. +// +// Lifetime invariants this helper relies on: +// +// * `enif_self(env, &caller)` MUST be called on the BEAM scheduler +// thread. We capture the resulting `ErlNifPid` by value into the +// lambda; the worker thread (a non-scheduler OS thread) MUST NOT +// call `enif_self` itself (the BEAM has no scheduler context for it). +// +// * `enif_make_copy(msg_env, term)` for a resource ref bumps the +// resource's ERTS refcount, so the resource (and the embedded MLX +// array, function, or worker it backs) stays alive at least until +// `msg_env` is freed at the end of the lambda. We do not need +// additional `enif_keep_resource` bumps. +// +// * `enif_send` with a non-NULL `msg_env` does not transfer ownership +// of the env object itself. We always `enif_free_env(msg_env)` after +// `enif_send` returns, regardless of success/failure. (Successful +// `enif_send` invalidates the terms in `msg_env` but the env handle +// itself remains owned by the caller.) +// +// * If `worker->post` throws (worker is stopping), we must reclaim +// `msg_env` and propagate the error to the BEAM caller synchronously. + +#pragma once + +#include "emlx_worker.hpp" +#include "erl_nif.h" +#include "nx_nif_utils.hpp" + +#include +#include +#include + +namespace emlx { + +// Build an `{:error, ""}` tuple in `msg_env`. Uses +// `enif_make_string` to mirror nx::nif::error so the Elixir side can +// `List.to_string/1` it uniformly. +inline ERL_NIF_TERM make_error_term(ErlNifEnv *msg_env, const char *what) { + return enif_make_tuple2(msg_env, enif_make_atom(msg_env, "error"), + enif_make_string(msg_env, what, ERL_NIF_LATIN1)); +} + +// Build an error tuple from the currently-thrown exception (must be +// called from inside a `catch` block). +inline ERL_NIF_TERM error_from_current_exception(ErlNifEnv *msg_env) { + try { + throw; + } catch (const std::exception &e) { + return make_error_term(msg_env, e.what()); + } catch (...) { + return make_error_term(msg_env, "Unknown error"); + } +} + +// Run `SyncOp(msg_env, op_argc, op_argv)` on `worker`'s thread and +// `enif_send` its tagged result back to the calling Elixir process. +// Returns the job ref synchronously for the caller to `receive` on. +// +// `argv[0]` MUST be the worker resource ref. `argv[1..argc-1]` are the +// op's actual arguments and are forwarded (after `enif_make_copy` into +// `msg_env`) to `SyncOp`. +// +// `SyncOp` is an existing sync-style NIF function that returns either +// `{:ok, value}` or `{:error, reason}`. The wrapper does not introspect +// the tuple — it is forwarded as-is as the second element of +// `{job_ref, payload}`. +template +ERL_NIF_TERM async_dispatch(ErlNifEnv *env, int argc, + const ERL_NIF_TERM argv[]) { + if (argc < 1) { + return enif_make_badarg(env); + } + + emlx::Worker *worker; + if (!enif_get_resource(env, argv[0], resource_object::type, + (void **)&worker)) { + return nx::nif::error(env, "Invalid command queue ref"); + } + + ErlNifPid caller_pid; + enif_self(env, &caller_pid); + + ErlNifEnv *msg_env = enif_alloc_env(); + if (!msg_env) { + return nx::nif::error(env, "Failed to allocate msg env"); + } + + ERL_NIF_TERM job_ref_msg = enif_make_ref(msg_env); + ERL_NIF_TERM job_ref_caller = enif_make_copy(env, job_ref_msg); + + // Copy the op's arguments (everything past argv[0]) into msg_env. + // For resource refs this also bumps the resource's ERTS refcount, + // keeping the underlying MLX array / function / worker alive for the + // duration of the lambda. + // We need to do this because the worker is async and might outlive + // the NIF env. + int op_argc = argc - 1; + std::vector op_argv; + op_argv.reserve(op_argc); + for (int i = 0; i < op_argc; ++i) { + op_argv.push_back(enif_make_copy(msg_env, argv[i + 1])); + } + + try { + worker->post([msg_env, job_ref_msg, caller_pid, + op_argv = std::move(op_argv)]() mutable { + ERL_NIF_TERM payload; + try { + payload = SyncOp(msg_env, static_cast(op_argv.size()), + op_argv.data()); + } catch (...) { + // The sync NIF should normally translate its own C++ exceptions + // into `{:error, _}` via the `CATCH()` macro, but defensively + // wrap anything that escapes so the caller's `receive` never + // hangs. + payload = error_from_current_exception(msg_env); + } + + ERL_NIF_TERM reply = + enif_make_tuple2(msg_env, job_ref_msg, payload); + ErlNifPid pid = caller_pid; + enif_send(NULL, &pid, msg_env, reply); + enif_free_env(msg_env); + }); + } catch (const std::exception &e) { + // Worker is stopping or rejected the job; reclaim msg_env and + // surface the error synchronously so the caller's wrapper can + // raise without ever entering its `receive`. + enif_free_env(msg_env); + return nx::nif::error(env, e.what()); + } catch (...) { + enif_free_env(msg_env); + return nx::nif::error(env, "Unknown error posting to worker"); + } + + return nx::nif::ok(env, job_ref_caller); +} + +} // namespace emlx diff --git a/c_src/emlx_nif.cpp b/c_src/emlx_nif.cpp index 491ff1f..24bc781 100644 --- a/c_src/emlx_nif.cpp +++ b/c_src/emlx_nif.cpp @@ -1,3 +1,5 @@ +#include "emlx_async.hpp" +#include "emlx_worker.hpp" #include "erl_nif.h" #include "mlx/mlx.h" #include "nx_nif_utils.hpp" @@ -200,6 +202,15 @@ ERL_NIF_TERM create_function_resource(ErlNifEnv *env, emlx::function function) { #define NIF(NAME) \ ERL_NIF_TERM NAME(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) +// One-line async wrapper: declare `NIF(OP) { ... }` then `ASYNC_NIF(OP)`. +// Register in `nif_funcs[]` at `original_arity + 1` (command queue is argv[0]). +// Example: {"add", 4, add_async} // was {"add", 3, add} +#define ASYNC_NIF(OP) \ + ERL_NIF_TERM OP##_async(ErlNifEnv *env, int argc, \ + const ERL_NIF_TERM argv[]) { \ + return emlx::async_dispatch(env, argc, argv); \ + } + #define PARAM(ARGN, TYPE, VAR) \ TYPE VAR; \ GET(ARGN, VAR) @@ -281,6 +292,48 @@ NIF(astype) { TENSOR(mlx::core::astype(*t, type, device)); } +// Builds the resource binary for `to_blob` in `out_env`. Used by both the +// legacy direct-call NIF and `command_queue_post_to_blob` (which builds the +// term inside the worker thread for delivery via enif_send). May throw +// std::runtime_error on resource allocation failure. +// +// `out_env` is the env the binary will live in (the caller env for the +// legacy path; a process-independent msg_env for the worker path). +static ERL_NIF_TERM to_blob_term(ErlNifEnv *out_env, mlx::core::array *t, + size_t byte_size) { + if (t->flags().row_contiguous) { + // Zero-copy: alias the MLX buffer via the existing tensor resource. + // Invariant: lib/emlx.ex calls eval(tensor) before this NIF, so + // data() is guaranteed non-null and stable (MLX arrays are immutable + // once materialised). enif_make_resource_binary keeps the resource alive + // until the binary is GC'd, decoupling the binary lifetime from Elixir GC + // of the tensor term. + return enif_make_resource_binary(out_env, static_cast(t), + t->data(), byte_size); + } + + // Non-contiguous: materialise a fresh row-major copy, wrap it in a minimal + // ERTS resource, and alias that buffer zero-copy. + // The resource holds only sizeof(mlx::core::array) — no TensorP refcount/ + // deleted-flag tail — because it is never exposed to TensorP or the Elixir + // side; only the binary holds a reference and default_dtor + // (~array()) is the sole destructor path. + auto ct = mlx::core::contiguous(*t); + mlx::core::eval(ct); + + auto *ct_ptr = static_cast(enif_alloc_resource( + resource_object::type, sizeof(mlx::core::array))); + if (!ct_ptr) { + throw std::runtime_error("Unable to allocate contiguous-copy resource"); + } + + new (ct_ptr) mlx::core::array(std::move(ct)); + ERL_NIF_TERM resource_bin = enif_make_resource_binary( + out_env, ct_ptr, ct_ptr->data(), byte_size); + enif_release_resource(ct_ptr); + return resource_bin; +} + NIF(to_blob) { TENSOR_PARAM(0, t); @@ -290,38 +343,10 @@ NIF(to_blob) { byte_size = static_cast(param_limit) * t->itemsize(); } - ERL_NIF_TERM resource_bin; - - if (t->flags().row_contiguous) { - // Zero-copy: alias the MLX buffer via the existing tensor resource. - // Invariant: lib/emlx.ex calls eval(tensor) before this NIF, so - // data() is guaranteed non-null and stable (MLX arrays are immutable - // once materialised). enif_make_resource_binary keeps the resource alive - // until the binary is GC'd, decoupling the binary lifetime from Elixir GC - // of the tensor term. - resource_bin = enif_make_resource_binary(env, t_tp.resource_ptr(), - t->data(), byte_size); - } else { - // Non-contiguous: materialise a fresh row-major copy, wrap it in a minimal - // ERTS resource, and alias that buffer zero-copy. - // The resource holds only sizeof(mlx::core::array) — no TensorP refcount/ - // deleted-flag tail — because it is never exposed to TensorP or the Elixir - // side; only the binary holds a reference and default_dtor - // (~array()) is the sole destructor path. - auto ct = mlx::core::contiguous(*t); - mlx::core::eval(ct); - - auto *ct_ptr = static_cast(enif_alloc_resource( - resource_object::type, sizeof(mlx::core::array))); - if (!ct_ptr) - return enif_make_badarg(env); - - new (ct_ptr) mlx::core::array(std::move(ct)); - resource_bin = - enif_make_resource_binary(env, ct_ptr, ct_ptr->data(), byte_size); - enif_release_resource(ct_ptr); + try { + return nx::nif::ok(env, to_blob_term(env, t, byte_size)); } - return nx::nif::ok(env, resource_bin); + CATCH() } uint64_t elem_count(std::vector shape) { @@ -1004,6 +1029,125 @@ NIF(array_from_ptr) { CATCH() } +// ─── Worker / EMLX.CommandQueue NIFs ──────────────────────────────────────── +// +// Lifecycle for posted jobs: +// 1. Caller's NIF env (`env`) is short-lived — we cannot hand its terms +// to the worker thread. +// 2. We allocate a process-independent ErlNifEnv (`msg_env`) per job. +// The job_ref and the reply tuple live in `msg_env`. +// 3. The job_ref is also enif_make_copy'd into the caller's `env` so +// the wrapper in lib/emlx.ex can `receive {^job_ref, _}`. +// 4. The lambda posted to the worker captures `t_ptr`, `t_refcount`, +// `caller_pid`, `msg_env`, and `job_ref_msg` by value. Tensor +// lifetime across the post boundary is held two ways: +// - enif_keep_resource on t_ptr (ERTS resource refcount) +// - ++(*t_refcount) on the embedded TensorP refcount +// Both are dropped at the end of the lambda. +// 5. After enif_send, msg_env is freed inside the lambda. +// +// Stop semantics: if the Worker is destroyed (NIF resource refcount drops +// to zero), pending jobs already in the queue at destructor time still +// run and still send their reply. Jobs posted *after* the destructor +// begins throw on post() and the NIF returns {:error, _} synchronously. + +NIF(command_queue_new) { + ATOM_PARAM(0, device_atom); + + try { + mlx::core::Device device = string2device(device_atom); + + // Guard before spawning any threads: mlx::core::new_stream on an + // unavailable device (e.g. gpu on Linux/CPU-only libmlx) does not + // throw — it calls std::terminate() internally, which we cannot + // catch. Check availability here and surface a clean {:error, _} + // so EMLX.Application can skip the GPU worker on unsupported hosts. + if (!mlx::core::is_available(device)) { + return nx::nif::error(env, "device not available"); + } + + auto *worker_ptr = static_cast(enif_alloc_resource( + resource_object::type, sizeof(emlx::Worker))); + if (!worker_ptr) { + return enif_make_badarg(env); + } + + try { + new (worker_ptr) emlx::Worker(device); + } catch (...) { + // Placement new failed before constructor completed; the resource + // memory is uninitialised so we must NOT call ~Worker. Just release + // the bare allocation and re-throw to the outer handler. + enif_release_resource(worker_ptr); + throw; + } + + ERL_NIF_TERM ref = enif_make_resource(env, worker_ptr); + enif_release_resource(worker_ptr); + return nx::nif::ok(env, ref); + } + CATCH() +} + +// Posts a no-op barrier job that calls mx::synchronize(stream). The +// Elixir wrapper blocks in `receive` until the reply lands, which only +// happens after every preceding job on this worker has completed AND +// MLX has flushed the GPU command buffer. +NIF(command_queue_synchronize) { + emlx::Worker *worker; + if (!enif_get_resource(env, argv[0], resource_object::type, + (void **)&worker)) { + return nx::nif::error(env, "Invalid command queue ref"); + } + + ErlNifPid caller_pid; + enif_self(env, &caller_pid); + + ErlNifEnv *msg_env = enif_alloc_env(); + if (!msg_env) { + return nx::nif::error(env, "Failed to allocate msg env"); + } + ERL_NIF_TERM job_ref_msg = enif_make_ref(msg_env); + ERL_NIF_TERM job_ref_caller = enif_make_copy(env, job_ref_msg); + + mlx::core::Stream stream = worker->stream(); + + try { + worker->post([stream, caller_pid, msg_env, job_ref_msg]() mutable { + ERL_NIF_TERM result; + try { + mlx::core::synchronize(stream); + result = enif_make_tuple2(msg_env, enif_make_atom(msg_env, "ok"), + enif_make_atom(msg_env, "ok")); + } catch (const std::exception &e) { + result = enif_make_tuple2( + msg_env, enif_make_atom(msg_env, "error"), + enif_make_string(msg_env, e.what(), ERL_NIF_LATIN1)); + } catch (...) { + result = enif_make_tuple2( + msg_env, enif_make_atom(msg_env, "error"), + enif_make_string(msg_env, "Unknown error in synchronize", + ERL_NIF_LATIN1)); + } + + ERL_NIF_TERM msg = enif_make_tuple2(msg_env, job_ref_msg, result); + ErlNifPid pid = caller_pid; + enif_send(NULL, &pid, msg_env, msg); + enif_free_env(msg_env); + }); + } catch (const std::exception &e) { + enif_free_env(msg_env); + return nx::nif::error(env, e.what()); + } + + return nx::nif::ok(env, job_ref_caller); +} + +// `eval` and `to_blob` are now worker-routed via `ASYNC_NIF` (see the +// async wrapper block near `nif_funcs[]`). The bespoke +// `command_queue_post_eval` / `command_queue_post_to_blob` NIFs were +// removed in favour of that uniform dispatch. + static int open_resources(ErlNifEnv *env) { const char *mod = "EMLX"; if (!open_resource(env, mod, "MLXArray")) { @@ -1014,6 +1158,13 @@ static int open_resources(ErlNifEnv *env) { return -1; } + // emlx::Worker — backs EMLX.CommandQueue and the application default + // worker. Default destructor (~Worker) signals stop, drains pending + // jobs, and joins the OS thread. + if (!open_resource(env, mod, "CommandQueue")) { + return -1; + } + return 0; } @@ -1545,146 +1696,308 @@ NIF(window_scatter_min) { device)); } +// ─── Async wrappers ──────────────────────────────────────────────────────── +// +// MLX 0.31.2's thread-local Metal CommandEncoders + thread-local default +// streams force every graph-touching op for a given GPU stream to run on +// the OS thread that created the stream. Each NIF below is the existing +// sync body wrapped by `emlx::async_dispatch`: the wrapper extracts +// the worker from `argv[0]`, copies `argv[1..]` into a process-independent +// `msg_env`, posts a lambda to the worker thread, and `enif_send`s +// `{job_ref, payload}` back to the caller. The sync body runs unchanged +// on the worker thread; `DEVICE_PARAM` resolutions transparently pick up +// the worker's stream via MLX's `to_stream(s, default_) -> default_stream` +// thread-local lookup (the worker called `set_default_stream` for its +// stream during `thread_main`). +// +// Sync (non-routed) NIFs preserved below. These either touch no MLX graph state +// or read fields that are safe across threads (resource allocator, cache +// limits, evaluated buffer pointers). + +ASYNC_NIF(eval) +ASYNC_NIF(to_blob) +ASYNC_NIF(tensor_to_shm) +ASYNC_NIF(item) +ASYNC_NIF(from_blob) +ASYNC_NIF(scalar_tensor) + +ASYNC_NIF(ones) +ASYNC_NIF(full) +ASYNC_NIF(arange) +ASYNC_NIF(eye) +ASYNC_NIF(reshape) +ASYNC_NIF(astype) +ASYNC_NIF(view) +ASYNC_NIF(broadcast_to) +ASYNC_NIF(transpose) +ASYNC_NIF(pad) +ASYNC_NIF(sort) +ASYNC_NIF(argsort) +ASYNC_NIF(slice) +ASYNC_NIF(slice_update) +ASYNC_NIF(squeeze) +ASYNC_NIF(as_strided) + +ASYNC_NIF(stack) +ASYNC_NIF(where) +ASYNC_NIF(concatenate) +ASYNC_NIF(take_along_axis) +ASYNC_NIF(take) +ASYNC_NIF(gather) +ASYNC_NIF(scatter_add) +ASYNC_NIF(scatter) + +ASYNC_NIF(all) +ASYNC_NIF(any) +ASYNC_NIF(sum) +ASYNC_NIF(product) +ASYNC_NIF(argmax) +ASYNC_NIF(argmin) +ASYNC_NIF(cumulative_sum) +ASYNC_NIF(cumulative_product) +ASYNC_NIF(cumulative_max) +ASYNC_NIF(cumulative_min) +ASYNC_NIF(max) +ASYNC_NIF(min) +ASYNC_NIF(clip) + +ASYNC_NIF(abs) +ASYNC_NIF(ceil) +ASYNC_NIF(conjugate) +ASYNC_NIF(floor) +ASYNC_NIF(negate) +ASYNC_NIF(round) +ASYNC_NIF(sign) +ASYNC_NIF(real) +ASYNC_NIF(imag) +ASYNC_NIF(is_nan) +ASYNC_NIF(is_infinity) +ASYNC_NIF(logical_not) +ASYNC_NIF(sigmoid) +ASYNC_NIF(asin) +ASYNC_NIF(asinh) +ASYNC_NIF(acos) +ASYNC_NIF(acosh) +ASYNC_NIF(cos) +ASYNC_NIF(cosh) +ASYNC_NIF(atan) +ASYNC_NIF(atanh) +ASYNC_NIF(erf) +ASYNC_NIF(erf_inv) +ASYNC_NIF(exp) +ASYNC_NIF(expm1) +ASYNC_NIF(log) +ASYNC_NIF(log1p) +ASYNC_NIF(rsqrt) +ASYNC_NIF(sin) +ASYNC_NIF(sinh) +ASYNC_NIF(sqrt) +ASYNC_NIF(tan) +ASYNC_NIF(tanh) + +ASYNC_NIF(add) +ASYNC_NIF(subtract) +ASYNC_NIF(multiply) +ASYNC_NIF(pow) +ASYNC_NIF(remainder) +ASYNC_NIF(divide) +ASYNC_NIF(atan2) +ASYNC_NIF(bitwise_and) +ASYNC_NIF(bitwise_or) +ASYNC_NIF(bitwise_xor) +ASYNC_NIF(bitwise_not) +ASYNC_NIF(left_shift) +ASYNC_NIF(right_shift) +ASYNC_NIF(minimum) +ASYNC_NIF(maximum) +ASYNC_NIF(quotient) +ASYNC_NIF(equal) +ASYNC_NIF(not_equal) +ASYNC_NIF(greater) +ASYNC_NIF(less) +ASYNC_NIF(greater_equal) +ASYNC_NIF(less_equal) +ASYNC_NIF(logical_and) +ASYNC_NIF(logical_or) +ASYNC_NIF(logical_xor) + +ASYNC_NIF(emlx_fft) +ASYNC_NIF(ifft) +ASYNC_NIF(emlx_fft2) +ASYNC_NIF(ifft2) +ASYNC_NIF(allclose) +ASYNC_NIF(isclose) +ASYNC_NIF(tri_inv) + +ASYNC_NIF(linalg_lu) +ASYNC_NIF(linalg_qr) +ASYNC_NIF(linalg_svd) +ASYNC_NIF(linalg_cholesky) +ASYNC_NIF(linalg_eigh) +ASYNC_NIF(linalg_inv) +ASYNC_NIF(linalg_pinv) +ASYNC_NIF(linalg_solve) +ASYNC_NIF(linalg_solve_triangular) +ASYNC_NIF(conv_general) +ASYNC_NIF(einsum) +ASYNC_NIF(tensordot) + +ASYNC_NIF(window_scatter_max) +ASYNC_NIF(window_scatter_min) + static ErlNifFunc nif_funcs[] = { + {"eval", 2, eval_async}, + {"to_blob", 2, to_blob_async}, + {"to_blob", 3, to_blob_async}, + {"tensor_to_shm", 3, tensor_to_shm_async}, + {"item", 2, item_async}, + {"from_blob", 5, from_blob_async}, + {"scalar_tensor", 4, scalar_tensor_async}, + + {"ones", 4, ones_async}, + {"full", 5, full_async}, + {"arange", 6, arange_async}, + {"eye", 5, eye_async}, + {"reshape", 4, reshape_async}, + {"astype", 4, astype_async}, + {"view", 4, view_async}, + {"broadcast_to", 4, broadcast_to_async}, + {"transpose", 4, transpose_async}, + {"pad", 7, pad_async}, + {"sort", 4, sort_async}, + {"argsort", 4, argsort_async}, + {"slice", 6, slice_async}, + {"slice_update", 6, slice_update_async}, + {"squeeze", 4, squeeze_async}, + {"as_strided", 6, as_strided_async}, + + {"stack", 4, stack_async}, + {"where", 5, where_async}, + {"concatenate", 4, concatenate_async}, + {"take_along_axis", 5, take_along_axis_async}, + {"take", 5, take_async}, + {"gather", 6, gather_async}, + {"scatter_add", 6, scatter_add_async}, + {"scatter", 6, scatter_async}, + + {"all", 5, all_async}, + {"any", 5, any_async}, + {"sum", 5, sum_async}, + {"product", 5, product_async}, + {"argmax", 4, argmax_async}, + {"argmax", 5, argmax_async}, + {"argmin", 4, argmin_async}, + {"argmin", 5, argmin_async}, + {"cumulative_sum", 6, cumulative_sum_async}, + {"cumulative_product", 6, cumulative_product_async}, + {"cumulative_max", 6, cumulative_max_async}, + {"cumulative_min", 6, cumulative_min_async}, + {"max", 5, max_async}, + {"min", 5, min_async}, + {"clip", 5, clip_async}, + + {"abs", 3, abs_async}, + {"ceil", 3, ceil_async}, + {"conjugate", 3, conjugate_async}, + {"floor", 3, floor_async}, + {"negate", 3, negate_async}, + {"round", 3, round_async}, + {"sign", 3, sign_async}, + {"real", 3, real_async}, + {"imag", 3, imag_async}, + {"is_nan", 3, is_nan_async}, + {"is_infinity", 3, is_infinity_async}, + {"logical_not", 3, logical_not_async}, + {"sigmoid", 3, sigmoid_async}, + {"asin", 3, asin_async}, + {"asinh", 3, asinh_async}, + {"acos", 3, acos_async}, + {"acosh", 3, acosh_async}, + {"cos", 3, cos_async}, + {"cosh", 3, cosh_async}, + {"atan", 3, atan_async}, + {"atanh", 3, atanh_async}, + {"erf", 3, erf_async}, + {"erf_inv", 3, erf_inv_async}, + {"exp", 3, exp_async}, + {"expm1", 3, expm1_async}, + {"log", 3, log_async}, + {"log1p", 3, log1p_async}, + {"rsqrt", 3, rsqrt_async}, + {"sin", 3, sin_async}, + {"sinh", 3, sinh_async}, + {"sqrt", 3, sqrt_async}, + {"tan", 3, tan_async}, + {"tanh", 3, tanh_async}, + + {"add", 4, add_async}, + {"subtract", 4, subtract_async}, + {"multiply", 4, multiply_async}, + {"pow", 4, pow_async}, + {"remainder", 4, remainder_async}, + {"divide", 4, divide_async}, + {"atan2", 4, atan2_async}, + {"bitwise_and", 4, bitwise_and_async}, + {"bitwise_or", 4, bitwise_or_async}, + {"bitwise_xor", 4, bitwise_xor_async}, + {"bitwise_not", 3, bitwise_not_async}, + {"left_shift", 4, left_shift_async}, + {"right_shift", 4, right_shift_async}, + {"minimum", 4, minimum_async}, + {"maximum", 4, maximum_async}, + {"quotient", 4, quotient_async}, + {"equal", 4, equal_async}, + {"not_equal", 4, not_equal_async}, + {"greater", 4, greater_async}, + {"less", 4, less_async}, + {"greater_equal", 4, greater_equal_async}, + {"less_equal", 4, less_equal_async}, + {"logical_and", 4, logical_and_async}, + {"logical_or", 4, logical_or_async}, + {"logical_xor", 4, logical_xor_async}, + + {"fft", 5, emlx_fft_async}, + {"ifft", 5, ifft_async}, + {"fft2", 5, emlx_fft2_async}, + {"ifft2", 5, ifft2_async}, + {"allclose", 7, allclose_async}, + {"isclose", 7, isclose_async}, + {"tri_inv", 4, tri_inv_async}, + + {"linalg_lu", 3, linalg_lu_async}, + {"linalg_qr", 3, linalg_qr_async}, + {"linalg_svd", 4, linalg_svd_async}, + {"linalg_cholesky", 4, linalg_cholesky_async}, + {"linalg_eigh", 4, linalg_eigh_async}, + {"linalg_inv", 3, linalg_inv_async}, + {"linalg_pinv", 3, linalg_pinv_async}, + {"linalg_solve", 4, linalg_solve_async}, + {"linalg_solve_triangular", 5, linalg_solve_triangular_async}, + {"conv_general", 10, conv_general_async}, + {"einsum", 5, einsum_async}, + {"tensordot", 6, tensordot_async}, + + {"window_scatter_max", 9, window_scatter_max_async}, + {"window_scatter_min", 9, window_scatter_min_async}, + + // ── Sync (non-routed) NIFs. {"strides", 1, strides}, - {"as_strided", 5, as_strided}, {"scalar_type", 1, scalar_type}, - {"eval", 1, eval}, - {"view", 3, view}, - {"stack", 3, stack}, - {"where", 4, where}, - {"concatenate", 3, concatenate}, - {"take_along_axis", 4, take_along_axis}, - {"take", 4, take}, - {"gather", 5, gather}, - {"scatter_add", 5, scatter_add}, - {"scatter", 5, scatter}, - {"slice", 5, slice}, - {"slice_update", 5, slice_update}, - {"squeeze", 3, squeeze}, - {"item", 1, item, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"all", 4, all}, - {"any", 4, any}, - {"sum", 4, sum}, - {"product", 4, product}, - {"argmax", 3, argmax}, - {"argmax", 4, argmax}, - {"argmin", 3, argmin}, - {"argmin", 4, argmin}, - {"cumulative_sum", 5, cumulative_sum}, - {"cumulative_product", 5, cumulative_product}, - {"cumulative_max", 5, cumulative_max}, - {"cumulative_min", 5, cumulative_min}, {"shape", 1, shape}, - {"reshape", 3, reshape}, - {"astype", 3, astype}, - {"to_blob", 1, to_blob, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"to_blob", 2, to_blob, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"from_blob", 4, from_blob}, - {"scalar_tensor", 3, scalar_tensor}, - {"ones", 3, ones}, - {"full", 4, full}, - {"arange", 5, arange}, - {"eye", 4, eye}, - {"broadcast_to", 3, broadcast_to}, - {"tensordot", 5, tensordot}, - {"einsum", 4, einsum}, - {"conv_general", 9, conv_general}, - {"transpose", 3, transpose}, - {"pad", 6, pad}, - {"sort", 3, sort}, - {"argsort", 3, argsort}, - {"abs", 2, abs}, - {"ceil", 2, ceil}, - {"conjugate", 2, conjugate}, - {"floor", 2, floor}, - {"negate", 2, negate}, - {"round", 2, round}, - {"sign", 2, sign}, - {"real", 2, real}, - {"imag", 2, imag}, - {"is_nan", 2, is_nan}, - {"is_infinity", 2, is_infinity}, - {"logical_not", 2, logical_not}, - {"sigmoid", 2, sigmoid}, - {"asin", 2, asin}, - {"asinh", 2, asinh}, - {"acos", 2, acos}, - {"acosh", 2, acosh}, - {"cos", 2, cos}, - {"cosh", 2, cosh}, - {"atan", 2, atan}, - {"atanh", 2, atanh}, - {"erf", 2, erf}, - {"erf_inv", 2, erf_inv}, - {"exp", 2, exp}, - {"expm1", 2, expm1}, - {"log", 2, log}, - {"log1p", 2, log1p}, - {"rsqrt", 2, rsqrt}, - {"sin", 2, sin}, - {"sinh", 2, sinh}, - {"sqrt", 2, sqrt}, - {"tan", 2, tan}, - {"tanh", 2, tanh}, - {"add", 3, add}, - {"subtract", 3, subtract}, - {"multiply", 3, multiply}, - {"pow", 3, pow}, - {"remainder", 3, remainder}, - {"divide", 3, divide}, - {"atan2", 3, atan2}, - {"bitwise_and", 3, bitwise_and}, - {"bitwise_or", 3, bitwise_or}, - {"bitwise_xor", 3, bitwise_xor}, - {"bitwise_not", 2, bitwise_not}, - {"left_shift", 3, left_shift}, - {"right_shift", 3, right_shift}, - {"minimum", 3, minimum}, - {"maximum", 3, maximum}, - {"quotient", 3, quotient}, - {"equal", 3, equal}, - {"not_equal", 3, not_equal}, - {"greater", 3, greater}, - {"less", 3, less}, - {"greater_equal", 3, greater_equal}, - {"less_equal", 3, less_equal}, - {"logical_and", 3, logical_and}, - {"logical_or", 3, logical_or}, - {"logical_xor", 3, logical_xor}, - {"fft", 4, emlx_fft}, - {"ifft", 4, ifft}, - {"fft2", 4, emlx_fft2}, - {"ifft2", 4, ifft2}, - {"allclose", 6, allclose}, - {"isclose", 6, isclose}, {"deallocate", 1, deallocate}, - {"max", 4, max}, - {"min", 4, min}, - {"clip", 4, clip}, - {"tri_inv", 3, tri_inv}, - {"linalg_lu", 2, linalg_lu}, - {"linalg_qr", 2, linalg_qr}, - {"linalg_svd", 3, linalg_svd}, - {"linalg_cholesky", 3, linalg_cholesky}, - {"linalg_eigh", 3, linalg_eigh}, - {"linalg_inv", 2, linalg_inv}, - {"linalg_pinv", 2, linalg_pinv}, - {"linalg_solve", 3, linalg_solve}, - {"linalg_solve_triangular", 4, linalg_solve_triangular}, - {"window_scatter_max", 8, window_scatter_max}, - {"window_scatter_min", 8, window_scatter_min}, + {"array_from_shm", 4, array_from_shm, ERL_NIF_DIRTY_JOB_CPU_BOUND}, + {"shm_unlink_handle", 1, shm_unlink_handle}, + {"tensor_data_ptr", 1, tensor_data_ptr, ERL_NIF_DIRTY_JOB_CPU_BOUND}, + {"array_from_ptr", 5, array_from_ptr, ERL_NIF_DIRTY_JOB_CPU_BOUND}, {"memory_info", 0, memory_info}, {"clear_cache", 0, clear_cache}, {"reset_peak_memory", 0, reset_peak_memory}, {"set_memory_limit", 1, set_memory_limit}, {"set_cache_limit", 1, set_cache_limit}, - {"tensor_data_ptr", 1, tensor_data_ptr, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"array_from_ptr", 5, array_from_ptr, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"tensor_to_shm", 2, tensor_to_shm, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"array_from_shm", 4, array_from_shm, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"shm_unlink_handle", 1, shm_unlink_handle} -}; + + // ── Worker control NIFs. + {"command_queue_new", 1, command_queue_new}, + {"command_queue_synchronize", 1, command_queue_synchronize}}; ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL) diff --git a/c_src/emlx_worker.hpp b/c_src/emlx_worker.hpp new file mode 100644 index 0000000..fabbd25 --- /dev/null +++ b/c_src/emlx_worker.hpp @@ -0,0 +1,168 @@ +#pragma once + +// emlx::Worker — one OS thread + one mlx::core::Stream. +// +// Each Worker owns a dedicated std::thread that pulls Jobs from a +// FIFO queue and runs them. The thread sets the worker's MLX Stream +// as its current-thread default on startup, so any MLX op invoked +// from inside a Job (mx::eval, mx::synchronize, etc.) dispatches to +// that stream. +// +// This is the dispatch primitive that backs both the application +// default worker (held in :persistent_term, see lib/emlx/application.ex) +// and per-context EMLX.CommandQueue instances created by +// Nx.Serving partitions or user code. +// +// The class is header-only because it is co-located with c_src/emlx_nif.cpp +// (single translation unit per the project's c_src layout convention). + +#include "erl_nif.h" +#include "mlx/mlx.h" +#include "mlx/scheduler.h" +#include "mlx/stream.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace emlx { + +class Worker { +public: + using Job = std::function; + + // Spawns the worker thread. The fresh mlx::core::Stream is allocated + // *inside* the worker thread — MLX 0.31.2 has thread-local state for + // GPU streams (mlx-lm#1090, mlx-lm#1179: "There is no Stream(gpu, N) + // in current thread"), so a stream created on thread A cannot be + // synchronized or dispatched to from thread B. We block here until + // the worker thread has created the stream and signalled ready, so + // that callers can rely on stream() / device() being valid the + // moment the constructor returns. + explicit Worker(mlx::core::Device device) + : device_(device), stream_(/*placeholder index*/ -1, device) { + std::promise stream_promise; + auto stream_future = stream_promise.get_future(); + thread_ = std::thread(&Worker::thread_main, this, std::move(stream_promise)); + // Blocks (or rethrows) until the worker thread has registered its + // stream and signalled ready. + // + // If stream_future.get() throws (worker thread set an exception), + // the thread has already exited cleanly. We must join before + // re-throwing, because ~thread() on a joinable thread calls + // std::terminate() — which is the "terminate called without an + // active exception" crash seen on Linux with no GPU. + try { + stream_ = stream_future.get(); + } catch (...) { + if (thread_.joinable()) { + thread_.join(); + } + throw; + } + } + + // Sets the stop flag, drains any in-flight jobs already past the + // pop, then joins the OS thread. Pending jobs in the queue at stop + // time are still executed (so callers awaiting an enif_send reply + // receive it). Jobs posted *after* the destructor begins are + // rejected by post(). + ~Worker() { + { + std::lock_guard lock(mutex_); + stop_ = true; + } + cv_.notify_all(); + if (thread_.joinable()) { + thread_.join(); + } + } + + Worker(const Worker &) = delete; + Worker &operator=(const Worker &) = delete; + Worker(Worker &&) = delete; + Worker &operator=(Worker &&) = delete; + + // Pushes a job onto the queue. Throws std::runtime_error if the + // worker is already stopping (the caller should translate this to + // an `{:error, "worker stopped"}` NIF return). + void post(Job job) { + { + std::lock_guard lock(mutex_); + if (stop_.load(std::memory_order_acquire)) { + throw std::runtime_error("Worker is stopping; cannot post job"); + } + queue_.push_back(std::move(job)); + } + cv_.notify_one(); + } + + mlx::core::Device device() const { return device_; } + mlx::core::Stream stream() const { return stream_; } + +private: + void thread_main(std::promise stream_promise) { + // Allocate the stream on THIS thread (MLX 0.31.2 thread-locality + // requirement — see constructor comment). Pin all MLX ops issued + // from this thread to our stream by making it the per-thread + // default. Graph-construction NIFs continue to run on BEAM + // scheduler threads (and use those threads' defaults); only ops + // invoked *inside* a posted job (currently mx::eval and + // mx::synchronize) inherit this binding. + try { + mlx::core::Stream stream = mlx::core::new_stream(device_); + mlx::core::set_default_stream(stream); + stream_promise.set_value(stream); + } catch (...) { + stream_promise.set_exception(std::current_exception()); + return; + } + + while (true) { + Job job; + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { + return !queue_.empty() || stop_.load(std::memory_order_acquire); + }); + if (queue_.empty() && stop_.load(std::memory_order_acquire)) { + break; + } + job = std::move(queue_.front()); + queue_.pop_front(); + } + + // Jobs are responsible for their own exception handling and for + // delivering a reply to the calling BEAM process via enif_send. + // We never let an exception escape into the std::thread wrapper + // (which would call std::terminate). + try { + job(); + } catch (...) { + // Swallow. The job contract is: catch your own errors and + // turn them into a {:error, _} reply. Anything that escapes + // here would orphan the calling process's receive. + } + } + + // Release any per-thread MLX resources we hold (the StreamThread + // for stream_ in the global scheduler will be torn down when the + // last reference to its index drops). + mlx::core::clear_streams(); + } + + mlx::core::Device device_; + mlx::core::Stream stream_; + std::thread thread_; + std::mutex mutex_; + std::condition_variable cv_; + std::deque queue_; + std::atomic stop_{false}; +}; + +} // namespace emlx diff --git a/lib/emlx.ex b/lib/emlx.ex index 18ff663..632cef8 100644 --- a/lib/emlx.ex +++ b/lib/emlx.ex @@ -27,6 +27,11 @@ defmodule EMLX.Macro do @doc """ Function that receives a device and allocates a tensor. + + Routes through an `EMLX.CommandQueue` worker because MLX 0.31.2 pins + every GPU stream to the OS thread that created it. The macro injects + `worker = EMLX.resolve_worker(device)` and prepends `worker` to the + NIF call, then awaits the tagged-ref reply via `await_worker/1`. """ defmacro defdevice(call) do {name, args} = Macro.decompose_call(call) @@ -42,14 +47,21 @@ defmodule EMLX.Macro do end quote do - @mlx_function {unquote(name), unquote(length(args))} + # NIF arity is original + 1 because the wrapper prepends `worker` as + # argv[0]. The original `device` arg is preserved (kept for MLX's + # `to_stream(device) -> default_stream(device)` lookup, which on the + # worker thread returns the worker's stream — see emlx_async.hpp). + @mlx_function {unquote(name), unquote(length(args) + 1)} def unquote(name)(unquote_splicing(args)) do unquote(tensors) {user_device, index} = normalize_device!(var!(device)) - var!(device) = mlx_device!(user_device, index) + {worker, effective_device} = resolve_worker(user_device) + var!(device) = mlx_device!(effective_device, index) - EMLX.NIF.unquote(name)(unquote_splicing(args)) - |> unwrap_tensor!(user_device) + job_ref = + EMLX.NIF.unquote(name)(worker, unquote_splicing(args)) |> unwrap!() + + await_worker(job_ref) |> wrap_tensor(effective_device) end end end @@ -57,22 +69,45 @@ defmodule EMLX.Macro do @doc """ Generates a call that returns a tensor (or a tuple/list of tensors). - All tensor variables must start with the name tensor. + All tensor variables must start with the name tensor. Routes through an + `EMLX.CommandQueue` worker (see `defdevice/1`). """ defmacro deftensor(call) do - defcall(call, :unwrap_tensor!, [Macro.var(:device, __MODULE__)]) + {name, args} = Macro.decompose_call(call) + tensors = tensors(args) + + if tensors == [] do + raise ArgumentError, "at least one tensor required in #{name}/#{length(args)}" + end + + quote do + # NIF arity = original + 2: leading `worker` + trailing `device`. The + # device atom is still passed through so the underlying sync NIF body + # (e.g. `mlx::core::add(*a, *b, device)`) gets a `StreamOrDevice` that + # MLX resolves to the worker's stream via the worker thread's default + # stream slot. + @mlx_function {unquote(name), unquote(length(args) + 2)} + def unquote(name)(unquote_splicing(args)) do + {unquote(tensors), device} = prepare_tensors!(unquote(tensors)) + {worker, effective_device} = resolve_worker(device) + + job_ref = + EMLX.NIF.unquote(name)(worker, unquote_splicing(args), effective_device) + |> unwrap!() + + await_worker(job_ref) |> wrap_tensor(effective_device) + end + end end @doc """ - Generates a call that returns a value (not a tensor). - - All tensor variables must start with the name tensor. + Generates a call that returns a value (not a tensor). NOT worker-routed — + use only for pure metadata / refcount NIFs (`scalar_type`, `shape`, + `strides`, `deallocate`). Graph-touching value NIFs (e.g. `item`) must be + hand-written so they thread a worker through `EMLX.NIF.(worker, ...)` + and await a tagged-ref reply. """ defmacro defvalue(call) do - defcall(call, :unwrap!, []) - end - - defp defcall(call, unwrapper, extra) do {name, args} = Macro.decompose_call(call) tensors = tensors(args) @@ -81,12 +116,10 @@ defmodule EMLX.Macro do end quote do - @mlx_function {unquote(name), unquote(length(args) + length(extra))} + @mlx_function {unquote(name), unquote(length(args))} def unquote(name)(unquote_splicing(args)) do - {unquote(tensors), device} = prepare_tensors!(unquote(tensors)) - - EMLX.NIF.unquote(name)(unquote_splicing(args ++ extra)) - |> unquote(unwrapper)(unquote_splicing(extra)) + {unquote(tensors), _device} = prepare_tensors!(unquote(tensors)) + EMLX.NIF.unquote(name)(unquote_splicing(args)) |> unwrap!() end end end @@ -291,15 +324,22 @@ defmodule EMLX do defvalue shape(tensor) def to_blob({device, ref} = tensor) when is_tensor(device, ref) do - # Two-step to_blob: eval on main scheduler, then copy on dirty scheduler + # Eval first so the underlying MLX array is materialised; then ask the + # worker for the contiguous-copy + zero-copy resource binary. Both + # operations are routed through the same worker resolution path so + # that the contiguous fallback in `to_blob_term` runs on the same OS + # thread that owns the tensor's stream encoder. eval(tensor) - EMLX.NIF.to_blob(ref) |> unwrap!() + {worker, _effective_device} = resolve_worker(device) + job_ref = EMLX.NIF.to_blob(worker, ref) |> unwrap!() + await_worker(job_ref) end def to_blob({device, ref} = tensor, limit) when is_tensor(device, ref) do - # Two-step to_blob: eval on main scheduler, then copy on dirty scheduler eval(tensor) - EMLX.NIF.to_blob(ref, limit) |> unwrap!() + {worker, _effective_device} = resolve_worker(device) + job_ref = EMLX.NIF.to_blob(worker, ref, limit) |> unwrap!() + await_worker(job_ref) end @doc """ @@ -328,7 +368,12 @@ defmodule EMLX do """ def tensor_to_shm({device, ref} = tensor, permissions) when is_tensor(device, ref) do eval(tensor) - EMLX.NIF.tensor_to_shm(ref, permissions) |> unwrap!() + {worker, _effective_device} = resolve_worker(device) + # Worker-routed: `tensor_to_shm`'s NIF body may call `mx::contiguous` + # + `mx::eval` on a non-contiguous tensor, which both touch the + # thread-local Metal encoder. Must run on the worker. + job_ref = EMLX.NIF.tensor_to_shm(worker, ref, permissions) |> unwrap!() + await_worker(job_ref) end @doc """ @@ -346,18 +391,16 @@ defmodule EMLX do defp unwrap!({:ok, result}), do: result defp unwrap!({:error, error}), do: raise(EMLX.NIFError, List.to_string(error)) - defp unwrap_tensor!(tagged_result, device) do - case unwrap!(tagged_result) do - ref when is_reference(ref) -> - {device, ref} + # Wraps a worker-thread payload in {device, ref} envelopes. + # Already-unwrapped (no leading {:ok, _}) — `await_worker/1` peels that + # off when the worker delivers the reply. + defp wrap_tensor(ref, device) when is_reference(ref), do: {device, ref} - list when is_list(list) -> - Enum.map(list, &{device, &1}) + defp wrap_tensor(list, device) when is_list(list), + do: Enum.map(list, &{device, &1}) - tuple when is_tuple(tuple) -> - tuple |> Tuple.to_list() |> Enum.map(&{device, &1}) |> List.to_tuple() - end - end + defp wrap_tensor(tuple, device) when is_tuple(tuple), + do: tuple |> Tuple.to_list() |> Enum.map(&{device, &1}) |> List.to_tuple() defp prepare_tensors_list!(tensors_list, device) do Enum.map_reduce(tensors_list, device, fn @@ -387,24 +430,143 @@ defmodule EMLX do defp merge_device(_, _), do: :cpu defvalue deallocate(tensor_ref) - defvalue eval(tensor) + + @doc """ + Evaluates a (possibly lazy) MLX tensor by routing the work through an + `EMLX.CommandQueue`. Blocks the caller until the worker thread has + finished `mlx::core::eval/1` for this tensor. + + Resolves the queue via `resolve_worker/1`: + + 1. If the calling process has bound a queue with + `EMLX.CommandQueue.with_queue/2`, that queue is used. + 2. Otherwise the application-default worker for the tensor's device + (CPU or GPU) is used — see `EMLX.Application`. + """ + def eval({device, ref}) when is_tensor(device, ref) do + {worker, _effective_device} = resolve_worker(device) + job_ref = EMLX.NIF.eval(worker, ref) |> unwrap!() + await_worker(job_ref) + end + + # ── Worker resolution ────────────────────────────────────────────────────── + # + # `Process.get(:emlx_command_queue)` is set by EMLX.CommandQueue.with_queue/2. + # The value is `{worker_ref, device}` so we know the queue's device without a + # second lookup. Returns `{worker, effective_device}` — callers must use + # `effective_device` for both the NIF device argument and `wrap_tensor/2`. + # + # When no queue is bound, we fall back to the application-default worker for + # the requested device. + + @doc false + def resolve_worker(device) do + case Process.get(:emlx_command_queue) do + nil -> {EMLX.Application.default_worker(device), device} + {worker, ^device} -> {worker, device} + {worker, bound_device} -> resolve_cross_device(device, worker, bound_device) + end + end + + # CPU and GPU operations do not share thread-local Metal encoder state, so + # routing a CPU tensor through a GPU queue (or vice-versa) is safe — MLX + # inserts the necessary cross-stream synchronization internally. We therefore + # do NOT force an intermediate eval; we let MLX manage the graph dependency. + defp resolve_cross_device(requested, worker, bound) do + if Application.get_env(:emlx, :cross_device_promotion, false) do + if Application.get_env(:emlx, :warn_cross_device, false) do + require Logger + + Logger.warning( + "[EMLX] cross-device promotion: #{requested} tensor routed to #{bound} queue" + ) + end + + {worker, bound} + else + {EMLX.Application.default_worker(requested), requested} + end + end + + defp await_worker(job_ref) do + receive do + # Worker NIFs (sync bodies) return one of: + # nx::nif::ok(env) => :ok + # nx::nif::ok(env, term) => {:ok, term} + # nx::nif::error(env, msg) => {:error, msg} + # The async wrapper forwards the payload as-is in {ref, payload}. + {^job_ref, :ok} -> :ok + {^job_ref, {:ok, result}} -> result + {^job_ref, {:error, reason}} -> raise(EMLX.NIFError, List.to_string(reason)) + end + end deftensor slice(tensor, starts, stops, strides) deftensor slice_update(tensor, tensor_updates, starts, stops) deftensor squeeze(tensor, axes) - defvalue item(tensor) defvalue strides(tensor) + @doc """ + Returns the scalar value of a 0-d tensor as a number. + + Worker-routed: the NIF body calls `mlx::core::eval(*t)` and `t->item()`, + both of which require running on the OS thread that owns the tensor's + stream encoder. + """ + def item({device, ref}) when is_tensor(device, ref) do + {worker, _effective_device} = resolve_worker(device) + job_ref = EMLX.NIF.item(worker, ref) |> unwrap!() + await_worker(job_ref) + end + @behaviour Nx.Defn.Compiler + # Known EMLX-specific compiler opts. `:command_queue` is injected by + # `__partitions_options__/1` but may also be passed directly by callers + # that manage their own queues (equivalent to a manual `with_queue`). + @valid_compiler_keys [:device, :max_concurrency, :command_queue] + @impl Nx.Defn.Compiler - defdelegate __jit__(key, vars, fun, args_list, opts), to: Nx.Defn.Evaluator + def __jit__(key, vars, fun, args_list, opts) do + __compile__(key, vars, fun, opts).(args_list) + end @impl Nx.Defn.Compiler - defdelegate __compile__(key, vars, fun, opts), to: Nx.Defn.Evaluator + def __compile__(key, vars, fun, opts) do + Keyword.validate!(opts, @valid_compiler_keys) + {compiler_opts, rest_opts} = split_compiler_opts(opts) + queue = Keyword.get(compiler_opts, :command_queue) + + inner = + Nx.Defn.Evaluator.__compile__( + key, + vars, + fun, + Keyword.put(rest_opts, :compiler, Nx.Defn.Evaluator) + ) + + if queue do + # Capture the queue ref in a closure so each invocation of the compiled + # function routes through the correct CommandQueue. The queue lives as + # long as the Nx.Serving module_state that holds this compiled function. + fn inputs -> EMLX.CommandQueue.with_queue(queue, fn -> inner.(inputs) end) end + else + inner + end + end @impl Nx.Defn.Compiler - defdelegate __partitions_options__(opts), to: Nx.Defn.Evaluator + def __partitions_options__(opts) do + n = Keyword.get(opts, :max_concurrency, 1) + device = Keyword.get(opts, :device, :gpu) + + # Allocate one CommandQueue (and its OS thread) per partition. This runs + # inside Nx.Serving's GenServer init/1 — queues are owned by module_state. + # For N ≤ 8 the synchronous pthread_create calls take a few milliseconds. + for _i <- 1..n do + [device: device, command_queue: EMLX.CommandQueue.new!(device)] + end + end @impl Nx.Defn.Compiler def __to_backend__(opts) do @@ -412,6 +574,12 @@ defmodule EMLX do {EMLX.Backend, device: device} end + # Splits opts into {emlx_compiler_opts, rest_opts}. The rest_opts are + # forwarded to Nx.Defn.Evaluator; EMLX-specific keys are consumed here. + defp split_compiler_opts(opts) do + Enum.split_with(opts, fn {k, _v} -> k in @valid_compiler_keys end) + end + @doc """ Returns a map with current memory usage information. diff --git a/lib/emlx/application.ex b/lib/emlx/application.ex new file mode 100644 index 0000000..89f61c8 --- /dev/null +++ b/lib/emlx/application.ex @@ -0,0 +1,82 @@ +defmodule EMLX.Application do + @moduledoc """ + OTP Application for EMLX. + + Allocates the application-default `EMLX.CommandQueue` (one per device: + `:cpu` and `:gpu`) at boot and stashes the NIF resource references in + `:persistent_term`. These are the workers used by `EMLX.eval/1` and + `EMLX.to_blob/1` for any process that has not bound its own queue via + `EMLX.CommandQueue.with_queue/2`. + + See `clean-room-import/01-worker-thread-dispatch.md` for the rationale + behind `:persistent_term` instead of a `GenServer` + `Registry`. + + ## Idempotency + + `start/2` is safe to call more than once (e.g. from an umbrella where + `:emlx` is referenced by several apps). The first call to allocate a + worker for a given device wins; subsequent calls observe the existing + `:persistent_term` entry and skip. This matters because every + `:persistent_term.put/2` triggers a global GC scan across every + process heap on the node — we do exactly two puts (CPU + GPU) over + the BEAM's lifetime. + + ## GPU absence + + On platforms where MLX cannot allocate a GPU stream (e.g. Linux without + Metal), `EMLX.NIF.command_queue_new(:gpu)` returns `{:error, _}` and + this module silently skips the GPU worker. Subsequent + `EMLX.eval/1` calls on a GPU tensor will raise at use time with the + underlying `:persistent_term` `ArgumentError`. + """ + + use Application + + @doc false + def start(_type, _args) do + ensure_default_worker!(:cpu, _gpu_optional? = false) + ensure_default_worker!(:gpu, _gpu_optional? = true) + Supervisor.start_link([], strategy: :one_for_one, name: __MODULE__) + end + + @doc """ + Returns the application-default `EMLX.CommandQueue` NIF resource for + the given device. + + Raises `ArgumentError` if no worker has been allocated for the device + (e.g. asking for `:gpu` on a system without Metal). + + **Never** call `:persistent_term.put/2` on the underlying key + (`{EMLX, :default_worker, device}`) — overwriting a `:persistent_term` + value triggers a node-wide GC. + """ + @spec default_worker(:cpu | :gpu) :: reference() + def default_worker(device) when device in [:cpu, :gpu] do + :persistent_term.get(persistent_term_key(device)) + end + + defp ensure_default_worker!(device, gpu_optional?) do + key = persistent_term_key(device) + + case :persistent_term.get(key, :unset) do + :unset -> + case EMLX.NIF.command_queue_new(device) do + {:ok, ref} -> + :persistent_term.put(key, ref) + + {:error, _reason} when gpu_optional? -> + :ok + + {:error, reason} -> + raise EMLX.NIFError, + "EMLX.Application could not allocate default #{device} worker: " <> + List.to_string(reason) + end + + _existing -> + :ok + end + end + + defp persistent_term_key(device), do: {EMLX, :default_worker, device} +end diff --git a/lib/emlx/command_queue.ex b/lib/emlx/command_queue.ex new file mode 100644 index 0000000..dd67cef --- /dev/null +++ b/lib/emlx/command_queue.ex @@ -0,0 +1,133 @@ +defmodule EMLX.CommandQueue do + @moduledoc """ + A handle to an `emlx::Worker` OS thread + `mlx::core::Stream` pair. + + Each `CommandQueue` wraps a NIF resource reference and records the device + (`:cpu` or `:gpu`) the underlying stream was created for. The device tag is + necessary because the raw NIF reference carries no device metadata, and + `EMLX.resolve_worker/1` needs it to decide whether cross-device promotion + applies. + + ## Usage + + {:ok, q} = EMLX.CommandQueue.new(:gpu) + + EMLX.CommandQueue.with_queue(q, fn -> + # All EMLX / Nx operations in this block route through `q` + Nx.add(a, b) + end) + + ## Process binding + + `with_queue/2` stores `{ref, device}` in the calling process's dictionary + under `:emlx_command_queue` for the duration of the given function, then + restores the previous value (supporting nested calls). `EMLX.resolve_worker/1` + reads this key to bypass the application-default worker. + + ## Inherited queue and `EMLX.Backend` + + Every EMLX NIF call (including those made from inside `EMLX.Backend` + callbacks) inherits the bound queue automatically through + `EMLX.resolve_worker/1`. This means that if your code calls a standard + `Nx` operation inside a `with_queue/2` block — even indirectly via a + library that uses `EMLX.Backend` — the work routes through the bound + queue. This is almost always the desired behaviour (e.g. all ops in a + Bumblebee forward pass go to the same Metal command queue). + + There is deliberately **no opt-out mechanism**: if you need work to run on + a different queue, open a nested `with_queue/2` for the inner scope. The + inner binding shadows the outer one for its duration and is restored on + exit. + + ## Cross-device promotion + + When a tensor's device does not match the bound queue's device, the + behaviour is controlled by application config: + + config :emlx, + cross_device_promotion: false, # default: fall back to device-default worker + warn_cross_device: false # default: no Logger.warning on mismatch + + See `EMLX.resolve_worker/1` for the promotion logic. + """ + + @enforce_keys [:ref, :device] + defstruct [:ref, :device] + + @type t :: %__MODULE__{ref: reference(), device: :cpu | :gpu} + + @doc """ + Allocates a new `CommandQueue` for `device`. + + Spawns a dedicated OS thread and a new `mlx::core::Stream` on it. Returns + `{:error, reason}` if the NIF cannot allocate the stream (e.g. `:gpu` on a + system without Metal). + """ + @spec new(:cpu | :gpu) :: {:ok, t()} | {:error, list()} + def new(device) do + case EMLX.NIF.command_queue_new(device) do + {:ok, ref} -> {:ok, %__MODULE__{ref: ref, device: device}} + {:error, _} = err -> err + end + end + + @doc """ + Like `new/1` but raises `EMLX.NIFError` on failure instead of returning + `{:error, reason}`. + + Useful in one-liner contexts (e.g. `__partitions_options__/1`) where the + caller has no reasonable recovery path if the device is unavailable. + """ + @spec new!(:cpu | :gpu) :: t() + def new!(device) do + case new(device) do + {:ok, q} -> q + {:error, reason} -> raise(EMLX.NIFError, List.to_string(reason)) + end + end + + @doc """ + Blocks the calling process until all previously enqueued jobs on `queue` + have finished and MLX has flushed its GPU command buffer. + + Internally posts a barrier job that calls `mlx::core::synchronize(stream)` + on the worker thread, then blocks in `receive` until the reply arrives. + """ + @spec synchronize(t()) :: :ok + def synchronize(%__MODULE__{ref: ref}) do + case EMLX.NIF.command_queue_synchronize(ref) do + {:ok, job_ref} -> + receive do + {^job_ref, {:ok, _}} -> :ok + {^job_ref, {:error, reason}} -> raise(EMLX.NIFError, List.to_string(reason)) + end + + {:error, reason} -> + raise(EMLX.NIFError, List.to_string(reason)) + end + end + + @doc """ + Runs `fun` with `queue` bound to the calling process for the duration. + + Stores `{ref, device}` under `:emlx_command_queue` in the process + dictionary, then restores the previous value (or deletes the key) in an + `after` block. Nested calls are safe — each level saves and restores + independently. + """ + @spec with_queue(t(), (-> result)) :: result when result: term() + def with_queue(%__MODULE__{ref: ref, device: device}, fun) when is_function(fun, 0) do + previous = Process.get(:emlx_command_queue) + Process.put(:emlx_command_queue, {ref, device}) + + try do + fun.() + after + if previous do + Process.put(:emlx_command_queue, previous) + else + Process.delete(:emlx_command_queue) + end + end + end +end diff --git a/lib/emlx/nif.ex b/lib/emlx/nif.ex index f723202..552d106 100644 --- a/lib/emlx/nif.ex +++ b/lib/emlx/nif.ex @@ -17,11 +17,23 @@ defmodule EMLX.NIF do :erlang.load_nif(path, 0) end - def to_blob(_tensor) do + # Worker-routed NIF stubs. The first argument is always an + # EMLX.CommandQueue resource ref; the C++ wrapper (emlx_async.hpp) + # extracts it and posts the rest to the worker thread, returning a + # job ref for the caller to `receive` on. + def to_blob(_worker, _tensor) do :erlang.nif_error(:nif_not_loaded) end - def to_blob(_tensor, _limit) do + def to_blob(_worker, _tensor, _limit) do + :erlang.nif_error(:nif_not_loaded) + end + + def eval(_worker, _tensor_ref) do + :erlang.nif_error(:nif_not_loaded) + end + + def item(_worker, _tensor_ref) do :erlang.nif_error(:nif_not_loaded) end @@ -53,7 +65,7 @@ defmodule EMLX.NIF do :erlang.nif_error(:nif_not_loaded) end - def tensor_to_shm(_tensor, _permissions) do + def tensor_to_shm(_worker, _tensor, _permissions) do :erlang.nif_error(:nif_not_loaded) end @@ -64,4 +76,17 @@ defmodule EMLX.NIF do def shm_unlink_handle(_name) do :erlang.nif_error(:nif_not_loaded) end + + # ── Worker / EMLX.CommandQueue control NIFs ─────────────────────────────── + # `command_queue_post_eval` and `command_queue_post_to_blob` were folded + # into the generic async dispatch path; their public entry points are + # `eval/2` and `to_blob/{2,3}` above. + + def command_queue_new(_device) do + :erlang.nif_error(:nif_not_loaded) + end + + def command_queue_synchronize(_queue_ref) do + :erlang.nif_error(:nif_not_loaded) + end end diff --git a/mix.exs b/mix.exs index d8536f9..3b4576f 100644 --- a/mix.exs +++ b/mix.exs @@ -54,7 +54,8 @@ defmodule EMLX.MixProject do def application do [ - extra_applications: [:logger, :inets, :ssl, :public_key, :crypto] + extra_applications: [:logger, :inets, :ssl, :public_key, :crypto], + mod: {EMLX.Application, []} ] end diff --git a/test/emlx/command_queue_nif_test.exs b/test/emlx/command_queue_nif_test.exs new file mode 100644 index 0000000..c526d67 --- /dev/null +++ b/test/emlx/command_queue_nif_test.exs @@ -0,0 +1,178 @@ +defmodule EMLX.CommandQueueNIFTest do + # Unit tests for the worker / command_queue NIFs introduced in + # clean-room-import/01-worker-thread-dispatch.md. These exercise the + # C++ primitives directly via EMLX.NIF before the Elixir-side + # EMLX.CommandQueue / EMLX.Application wrappers are exposed. + # + # async: true is fine — every test allocates its own worker (no + # shared global state). + use ExUnit.Case, async: true + + @moduletag :metal + + alias EMLX.NIF + + defp ok!({:ok, value}), do: value + defp ok!(:ok), do: :ok + + describe "command_queue_new/1" do + test "returns a reference for :gpu" do + ref = NIF.command_queue_new(:gpu) |> ok!() + assert is_reference(ref) + end + + test "returns a reference for :cpu" do + ref = NIF.command_queue_new(:cpu) |> ok!() + assert is_reference(ref) + end + + test "rejects unknown device atoms" do + assert {:error, _} = NIF.command_queue_new(:tpu) + end + + test "two queues yield distinct refs" do + a = NIF.command_queue_new(:gpu) |> ok!() + b = NIF.command_queue_new(:gpu) |> ok!() + assert a != b + end + end + + describe "command_queue_synchronize/1" do + test "round-trip on an idle queue is prompt" do + queue = NIF.command_queue_new(:gpu) |> ok!() + job_ref = NIF.command_queue_synchronize(queue) |> ok!() + assert is_reference(job_ref) + + assert_receive {^job_ref, {:ok, :ok}}, 1_000 + end + + test "sequential synchronize calls each get their own reply" do + queue = NIF.command_queue_new(:gpu) |> ok!() + + refs = + for _ <- 1..5 do + NIF.command_queue_synchronize(queue) |> ok!() + end + + for ref <- refs do + assert_receive {^ref, {:ok, :ok}}, 1_000 + end + end + + test "rejects an invalid queue ref" do + bogus = make_ref() + assert {:error, _} = NIF.command_queue_synchronize(bogus) + end + end + + # All graph-touching NIFs are now async-routed via emlx::async_dispatch. + # The smoke tests below validate the routing for a representative few: + # eval, to_blob, ones (creation), add (binary), reshape (manipulation). + # The full Elixir API is exercised by the rest of the suite. + describe "graph-touching NIFs route through worker (smoke)" do + test "ones/4 + eval/2 on a private queue" do + queue = NIF.command_queue_new(:gpu) |> ok!() + + tensor_ref = NIF.ones(queue, {4, 4}, :float32, :gpu) |> ok!() + assert is_reference(tensor_ref) + + assert_receive {^tensor_ref, {:ok, ref}}, 5_000 + assert is_reference(ref) + + eval_ref = NIF.eval(queue, ref) |> ok!() + assert_receive {^eval_ref, :ok}, 5_000 + end + + test "add/4 produces a tensor whose to_blob/2 matches Nx" do + queue = NIF.command_queue_new(:gpu) |> ok!() + + a_job = NIF.ones(queue, {4}, :float32, :gpu) |> ok!() + assert_receive {^a_job, {:ok, a}} + b_job = NIF.ones(queue, {4}, :float32, :gpu) |> ok!() + assert_receive {^b_job, {:ok, b}} + + sum_job = NIF.add(queue, a, b, :gpu) |> ok!() + assert_receive {^sum_job, {:ok, sum_ref}}, 5_000 + + eval_job = NIF.eval(queue, sum_ref) |> ok!() + assert_receive {^eval_job, :ok}, 5_000 + + blob_job = NIF.to_blob(queue, sum_ref) |> ok!() + assert_receive {^blob_job, {:ok, blob}}, 5_000 + + reference = Nx.broadcast(2.0, {4}) |> Nx.as_type(:f32) |> Nx.to_binary() + assert blob == reference + end + + test "rejects an invalid worker ref" do + bogus = make_ref() + assert {:error, _} = NIF.eval(bogus, make_ref()) + end + end + + describe "ordering and concurrency" do + test "jobs posted in order to one queue reply in FIFO order" do + queue = NIF.command_queue_new(:gpu) |> ok!() + + refs = + for _ <- 1..20 do + NIF.command_queue_synchronize(queue) |> ok!() + end + + # The worker is single-threaded; replies are emitted in the + # order jobs ran, which equals the order they were posted. + received = + for _ <- refs do + receive do + {ref, {:ok, :ok}} -> ref + after + 2_000 -> flunk("Timed out awaiting worker reply") + end + end + + assert received == refs + end + + test "two queues handle interleaved load without crashing" do + # Each task owns its own queue and uses ONLY that queue for both + # graph construction and eval — this is the critical invariant in + # MLX 0.31.2 (encoders are thread-local). Sharing a lazy tensor + # across queues would cross threads and crash. + run_on_own_queue = fn -> + q = NIF.command_queue_new(:gpu) |> ok!() + + a_job = NIF.ones(q, {64, 64}, :float32, :gpu) |> ok!() + assert_receive {^a_job, {:ok, a}}, 5_000 + b_job = NIF.ones(q, {64, 64}, :float32, :gpu) |> ok!() + assert_receive {^b_job, {:ok, b}}, 5_000 + + prod_job = NIF.add(q, a, b, :gpu) |> ok!() + assert_receive {^prod_job, {:ok, prod}}, 5_000 + + eval_ref = NIF.eval(q, prod) |> ok!() + + receive do + {^eval_ref, :ok} -> :ok + after + 5_000 -> flunk("Timed out awaiting eval") + end + end + + tasks = for _ <- 1..6, do: Task.async(run_on_own_queue) + + for task <- tasks do + assert :ok = Task.await(task, 10_000) + end + end + end + + describe "garbage collection" do + test "dropping the queue ref triggers thread join without crashing" do + _queue = NIF.command_queue_new(:gpu) |> ok!() + :erlang.garbage_collect(self()) + # If the destructor double-frees or fails to join, BEAM crashes + # before this assertion runs. + assert true + end + end +end diff --git a/test/emlx/command_queue_test.exs b/test/emlx/command_queue_test.exs new file mode 100644 index 0000000..b0740c9 --- /dev/null +++ b/test/emlx/command_queue_test.exs @@ -0,0 +1,254 @@ +defmodule EMLX.CommandQueueTest do + # Tests for the EMLX.CommandQueue struct + with_queue/2 wrapper and the + # resolve_worker/1 dispatch logic (process-dict binding, cross-device + # promotion, and optional Logger.warning). + # + # async: false because several tests temporarily mutate Application env or + # the process dictionary in ways that could interfere if run concurrently. + use ExUnit.Case, async: false + + import ExUnit.CaptureLog + + alias EMLX.CommandQueue + + # ── CommandQueue.new/1 ───────────────────────────────────────────────────── + + describe "CommandQueue.new/1" do + @tag :metal + test "returns a struct with a reference and the requested device for :gpu" do + assert {:ok, %CommandQueue{ref: ref, device: :gpu}} = CommandQueue.new(:gpu) + assert is_reference(ref) + end + + test "returns a struct with a reference and the requested device for :cpu" do + assert {:ok, %CommandQueue{ref: ref, device: :cpu}} = CommandQueue.new(:cpu) + assert is_reference(ref) + end + + test "returns {:error, _} for unknown device" do + assert {:error, _} = CommandQueue.new(:tpu) + end + + @tag :metal + test "two queues for the same device yield distinct refs" do + {:ok, a} = CommandQueue.new(:gpu) + {:ok, b} = CommandQueue.new(:gpu) + assert a.ref != b.ref + end + end + + # ── CommandQueue.new!/1 ──────────────────────────────────────────────────── + + describe "CommandQueue.new!/1" do + @tag :metal + test "returns the struct directly for a valid device" do + assert %CommandQueue{ref: ref, device: :gpu} = CommandQueue.new!(:gpu) + assert is_reference(ref) + end + + test "raises EMLX.NIFError for an unknown device" do + assert_raise EMLX.NIFError, fn -> CommandQueue.new!(:tpu) end + end + end + + # ── CommandQueue.synchronize/1 ───────────────────────────────────────────── + + describe "CommandQueue.synchronize/1" do + test "returns :ok on a freshly created CPU queue" do + q = CommandQueue.new!(:cpu) + assert :ok = CommandQueue.synchronize(q) + end + + @tag :metal + test "returns :ok on a freshly created GPU queue" do + q = CommandQueue.new!(:gpu) + assert :ok = CommandQueue.synchronize(q) + end + + @tag :metal + test "returns :ok after enqueuing work via with_queue" do + q = CommandQueue.new!(:gpu) + + CommandQueue.with_queue(q, fn -> + Nx.add( + Nx.tensor([1, 2, 3], backend: EMLX.Backend), + Nx.tensor([4, 5, 6], backend: EMLX.Backend) + ) + end) + + assert :ok = CommandQueue.synchronize(q) + end + end + + # ── CommandQueue.with_queue/2 ────────────────────────────────────────────── + + describe "CommandQueue.with_queue/2" do + @tag :metal + test "sets :emlx_command_queue to {ref, device} for the duration" do + {:ok, q} = CommandQueue.new(:gpu) + assert Process.get(:emlx_command_queue) == nil + + CommandQueue.with_queue(q, fn -> + assert Process.get(:emlx_command_queue) == {q.ref, :gpu} + end) + end + + @tag :metal + test "restores nil after the block" do + {:ok, q} = CommandQueue.new(:gpu) + CommandQueue.with_queue(q, fn -> :ok end) + assert Process.get(:emlx_command_queue) == nil + end + + @tag :metal + test "restores the previous value when nested" do + {:ok, outer} = CommandQueue.new(:gpu) + {:ok, inner} = CommandQueue.new(:cpu) + + CommandQueue.with_queue(outer, fn -> + assert Process.get(:emlx_command_queue) == {outer.ref, :gpu} + + CommandQueue.with_queue(inner, fn -> + assert Process.get(:emlx_command_queue) == {inner.ref, :cpu} + end) + + assert Process.get(:emlx_command_queue) == {outer.ref, :gpu} + end) + + assert Process.get(:emlx_command_queue) == nil + end + + @tag :metal + test "restores previous value even when the block raises" do + {:ok, q} = CommandQueue.new(:gpu) + + assert_raise RuntimeError, fn -> + CommandQueue.with_queue(q, fn -> raise "boom" end) + end + + assert Process.get(:emlx_command_queue) == nil + end + end + + # ── EMLX.resolve_worker/1 ────────────────────────────────────────────────── + + describe "resolve_worker/1 with no bound queue" do + @tag :metal + test "returns the application-default worker for :gpu" do + Process.delete(:emlx_command_queue) + {worker, device} = EMLX.resolve_worker(:gpu) + assert is_reference(worker) + assert device == :gpu + end + + test "returns the application-default worker for :cpu" do + Process.delete(:emlx_command_queue) + {worker, device} = EMLX.resolve_worker(:cpu) + assert is_reference(worker) + assert device == :cpu + end + end + + describe "resolve_worker/1 with a matching bound queue" do + @tag :metal + test "returns the bound queue worker and the requested device" do + {:ok, q} = CommandQueue.new(:gpu) + + CommandQueue.with_queue(q, fn -> + {worker, device} = EMLX.resolve_worker(:gpu) + assert worker == q.ref + assert device == :gpu + end) + end + + test "CPU queue is returned for a CPU tensor" do + {:ok, q} = CommandQueue.new(:cpu) + + CommandQueue.with_queue(q, fn -> + {worker, device} = EMLX.resolve_worker(:cpu) + assert worker == q.ref + assert device == :cpu + end) + end + end + + # ── Cross-device promotion ───────────────────────────────────────────────── + + describe "resolve_worker/1 cross-device promotion disabled (default)" do + setup do + Application.delete_env(:emlx, :cross_device_promotion) + Application.delete_env(:emlx, :warn_cross_device) + + on_exit(fn -> + Application.delete_env(:emlx, :cross_device_promotion) + Application.delete_env(:emlx, :warn_cross_device) + end) + end + + @tag :metal + test "returns the app-default CPU worker when GPU queue is bound but :cpu is requested" do + {:ok, gpu_q} = CommandQueue.new(:gpu) + app_cpu_worker = EMLX.Application.default_worker(:cpu) + + CommandQueue.with_queue(gpu_q, fn -> + {worker, device} = EMLX.resolve_worker(:cpu) + assert worker == app_cpu_worker + assert device == :cpu + end) + end + end + + describe "resolve_worker/1 cross-device promotion enabled" do + setup do + Application.put_env(:emlx, :cross_device_promotion, true) + Application.delete_env(:emlx, :warn_cross_device) + + on_exit(fn -> + Application.delete_env(:emlx, :cross_device_promotion) + Application.delete_env(:emlx, :warn_cross_device) + end) + end + + @tag :metal + test "routes a :cpu request through the bound GPU queue" do + {:ok, gpu_q} = CommandQueue.new(:gpu) + + CommandQueue.with_queue(gpu_q, fn -> + {worker, device} = EMLX.resolve_worker(:cpu) + assert worker == gpu_q.ref + assert device == :gpu + end) + end + + @tag :metal + test "does not emit a warning when warn_cross_device is false" do + {:ok, gpu_q} = CommandQueue.new(:gpu) + + log = + capture_log(fn -> + CommandQueue.with_queue(gpu_q, fn -> + EMLX.resolve_worker(:cpu) + end) + end) + + refute log =~ "[EMLX] cross-device promotion" + end + + @tag :metal + test "emits a warning when warn_cross_device is true" do + Application.put_env(:emlx, :warn_cross_device, true) + {:ok, gpu_q} = CommandQueue.new(:gpu) + + log = + capture_log(fn -> + CommandQueue.with_queue(gpu_q, fn -> + EMLX.resolve_worker(:cpu) + end) + end) + + assert log =~ "[EMLX] cross-device promotion" + assert log =~ "cpu" + assert log =~ "gpu" + end + end +end diff --git a/test/emlx/compiler_test.exs b/test/emlx/compiler_test.exs new file mode 100644 index 0000000..bd4ab49 --- /dev/null +++ b/test/emlx/compiler_test.exs @@ -0,0 +1,175 @@ +defmodule EMLX.CompilerTest do + # Tests for EMLX's Nx.Defn.Compiler callbacks: + # __jit__/5, __compile__/4, __partitions_options__/1, __to_backend__/1 + # + # async: false because tests inspect process-dict state set by CommandQueue + # and check Logger output — both are process-global. + use ExUnit.Case, async: false + + alias EMLX.CommandQueue + + describe "__to_backend__/1" do + test "defaults to EMLX.Backend on :gpu" do + assert {EMLX.Backend, device: :gpu} = EMLX.__to_backend__([]) + end + + test "honours an explicit :device opt" do + assert {EMLX.Backend, device: :cpu} = EMLX.__to_backend__(device: :cpu) + end + end + + describe "__partitions_options__/1" do + @tag :metal + test "returns a single-element list by default" do + partitions = EMLX.__partitions_options__([]) + assert length(partitions) == 1 + end + + @tag :metal + test "default partition has a :device key and a :command_queue struct" do + [opts] = EMLX.__partitions_options__([]) + assert Keyword.get(opts, :device) == :gpu + assert %CommandQueue{device: :gpu} = Keyword.get(opts, :command_queue) + end + + test "honours :device opt" do + [opts] = EMLX.__partitions_options__(device: :cpu) + assert Keyword.get(opts, :device) == :cpu + assert %CommandQueue{device: :cpu} = Keyword.get(opts, :command_queue) + end + + @tag :metal + test "max_concurrency: 2 returns two partitions" do + partitions = EMLX.__partitions_options__(max_concurrency: 2) + assert length(partitions) == 2 + end + + @tag :metal + test "partitions for max_concurrency: 2 have distinct queue refs" do + [opts1, opts2] = EMLX.__partitions_options__(max_concurrency: 2) + q1 = Keyword.get(opts1, :command_queue) + q2 = Keyword.get(opts2, :command_queue) + assert %CommandQueue{} = q1 + assert %CommandQueue{} = q2 + assert q1.ref != q2.ref + end + end + + describe "option validation" do + test "unknown opt in __compile__ raises ArgumentError" do + defmodule IdentFn do + import Nx.Defn + defn ident(x), do: x + end + + assert_raise ArgumentError, ~r/unknown_opt/, fn -> + Nx.Defn.compile(&IdentFn.ident/1, [Nx.template({}, :f32)], + compiler: EMLX, + unknown_opt: true + ) + end + end + + test "valid opts :device and :max_concurrency do not raise" do + defmodule IdentFn2 do + import Nx.Defn + defn ident(x), do: x + end + + compiled = + Nx.Defn.compile(&IdentFn2.ident/1, [Nx.template({}, :f32)], + compiler: EMLX, + device: :gpu, + max_concurrency: 1 + ) + + result = compiled.(Nx.tensor(1.0, backend: EMLX.Backend)) + assert_in_delta Nx.to_number(result), 1.0, 1.0e-6 + end + end + + describe "__jit__/5 queue wrapping" do + test "without :command_queue, the jit closure is returned as-is" do + defmodule SumJitFn do + import Nx.Defn + + defn add(a, b), do: Nx.add(a, b) + end + + jitted = Nx.Defn.jit(&SumJitFn.add/2, compiler: EMLX) + + result = + jitted.(Nx.tensor(1.0, backend: EMLX.Backend), Nx.tensor(2.0, backend: EMLX.Backend)) + + assert_in_delta Nx.to_number(result), 3.0, 1.0e-6 + end + + @tag :metal + test "with :command_queue, the jit closure installs the queue during execution" do + q = CommandQueue.new!(:gpu) + + defmodule AddJitFn do + import Nx.Defn + + defn add(a, b), do: Nx.add(a, b) + end + + jitted = Nx.Defn.jit(&AddJitFn.add/2, compiler: EMLX, command_queue: q) + + assert Process.get(:emlx_command_queue) == nil + + result = + jitted.(Nx.tensor(1.0, backend: EMLX.Backend), Nx.tensor(2.0, backend: EMLX.Backend)) + + assert Process.get(:emlx_command_queue) == nil + assert_in_delta Nx.to_number(result), 3.0, 1.0e-6 + end + end + + describe "__compile__/4 queue wrapping" do + test "without :command_queue, the compiled closure is returned as-is" do + # Compile a minimal defn function and verify the closure can be called. + defmodule SumFn do + import Nx.Defn + + defn add(a, b), do: Nx.add(a, b) + end + + compiled = + Nx.Defn.compile(&SumFn.add/2, [Nx.template({}, :f32), Nx.template({}, :f32)], + compiler: EMLX + ) + + result = + compiled.(Nx.tensor(1.0, backend: EMLX.Backend), Nx.tensor(2.0, backend: EMLX.Backend)) + + assert_in_delta Nx.to_number(result), 3.0, 1.0e-6 + end + + @tag :metal + test "with :command_queue, the compiled closure installs the queue during execution" do + q = CommandQueue.new!(:gpu) + + defmodule AddFn do + import Nx.Defn + + defn add(a, b), do: Nx.add(a, b) + end + + compiled = + Nx.Defn.compile(&AddFn.add/2, [Nx.template({}, :f32), Nx.template({}, :f32)], + compiler: EMLX, + command_queue: q + ) + + # Call the wrapped closure; the queue should be active during execution. + assert Process.get(:emlx_command_queue) == nil + + result = + compiled.(Nx.tensor(1.0, backend: EMLX.Backend), Nx.tensor(2.0, backend: EMLX.Backend)) + + assert Process.get(:emlx_command_queue) == nil + assert_in_delta Nx.to_number(result), 3.0, 1.0e-6 + end + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs index 95d2f6a..50a69bb 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -43,4 +43,10 @@ distributed_exclude = [:distributed] end -ExUnit.start(exclude: distributed_exclude) +gpu_exclude = + case EMLX.NIF.command_queue_new(:gpu) do + {:ok, _} -> [] + {:error, _} -> [:metal] + end + +ExUnit.start(exclude: distributed_exclude ++ gpu_exclude) diff --git a/validation/test/distilbert_test.exs b/validation/test/distilbert_test.exs index 95b66e1..3b476cd 100644 --- a/validation/test/distilbert_test.exs +++ b/validation/test/distilbert_test.exs @@ -14,7 +14,7 @@ defmodule EMLX.Validation.DistilBERTTest do mix test test/distilbert_test.exs --only validation """ - use EMLX.ValidationCase, async: false + use EMLX.ValidationCase, async: true @moduletag :validation @moduletag capture_log: true diff --git a/validation/test/llama_test.exs b/validation/test/llama_test.exs index a4c5709..b75b438 100644 --- a/validation/test/llama_test.exs +++ b/validation/test/llama_test.exs @@ -17,7 +17,7 @@ defmodule EMLX.Validation.LlamaTest do mix test test/llama_test.exs --only validation """ - use EMLX.ValidationCase, async: false + use EMLX.ValidationCase, async: true @moduletag :validation @moduletag capture_log: true diff --git a/validation/test/support/gpu_pool.ex b/validation/test/support/gpu_pool.ex deleted file mode 100644 index c4af2c8..0000000 --- a/validation/test/support/gpu_pool.ex +++ /dev/null @@ -1,58 +0,0 @@ -defmodule EMLX.GPUPool do - @moduledoc """ - Single-slot semaphore for GPU access during the validation test suite. - - MLX streams are thread-local; only one test should hold an active GPU - context at a time. Call `checkout/0` at the start of a test and - `checkin/0` when done (or use the `on_exit` hook in `ValidationCase`). - - Waiting callers block on `checkout/0` with no timeout — the test runner's - own module timeout is the effective deadline. - """ - - use GenServer - - def start_link(_opts \\ []) do - GenServer.start_link(__MODULE__, :queue.new(), name: __MODULE__) - end - - @doc "Acquire the GPU slot, blocking until available." - def checkout do - GenServer.call(__MODULE__, :checkout, :infinity) - end - - @doc "Release the GPU slot." - def checkin do - GenServer.cast(__MODULE__, :checkin) - end - - # -- GenServer callbacks --------------------------------------------------- - - # State: `{locked?, pending_callers_queue}` - - @impl true - def init(_queue) do - {:ok, {false, :queue.new()}} - end - - @impl true - def handle_call(:checkout, _from, {false, queue}) do - {:reply, :ok, {true, queue}} - end - - def handle_call(:checkout, from, {true, queue}) do - {:noreply, {true, :queue.in(from, queue)}} - end - - @impl true - def handle_cast(:checkin, {_, queue}) do - case :queue.out(queue) do - {{:value, next}, rest} -> - GenServer.reply(next, :ok) - {:noreply, {true, rest}} - - {:empty, _} -> - {:noreply, {false, :queue.new()}} - end - end -end diff --git a/validation/test/support/validation_case.ex b/validation/test/support/validation_case.ex index 20d4da9..c086761 100644 --- a/validation/test/support/validation_case.ex +++ b/validation/test/support/validation_case.ex @@ -8,8 +8,6 @@ defmodule EMLX.ValidationCase do end setup do - EMLX.GPUPool.checkout() - on_exit(fn -> EMLX.GPUPool.checkin() end) Nx.default_backend({EMLX.Backend, device: :gpu}) :ok end diff --git a/validation/test/test_helper.exs b/validation/test/test_helper.exs index 69c091e..869559e 100644 --- a/validation/test/test_helper.exs +++ b/validation/test/test_helper.exs @@ -1,3 +1 @@ ExUnit.start() - -{:ok, _} = EMLX.GPUPool.start_link() diff --git a/validation/test/vit_test.exs b/validation/test/vit_test.exs index ccea08c..aba0142 100644 --- a/validation/test/vit_test.exs +++ b/validation/test/vit_test.exs @@ -14,7 +14,7 @@ defmodule EMLX.Validation.ViTTest do mix test test/vit_test.exs --only validation """ - use EMLX.ValidationCase, async: false + use EMLX.ValidationCase, async: true @moduletag :validation @moduletag capture_log: true diff --git a/validation/test/whisper_test.exs b/validation/test/whisper_test.exs index 43882f5..68990b7 100644 --- a/validation/test/whisper_test.exs +++ b/validation/test/whisper_test.exs @@ -15,7 +15,7 @@ defmodule EMLX.Validation.WhisperTest do mix test test/whisper_test.exs --only validation """ - use EMLX.ValidationCase, async: false + use EMLX.ValidationCase, async: true @moduletag :validation @moduletag capture_log: true