|
| 1 | +# |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | +# |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union |
| 12 | + |
| 13 | +import torch |
| 14 | +from executorch.backends.mlx.builder.slot_manager import Slot |
| 15 | +from executorch.exir.scalar_type import ScalarType |
| 16 | +from torch.fx.node import Node |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder |
| 20 | + |
| 21 | + |
| 22 | +def get_aten_target(target): |
| 23 | + """ |
| 24 | + Unwrap EdgeOpOverload to get the underlying ATen op. |
| 25 | +
|
| 26 | + In Edge IR, ops are wrapped in EdgeOpOverload. This extracts the |
| 27 | + underlying ATen op for consistent comparison. |
| 28 | + """ |
| 29 | + if hasattr(target, "_op") and "EdgeOpOverload" in type(target).__name__: |
| 30 | + return target._op |
| 31 | + return target |
| 32 | + |
| 33 | + |
| 34 | +# Mapping from _copy variants to their non-copy equivalents. |
| 35 | +# Edge IR uses _copy variants for certain ops, but for pattern matching |
| 36 | +# we want to compare against the semantic operation. |
| 37 | +_COPY_TO_NON_COPY = { |
| 38 | + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, |
| 39 | + torch.ops.aten.transpose_copy.int: torch.ops.aten.transpose.int, |
| 40 | + torch.ops.aten.view_copy.default: torch.ops.aten.view.default, |
| 41 | + torch.ops.aten.permute_copy.default: torch.ops.aten.permute.default, |
| 42 | + torch.ops.aten.unsqueeze_copy.default: torch.ops.aten.unsqueeze.default, |
| 43 | + torch.ops.aten.squeeze_copy.dim: torch.ops.aten.squeeze.dim, |
| 44 | + torch.ops.aten.squeeze_copy.dims: torch.ops.aten.squeeze.dims, |
| 45 | + torch.ops.aten.squeeze_copy.default: torch.ops.aten.squeeze.default, |
| 46 | + torch.ops.aten.expand_copy.default: torch.ops.aten.expand.default, |
| 47 | + torch.ops.aten.alias_copy.default: torch.ops.aten.alias.default, |
| 48 | +} |
| 49 | + |
| 50 | + |
| 51 | +def get_aten_target_normalized(target): |
| 52 | + """ |
| 53 | + Get ATen target, mapping _copy variants to their non-copy equivalents. |
| 54 | +
|
| 55 | + Use this for pattern matching where Edge IR uses _copy variants but |
| 56 | + we want to match the semantic operation. |
| 57 | +
|
| 58 | + E.g., aten.transpose_copy.int -> aten.transpose.int |
| 59 | + """ |
| 60 | + target = get_aten_target(target) |
| 61 | + return _COPY_TO_NON_COPY.get(target, target) |
| 62 | + |
| 63 | + |
| 64 | +def emit_stop_position( |
| 65 | + P: "MLXProgramBuilder", |
| 66 | + start: "Union[int, Slot]", |
| 67 | + length_tensor: "Slot", |
| 68 | + length_dim: int, |
| 69 | + length_meta: "Optional[torch.Tensor]" = None, |
| 70 | +) -> "Union[int, Slot]": |
| 71 | + """ |
| 72 | + Emit nodes to compute stop = start + length for slice operations. |
| 73 | +
|
| 74 | + May emit SymSizeNode and/or AddIntNode depending on whether |
| 75 | + start and length are static or dynamic. |
| 76 | +
|
| 77 | + Args: |
| 78 | + P: The program builder |
| 79 | + start: Start position (int or Slot) |
| 80 | + length_tensor: The tensor slot whose dimension gives the length |
| 81 | + length_dim: Which dimension of length_tensor contains the length |
| 82 | + length_meta: Optional tensor metadata for static length extraction |
| 83 | +
|
| 84 | + Returns: |
| 85 | + stop position as int (if fully static) or Slot (if any dynamic) |
| 86 | + """ |
| 87 | + from executorch.backends.mlx.serialization.mlx_graph_schema import ( |
| 88 | + AddIntNode, |
| 89 | + IntOrVid, |
| 90 | + SymSizeNode, |
| 91 | + ) |
| 92 | + |
| 93 | + # Check if seq_len is symbolic (dynamic) |
| 94 | + seq_len_is_symbolic = False |
| 95 | + seq_len_concrete = None |
| 96 | + |
| 97 | + if length_meta is not None: |
| 98 | + seq_len_dim = length_meta.shape[length_dim] |
| 99 | + if hasattr(seq_len_dim, "node"): |
| 100 | + seq_len_is_symbolic = True |
| 101 | + else: |
| 102 | + seq_len_concrete = int(seq_len_dim) |
| 103 | + |
| 104 | + if seq_len_is_symbolic or length_meta is None: |
| 105 | + # Dynamic seq_len: emit SymSizeNode to get length at runtime |
| 106 | + _, seq_len_slot = P.slot_manager.make_tmp_value_slot() |
| 107 | + P.emit( |
| 108 | + SymSizeNode( |
| 109 | + a=P.slot_to_tid(length_tensor), |
| 110 | + dim=length_dim, |
| 111 | + out=P.slot_to_vid(seq_len_slot), |
| 112 | + ) |
| 113 | + ) |
| 114 | + _, stop_slot = P.slot_manager.make_tmp_value_slot() |
| 115 | + if isinstance(start, Slot): |
| 116 | + start_iov = P.to_int_or_vid(start) |
| 117 | + else: |
| 118 | + start_iov = IntOrVid.from_literal(int(start)) |
| 119 | + P.emit( |
| 120 | + AddIntNode( |
| 121 | + a=start_iov, |
| 122 | + b=IntOrVid.from_vid(P.slot_to_vid(seq_len_slot)), |
| 123 | + out=P.slot_to_vid(stop_slot), |
| 124 | + ) |
| 125 | + ) |
| 126 | + return stop_slot |
| 127 | + else: |
| 128 | + # Static seq_len |
| 129 | + if isinstance(start, Slot): |
| 130 | + # Dynamic start + static length |
| 131 | + _, stop_slot = P.slot_manager.make_tmp_value_slot() |
| 132 | + P.emit( |
| 133 | + AddIntNode( |
| 134 | + a=P.to_int_or_vid(start), |
| 135 | + b=IntOrVid.from_literal(seq_len_concrete), |
| 136 | + out=P.slot_to_vid(stop_slot), |
| 137 | + ) |
| 138 | + ) |
| 139 | + return stop_slot |
| 140 | + else: |
| 141 | + # Both static - just return the sum |
| 142 | + return start + seq_len_concrete |
| 143 | + |
| 144 | + |
| 145 | +def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> Slot: |
| 146 | + """Lift a scalar to a 0-D tensor. |
| 147 | +
|
| 148 | + Concrete scalars (int/float/bool) become deduplicated constants. |
| 149 | + Dynamic values (SymInt Slots) emit a FullNode at runtime. |
| 150 | + """ |
| 151 | + |
| 152 | + if isinstance(value, (int, float, bool)): |
| 153 | + return P.make_or_get_constant( |
| 154 | + f"_scalar_{value}", torch.tensor(value, dtype=dtype) # 0-D |
| 155 | + ) |
| 156 | + |
| 157 | + from executorch.backends.mlx.serialization.mlx_graph_schema import FullNode |
| 158 | + |
| 159 | + _, slot = P.make_tmp_slot() |
| 160 | + P.emit( |
| 161 | + FullNode( |
| 162 | + shape=[], |
| 163 | + v=P.to_float_or_vid(value), |
| 164 | + scalar_type=torch_dtype_to_scalar_type(dtype), |
| 165 | + out=P.slot_to_tid(slot), |
| 166 | + ) |
| 167 | + ) |
| 168 | + return slot |
| 169 | + |
| 170 | + |
| 171 | +def to_mlx_qparams( |
| 172 | + qdata: torch.Tensor, |
| 173 | + scale: torch.Tensor, |
| 174 | + zero_point: torch.Tensor, |
| 175 | + bits: int, |
| 176 | + compute_biases: bool = True, |
| 177 | +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| 178 | + """ |
| 179 | + Convert TorchAO quantization params to MLX format. |
| 180 | +
|
| 181 | + TorchAO uses: s * (q - z), with q signed |
| 182 | + MLX uses: S * Q + B, with Q unsigned |
| 183 | +
|
| 184 | + s * (q - z) |
| 185 | + = s ((q + offset) - (z + offset)) |
| 186 | + = s Q + B, |
| 187 | + where Q = q + offset, B = -s * (z + offset) |
| 188 | +
|
| 189 | + Args: |
| 190 | + compute_biases: If False, skip bias computation (for scale_only mode). |
| 191 | + Returns (Q, None) in this case. This is valid when |
| 192 | + zero_point is all zeros, as the C++ runtime will compute |
| 193 | + biases = -scales * 2^(bits-1). |
| 194 | + """ |
| 195 | + assert qdata.dtype == torch.int8 |
| 196 | + offset = 2 ** (bits - 1) |
| 197 | + Q = qdata.to(torch.int32) + offset |
| 198 | + |
| 199 | + # Pack data tightly into uint32 |
| 200 | + assert 32 % bits == 0 |
| 201 | + vals_per_uint32 = 32 // bits |
| 202 | + assert qdata.shape[1] % vals_per_uint32 == 0 |
| 203 | + |
| 204 | + Q = Q.reshape(-1, vals_per_uint32) |
| 205 | + shifts = torch.arange(0, 32, bits, dtype=torch.int64) |
| 206 | + |
| 207 | + # Convert to int64 for shift/packing |
| 208 | + Q = Q.to(torch.int64) |
| 209 | + Q = (Q << shifts).sum(dim=-1) |
| 210 | + Q = Q.to(torch.uint32) |
| 211 | + Q = Q.reshape(qdata.shape[0], -1) |
| 212 | + |
| 213 | + if compute_biases: |
| 214 | + B = -scale * (zero_point.to(scale.dtype) + offset) |
| 215 | + return Q, B |
| 216 | + else: |
| 217 | + return Q, None |
| 218 | + |
| 219 | + |
| 220 | +def parse_dequant_node( |
| 221 | + node: Node, |
| 222 | +) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: |
| 223 | + """Parse a torchao.dequantize_affine node. |
| 224 | +
|
| 225 | + Accepts N-dimensional block_size with a single non-1 element identifying |
| 226 | + the quantized dimension and group_size. For example: |
| 227 | + - Linear weights (2D): block_size=[1, 32] → quantized_dim=1 |
| 228 | + - Conv2d weights (4D): block_size=[1, 32, 1, 1] → quantized_dim=1 |
| 229 | +
|
| 230 | + Returns (qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim) |
| 231 | + or None if unsupported. |
| 232 | + """ |
| 233 | + qdata, block_size, scale, zero_point, dtype, qmin, qmax = node.args[0:7] |
| 234 | + out_dtype = ( |
| 235 | + node.args[7] if len(node.args) > 7 else node.kwargs.get("output_dtype", None) |
| 236 | + ) |
| 237 | + if dtype != torch.int8: |
| 238 | + return None |
| 239 | + if len(block_size) < 2: |
| 240 | + return None |
| 241 | + non_one = [(i, d) for i, d in enumerate(block_size) if d != 1] |
| 242 | + if len(non_one) != 1: |
| 243 | + return None |
| 244 | + quantized_dim, group_size = non_one[0] |
| 245 | + if group_size not in [32, 64, 128]: |
| 246 | + return None |
| 247 | + if qmin == -8 and qmax == 7: |
| 248 | + bits = 4 |
| 249 | + elif qmin == -128 and qmax == 127: |
| 250 | + bits = 8 |
| 251 | + else: |
| 252 | + return None |
| 253 | + return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim |
| 254 | + |
| 255 | + |
| 256 | +# Mapping from torch dtype to ET ScalarType int value |
| 257 | +# See executorch/exir/scalar_type.py for ScalarType enum |
| 258 | +_TORCH_DTYPE_TO_SCALAR_TYPE: Dict[torch.dtype, int] = { |
| 259 | + torch.float16: ScalarType.HALF, |
| 260 | + torch.float32: ScalarType.FLOAT, |
| 261 | + torch.bfloat16: ScalarType.BFLOAT16, |
| 262 | + torch.int32: ScalarType.INT, |
| 263 | + torch.int64: ScalarType.LONG, |
| 264 | + torch.uint32: ScalarType.UINT32, |
| 265 | + torch.uint8: ScalarType.BYTE, |
| 266 | + torch.bool: ScalarType.BOOL, |
| 267 | + torch.int8: ScalarType.CHAR, |
| 268 | +} |
| 269 | + |
| 270 | + |
| 271 | +def torch_dtype_to_scalar_type(dtype: torch.dtype) -> int: |
| 272 | + """Convert torch dtype to ET ScalarType int value.""" |
| 273 | + if dtype not in _TORCH_DTYPE_TO_SCALAR_TYPE: |
| 274 | + raise ValueError(f"Unsupported dtype: {dtype}") |
| 275 | + return int(_TORCH_DTYPE_TO_SCALAR_TYPE[dtype]) |
0 commit comments