|
18 | 18 | if TYPE_CHECKING: |
19 | 19 | from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder |
20 | 20 |
|
| 21 | +# When True, always serialize the biases tensor for quantized ops. |
| 22 | +# When False, use init-time computation when zero_point is all zeros, |
| 23 | +# computing biases = -scales * 2^(bits-1) during the init chain. |
| 24 | +QUANTIZED_SERIALIZE_BIASES = False |
| 25 | + |
21 | 26 |
|
22 | 27 | def get_aten_target(target): |
23 | 28 | """ |
@@ -168,6 +173,50 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S |
168 | 173 | return slot |
169 | 174 |
|
170 | 175 |
|
| 176 | +def emit_quantized_biases( |
| 177 | + P: "MLXProgramBuilder", |
| 178 | + zero_point_key: str, |
| 179 | + scale: torch.Tensor, |
| 180 | + zero_point: torch.Tensor, |
| 181 | + bits: int, |
| 182 | + B: torch.Tensor, |
| 183 | + scale_slot: "Slot", |
| 184 | +) -> "Slot": |
| 185 | + """Emit biases for quantized ops, computing at init time when possible. |
| 186 | +
|
| 187 | + When zero_point is all zeros and QUANTIZED_SERIALIZE_BIASES is False, |
| 188 | + avoids serializing the biases tensor by computing biases = scales * -offset |
| 189 | + during the init chain instead. |
| 190 | +
|
| 191 | + Returns the biases Slot. |
| 192 | + """ |
| 193 | + from executorch.backends.mlx.serialization.mlx_graph_schema import MultiplyNode |
| 194 | + from torch._subclasses.fake_tensor import FakeTensor |
| 195 | + |
| 196 | + is_scale_only = False |
| 197 | + if not isinstance(zero_point, FakeTensor): |
| 198 | + if torch.sum(torch.abs(zero_point)).item() == 0: |
| 199 | + is_scale_only = True |
| 200 | + |
| 201 | + if QUANTIZED_SERIALIZE_BIASES or not is_scale_only: |
| 202 | + return P.make_or_get_constant(f"{zero_point_key}_to_biases", B) |
| 203 | + |
| 204 | + scale_dtype = scale.dtype |
| 205 | + offset = 1 << (bits - 1) |
| 206 | + neg_offset = emit_lifted_constant(P, -offset, scale_dtype) |
| 207 | + biases = P.make_or_get_constant( |
| 208 | + f"{zero_point_key}_to_biases_dummy", torch.tensor(0.0, dtype=B.dtype) |
| 209 | + ) |
| 210 | + P.emit_init( |
| 211 | + MultiplyNode( |
| 212 | + a=P.slot_to_tid(scale_slot), |
| 213 | + b=P.slot_to_tid(neg_offset), |
| 214 | + out=P.slot_to_tid(biases), |
| 215 | + ) |
| 216 | + ) |
| 217 | + return biases |
| 218 | + |
| 219 | + |
171 | 220 | def to_mlx_qparams( |
172 | 221 | qdata: torch.Tensor, |
173 | 222 | scale: torch.Tensor, |
@@ -217,6 +266,34 @@ def to_mlx_qparams( |
217 | 266 | return Q, None |
218 | 267 |
|
219 | 268 |
|
| 269 | +def parse_dequant_nvfp4_node( |
| 270 | + node: Node, |
| 271 | +) -> Optional[Tuple[Node, Node, Node, torch.dtype]]: |
| 272 | + """Parse a torchao.dequantize_nvfp4 node. |
| 273 | +
|
| 274 | + Returns (qdata, scale, per_tensor_scale, output_dtype) or None if not a |
| 275 | + dequantize_nvfp4 node or the custom op is not registered. |
| 276 | + """ |
| 277 | + target = get_aten_target(node.target) |
| 278 | + try: |
| 279 | + import executorch.extension.llm.export.nvfp4 # noqa: F401 |
| 280 | + except ImportError: |
| 281 | + return None |
| 282 | + |
| 283 | + if target is not torch.ops.torchao.dequantize_nvfp4.default: |
| 284 | + return None |
| 285 | + |
| 286 | + qdata, scale, per_tensor_scale = node.args[0:3] |
| 287 | + |
| 288 | + output_dtype = torch.float32 |
| 289 | + if len(node.args) > 4: |
| 290 | + output_dtype = node.args[4] |
| 291 | + elif "output_dtype" in node.kwargs: |
| 292 | + output_dtype = node.kwargs["output_dtype"] |
| 293 | + |
| 294 | + return qdata, scale, per_tensor_scale, output_dtype |
| 295 | + |
| 296 | + |
220 | 297 | def parse_dequant_node( |
221 | 298 | node: Node, |
222 | 299 | ) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: |
|
0 commit comments