Skip to content

Commit 3dc9bfd

Browse files
committed
up
1 parent 0b25002 commit 3dc9bfd

7 files changed

Lines changed: 648 additions & 103 deletions

File tree

backends/mlx/builder/op_helpers.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
if TYPE_CHECKING:
1919
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
2020

21+
# When True, always serialize the biases tensor for quantized ops.
22+
# When False, use init-time computation when zero_point is all zeros,
23+
# computing biases = -scales * 2^(bits-1) during the init chain.
24+
QUANTIZED_SERIALIZE_BIASES = False
25+
2126

2227
def get_aten_target(target):
2328
"""
@@ -168,6 +173,50 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S
168173
return slot
169174

170175

176+
def emit_quantized_biases(
177+
P: "MLXProgramBuilder",
178+
zero_point_key: str,
179+
scale: torch.Tensor,
180+
zero_point: torch.Tensor,
181+
bits: int,
182+
B: torch.Tensor,
183+
scale_slot: "Slot",
184+
) -> "Slot":
185+
"""Emit biases for quantized ops, computing at init time when possible.
186+
187+
When zero_point is all zeros and QUANTIZED_SERIALIZE_BIASES is False,
188+
avoids serializing the biases tensor by computing biases = scales * -offset
189+
during the init chain instead.
190+
191+
Returns the biases Slot.
192+
"""
193+
from executorch.backends.mlx.serialization.mlx_graph_schema import MultiplyNode
194+
from torch._subclasses.fake_tensor import FakeTensor
195+
196+
is_scale_only = False
197+
if not isinstance(zero_point, FakeTensor):
198+
if torch.sum(torch.abs(zero_point)).item() == 0:
199+
is_scale_only = True
200+
201+
if QUANTIZED_SERIALIZE_BIASES or not is_scale_only:
202+
return P.make_or_get_constant(f"{zero_point_key}_to_biases", B)
203+
204+
scale_dtype = scale.dtype
205+
offset = 1 << (bits - 1)
206+
neg_offset = emit_lifted_constant(P, -offset, scale_dtype)
207+
biases = P.make_or_get_constant(
208+
f"{zero_point_key}_to_biases_dummy", torch.tensor(0.0, dtype=B.dtype)
209+
)
210+
P.emit_init(
211+
MultiplyNode(
212+
a=P.slot_to_tid(scale_slot),
213+
b=P.slot_to_tid(neg_offset),
214+
out=P.slot_to_tid(biases),
215+
)
216+
)
217+
return biases
218+
219+
171220
def to_mlx_qparams(
172221
qdata: torch.Tensor,
173222
scale: torch.Tensor,
@@ -217,6 +266,34 @@ def to_mlx_qparams(
217266
return Q, None
218267

219268

269+
def parse_dequant_nvfp4_node(
270+
node: Node,
271+
) -> Optional[Tuple[Node, Node, Node, torch.dtype]]:
272+
"""Parse a torchao.dequantize_nvfp4 node.
273+
274+
Returns (qdata, scale, per_tensor_scale, output_dtype) or None if not a
275+
dequantize_nvfp4 node or the custom op is not registered.
276+
"""
277+
target = get_aten_target(node.target)
278+
try:
279+
import executorch.extension.llm.export.nvfp4 # noqa: F401
280+
except ImportError:
281+
return None
282+
283+
if target is not torch.ops.torchao.dequantize_nvfp4.default:
284+
return None
285+
286+
qdata, scale, per_tensor_scale = node.args[0:3]
287+
288+
output_dtype = torch.float32
289+
if len(node.args) > 4:
290+
output_dtype = node.args[4]
291+
elif "output_dtype" in node.kwargs:
292+
output_dtype = node.kwargs["output_dtype"]
293+
294+
return qdata, scale, per_tensor_scale, output_dtype
295+
296+
220297
def parse_dequant_node(
221298
node: Node,
222299
) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]:

backends/mlx/ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from executorch.backends.mlx.builder.op_helpers import (
2323
emit_lifted_constant,
24+
emit_quantized_biases,
2425
parse_dequant_node,
2526
to_mlx_qparams,
2627
torch_dtype_to_scalar_type,
@@ -3646,8 +3647,10 @@ def _dequantize_affine_handler(P: MLXProgramBuilder, n: Node) -> Slot:
36463647
B = B.reshape(*leading_dims, B.shape[-1])
36473648

36483649
w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q)
3649-
biases = P.make_or_get_constant(f"{zero_point_target}_to_biases", B)
36503650
scale_const = P.make_or_get_constant(f"{scale_target}_scale", scale_nd)
3651+
biases = emit_quantized_biases(
3652+
P, zero_point_target, scale, zero_point, bits, B, scale_const
3653+
)
36513654

36523655
if needs_permute:
36533656
_, dequant_tmp = P.make_tmp_slot()

0 commit comments

Comments
 (0)