diff --git a/exla/Makefile b/exla/Makefile index b875ef59f7..0183a3ea0e 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -21,6 +21,8 @@ EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB) EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO) +.DEFAULT_GOAL := $(EXLA_SO) + # Build flags # # Note that XLA requires c++17, Fine as well @@ -86,7 +88,21 @@ else LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' endif -$(EXLA_SO): $(EXLA_CACHE_SO) +# Optional test dylib: registers qr_cpu_custom_call_f32_exla_alias -> same +# handler as qr_cpu_custom_call_f32. Built only when MIX_ENV=test. +TEST_PLUGIN_CC = c_src/exla_test/custom_calls.cc +TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias.so + +$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) + @ mkdir -p $(dir $@) + $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) + +EXLA_SO_DEPS = $(EXLA_CACHE_SO) +ifeq ($(MIX_ENV),test) +EXLA_SO_DEPS += $(TEST_PLUGIN_SO) +endif + +$(EXLA_SO): $(EXLA_SO_DEPS) @ mkdir -p $(PRIV_DIR) @ mkdir -p $(PRIV_DIR)/xla_extension @ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \ diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index c4e9085833..f6f3782d02 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -29,6 +31,12 @@ #include "xla/tsl/platform/statusor.h" #include "llvm/Support/ThreadPool.h" +#include + +#include "xla/extension/custom_calls/eigh.h" +#include "xla/extension/custom_calls/qr.h" +#include "xla/ffi/ffi_api.h" + namespace exla { using callback_bridge::Pending; @@ -535,6 +543,19 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type, FINE_NIF(load_pjrt_plugin, 0); +// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run. +fine::Ok<> load_dylib(ErlNifEnv *env, std::string path) { + void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + const char *err = dlerror(); + throw std::invalid_argument(err ? err : "dlopen failed"); + } + (void)handle; + return fine::Ok(); +} + +FINE_NIF(load_dylib, 0); + int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr client) { return client->client()->device_count(); } @@ -715,4 +736,146 @@ FINE_NIF(write_to_pointer, 0); } // namespace exla +// CPU custom-call FFI handlers (QR / Eigh) for integer operands and f32 +// results. These stay in EXLA until the same symbols are provided from the +// shared elixir-nx/xla build; not test-only. +namespace { + +namespace ffi = xla::ffi; + +template +ffi::Error QrCpuCustomCallIntegerOperandF32ResultsImpl( + ffi::Buffer operand, ffi::ResultBuffer q, + ffi::ResultBuffer r) { + using IntT = ffi::NativeType; + auto operand_dims = operand.dimensions(); + auto q_dims = q->dimensions(); + auto r_dims = r->dimensions(); + + uint64_t m = q_dims[q_dims.size() - 2]; + uint64_t k = q_dims[q_dims.size() - 1]; + uint64_t n = r_dims[r_dims.size() - 1]; + uint64_t l = r_dims[r_dims.size() - 2]; + + bool complete = l == m; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= static_cast(*it); + } + + uint64_t q_stride = m * k; + uint64_t r_stride = n * l; + uint64_t inner_stride = m * n; + + std::vector tmp(inner_stride); + const IntT *in_base = operand.typed_data(); + float *q_base = reinterpret_cast(q->untyped_data()); + float *r_base = reinterpret_cast(r->untyped_data()); + + for (uint64_t b = 0; b < batch_items; b++) { + const IntT *in = in_base + b * inner_stride; + for (uint64_t j = 0; j < inner_stride; j++) { + tmp[j] = static_cast(in[j]); + } + single_matrix_qr_cpu_custom_call( + q_base + b * q_stride, r_base + b * r_stride, tmp.data(), m, k, n, + complete); + } + + return ffi::Error::Success(); +} + +#define EXLA_REGISTER_QR_INT_F32(DTYPE, NAME) \ + static ffi::Error NAME##_impl(ffi::Buffer operand, \ + ffi::ResultBuffer q, \ + ffi::ResultBuffer r) { \ + return QrCpuCustomCallIntegerOperandF32ResultsImpl(operand, \ + q, r); \ + } \ + XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \ + ffi::Ffi::Bind() \ + .Arg>() \ + .Ret>() \ + .Ret>()); \ + XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME); + +EXLA_REGISTER_QR_INT_F32(S8, qr_cpu_custom_call_s8) +EXLA_REGISTER_QR_INT_F32(S16, qr_cpu_custom_call_s16) +EXLA_REGISTER_QR_INT_F32(S32, qr_cpu_custom_call_s32) +EXLA_REGISTER_QR_INT_F32(S64, qr_cpu_custom_call_s64) +EXLA_REGISTER_QR_INT_F32(U8, qr_cpu_custom_call_u8) +EXLA_REGISTER_QR_INT_F32(U16, qr_cpu_custom_call_u16) +EXLA_REGISTER_QR_INT_F32(U32, qr_cpu_custom_call_u32) +EXLA_REGISTER_QR_INT_F32(U64, qr_cpu_custom_call_u64) + +#undef EXLA_REGISTER_QR_INT_F32 + +template +ffi::Error EighCpuCustomCallIntegerOperandF32ResultsImpl( + ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + using IntT = ffi::NativeType; + auto operand_dims = operand.dimensions(); + auto eigenvalues_dims = eigenvalues->dimensions(); + auto eigenvectors_dims = eigenvectors->dimensions(); + + uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= static_cast(*it); + } + + uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; + uint64_t eigenvectors_stride = m * n; + uint64_t inner_stride = m * n; + + std::vector tmp(inner_stride); + const IntT *in_base = operand.typed_data(); + float *eval_base = reinterpret_cast(eigenvalues->untyped_data()); + float *evec_base = reinterpret_cast(eigenvectors->untyped_data()); + + for (uint64_t b = 0; b < batch_items; b++) { + const IntT *in = in_base + b * inner_stride; + for (uint64_t j = 0; j < inner_stride; j++) { + tmp[j] = static_cast(in[j]); + } + single_matrix_eigh_cpu_custom_call( + eval_base + b * eigenvalues_stride, evec_base + b * eigenvectors_stride, + tmp.data(), m, n); + } + + return ffi::Error::Success(); +} + +#define EXLA_REGISTER_EIGH_INT_F32(DTYPE, NAME) \ + static ffi::Error NAME##_impl(ffi::Buffer operand, \ + ffi::ResultBuffer eigenvalues, \ + ffi::ResultBuffer eigenvectors) { \ + return EighCpuCustomCallIntegerOperandF32ResultsImpl( \ + operand, eigenvalues, eigenvectors); \ + } \ + XLA_FFI_DEFINE_HANDLER_SYMBOL(NAME, NAME##_impl, \ + ffi::Ffi::Bind() \ + .Arg>() \ + .Ret>() \ + .Ret>()); \ + XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), #NAME, "Host", NAME); + +EXLA_REGISTER_EIGH_INT_F32(S8, eigh_cpu_custom_call_s8) +EXLA_REGISTER_EIGH_INT_F32(S16, eigh_cpu_custom_call_s16) +EXLA_REGISTER_EIGH_INT_F32(S32, eigh_cpu_custom_call_s32) +EXLA_REGISTER_EIGH_INT_F32(S64, eigh_cpu_custom_call_s64) +EXLA_REGISTER_EIGH_INT_F32(U8, eigh_cpu_custom_call_u8) +EXLA_REGISTER_EIGH_INT_F32(U16, eigh_cpu_custom_call_u16) +EXLA_REGISTER_EIGH_INT_F32(U32, eigh_cpu_custom_call_u32) +EXLA_REGISTER_EIGH_INT_F32(U64, eigh_cpu_custom_call_u64) + +#undef EXLA_REGISTER_EIGH_INT_F32 + +} // namespace + FINE_INIT("Elixir.EXLA.NIF"); diff --git a/exla/c_src/exla_test/custom_calls.cc b/exla/c_src/exla_test/custom_calls.cc new file mode 100644 index 0000000000..e54b095bf1 --- /dev/null +++ b/exla/c_src/exla_test/custom_calls.cc @@ -0,0 +1,15 @@ +// Test-only shared library: registers an alias FFI name that reuses the +// existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so. +#ifndef EXLA_PROD + +#include "xla/ffi/api/api.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; + +extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias", + "Host", qr_cpu_custom_call_f32); + +#endif diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index b558b0fd83..9be5dc13db 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -75,6 +75,13 @@ defmodule EXLA do * `:highest` - Slowest but most accurate. Performs computations in float32 or float64 as applicable + ## Native custom calls (`EXLA.CustomCall`) + + Some `Nx.block/4` tags can be lowered to XLA **custom calls** (StableHLO plus + a registered native handler). Implement the `EXLA.CustomCall` protocol for + your block tag struct; see `EXLA.CustomCall` for the `call/4` contract, + including returning `:skip` to fall back to the block's default Elixir callback. + ## Clients The `EXLA` library uses a client for compiling and executing code. diff --git a/exla/lib/exla/custom_call.ex b/exla/lib/exla/custom_call.ex new file mode 100644 index 0000000000..6ffdb5863c --- /dev/null +++ b/exla/lib/exla/custom_call.ex @@ -0,0 +1,150 @@ +defprotocol EXLA.CustomCall do + @moduledoc """ + Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls** + (`stablehlo.custom_call` in MLIR), the same style as helpers on + `EXLA.MLIR.Value` such as `qr/3` and `eigh/3`. + + Other blocks (for example gather-based `take` or FFT) are lowered inline in + `EXLA.Defn` and do not use this protocol. + + ## When `EXLA.Defn` calls it + + During compilation with `compiler: EXLA`, when the builder is an MLIR + `EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is + passed here. `EXLA.Defn` invokes `call/4` once per block. + + If `call/4` returns `:skip`, EXLA compiles the block's **default callback** + (the anonymous function body) instead of emitting a custom call. + + ## `call/4` arguments + + * `struct` — the **tag** passed as the first argument to `Nx.block/4` + (your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`). + + * `out` — the **output template** tuple passed to `Nx.block/4` (expression + metadata for shapes and types, not runtime tensors). + + * `args` — list of **input templates**, in the same order as `inputs` in + `Nx.block/4`. + + * `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate + host-only lowerings). + + ## `call/4` return value + + * **`:skip`** — this implementation does not apply (unsupported type, + non-host platform, wrong arity, etc.). The default block implementation + is used instead. + + * **`{:ok, %EXLA.CustomCall.Spec{}}`** — emit a StableHLO custom call; see + `EXLA.CustomCall.Spec` for `call_target_name`, optional `backend_config`, + and optional `operand_element_types` (operand converts when they differ + from the lowered inputs). + + ## Dispatch + + The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags + live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can + add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen + whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback. + + ## Native handlers + + Emitting a custom call in MLIR is only half of the story: the **target name** + must be registered with XLA on the relevant platform (typically via a native + library loaded into the process). That registration is **not** configured + through `config :exla, ...`; you load or link the native code by the same + means you would for any other NIF-backed extension. + + ## Example + + defmodule MyApp.CustomQrTag do + defstruct [] + end + + defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do + def call(_tag, {%{type: {kind, size}}, _r_expr}, [_input], %{platform: :host}) + when kind != :c and kind in [:f, :bf] and size in [16, 32, 64] do + {:ok, %EXLA.CustomCall.Spec{call_target_name: "my_custom_qr_target"}} + end + + def call(_, _, _, _), do: :skip + end + + Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with + `compiler: EXLA`. + """ + + @fallback_to_any true + + @doc """ + Returns `:skip` or `{:ok, %EXLA.CustomCall.Spec{}}`. + """ + def call(struct, out, args, client) +end + +# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live +# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the +# protocol, applications and libraries can define their own +# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that +# implementation instead of this fallback when the block tag matches (you can +# also target a built-in struct such as `Nx.Block...` from your app if needed). +# +defimpl EXLA.CustomCall, for: Any do + @moduledoc false + + alias EXLA.CustomCall.Spec + + def call( + %Nx.Block.LinAlg.QR{}, + {%{type: q_type}, _r_expr}, + [%{type: in_type} | _], + %{platform: :host} + ) + when elem(q_type, 0) != :c and elem(in_type, 0) != :c do + qr_cpu_custom_call(in_type) + end + + # Native target names depend only on the input dtype; output templates may use + # different element types (e.g. promotion) and must not change the call target. + def call(%Nx.Block.LinAlg.Eigh{}, _out, [%{type: in_type} | _], %{platform: :host}) + when elem(in_type, 0) != :c do + eigh_cpu_custom_call(in_type) + end + + def call(_, _, _, _), do: :skip + + defp qr_cpu_custom_call(in_type) do + case in_type do + {:f, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f32"}} + {:f, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f64"}} + {:f, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f16"}} + {:bf, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_bf16"}} + {:s, 8} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s8"}} + {:s, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s16"}} + {:s, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s32"}} + {:s, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_s64"}} + {:u, 8} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u8"}} + {:u, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u16"}} + {:u, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u32"}} + {:u, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_u64"}} + _ -> :skip + end + end + + defp eigh_cpu_custom_call(in_type) do + case in_type do + {:f, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f32"}} + {:f, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f64"}} + {:s, 8} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s8"}} + {:s, 16} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s16"}} + {:s, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s32"}} + {:s, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_s64"}} + {:u, 8} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u8"}} + {:u, 16} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u16"}} + {:u, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u32"}} + {:u, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_u64"}} + _ -> :skip + end + end +end diff --git a/exla/lib/exla/custom_call/spec.ex b/exla/lib/exla/custom_call/spec.ex new file mode 100644 index 0000000000..4cf96dbb31 --- /dev/null +++ b/exla/lib/exla/custom_call/spec.ex @@ -0,0 +1,33 @@ +defmodule EXLA.CustomCall.Spec do + @moduledoc """ + Result of `EXLA.CustomCall.call/4` when lowering a tagged `Nx.block/4` to + `stablehlo.custom_call`. + + * **`call_target_name`** — XLA FFI handler name (`call_target_name` on the op). + + * **`backend_config`** — Optional StableHLO dictionary attribute (`nil` omits it). + Same constraints as `EXLA.MLIR.Value.custom_call/4`. + + * **`operand_element_types`** — How operand SSA values are presented to the handler: + + * **`:infer`** (default) — use each lowered operand’s element type as produced + from the block inputs. No extra converts. + + * **`[Nx.Type.t(), ...]`** — one type per block input, same order and length as + `Nx.block/4` inputs. Before building the custom call, each operand is + converted (StableHLO `convert`) when its element type differs from the + requested type; shapes are unchanged. Use this when the native kernel’s + FFI signature expects dtypes that may differ from the traced expression + types (for example after promotion rules). + """ + + @enforce_keys [:call_target_name] + + defstruct [:call_target_name, backend_config: nil, operand_element_types: :infer] + + @type t :: %__MODULE__{ + call_target_name: String.t(), + backend_config: map() | nil, + operand_element_types: :infer | [Nx.Type.t()] + } +end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 602b4972c6..0aca19cf64 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -9,6 +9,7 @@ defmodule EXLA.Defn do alias EXLA.Typespec alias EXLA.MLIR.Value alias EXLA.MLIR.Function + alias EXLA.CustomCall.Spec, as: CustomCallSpec @doc false def __partitions_options__(options) do @@ -600,78 +601,8 @@ defmodule EXLA.Defn do {fun_computation(args, expr, type, state), cache} end - defp cached_recur_operator( - :block, - %T{ - data: %Expr{ - args: [ - %Nx.Block.LinAlg.QR{}, - [tensor], - {%{type: {type_kind, _}} = q_expr, r_expr}, - _callback - ] - } - }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, - cache - ) - when type_kind != :c do - # We match only on platform: :host for MLIR, as we want to support - # QR-on-cpu as a custom call only in this case - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - tensor = - if op_type(tensor) != q_expr.type do - to_type(tensor, q_expr.type) - else - tensor - end - - {q, r} = Value.qr(tensor, expr_to_typespec(q_expr), expr_to_typespec(r_expr)) - {[q, r], cache} - end - - defp cached_recur_operator( - :block, - %T{ - data: %Expr{ - args: [ - %Nx.Block.LinAlg.Eigh{}, - [tensor], - {%{type: {evec_type_kind, _}} = eigenvals_expr, - %{type: {eval_type_kind, _}} = eigenvecs_expr}, - _callback - ] - } - }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, - cache - ) - when evec_type_kind != :c and eval_type_kind != :c do - # We match only on platform: :host for MLIR, as we want to support - # eigh-on-cpu as a custom call only in this case - {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - # convert to float and ensure that we're either using f32 or f64, because Eigen - # only supports f32 and f64 easily. - out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) - - tensor = - if op_type(tensor) != out_type do - to_type(tensor, out_type) - else - tensor - end - - {eigenvals, eigenvecs} = - Value.eigh( - tensor, - expr_to_typespec(%{eigenvals_expr | type: out_type}), - expr_to_typespec(%{eigenvecs_expr | type: out_type}) - ) - - {[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache} - end + # StableHLO-style lowering (gather, top_k, fft): not the C custom_call path; + # see `EXLA.CustomCall` for blocks that delegate to native CPU kernels. defp cached_recur_operator( :block, @@ -736,15 +667,9 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [lengths: fft2_struct.lengths, axes: fft2_struct.axes] + opts = if eps = fft2_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = fft2_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - {fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr, state), cache} + {fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr), cache} end defp cached_recur_operator( @@ -756,15 +681,9 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [lengths: ifft2_struct.lengths, axes: ifft2_struct.axes] + opts = if eps = ifft2_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = ifft2_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache} + {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr), cache} end defp cached_recur_operator( @@ -776,19 +695,11 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [length: rfft_struct.length, axis: rfft_struct.axis] + opts = if eps = rfft_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = rfft_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - # expr.type is complex; input tensor is real input_type = Nx.Type.to_real(expr.type) - {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr, state), - cache} + {fft(&Value.fft(&1, :rfft, &2, &3), input_type, expr.type, [tensor, opts], expr), cache} end defp cached_recur_operator( @@ -800,16 +711,8 @@ defmodule EXLA.Defn do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() opts = [length: irfft_struct.length, axis: irfft_struct.axis] + opts = if eps = irfft_struct.eps, do: Keyword.put(opts, :eps, eps), else: opts - opts = - if eps = irfft_struct.eps do - Keyword.put(opts, :eps, eps) - else - opts - end - - # expr.type is real; input tensor is complex. - # pad_n = div(n,2)+1 (the expected input size), while fft_n = n (the output length). n = irfft_struct.length input_type = Nx.Type.to_complex(expr.type) @@ -819,44 +722,53 @@ defmodule EXLA.Defn do expr.type, div(n, 2) + 1, [tensor, opts], - expr, - state + expr ), cache} end - defp cached_recur_operator(:block, %T{data: %Expr{args: args}}, state, cache) do - [struct, in_args, expr, _callback] = args - %module{} = struct - + # C-backed custom_call blocks (QR, Eigh, …): `EXLA.CustomCall`; else compile default callback. + defp cached_recur_operator( + :block, + %T{data: %Expr{args: [struct, in_args, out, _callback]}}, + %{client: client, builder: %Function{}} = state, + cache + ) do {call_args, cache} = Enum.map_reduce(in_args, cache, &recur_operator(&1, state, &2)) - key = computation_key(module, [struct | call_args]) - {call_body, cache} = - case cache do - %{^key => computation} -> - {computation, cache} + case EXLA.CustomCall.call(struct, out, in_args, client) do + :skip -> + default_block_implementation(struct, call_args, out, state, cache) - %{} -> - {computation, cache} = - block_computation( - block_subfunction_description(struct), - call_args, - expr, - state, - cache - ) + {:ok, %CustomCallSpec{} = spec} -> + backend_config = + case spec.backend_config do + nil -> + nil - {computation, Map.put(cache, key, computation)} - end + %{} = map -> + map - if token = get_token(cache) do - typespecs = [Typespec.token() | container_to_typespecs(expr)] - [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) - {wrap_tuple_result(result, expr), update_token(cache, token)} - else - typespecs = container_to_typespecs(expr) - result = Value.call(state.builder, call_args, call_body, typespecs) - {wrap_tuple_result(result, expr), cache} + other -> + raise ArgumentError, + "EXLA.CustomCall.Spec backend_config must be map() | nil, got: #{inspect(other)}" + end + + call_args = cast_custom_call_operands(call_args, spec.operand_element_types) + + out_typespecs = + [out] + |> Composite.flatten_list() + |> Enum.map(&expr_to_typespec/1) + + lowered = + Value.custom_call(call_args, out_typespecs, spec.call_target_name, backend_config) + |> wrap_tuple_result(out) + + {lowered, cache} + + other -> + raise ArgumentError, + "EXLA.CustomCall.call/4 must return :skip or {:ok, %EXLA.CustomCall.Spec{}}, got: #{inspect(other)}" end end @@ -998,6 +910,65 @@ defmodule EXLA.Defn do {to_operator(op, args, expr, state), cache} end + defp cast_custom_call_operands(call_args, :infer), do: call_args + + defp cast_custom_call_operands(call_args, types) when is_list(types) do + n = length(call_args) + + if length(types) != n do + raise ArgumentError, + "EXLA.CustomCall.Spec operand_element_types must be a list of length #{n} (one per block input), got length #{length(types)}" + end + + Enum.zip_with(call_args, types, fn arg, desired -> + ts = Value.get_typespec(arg) + + if ts.type == desired do + arg + else + Value.convert(arg, Typespec.tensor(desired, ts.shape)) + end + end) + end + + defp cast_custom_call_operands(_call_args, other) do + raise ArgumentError, + "EXLA.CustomCall.Spec operand_element_types must be :infer or a list of Nx types, got: #{inspect(other)}" + end + + defp default_block_implementation(struct, call_args, expr, state, cache) do + %module{} = struct + key = computation_key(module, [struct | call_args]) + + {call_body, cache} = + case cache do + %{^key => computation} -> + {computation, cache} + + %{} -> + {computation, cache} = + block_computation( + block_subfunction_description(struct), + call_args, + expr, + state, + cache + ) + + {computation, Map.put(cache, key, computation)} + end + + if token = get_token(cache) do + typespecs = [Typespec.token() | container_to_typespecs(expr)] + [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) + {wrap_tuple_result(result, expr), update_token(cache, token)} + else + typespecs = container_to_typespecs(expr) + result = Value.call(state.builder, call_args, call_body, typespecs) + {wrap_tuple_result(result, expr), cache} + end + end + ## to_operator creation defp to_operator(:constant, [constant], ans, state) do @@ -1289,11 +1260,11 @@ defmodule EXLA.Defn do apply(Value, op, [to_type(arg, type), expr_to_typespec(ans)]) end - defp to_operator(:fft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out, state) + defp to_operator(:fft, [%Value{} | _] = args, out, _state), + do: fft(&Value.fft(&1, :fft, &2, &3), out.type, out.type, args, out) - defp to_operator(:ifft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out, state) + defp to_operator(:ifft, [%Value{} | _] = args, out, _state), + do: fft(&Value.fft(&1, :ifft, &2, &3), out.type, out.type, args, out) defp to_operator(:is_nan, [%Value{} = arg], out, _state), do: Value.is_nan(arg, expr_to_typespec(out)) @@ -1618,7 +1589,8 @@ defmodule EXLA.Defn do EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end - defp fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans, state) do + @doc false + def fft(exla_op, input_type, output_type, pad_n \\ nil, [%Value{} = tensor, opts], ans) do fft_n = opts[:length] pad_n = pad_n || fft_n axis = opts[:axis] @@ -1627,7 +1599,7 @@ defmodule EXLA.Defn do shape = op_shape(tensor) m = elem(shape, axis) - tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type, state) + tensor = fft_pad_or_slice(tensor, m, pad_n, axis, shape, input_type) last_axis = tuple_size(shape) - 1 @@ -1662,7 +1634,8 @@ defmodule EXLA.Defn do end end - defp fft2(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do + @doc false + def fft2(exla_op, [%Value{} = tensor, opts], %{type: type} = ans) do [l1, l2] = lengths = opts[:lengths] [ax1, ax2] = axes = opts[:axes] output_type = Nx.Type.to_complex(type) @@ -1672,8 +1645,8 @@ defmodule EXLA.Defn do m1 = elem(shape, ax1) m2 = elem(shape, ax2) - tensor = fft_pad_or_slice(tensor, m1, l1, ax1, shape, output_type, state) - tensor = fft_pad_or_slice(tensor, m2, l2, ax2, op_shape(tensor), output_type, state) + tensor = fft_pad_or_slice(tensor, m1, l1, ax1, shape, output_type) + tensor = fft_pad_or_slice(tensor, m2, l2, ax2, op_shape(tensor), output_type) last_axis = tuple_size(shape) - 1 penultimate_axis = last_axis - 1 @@ -1701,7 +1674,7 @@ defmodule EXLA.Defn do end end - defp fft_pad_or_slice(tensor, m, n, axis, shape, output_type, state) do + defp fft_pad_or_slice(%Value{function: builder} = tensor, m, n, axis, shape, output_type) do cond do m == n -> tensor @@ -1726,7 +1699,7 @@ defmodule EXLA.Defn do zero_value = if Nx.Type.complex?(output_type), do: Complex.new(0), else: 0 zero = - Value.constant(state.builder, [zero_value], Typespec.tensor(output_type, {})) + Value.constant(builder, [zero_value], Typespec.tensor(output_type, {})) padding_config = {0, 0, 0} diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 9d028ff6dd..e5a6d6e1b5 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -719,70 +719,29 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.return", values, []) end - def eigh(%Value{function: func} = value, eigenvals_typespec, eigenvecs_typespec) do - %{type: op_type} = get_typespec(value) - - operands = [value] - result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec]) - - call_target_name = - case op_type do - {:f, 32} -> - "eigh_cpu_custom_call_f32" - - {:f, 64} -> - "eigh_cpu_custom_call_f64" - - type -> - # Due to matching on EXLA.Defn, we are sure that the device here is always :host - raise "Eigh decomposition not supported on :host device for type #{inspect(type)}" - end + @doc false + def custom_call( + [%Value{function: func} | _] = operands, + typespecs, + call_target_name, + backend_config \\ nil + ) + when is_binary(call_target_name) and is_list(typespecs) do + result_types = typespecs_to_mlir_types(typespecs) attributes = [ call_target_name: attr_string(call_target_name), api_version: attr_i32(4) ] - [eigenvals, eigenvecs] = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - - {eigenvals, eigenvecs} - end - - def qr(%Value{function: func} = value, q_typespec, r_typespec) do - %{type: op_type} = get_typespec(value) - - operands = [value] - result_types = typespecs_to_mlir_types([q_typespec, r_typespec]) - - call_target_name = - case op_type do - {:f, 32} -> - "qr_cpu_custom_call_f32" - - {:f, 64} -> - "qr_cpu_custom_call_f64" - - {:f, 16} -> - "qr_cpu_custom_call_f16" - - {:bf, 16} -> - "qr_cpu_custom_call_bf16" - - type -> - # Due to matching on EXLA.Defn, we are sure that the device here is always :host - raise "QR decomposition not supported on :host device for type #{inspect(type)}" + attributes = + if is_map(backend_config) do + Keyword.put(attributes, :backend_config, backend_config_to_attr(backend_config)) + else + attributes end - attributes = [ - call_target_name: attr_string(call_target_name), - api_version: attr_i32(4) - ] - - [q, r] = - op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) - - {q, r} + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) end def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do @@ -1088,6 +1047,44 @@ defmodule EXLA.MLIR.Value do "{#{content}}" end + defp backend_config_to_attr(map) when is_map(map) do + map + |> Enum.map(fn {k, v} -> {attr_dict_key(k), backend_config_value_to_attr(v)} end) + |> attr_dict() + end + + defp backend_config_value_to_attr(v) when is_boolean(v), do: attr_boolean(v) + defp backend_config_value_to_attr(v) when is_integer(v), do: attr_i64(v) + defp backend_config_value_to_attr(v) when is_float(v), do: "#{v} : f64" + defp backend_config_value_to_attr(v) when is_binary(v), do: attr_string(v) + + defp backend_config_value_to_attr(v) when is_list(v) do + "[" <> Enum.map_join(v, ", ", &backend_config_value_to_attr/1) <> "]" + end + + defp backend_config_value_to_attr(v) when is_map(v), do: backend_config_to_attr(v) + + defp backend_config_value_to_attr(v) do + raise ArgumentError, + "custom_call backend_config value is not encodable to MLIR DictionaryAttr: #{inspect(v)}" + end + + defp attr_dict_key(key) when is_atom(key), do: Atom.to_string(key) + + defp attr_dict_key(key) when is_binary(key) do + if Regex.match?(~r/^[A-Za-z_][A-Za-z0-9_]*$/, key) do + key + else + raise ArgumentError, + "custom_call backend_config key must match [A-Za-z_][A-Za-z0-9_]*, got: #{inspect(key)}" + end + end + + defp attr_dict_key(key) do + raise ArgumentError, + "custom_call backend_config key must be an atom or string, got: #{inspect(key)}" + end + defp join_list(list) do "[" <> Enum.join(list, ", ") <> "]" end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 2a0a99f1ef..9fab99366c 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -79,6 +79,7 @@ defmodule EXLA.NIF do def get_tpu_client(), do: err!() def get_c_api_client(_device_type), do: err!() def load_pjrt_plugin(_device_type, _library_path), do: err!() + def load_dylib(_path), do: err!() def get_device_count(_client), do: err!() def get_supported_platforms, do: err!() def run_cpu(_executable, _arguments, _device_id, _callback_server_pid), do: err!() diff --git a/exla/test/exla/custom_call_alias_test.exs b/exla/test/exla/custom_call_alias_test.exs new file mode 100644 index 0000000000..5eab675e76 --- /dev/null +++ b/exla/test/exla/custom_call_alias_test.exs @@ -0,0 +1,106 @@ +defmodule EXLA.CustomCallAliasTest do + use EXLA.Case, async: false + + import Nx.Defn + + alias EXLA.Test.QRAliasBlock + + defmodule BuiltinFun do + import Nx.Defn + + defn qr(t), do: Nx.LinAlg.qr(t) + end + + defmodule Fun do + import Nx.Defn + + alias EXLA.Test.QRAliasBlock + + defn qr_alias_fn(t) do + q_out = Nx.template({3, 3}, {:f, 32}) + r_out = Nx.template({3, 4}, {:f, 32}) + + Nx.block(%QRAliasBlock{}, [t], {q_out, r_out}, fn _, t2 -> + Nx.LinAlg.qr(t2, mode: :reduced) + end) + end + end + + @plugin_relative ~c"test/exla_qr_alias.so" + + defp plugin_path do + :filename.join(:code.priv_dir(:exla), @plugin_relative) + end + + defp mlir_via_jit_apply!(fun, args) when is_function(fun) and is_list(args) do + try do + Nx.Defn.jit_apply(fun, args, + compiler: EXLA, + module_compilation: :to_mlir + ) + catch + :throw, {:mlir_module, ref, used_inputs, output_container} -> + %{ + mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}), + used_inputs: used_inputs, + output_container: output_container + } + end + end + + defp load_plugin! do + path = List.to_string(plugin_path()) + + unless File.exists?(path) do + flunk(""" + Missing #{path}. Build EXLA with MIX_ENV=test so the alias dylib is compiled \ + (see Makefile target exla_qr_alias.so). + """) + end + + case EXLA.NIF.load_dylib(path) do + :ok -> + :ok + + other -> + flunk("load_dylib(#{path}) expected :ok, got: #{inspect(other)}") + end + end + + test "builtin QR lowering includes qr_cpu_custom_call_f32 in MLIR" do + arg = Nx.iota({3, 4}, type: {:f, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&BuiltinFun.qr/1, [arg]) + + assert mlir =~ "@qr_cpu_custom_call_f32(" + refute mlir =~ "qr_cpu_custom_call_f32_exla_alias" + end + + test "builtin QR lowering includes qr_cpu_custom_call_s32 in MLIR" do + arg = Nx.iota({3, 4}, type: {:s, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&BuiltinFun.qr/1, [arg]) + + assert mlir =~ "@qr_cpu_custom_call_s32(" + refute mlir =~ "qr_cpu_custom_call_f32_exla_alias" + end + + test "QR alias plugin: MLIR uses alias name and not the builtin target string" do + load_plugin!() + + arg = Nx.iota({3, 4}, type: {:f, 32}) + assert %{mlir_module: mlir} = mlir_via_jit_apply!(&Fun.qr_alias_fn/1, [arg]) + + assert mlir =~ "qr_cpu_custom_call_f32_exla_alias" + refute mlir =~ "@qr_cpu_custom_call_f32(" + end + + test "QR alias plugin: JIT result matches builtin QR" do + load_plugin!() + + t = Nx.iota({3, 4}, type: {:f, 32}) + exp = EXLA.jit(fn t -> Nx.LinAlg.qr(t) end).(t) + act = EXLA.jit(&Fun.qr_alias_fn/1).(t) + + assert Nx.all_close(elem(exp, 0), elem(act, 0), atol: 1.0e-4, rtol: 1.0e-4) + assert Nx.all_close(elem(exp, 1), elem(act, 1), atol: 1.0e-4, rtol: 1.0e-4) + end +end diff --git a/exla/test/support/exla_test_qr_alias_block.ex b/exla/test/support/exla_test_qr_alias_block.ex new file mode 100644 index 0000000000..975f069380 --- /dev/null +++ b/exla/test/support/exla_test_qr_alias_block.ex @@ -0,0 +1,16 @@ +# Test-only block tag + `EXLA.CustomCall` impl used to emit a StableHLO custom_call +# with `call_target_name` `qr_cpu_custom_call_f32_exla_alias` (registered by +# `priv/test/exla_qr_alias.so` when built with `MIX_ENV=test`). +defmodule EXLA.Test.QRAliasBlock do + @moduledoc false + defstruct [] +end + +defimpl EXLA.CustomCall, for: EXLA.Test.QRAliasBlock do + def call(_, {%{type: {q_kind, q_size}}, _r_expr}, [_tensor], client) + when q_kind != :c and q_size == 32 and client.platform == :host do + {:ok, %EXLA.CustomCall.Spec{call_target_name: "qr_cpu_custom_call_f32_exla_alias"}} + end + + def call(_, _, _, _), do: :skip +end