1818if TYPE_CHECKING :
1919 from executorch .backends .mlx .builder .program_builder import MLXProgramBuilder
2020
21+ # When True, always serialize the biases tensor for quantized ops.
22+ # When False, use init-time computation when zero_point is all zeros,
23+ # computing biases = -scales * 2^(bits-1) during the init chain.
24+ QUANTIZED_SERIALIZE_BIASES = False
25+
2126
2227def get_aten_target (target ):
2328 """
@@ -168,6 +173,50 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S
168173 return slot
169174
170175
176+ def emit_quantized_biases (
177+ P : "MLXProgramBuilder" ,
178+ zero_point_key : str ,
179+ scale : torch .Tensor ,
180+ zero_point : torch .Tensor ,
181+ bits : int ,
182+ B : torch .Tensor ,
183+ scale_slot : "Slot" ,
184+ ) -> "Slot" :
185+ """Emit biases for quantized ops, computing at init time when possible.
186+
187+ When zero_point is all zeros and QUANTIZED_SERIALIZE_BIASES is False,
188+ avoids serializing the biases tensor by computing biases = scales * -offset
189+ during the init chain instead.
190+
191+ Returns the biases Slot.
192+ """
193+ from executorch .backends .mlx .serialization .mlx_graph_schema import MultiplyNode
194+ from torch ._subclasses .fake_tensor import FakeTensor
195+
196+ is_scale_only = False
197+ if not isinstance (zero_point , FakeTensor ):
198+ if torch .sum (torch .abs (zero_point )).item () == 0 :
199+ is_scale_only = True
200+
201+ if QUANTIZED_SERIALIZE_BIASES or not is_scale_only :
202+ return P .make_or_get_constant (f"{ zero_point_key } _to_biases" , B )
203+
204+ scale_dtype = scale .dtype
205+ offset = 1 << (bits - 1 )
206+ neg_offset = emit_lifted_constant (P , - offset , scale_dtype )
207+ biases = P .make_or_get_constant (
208+ f"{ zero_point_key } _to_biases_dummy" , torch .tensor (0.0 , dtype = B .dtype )
209+ )
210+ P .emit_init (
211+ MultiplyNode (
212+ a = P .slot_to_tid (scale_slot ),
213+ b = P .slot_to_tid (neg_offset ),
214+ out = P .slot_to_tid (biases ),
215+ )
216+ )
217+ return biases
218+
219+
171220def to_mlx_qparams (
172221 qdata : torch .Tensor ,
173222 scale : torch .Tensor ,
@@ -194,21 +243,36 @@ def to_mlx_qparams(
194243 """
195244 assert qdata .dtype == torch .int8
196245 offset = 2 ** (bits - 1 )
197- Q = qdata .to (torch .int32 ) + offset
198246
199247 # Pack data tightly into uint32
200248 assert 32 % bits == 0
201249 vals_per_uint32 = 32 // bits
202250 assert qdata .shape [1 ] % vals_per_uint32 == 0
203-
204- Q = Q .reshape (- 1 , vals_per_uint32 )
205- shifts = torch .arange (0 , 32 , bits , dtype = torch .int64 )
206-
207- # Convert to int64 for shift/packing
208- Q = Q .to (torch .int64 )
209- Q = (Q << shifts ).sum (dim = - 1 )
210- Q = Q .to (torch .uint32 )
211- Q = Q .reshape (qdata .shape [0 ], - 1 )
251+ rows , cols = qdata .shape
252+
253+ if bits == 4 :
254+ # 4-bit: view(uint8) + wrapping add + pack 2 nibbles per byte → view as uint32
255+ q = qdata .view (torch .uint8 ) + offset
256+ q3 = q .reshape (rows , cols // 2 , 2 )
257+ Q = (q3 [:, :, 0 ] | (q3 [:, :, 1 ] << 4 )).view (torch .uint32 )
258+ elif bits == 2 :
259+ # 2-bit: pack 4×2-bit values per byte in uint8, then view as uint32
260+ Q = ((qdata .view (torch .uint8 ) + offset ) & 0x3 ).reshape (rows , cols // 4 , 4 )
261+ packed = Q [:, :, 0 ] | (Q [:, :, 1 ] << 2 ) | (Q [:, :, 2 ] << 4 ) | (Q [:, :, 3 ] << 6 )
262+ Q = packed .contiguous ().view (torch .uint32 )
263+ elif bits == 8 :
264+ # 8-bit: each byte maps 1:1 to a uint32 slot — no shifting needed
265+ q = qdata .view (torch .uint8 ) + offset
266+ Q = q .contiguous ().view (torch .uint32 ).reshape (rows , - 1 )
267+ else :
268+ # General fallback for other bit widths
269+ Q = (qdata .to (torch .int32 ) + offset ).reshape (- 1 , vals_per_uint32 )
270+ shifts = torch .arange (0 , 32 , bits , dtype = torch .int32 )
271+ shifted = Q << shifts
272+ packed = shifted [:, 0 ]
273+ for i in range (1 , vals_per_uint32 ):
274+ packed = packed | shifted [:, i ]
275+ Q = packed .view (torch .uint32 ).reshape (rows , - 1 )
212276
213277 if compute_biases :
214278 B = - scale * (zero_point .to (scale .dtype ) + offset )
@@ -217,6 +281,34 @@ def to_mlx_qparams(
217281 return Q , None
218282
219283
284+ def parse_dequant_nvfp4_node (
285+ node : Node ,
286+ ) -> Optional [Tuple [Node , Node , Node , torch .dtype ]]:
287+ """Parse a torchao.dequantize_nvfp4 node.
288+
289+ Returns (qdata, scale, per_tensor_scale, output_dtype) or None if not a
290+ dequantize_nvfp4 node or the custom op is not registered.
291+ """
292+ target = get_aten_target (node .target )
293+ try :
294+ import executorch .extension .llm .export .nvfp4 # noqa: F401
295+ except ImportError :
296+ return None
297+
298+ if target is not torch .ops .torchao .dequantize_nvfp4 .default :
299+ return None
300+
301+ qdata , scale , per_tensor_scale = node .args [0 :3 ]
302+
303+ output_dtype = torch .float32
304+ if len (node .args ) > 4 :
305+ output_dtype = node .args [4 ]
306+ elif "output_dtype" in node .kwargs :
307+ output_dtype = node .kwargs ["output_dtype" ]
308+
309+ return qdata , scale , per_tensor_scale , output_dtype
310+
311+
220312def parse_dequant_node (
221313 node : Node ,
222314) -> Optional [Tuple [Node , Node , Node , int , int , Optional [torch .dtype ], int ]]:
@@ -244,11 +336,11 @@ def parse_dequant_node(
244336 quantized_dim , group_size = non_one [0 ]
245337 if group_size not in [32 , 64 , 128 ]:
246338 return None
247- if qmin == - 8 and qmax == 7 :
248- bits = 4
249- elif qmin == - 128 and qmax == 127 :
250- bits = 8
251- else :
339+
340+ # TODO: MLX supports 3, 5, and 7, but we need to figure out the
341+ # packing story in to_mlx_qparams to use them
342+ bits = ( qmax - qmin + 1 ). bit_length () - 1
343+ if bits not in [ 2 , 4 , 8 ] :
252344 return None
253345 return qdata , scale , zero_point , group_size , bits , out_dtype , quantized_dim
254346
0 commit comments