Skip to content

Commit c8c0ebb

Browse files
authored
feat: EMLX.Fast and EMLXAxon (#109)
* feat: emlx.fast * feat: add emlx axon * feat: add emlx axon rewrites * improve emlx fast * close the gap * wip: 50tok/s with axon * feat: add more rewrites * more improvements * fix: make benchmarks work with github bumblebee * feat: add compiler * more compilation * feat: add compiler * feat: add bench * bench * close more gaps * revert compiler * remove bench * cleanup * docs * wip * clean up code * split nifs * fix nifs * formatter * config * chore: remove validation * fix: profile eval * chore: format
1 parent 41b2e83 commit c8c0ebb

46 files changed

Lines changed: 6198 additions & 615 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ emlx-*.tar
2727
/cache
2828

2929
# Validation subproject build artifacts (not the mix.lock — that is committed).
30-
/validation/_build/
31-
/validation/deps/
30+
/**/_build/
31+
/**/deps/

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ endif
4040
MAKE_JOBS ?= $(MAKE_DEFAULT_JOBS)
4141

4242
# Source files
43-
SOURCES = c_src/emlx_nif.cpp
44-
HEADERS = c_src/nx_nif_utils.hpp c_src/emlx_worker.hpp c_src/emlx_async.hpp
43+
SOURCES = c_src/emlx_nif.cpp c_src/emlx_fast.cpp
44+
HEADERS = c_src/nx_nif_utils.hpp c_src/emlx_worker.hpp c_src/emlx_async.hpp c_src/emlx_nif_shared.hpp
4545
OBJECTS = $(patsubst c_src/%.cpp,$(BUILD_DIR)/%.o,$(SOURCES))
4646

4747
# Main targets

bench/mx_compile_bench.exs

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
#!/usr/bin/env elixir
2+
# C03 — graph_capture / graph_replay benchmark
3+
#
4+
# Tests graph_capture/replay speedup across kernels of increasing graph depth
5+
# and NIF-call count, to determine whether NIF dispatch or GPU execution
6+
# dominates, and at what graph size compile gives ≥2×.
7+
#
8+
# Kernels (increasing NIF count, decreasing NIF/GPU ratio):
9+
# K1 — elementwise chain : 100 sequential Nx ops, small GPU work per op
10+
# K2 — FFN block : ~6 Nx ops, large GPU matmuls (Qwen3-0.6B proxy)
11+
# K3 — 28-layer FFN stack : 28× K2 on the same inputs (full decode proxy)
12+
# K4 — SVD 512×512 : 1 Nx op, complex internal graph in MLX
13+
# K5 — SVD 1024×1024 : 1 Nx op, larger SVD (more internal GPU kernels)
14+
#
15+
# Run:
16+
# mix run bench/mx_compile_bench.exs
17+
# EMLX_BENCH_ITERS=200 mix run bench/mx_compile_bench.exs
18+
19+
defmodule MxCompileBench do
20+
@iters String.to_integer(System.get_env("EMLX_BENCH_ITERS", "200"))
21+
@warmup 15
22+
23+
# ── Helpers ─────────────────────────────────────────────────────────────
24+
25+
defp raw_ref(t), do: elem(t.data.ref, 1)
26+
defp dev(t), do: elem(t.data.ref, 0)
27+
28+
defp tensor(val, shape, type \\ :f32),
29+
do: Nx.broadcast(Nx.tensor(val, type: type, backend: EMLX.Backend), shape)
30+
31+
defp bench(label, n, fun) do
32+
for _ <- 1..@warmup, do: fun.()
33+
t0 = System.monotonic_time(:microsecond)
34+
for _ <- 1..n, do: fun.()
35+
t1 = System.monotonic_time(:microsecond)
36+
per = (t1 - t0) / n
37+
IO.puts(" #{String.pad_trailing(label, 22)}: #{Float.round(per, 1)} μs/iter")
38+
per
39+
end
40+
41+
# Capture graph, measure capture latency, return compiled_ref
42+
defp capture(input_tensors, output_tensors) do
43+
inputs = Enum.map(input_tensors, &raw_ref/1)
44+
outputs = Enum.map(output_tensors, &raw_ref/1)
45+
{us, {:ok, cr}} = :timer.tc(fn -> EMLX.NIF.graph_capture(inputs, outputs, false) end)
46+
{cr, inputs, us}
47+
end
48+
49+
defp replay_and_eval(cr, input_refs, dev, _n_outputs) do
50+
{:ok, out_raws} = EMLX.NIF.graph_replay(cr, input_refs)
51+
Enum.each(out_raws, fn r -> EMLX.eval({dev, r}) end)
52+
# Return first output ref for correctness checks
53+
{dev, hd(out_raws)}
54+
end
55+
56+
# Check max-abs-diff between ref result and replayed result
57+
defp check_correctness(ref_tensor, cr, input_refs, dev, shape, type) do
58+
{:ok, [r | _]} = EMLX.NIF.graph_replay(cr, input_refs)
59+
rep = {dev, r} |> EMLX.Backend.to_nx(Nx.template(shape, type))
60+
Nx.subtract(ref_tensor, rep) |> Nx.abs() |> Nx.reduce_max() |> Nx.to_number()
61+
end
62+
63+
defp print_result(baseline, compiled, capture_us, max_diff, gate) do
64+
speedup = baseline / compiled
65+
pass = if speedup >= gate, do: "✓ PASS", else: "✗ FAIL"
66+
IO.puts(" baseline #{Float.round(baseline, 1)} μs │ " <>
67+
"compiled #{Float.round(compiled, 1)} μs │ " <>
68+
"speedup #{Float.round(speedup, 2)}× │ " <>
69+
"capture #{capture_us} μs │ " <>
70+
"max_diff #{Float.round(max_diff * 1.0, 4)} │ " <>
71+
"#{pass} (gate #{gate}×)")
72+
speedup
73+
end
74+
75+
# ── Kernel K1: elementwise chain ─────────────────────────────────────────
76+
77+
def bench_elementwise_chain do
78+
IO.puts("\n── K1 Elementwise chain (100 sequential ops, 1024-elem vector) ──")
79+
n = 100
80+
x = tensor(1.0, {1024})
81+
82+
# Build the chain
83+
chain_fn = fn x ->
84+
Enum.reduce(1..n, x, fn i, acc ->
85+
scale = tensor(1.0 + i * 0.001, {1024})
86+
Nx.multiply(acc, scale)
87+
end)
88+
end
89+
90+
out_trace = chain_fn.(x)
91+
# Inputs = x + all n scale tensors (captures them as constants in the tape)
92+
{cr, input_refs, cap_us} = capture([x], [out_trace])
93+
d = dev(x)
94+
95+
t_base = bench("dispatch+eval", @iters, fn ->
96+
EMLX.eval(chain_fn.(x).data.ref)
97+
end)
98+
99+
t_comp = bench("replay+eval", @iters, fn ->
100+
replay_and_eval(cr, input_refs, d, 1)
101+
end)
102+
103+
diff = check_correctness(chain_fn.(x), cr, input_refs, d, {1024}, :f32)
104+
print_result(t_base, t_comp, cap_us, diff, 1.5)
105+
end
106+
107+
# ── Kernel K2: single FFN block ───────────────────────────────────────────
108+
109+
def bench_ffn_block do
110+
IO.puts("\n── K2 Single FFN block (Qwen3-0.6B, decode seq_len=1) ──────────")
111+
# hidden=1024, intermediate=2816
112+
x = tensor(0.1, {1, 1024})
113+
w1 = tensor(0.01, {1024, 2816})
114+
w2 = tensor(0.01, {2816, 1024})
115+
b = tensor(0.0, {1, 1024})
116+
117+
ffn = fn x, w1, w2, b ->
118+
x |> Nx.dot(w1) |> Nx.sigmoid() |> Nx.dot(w2) |> Nx.add(b)
119+
end
120+
121+
out_trace = ffn.(x, w1, w2, b)
122+
{cr, input_refs, cap_us} = capture([x, w1, w2, b], [out_trace])
123+
d = dev(x)
124+
125+
t_base = bench("dispatch+eval", @iters, fn ->
126+
EMLX.eval(ffn.(x, w1, w2, b).data.ref)
127+
end)
128+
129+
t_comp = bench("replay+eval", @iters, fn ->
130+
replay_and_eval(cr, input_refs, d, 1)
131+
end)
132+
133+
diff = check_correctness(ffn.(x, w1, w2, b), cr, input_refs, d, {1, 1024}, :f32)
134+
print_result(t_base, t_comp, cap_us, diff, 2.0)
135+
end
136+
137+
# ── Kernel K3: 28-layer FFN stack ─────────────────────────────────────────
138+
139+
def bench_ffn_stack do
140+
IO.puts("\n── K3 28-layer FFN stack (28× K2, full decode proxy) ───────────")
141+
x = tensor(0.1, {1, 1024})
142+
w1 = tensor(0.01, {1024, 2816})
143+
w2 = tensor(0.01, {2816, 1024})
144+
b = tensor(0.0, {1, 1024})
145+
146+
stack_fn = fn x, w1, w2, b ->
147+
Enum.reduce(1..28, x, fn _, acc ->
148+
acc |> Nx.dot(w1) |> Nx.sigmoid() |> Nx.dot(w2) |> Nx.add(b)
149+
end)
150+
end
151+
152+
out_trace = stack_fn.(x, w1, w2, b)
153+
{cr, input_refs, cap_us} = capture([x, w1, w2, b], [out_trace])
154+
d = dev(x)
155+
156+
t_base = bench("dispatch+eval", @iters, fn ->
157+
EMLX.eval(stack_fn.(x, w1, w2, b).data.ref)
158+
end)
159+
160+
t_comp = bench("replay+eval", @iters, fn ->
161+
replay_and_eval(cr, input_refs, d, 1)
162+
end)
163+
164+
diff = check_correctness(stack_fn.(x, w1, w2, b), cr, input_refs, d, {1, 1024}, :f32)
165+
print_result(t_base, t_comp, cap_us, diff, 2.0)
166+
end
167+
168+
# ── Kernel K4: SVD 512×512 ─────────────────────────────────────────────
169+
170+
def bench_svd_512 do
171+
IO.puts("\n── K4 SVD 512×512 (single NIF, complex internal MLX graph) ────")
172+
m = tensor(0.1, {512, 512})
173+
# SVD returns {u, s, vt} — we capture all 3 outputs
174+
{u, s, vt} = Nx.LinAlg.svd(m, full_matrices?: false)
175+
{cr, input_refs, cap_us} = capture([m], [u, s, vt])
176+
d = dev(m)
177+
178+
t_base = bench("dispatch+eval", @iters, fn ->
179+
{u2, s2, vt2} = Nx.LinAlg.svd(m, full_matrices?: false)
180+
EMLX.eval(u2.data.ref); EMLX.eval(s2.data.ref); EMLX.eval(vt2.data.ref)
181+
end)
182+
183+
t_comp = bench("replay+eval", @iters, fn ->
184+
{:ok, [ur, sr, vtr]} = EMLX.NIF.graph_replay(cr, input_refs)
185+
EMLX.eval({d, ur}); EMLX.eval({d, sr}); EMLX.eval({d, vtr})
186+
end)
187+
188+
# Correctness: check U only
189+
{:ok, [ur | _]} = EMLX.NIF.graph_replay(cr, input_refs)
190+
{u_ref, _, _} = Nx.LinAlg.svd(m, full_matrices?: false)
191+
rep_u = {d, ur} |> EMLX.Backend.to_nx(Nx.template({512, 512}, :f32))
192+
diff = Nx.subtract(u_ref, rep_u) |> Nx.abs() |> Nx.reduce_max() |> Nx.to_number()
193+
print_result(t_base, t_comp, cap_us, diff, 1.5)
194+
end
195+
196+
# ── Kernel K5: SVD 1024×1024 ───────────────────────────────────────────
197+
198+
def bench_svd_1024 do
199+
IO.puts("\n── K5 SVD 1024×1024 (larger SVD, more internal GPU kernels) ────")
200+
m = tensor(0.1, {1024, 1024})
201+
{u, s, vt} = Nx.LinAlg.svd(m, full_matrices?: false)
202+
{cr, input_refs, cap_us} = capture([m], [u, s, vt])
203+
d = dev(m)
204+
205+
t_base = bench("dispatch+eval", @iters, fn ->
206+
{u2, s2, vt2} = Nx.LinAlg.svd(m, full_matrices?: false)
207+
EMLX.eval(u2.data.ref); EMLX.eval(s2.data.ref); EMLX.eval(vt2.data.ref)
208+
end)
209+
210+
t_comp = bench("replay+eval", @iters, fn ->
211+
{:ok, [ur, sr, vtr]} = EMLX.NIF.graph_replay(cr, input_refs)
212+
EMLX.eval({d, ur}); EMLX.eval({d, sr}); EMLX.eval({d, vtr})
213+
end)
214+
215+
{:ok, [ur | _]} = EMLX.NIF.graph_replay(cr, input_refs)
216+
{u_ref, _, _} = Nx.LinAlg.svd(m, full_matrices?: false)
217+
rep_u = {d, ur} |> EMLX.Backend.to_nx(Nx.template({1024, 1024}, :f32))
218+
diff = Nx.subtract(u_ref, rep_u) |> Nx.abs() |> Nx.reduce_max() |> Nx.to_number()
219+
print_result(t_base, t_comp, cap_us, diff, 1.5)
220+
end
221+
222+
# ── Overhead breakdown ─────────────────────────────────────────────────
223+
224+
def bench_overhead do
225+
IO.puts("\n── Overhead breakdown ────────────────────────────────────────────")
226+
n = 1000
227+
x = tensor(0.1, {1, 1024})
228+
w1 = tensor(0.01, {1024, 2816})
229+
ffn1 = fn x, w1 -> x |> Nx.dot(w1) |> Nx.sigmoid() end
230+
out = ffn1.(x, w1)
231+
{cr, input_refs, _} = capture([x, w1], [out])
232+
nif_dispatch_only =
233+
(for _ <- 1..n, do: ffn1.(x, w1)) |> then(fn _ ->
234+
t0 = System.monotonic_time(:microsecond)
235+
for _ <- 1..n, do: ffn1.(x, w1)
236+
t1 = System.monotonic_time(:microsecond)
237+
(t1 - t0) / n
238+
end)
239+
240+
graph_replay_only =
241+
(for _ <- 1..n, do: EMLX.NIF.graph_replay(cr, input_refs)) |> then(fn _ ->
242+
t0 = System.monotonic_time(:microsecond)
243+
for _ <- 1..n, do: EMLX.NIF.graph_replay(cr, input_refs)
244+
t1 = System.monotonic_time(:microsecond)
245+
(t1 - t0) / n
246+
end)
247+
248+
eval_only =
249+
# Eval same pre-built array (GPU time lower bound — cached result, no-op)
250+
(for _ <- 1..n, do: EMLX.eval(out.data.ref)) |> then(fn _ ->
251+
t0 = System.monotonic_time(:microsecond)
252+
for _ <- 1..n, do: EMLX.eval(out.data.ref)
253+
t1 = System.monotonic_time(:microsecond)
254+
(t1 - t0) / n
255+
end)
256+
257+
# Fresh eval (actually computes on GPU)
258+
fresh_eval =
259+
(for _ <- 1..@warmup, do: ffn1.(x, w1) |> then(& EMLX.eval(&1.data.ref))) |> then(fn _ ->
260+
t0 = System.monotonic_time(:microsecond)
261+
for _ <- 1..@iters, do: ffn1.(x, w1) |> then(& EMLX.eval(&1.data.ref))
262+
t1 = System.monotonic_time(:microsecond)
263+
(t1 - t0) / @iters
264+
end)
265+
266+
IO.puts(" NIF dispatch only (no eval) : #{Float.round(nif_dispatch_only, 1)} μs")
267+
IO.puts(" graph_replay only (no eval) : #{Float.round(graph_replay_only, 1)} μs")
268+
IO.puts(" eval (cached/no-op) : #{Float.round(eval_only, 1)} μs")
269+
IO.puts(" dispatch + fresh eval : #{Float.round(fresh_eval, 1)} μs")
270+
IO.puts(" GPU time (approx) : #{Float.round(fresh_eval - nif_dispatch_only, 1)} μs")
271+
IO.puts(" NIF share of total : #{Float.round(100 * nif_dispatch_only / fresh_eval, 1)}%")
272+
end
273+
274+
# ── Main ──────────────────────────────────────────────────────────────────
275+
276+
def run do
277+
IO.puts("""
278+
╔══════════════════════════════════════════════════════════════════╗
279+
C03 graph_capture / graph_replay benchmark (#{@iters} iters, #{@warmup} warmup)
280+
╚══════════════════════════════════════════════════════════════════╝
281+
""")
282+
283+
bench_overhead()
284+
bench_elementwise_chain()
285+
bench_ffn_block()
286+
bench_ffn_stack()
287+
bench_svd_512()
288+
bench_svd_1024()
289+
290+
IO.puts("\nDone.")
291+
end
292+
end
293+
294+
MxCompileBench.run()

0 commit comments

Comments
 (0)