Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib
XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB)
EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)

.DEFAULT_GOAL := $(EXLA_SO)

# Build flags
#
# Note that XLA requires c++17, Fine as well
Expand Down Expand Up @@ -86,7 +88,21 @@ else
LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib'
endif

$(EXLA_SO): $(EXLA_CACHE_SO)
# Optional test dylib: registers qr_cpu_custom_call_f32_exla_alias -> same
# handler as qr_cpu_custom_call_f32. Built only when MIX_ENV=test.
TEST_PLUGIN_CC = c_src/exla_test/custom_calls.cc
TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias.so

$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(dir $@)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)

EXLA_SO_DEPS = $(EXLA_CACHE_SO)
ifeq ($(MIX_ENV),test)
EXLA_SO_DEPS += $(TEST_PLUGIN_SO)
endif

$(EXLA_SO): $(EXLA_SO_DEPS)
@ mkdir -p $(PRIV_DIR)
@ mkdir -p $(PRIV_DIR)/xla_extension
@ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \
Expand Down
163 changes: 163 additions & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <dlfcn.h>

#include <cstring>
#include <fine.hpp>
#include <stdexcept>
Expand Down Expand Up @@ -29,6 +31,12 @@
#include "xla/tsl/platform/statusor.h"
#include "llvm/Support/ThreadPool.h"

#include <vector>

#include "xla/extension/custom_calls/eigh.h"
#include "xla/extension/custom_calls/qr.h"
#include "xla/ffi/ffi_api.h"

namespace exla {

using callback_bridge::Pending;
Expand Down Expand Up @@ -535,6 +543,19 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type,

FINE_NIF(load_pjrt_plugin, 0);

// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run.
fine::Ok<> load_dylib(ErlNifEnv *env, std::string path) {
void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
const char *err = dlerror();
throw std::invalid_argument(err ? err : "dlopen failed");
}
(void)handle;
return fine::Ok();
}

FINE_NIF(load_dylib, 0);

int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client) {
return client->client()->device_count();
}
Expand Down Expand Up @@ -715,4 +736,146 @@ FINE_NIF(write_to_pointer, 0);

} // namespace exla

// CPU custom-call FFI handlers (QR / Eigh) for integer operands and f32
// results. These stay in EXLA until the same symbols are provided from the
// shared elixir-nx/xla build; not test-only.
namespace {

namespace ffi = xla::ffi;

template <ffi::DataType kIntDtype>
ffi::Error QrCpuCustomCallIntegerOperandF32ResultsImpl(
ffi::Buffer<kIntDtype> operand, ffi::ResultBuffer<ffi::F32> q,
ffi::ResultBuffer<ffi::F32> r) {
using IntT = ffi::NativeType<kIntDtype>;
auto operand_dims = operand.dimensions();
auto q_dims = q->dimensions();
auto r_dims = r->dimensions();

uint64_t m = q_dims[q_dims.size() - 2];
uint64_t k = q_dims[q_dims.size() - 1];
uint64_t n = r_dims[r_dims.size() - 1];
uint64_t l = r_dims[r_dims.size() - 2];

bool complete = l == m;

uint64_t batch_items = 1;
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= static_cast<uint64_t>(*it);
}

uint64_t q_stride = m * k;
uint64_t r_stride = n * l;
uint64_t inner_stride = m * n;

std::vector<float> tmp(inner_stride);
const IntT *in_base = operand.typed_data();
float *q_base = reinterpret_cast<float *>(q->untyped_data());
float *r_base = reinterpret_cast<float *>(r->untyped_data());

for (uint64_t b = 0; b < batch_items; b++) {
const IntT *in = in_base + b * inner_stride;
for (uint64_t j = 0; j < inner_stride; j++) {
tmp[j] = static_cast<float>(in[j]);
}
single_matrix_qr_cpu_custom_call<float>(
q_base + b * q_stride, r_base + b * r_stride, tmp.data(), m, k, n,
complete);
}

return ffi::Error::Success();
}

#define EXLA_REGISTER_QR_INT_F32(DTYPE, NAME) \
static ffi::Error NAME##_impl(ffi::Buffer<ffi::DTYPE> operand, \
ffi::ResultBuffer<ffi::F32> q, \
ffi::ResultBuffer<ffi::F32> r) { \
return QrCpuCustomCallIntegerOperandF32ResultsImpl<ffi::DTYPE>(operand, \
q, r); \
} \
XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \
ffi::Ffi::Bind() \
.Arg<ffi::Buffer<ffi::DTYPE>>() \
.Ret<ffi::Buffer<ffi::F32>>() \
.Ret<ffi::Buffer<ffi::F32>>()); \
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME);

EXLA_REGISTER_QR_INT_F32(S8, qr_cpu_custom_call_s8)
EXLA_REGISTER_QR_INT_F32(S16, qr_cpu_custom_call_s16)
EXLA_REGISTER_QR_INT_F32(S32, qr_cpu_custom_call_s32)
EXLA_REGISTER_QR_INT_F32(S64, qr_cpu_custom_call_s64)
EXLA_REGISTER_QR_INT_F32(U8, qr_cpu_custom_call_u8)
EXLA_REGISTER_QR_INT_F32(U16, qr_cpu_custom_call_u16)
EXLA_REGISTER_QR_INT_F32(U32, qr_cpu_custom_call_u32)
EXLA_REGISTER_QR_INT_F32(U64, qr_cpu_custom_call_u64)

#undef EXLA_REGISTER_QR_INT_F32

template <ffi::DataType kIntDtype>
ffi::Error EighCpuCustomCallIntegerOperandF32ResultsImpl(
ffi::Buffer<kIntDtype> operand,
ffi::ResultBuffer<ffi::F32> eigenvalues,
ffi::ResultBuffer<ffi::F32> eigenvectors) {
using IntT = ffi::NativeType<kIntDtype>;
auto operand_dims = operand.dimensions();
auto eigenvalues_dims = eigenvalues->dimensions();
auto eigenvectors_dims = eigenvectors->dimensions();

uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

uint64_t batch_items = 1;
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= static_cast<uint64_t>(*it);
}

uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
uint64_t eigenvectors_stride = m * n;
uint64_t inner_stride = m * n;

std::vector<float> tmp(inner_stride);
const IntT *in_base = operand.typed_data();
float *eval_base = reinterpret_cast<float *>(eigenvalues->untyped_data());
float *evec_base = reinterpret_cast<float *>(eigenvectors->untyped_data());

for (uint64_t b = 0; b < batch_items; b++) {
const IntT *in = in_base + b * inner_stride;
for (uint64_t j = 0; j < inner_stride; j++) {
tmp[j] = static_cast<float>(in[j]);
}
single_matrix_eigh_cpu_custom_call<float>(
eval_base + b * eigenvalues_stride, evec_base + b * eigenvectors_stride,
tmp.data(), m, n);
}

return ffi::Error::Success();
}

#define EXLA_REGISTER_EIGH_INT_F32(DTYPE, NAME) \
static ffi::Error NAME##_impl(ffi::Buffer<ffi::DTYPE> operand, \
ffi::ResultBuffer<ffi::F32> eigenvalues, \
ffi::ResultBuffer<ffi::F32> eigenvectors) { \
return EighCpuCustomCallIntegerOperandF32ResultsImpl<ffi::DTYPE>( \
operand, eigenvalues, eigenvectors); \
} \
XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \
ffi::Ffi::Bind() \
.Arg<ffi::Buffer<ffi::DTYPE>>() \
.Ret<ffi::Buffer<ffi::F32>>() \
.Ret<ffi::Buffer<ffi::F32>>()); \
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME);

EXLA_REGISTER_EIGH_INT_F32(S8, eigh_cpu_custom_call_s8)
EXLA_REGISTER_EIGH_INT_F32(S16, eigh_cpu_custom_call_s16)
EXLA_REGISTER_EIGH_INT_F32(S32, eigh_cpu_custom_call_s32)
EXLA_REGISTER_EIGH_INT_F32(S64, eigh_cpu_custom_call_s64)
EXLA_REGISTER_EIGH_INT_F32(U8, eigh_cpu_custom_call_u8)
EXLA_REGISTER_EIGH_INT_F32(U16, eigh_cpu_custom_call_u16)
EXLA_REGISTER_EIGH_INT_F32(U32, eigh_cpu_custom_call_u32)
EXLA_REGISTER_EIGH_INT_F32(U64, eigh_cpu_custom_call_u64)

#undef EXLA_REGISTER_EIGH_INT_F32

} // namespace

FINE_INIT("Elixir.EXLA.NIF");
15 changes: 15 additions & 0 deletions exla/c_src/exla_test/custom_calls.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Test-only shared library: registers an alias FFI name that reuses the
// existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so.
#ifndef EXLA_PROD

#include "xla/ffi/api/api.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame);

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias",
"Host", qr_cpu_custom_call_f32);

#endif
7 changes: 7 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ defmodule EXLA do
* `:highest` - Slowest but most accurate. Performs computations in float32
or float64 as applicable

## Native custom calls (`EXLA.CustomCall`)

Some `Nx.block/4` tags can be lowered to XLA **custom calls** (StableHLO plus
a registered native handler). Implement the `EXLA.CustomCall` protocol for
your block tag struct; see `EXLA.CustomCall` for the `call/4` contract,
including returning `:skip` to fall back to the block's default Elixir callback.

## Clients

The `EXLA` library uses a client for compiling and executing code.
Expand Down
150 changes: 150 additions & 0 deletions exla/lib/exla/custom_call.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
defprotocol EXLA.CustomCall do
@moduledoc """
Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls**
(`stablehlo.custom_call` in MLIR), the same style as helpers on
`EXLA.MLIR.Value` such as `qr/3` and `eigh/3`.

Other blocks (for example gather-based `take` or FFT) are lowered inline in
`EXLA.Defn` and do not use this protocol.

## When `EXLA.Defn` calls it

During compilation with `compiler: EXLA`, when the builder is an MLIR
`EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is
passed here. `EXLA.Defn` invokes `call/4` once per block.

If `call/4` returns `:skip`, EXLA compiles the block's **default callback**
(the anonymous function body) instead of emitting a custom call.

## `call/4` arguments

* `struct` — the **tag** passed as the first argument to `Nx.block/4`
(your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`).

* `out` — the **output template** tuple passed to `Nx.block/4` (expression
metadata for shapes and types, not runtime tensors).

* `args` — list of **input templates**, in the same order as `inputs` in
`Nx.block/4`.

* `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate
host-only lowerings).

## `call/4` return value

* **`:skip`** — this implementation does not apply (unsupported type,
non-host platform, wrong arity, etc.). The default block implementation
is used instead.

* **`{:ok, %EXLA.CustomCall.Spec{}}`** — emit a StableHLO custom call; see
`EXLA.CustomCall.Spec` for `call_target_name`, optional `backend_config`,
and optional `operand_element_types` (operand converts when they differ
from the lowered inputs).

## Dispatch

The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags
live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can
add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen
whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback.

## Native handlers

Emitting a custom call in MLIR is only half of the story: the **target name**
must be registered with XLA on the relevant platform (typically via a native
library loaded into the process). That registration is **not** configured
through `config :exla, ...`; you load or link the native code by the same
means you would for any other NIF-backed extension.

## Example

defmodule MyApp.CustomQrTag do
defstruct []
end

defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do
def call(_tag, {%{type: {kind, size}}, _r_expr}, [_input], %{platform: :host})
when kind != :c and kind in [:f, :bf] and size in [16, 32, 64] do
{:ok, %EXLA.CustomCall.Spec{call_target_name: "my_custom_qr_target"}}
end

def call(_, _, _, _), do: :skip
end

Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with
`compiler: EXLA`.
"""

@fallback_to_any true

@doc """
Returns `:skip` or `{:ok, %EXLA.CustomCall.Spec{}}`.
"""
def call(struct, out, args, client)
end

# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live
# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the
# protocol, applications and libraries can define their own
# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that
# implementation instead of this fallback when the block tag matches (you can
# also target a built-in struct such as `Nx.Block...` from your app if needed).
#
defimpl EXLA.CustomCall, for: Any do
@moduledoc false

alias EXLA.CustomCall.Spec

def call(
%Nx.Block.LinAlg.QR{},
{%{type: q_type}, _r_expr},
[%{type: in_type} | _],
%{platform: :host}
)
when elem(q_type, 0) != :c and elem(in_type, 0) != :c do
qr_cpu_custom_call(in_type)
end

# Native target names depend only on the input dtype; output templates may use
# different element types (e.g. promotion) and must not change the call target.
def call(%Nx.Block.LinAlg.Eigh{}, _out, [%{type: in_type} | _], %{platform: :host})
when elem(in_type, 0) != :c do
eigh_cpu_custom_call(in_type)
end

def call(_, _, _, _), do: :skip

defp qr_cpu_custom_call(in_type) do
case in_type do
{:f, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f32"}}
{:f, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f64"}}
{:f, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f16"}}
{:bf, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_bf16"}}
{:s, 8} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s8"}}
{:s, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s16"}}
{:s, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s32"}}
{:s, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s64"}}
{:u, 8} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u8"}}
{:u, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u16"}}
{:u, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u32"}}
{:u, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u64"}}
_ -> :skip
end
end

defp eigh_cpu_custom_call(in_type) do
case in_type do
{:f, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f32"}}
{:f, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f64"}}
{:s, 8} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s8"}}
{:s, 16} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s16"}}
{:s, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s32"}}
{:s, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s64"}}
{:u, 8} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u8"}}
{:u, 16} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u16"}}
{:u, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u32"}}
{:u, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u64"}}
_ -> :skip
end
end
end
Loading
Loading