Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions backends/cuda/quantize_op_dispatch/int8_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
dequant + cuBLAS matmul kernels.

Dispatch strategy (determines what gets captured in the export graph):
Decode (M<=4): Custom op ``executorch_cuda::int8_plain_mm``
Prefill (M>4): Inline dequant + F.linear (standard PyTorch ops)
Small M (M<=MATVEC_MAX_M): Custom op ``executorch_cuda::int8_plain_mm``
Large M (M>MATVEC_MAX_M): Inline dequant + F.linear (standard PyTorch ops)

``MATVEC_MAX_M`` defaults to 4 (decode). The W8A8 dp4a matvec launches one
block-row per M, so it handles any small M; an export may raise the threshold
locally (e.g. EAGLE-3 verify at M=chain_len+1) so a dynamic M window does not
straddle it and force a data-dependent branch.

Keeping INT8 on the same fused dp4a path lets mixed-precision recipes (e.g.
INT8 edge-layer v_proj/down_proj + INT4 elsewhere) keep ALL decode linears on a
Expand Down Expand Up @@ -49,9 +54,14 @@
)

# ---------------------------------------------------------------------------
# Custom op for INT8 decode (M<=4): W8A8 dp4a matvec in C shim.
# Custom op for INT8 small-M (M<=MATVEC_MAX_M): W8A8 dp4a matvec in C shim.
# ---------------------------------------------------------------------------

# Max M routed to the custom INT8 op; above this, dequant+cuBLAS wins. Defaults
# to 4 (decode); an export may raise it for a small dynamic M window (e.g. the
# EAGLE-3 verify window) so the range stays on one dispatch branch.
MATVEC_MAX_M = 4

_lib.define(
"int8_plain_mm(Tensor self, Tensor qdata, Tensor scale, Tensor zero, int group_size) -> Tensor"
)
Expand Down Expand Up @@ -121,7 +131,7 @@ def _(func, types, args, kwargs):
gs = weight_tensor.block_size[-1]

M = x_2d.shape[0]
if M <= 4:
if M <= MATVEC_MAX_M:
out = torch.ops.executorch_cuda.int8_plain_mm(x_2d, qdata, scale, zero, gs)
else:
out = _dequant_matmul_int8(x_2d, qdata, scale, zero, gs)
Expand Down
6 changes: 5 additions & 1 deletion examples/models/eagle3/draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ class Eagle3Config:


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
# Slice rather than chunk: chunk lowers to aten::split_copy, which the AOTI
# CUDA backend has no fallback kernel for.
half = x.shape[-1] // 2
x1 = x[..., :half]
x2 = x[..., half:]
return torch.cat((-x2, x1), dim=-1)


Expand Down
Loading
Loading