Skip to content

Commit 41b2e83

Browse files
notactuallytreyanastasioclaudepolvalente
authored
feat: Add 4-bit quantization support for LLM inference on Apple Silicon (#96)
* feat: Add 4-bit quantization support for LLM inference on Apple Silicon This PR adds quantized tensor operations to EMLX, enabling efficient large language model inference on Apple Silicon GPUs. It powers a pure Elixir LLM inference stack achieving 135 tok/s on Qwen3-8B-4bit. ## Motivation Running 8B parameter models requires 16GB+ at fp16. With 4-bit quantization, the same model fits in ~5GB, enabling inference on consumer Macs. This work is part of a broader effort to bring production LLM inference to the Elixir ecosystem: - bobby_posts: Pure Elixir Qwen3-8B inference (135 tok/s) - bobby_posts_adapters: LoRA fine-tuning for personalized generation - bumblebee_quantized: Quantized model loading for Bumblebee - safetensors_ex: MLX 4-bit safetensors format support ## Implementation ### NIFs (c_src/emlx_nif.cpp) Three new NIFs wrapping MLX's quantization functions: - quantized_matmul(x, w, scales, biases, transpose, group_size, bits) - dequantize(w, scales, biases, group_size, bits) - quantize(w, group_size, bits) ### Backend Integration (lib/emlx/backend.ex) Per Paulo's feedback, quantization metadata is stored directly on the Backend struct (not a nested map): defstruct [:ref, :shape, :type, :data, :scales, :biases, :group_size] When Nx.dot detects a quantized tensor (scales != nil), it automatically dispatches to quantized_matmul. The tensor type {:s, 4} carries the bit width, so bits is not stored separately. ### User API (lib/emlx/quantization.ex) Clean user-facing module with comprehensive documentation: # Quantize weights {q_weight, scales, biases} = EMLX.Quantization.quantize(weight) # Create tensor for Nx operations qt = EMLX.Quantization.tensor(q_weight, scales, biases, shape) # Nx.dot automatically uses quantized_matmul result = Nx.dot(input, qt) ### Elixir API (lib/emlx.ex) Low-level functions for direct NIF access: - EMLX.quantized_matmul/7 - EMLX.dequantize/5 - EMLX.quantize/3 - EMLX.quantized_tensor/5 ## MLX 4-bit Format MLX uses group-wise affine quantization: dequantized[i] = scales[i/group_size] * (packed_int4[i] - biases[i/group_size]) Weights are packed as uint32 (8 int4 values per uint32). With group_size=64: - Weight [out, in] becomes [out, in/8] as uint32 - Scales: [out, in/group_size] as bfloat16 - Biases: [out, in/group_size] as bfloat16 ## Tests 33 tests covering: - Low-level NIF operations (6 tests) - Backend integration with Nx.dot (9 tests) - EMLX.Quantization module API (18 tests) - End-to-end LLM inference patterns ## Performance On Apple M-series with Qwen3-8B-4bit: - Single-token latency: ~135 tok/s - Memory: 4-5GB vs 16GB for fp16 - 14x faster than Python mlx_lm (9.5 tok/s) ## Bumblebee Integration Path With this merged, quantized models can use EMLX as a pure backend: 1. Model loader detects quantized safetensors 2. Creates EMLX.Quantization.tensor for each quantized weight 3. Model definition unchanged - Nx.dot works transparently 4. EMLX backend handles all dispatch This enables upstreaming quantized model support to Bumblebee without changing the serving interface. ## References - Use case: https://github.com/notactuallytreyanastasio/bobby_posts - PR discussion: #96 - MLX quantization: https://ml-explore.github.io/mlx/build/html/python/nn.html Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * docs: PR feedback suggestions Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> * Apply suggestion from @polvalente Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> * green out after merge * refactor: enable quantized ops to run in defn * feat: fused quantized matmul in backend * fix: tag metal tests * refactor default device * fix: from pointer use default device * fix: propagate env * fix ensure_all_started * fix ensure all started call * fix again --------- Co-authored-by: Trey Anastasio <notactuallytreyanastasio@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent 3ceaf34 commit 41b2e83

11 files changed

Lines changed: 1242 additions & 47 deletions

c_src/emlx_nif.cpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,67 @@ NIF(as_strided) {
14911491
TENSOR(mlx::core::as_strided(*t, to_shape(shape), to_strides(strides), offset, device));
14921492
}
14931493

1494+
// ============================================================================
1495+
// Quantization Operations (for 4-bit model support)
1496+
// ============================================================================
1497+
1498+
// quantized_matmul - Multiplies x with a quantized weight matrix w
1499+
// This is the key operation for efficient 4-bit inference
1500+
// MLX API: quantized_matmul(x, w, scales, biases, transpose, group_size, bits, stream)
1501+
NIF(quantized_matmul) {
1502+
TENSOR_PARAM(0, x); // Input tensor [batch, seq, hidden]
1503+
TENSOR_PARAM(1, w); // Quantized weights [out/8, in] (uint32 packed)
1504+
TENSOR_PARAM(2, scales); // Scales [out/group_size, in] (bfloat16)
1505+
TENSOR_PARAM(3, biases); // Biases [out/group_size, in] (bfloat16)
1506+
PARAM(4, bool, transpose);
1507+
PARAM(5, int, group_size);
1508+
PARAM(6, int, bits);
1509+
DEVICE_PARAM(7, device);
1510+
1511+
TENSOR(mlx::core::quantized_matmul(
1512+
*x, *w, *scales, *biases, transpose, group_size, bits, "affine", device));
1513+
}
1514+
1515+
// dequantize - Converts quantized weights back to float
1516+
// Useful for debugging and verification
1517+
// MLX API: dequantize(w, scales, biases, group_size, bits, stream)
1518+
NIF(dequantize) {
1519+
TENSOR_PARAM(0, w); // Quantized weights (uint32 packed)
1520+
TENSOR_PARAM(1, scales); // Scales (bfloat16)
1521+
TENSOR_PARAM(2, biases); // Biases (bfloat16)
1522+
PARAM(3, int, group_size);
1523+
PARAM(4, int, bits);
1524+
DEVICE_PARAM(5, device);
1525+
1526+
TENSOR(mlx::core::dequantize(*w, *scales, *biases, group_size, bits, "affine", std::nullopt, std::nullopt, device));
1527+
}
1528+
1529+
// quantize - Quantizes a float tensor to packed format
1530+
// Returns tuple of {weights, scales, biases}
1531+
// MLX API: quantize(w, group_size, bits, stream) -> tuple<array, array, array>
1532+
NIF(quantize) {
1533+
TENSOR_PARAM(0, w); // Float weights to quantize
1534+
PARAM(1, int, group_size);
1535+
PARAM(2, int, bits);
1536+
DEVICE_PARAM(3, device);
1537+
1538+
try {
1539+
auto result = mlx::core::quantize(*w, group_size, bits, "affine", std::nullopt, device);
1540+
1541+
ERL_NIF_TERM result_tuple[3];
1542+
result_tuple[0] = create_tensor_resource(env, result[0]);
1543+
result_tuple[1] = create_tensor_resource(env, result[1]);
1544+
result_tuple[2] = create_tensor_resource(env, result[2]);
1545+
1546+
return nx::nif::ok(env, enif_make_tuple3(env, result_tuple[0], result_tuple[1], result_tuple[2]));
1547+
}
1548+
CATCH()
1549+
}
1550+
1551+
ASYNC_NIF(quantized_matmul)
1552+
ASYNC_NIF(dequantize)
1553+
ASYNC_NIF(quantize)
1554+
14941555
// Build a sliding window view of a padded tensor.
14951556
// padded: [...] of ndim n; window/strides: per-axis lists of length n.
14961557
// Returns a view of shape [o0,...,on-1, w0,...,wn-1] where
@@ -1997,7 +2058,11 @@ static ErlNifFunc nif_funcs[] = {
19972058

19982059
// ── Worker control NIFs.
19992060
{"command_queue_new", 1, command_queue_new},
2000-
{"command_queue_synchronize", 1, command_queue_synchronize}};
2061+
{"command_queue_synchronize", 1, command_queue_synchronize},
2062+
// Quantization operations (async — must run on a worker thread)
2063+
{"quantized_matmul", 9, quantized_matmul_async},
2064+
{"dequantize", 7, dequantize_async},
2065+
{"quantize", 5, quantize_async}};
20012066

20022067
ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL)
20032068

lib/emlx.ex

Lines changed: 251 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,231 @@ defmodule EMLX do
323323
defvalue scalar_type(tensor)
324324
defvalue shape(tensor)
325325

326+
## Quantization operations (for 4-bit model support)
327+
328+
@doc """
329+
Performs quantized matrix multiplication.
330+
331+
This is the key operation for efficient 4-bit inference. It multiplies `x` with
332+
quantized weights `w` (packed as uint32), using scales and biases for
333+
dequantization during the computation.
334+
335+
## Parameters
336+
- `x` - Input tensor (e.g., {batch, seq, hidden})
337+
- `w` - Quantized weights as uint32 (8 int4 values packed per uint32)
338+
- `scales` - Per-group scale factors (bfloat16)
339+
- `biases` - Per-group zero points (bfloat16)
340+
- `transpose` - Whether to transpose weights (default: true)
341+
- `group_size` - Number of weights per scale/bias group (default: 64)
342+
- `bits` - Quantization bits (default: 4)
343+
"""
344+
@mlx_function {:quantized_matmul, 9}
345+
def quantized_matmul(
346+
{dev_x, ref_x} = _tensor_x,
347+
{dev_w, ref_w} = _tensor_w,
348+
{dev_s, ref_s} = _tensor_scales,
349+
{dev_b, ref_b} = _tensor_biases,
350+
transpose \\ true,
351+
group_size \\ 64,
352+
bits \\ 4
353+
)
354+
when is_tensor(dev_x, ref_x) and is_tensor(dev_w, ref_w) and
355+
is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
356+
device = merge_device(merge_device(dev_x, dev_w), merge_device(dev_s, dev_b))
357+
{worker, effective_device} = resolve_worker(device)
358+
359+
job_ref =
360+
EMLX.NIF.quantized_matmul(
361+
worker,
362+
ref_x,
363+
ref_w,
364+
ref_s,
365+
ref_b,
366+
transpose,
367+
group_size,
368+
bits,
369+
effective_device
370+
)
371+
|> unwrap!()
372+
373+
await_worker(job_ref) |> wrap_tensor(effective_device)
374+
end
375+
376+
@doc """
377+
Dequantizes packed weights to floating point.
378+
379+
Converts quantized weights back to their original floating point representation.
380+
Useful for debugging and verification.
381+
382+
## Parameters
383+
- `w` - Quantized weights as uint32 (packed int4 values)
384+
- `scales` - Per-group scale factors
385+
- `biases` - Per-group zero points
386+
- `group_size` - Number of weights per group (default: 64)
387+
- `bits` - Quantization bits (default: 4)
388+
"""
389+
@mlx_function {:dequantize, 7}
390+
def dequantize(
391+
{dev_w, ref_w} = _tensor_w,
392+
{dev_s, ref_s} = _tensor_scales,
393+
{dev_b, ref_b} = _tensor_biases,
394+
group_size,
395+
bits
396+
)
397+
when is_tensor(dev_w, ref_w) and is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
398+
device = merge_device(dev_w, merge_device(dev_s, dev_b))
399+
{worker, effective_device} = resolve_worker(device)
400+
401+
job_ref =
402+
EMLX.NIF.dequantize(worker, ref_w, ref_s, ref_b, group_size, bits, effective_device)
403+
|> unwrap!()
404+
405+
await_worker(job_ref) |> wrap_tensor(effective_device)
406+
end
407+
408+
@doc """
409+
Quantizes a floating point tensor to packed format.
410+
411+
Returns a tuple of `{quantized_weights, scales, biases}` where:
412+
- `quantized_weights` - Packed uint32 tensor (8 int4 values per uint32)
413+
- `scales` - Per-group scale factors
414+
- `biases` - Per-group zero points
415+
416+
## Parameters
417+
- `w` - Float tensor to quantize
418+
- `group_size` - Number of weights per group (default: 64)
419+
- `bits` - Quantization bits (default: 4)
420+
"""
421+
@mlx_function {:quantize, 5}
422+
def quantize({dev_w, ref_w}, group_size, bits)
423+
when is_tensor(dev_w, ref_w) do
424+
device = dev_w
425+
{worker, effective_device} = resolve_worker(device)
426+
427+
{weights_ref, scales_ref, biases_ref} =
428+
EMLX.NIF.quantize(worker, ref_w, group_size, bits, effective_device)
429+
|> unwrap!()
430+
|> await_worker()
431+
432+
{{effective_device, weights_ref}, {effective_device, scales_ref},
433+
{effective_device, biases_ref}}
434+
end
435+
436+
@doc """
437+
Quantize a dense 2-D `Nx.Tensor` and return an annotated quantized tensor.
438+
439+
The returned tensor carries the original logical shape and type (e.g.
440+
`{:s, 4}`). Its backend stores the packed uint32 data and a
441+
`EMLX.Quantization.Config` with scales, biases, `group_size`, and `bits`.
442+
443+
## Options
444+
445+
* `:type` — storage type: `{:s, 2}`, `{:s, 4}` (default), or `{:s, 8}`.
446+
* `:group_size` — 32, 64, or 128 (default 64). Must evenly divide the last
447+
dimension of `tensor`.
448+
"""
449+
@spec quantize(Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
450+
def quantize(%Nx.Tensor{} = tensor, opts) when is_list(opts) do
451+
type = Keyword.get(opts, :type, {:s, 4})
452+
{_, bits} = type
453+
group_size = Keyword.get(opts, :group_size, 64)
454+
455+
unless Nx.rank(tensor) == 2 do
456+
raise ArgumentError,
457+
"EMLX.quantize/2 requires a rank-2 tensor, got rank #{Nx.rank(tensor)}"
458+
end
459+
460+
{_out_features, in_features} = Nx.shape(tensor)
461+
462+
unless rem(in_features, group_size) == 0 do
463+
raise ArgumentError,
464+
"EMLX.quantize/2 requires the last dimension (#{in_features}) " <>
465+
"to be divisible by group_size (#{group_size})"
466+
end
467+
468+
device_ref = EMLX.Backend.from_nx(tensor)
469+
{weight_ref, scales_ref, biases_ref} = EMLX.quantize(device_ref, group_size, bits)
470+
471+
scales = EMLX.Backend.to_nx(scales_ref)
472+
biases = EMLX.Backend.to_nx(biases_ref)
473+
474+
config = %EMLX.Quantization.Config{
475+
scales: scales,
476+
biases: biases,
477+
group_size: group_size,
478+
bits: bits
479+
}
480+
481+
weight_shape = EMLX.shape(weight_ref)
482+
template = Nx.template(Nx.shape(tensor), type)
483+
484+
%Nx.Tensor{
485+
template
486+
| data: %EMLX.Backend{
487+
ref: weight_ref,
488+
shape: weight_shape,
489+
type: {:u, 32},
490+
quantization_config: config
491+
}
492+
}
493+
end
494+
495+
@doc """
496+
Dequantize a quantized `Nx.Tensor` (created by `EMLX.quantize/2`) to a
497+
dense float tensor by calling `mx::dequantize`.
498+
"""
499+
@spec dequantize(Nx.Tensor.t()) :: Nx.Tensor.t()
500+
def dequantize(
501+
%Nx.Tensor{
502+
data: %EMLX.Backend{ref: weight_ref, quantization_config: cfg}
503+
} = _qw
504+
)
505+
when not is_nil(cfg) do
506+
EMLX.dequantize(
507+
weight_ref,
508+
EMLX.Backend.from_nx(cfg.scales),
509+
EMLX.Backend.from_nx(cfg.biases),
510+
cfg.group_size,
511+
cfg.bits
512+
)
513+
|> EMLX.Backend.to_nx()
514+
end
515+
516+
@doc """
517+
Run `activation @ dequantize(qw)` using `mx::quantized_matmul`.
518+
519+
`qw` must be a quantized tensor produced by `EMLX.quantize/2`. Raises
520+
`ArgumentError` if both arguments are quantized.
521+
"""
522+
@spec quantized_matmul(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
523+
def quantized_matmul(%Nx.Tensor{} = activation, %Nx.Tensor{} = qw) do
524+
cfg = qw.data.quantization_config
525+
526+
if is_nil(cfg) do
527+
raise ArgumentError,
528+
"EMLX.quantized_matmul/2: second argument must be a quantized tensor"
529+
end
530+
531+
if not is_nil(activation.data.quantization_config) do
532+
raise ArgumentError,
533+
"EMLX.quantized_matmul/2 requires a dense activation as the first " <>
534+
"argument; got two quantized tensors. Dequantize one of them first."
535+
end
536+
537+
result =
538+
EMLX.quantized_matmul(
539+
EMLX.Backend.from_nx(activation),
540+
qw.data.ref,
541+
EMLX.Backend.from_nx(cfg.scales),
542+
EMLX.Backend.from_nx(cfg.biases),
543+
true,
544+
cfg.group_size,
545+
cfg.bits
546+
)
547+
548+
EMLX.Backend.to_nx(result)
549+
end
550+
326551
def to_blob({device, ref} = tensor) when is_tensor(device, ref) do
327552
# Eval first so the underlying MLX array is materialised; then ask the
328553
# worker for the contiguous-copy + zero-copy resource binary. Both
@@ -506,6 +731,18 @@ defmodule EMLX do
506731
deftensor squeeze(tensor, axes)
507732
defvalue strides(tensor)
508733

734+
@doc """
735+
Converts an EMLX device ref back to an Nx.Tensor.
736+
737+
## Example
738+
739+
result_ref = EMLX.some_operation(input)
740+
result_tensor = EMLX.to_nx(result_ref)
741+
"""
742+
def to_nx({device, ref} = device_ref) when is_atom(device) and is_reference(ref) do
743+
EMLX.Backend.to_nx(device_ref)
744+
end
745+
509746
@doc """
510747
Returns the scalar value of a 0-d tensor as a number.
511748
@@ -558,7 +795,7 @@ defmodule EMLX do
558795
@impl Nx.Defn.Compiler
559796
def __partitions_options__(opts) do
560797
n = Keyword.get(opts, :max_concurrency, 1)
561-
device = Keyword.get(opts, :device, :gpu)
798+
device = Keyword.get(opts, :device, default_device())
562799

563800
# Allocate one CommandQueue (and its OS thread) per partition. This runs
564801
# inside Nx.Serving's GenServer init/1 — queues are owned by module_state.
@@ -570,10 +807,22 @@ defmodule EMLX do
570807

571808
@impl Nx.Defn.Compiler
572809
def __to_backend__(opts) do
573-
device = Keyword.get(opts, :device, :gpu)
810+
device = Keyword.get(opts, :device, default_device())
574811
{EMLX.Backend, device: device}
575812
end
576813

814+
@doc """
815+
Returns the default MLX device for this process.
816+
817+
Reads `:default_device` from the `:emlx` application environment, falling
818+
back to `:gpu`. Override in tests or config via:
819+
820+
Application.put_env(:emlx, :default_device, :cpu)
821+
"""
822+
def default_device do
823+
Application.get_env(:emlx, :default_device, :gpu)
824+
end
825+
577826
# Splits opts into {emlx_compiler_opts, rest_opts}. The rest_opts are
578827
# forwarded to Nx.Defn.Evaluator; EMLX-specific keys are consumed here.
579828
defp split_compiler_opts(opts) do

0 commit comments

Comments
 (0)