|
18 | 18 | from executorch.exir.dialects._ops import ops as exir_ops |
19 | 19 |
|
20 | 20 |
|
| 21 | +embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype |
| 22 | +embedding_target = exir_ops.edge.aten.embedding.default |
| 23 | +torchao_dequantize_affine_target = exir_ops.edge.torchao.dequantize_affine.default |
| 24 | + |
| 25 | + |
21 | 26 | class QuantizedEmbeddingMatch(PatternMatch): |
22 | 27 | def __init__(self, node: torch.fx.Node) -> None: |
23 | 28 | self.anchor_node = node |
@@ -68,9 +73,6 @@ def __init__(self, node: torch.fx.Node) -> None: |
68 | 73 | self.match_found = True |
69 | 74 |
|
70 | 75 |
|
71 | | -embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype |
72 | | - |
73 | | - |
74 | 76 | def _detect_tied_linear_weight( |
75 | 77 | ep: ExportedProgram, |
76 | 78 | weight_node: torch.fx.Node, |
@@ -175,3 +177,163 @@ def replace_quantized_embedding_patterns( |
175 | 177 |
|
176 | 178 | embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"] |
177 | 179 | match.anchor_node.replace_all_uses_with(embedding_q4gsw_node) |
| 180 | + |
| 181 | + |
| 182 | +class TorchAOQuantizedEmbeddingMatch(PatternMatch): |
| 183 | + """Matches a torchao 4-bit weight-only quantized embedding and rewrites it |
| 184 | + as a single et_vk.embedding_q4gsw.default node. |
| 185 | +
|
| 186 | + The recognized graph shape is a split torchao.dequantize_affine -> |
| 187 | + aten.embedding, whose weight is unpacked int8 [vocab, embed_dim] with values |
| 188 | + in [-8, 7]. This requires symmetric 4-bit signed quantization (quant_min=-8, |
| 189 | + quant_max=7, zero_point=0) and per-row groupwise blocks (block_size=[1, G]), |
| 190 | + which the runtime shader assumes via a fixed subtract-8 offset. |
| 191 | + """ |
| 192 | + |
| 193 | + def __init__(self, node: torch.fx.Node) -> None: |
| 194 | + self.anchor_node = node |
| 195 | + self.match_found = False |
| 196 | + self.all_nodes = [node] |
| 197 | + |
| 198 | + # aten.embedding.default args: (weight, indices, *) |
| 199 | + dequant_node = node.args[0] |
| 200 | + self.indices_node = node.args[1] |
| 201 | + |
| 202 | + if not isinstance(dequant_node, torch.fx.Node): |
| 203 | + return |
| 204 | + if dequant_node.target != torchao_dequantize_affine_target: |
| 205 | + return |
| 206 | + |
| 207 | + self.all_nodes.append(dequant_node) |
| 208 | + |
| 209 | + # torchao.dequantize_affine args: |
| 210 | + # (input, block_size, scale, zero_point, input_dtype, quant_min, |
| 211 | + # quant_max, ...) |
| 212 | + block_size = dequant_node.args[1] |
| 213 | + quant_min = dequant_node.args[5] if len(dequant_node.args) > 5 else None |
| 214 | + quant_max = dequant_node.args[6] if len(dequant_node.args) > 6 else None |
| 215 | + |
| 216 | + # The shader hardcodes the 4-bit signed offset (subtract 8), which |
| 217 | + # corresponds to quant_min=-8, quant_max=7, zero_point=0. |
| 218 | + if quant_min != -8 or quant_max != 7: |
| 219 | + return |
| 220 | + |
| 221 | + # block_size must be per-row groupwise: [1, group_size] |
| 222 | + if not isinstance(block_size, (list, tuple)) or len(block_size) != 2: |
| 223 | + return |
| 224 | + if block_size[0] != 1: |
| 225 | + return |
| 226 | + self.group_size = int(block_size[1]) |
| 227 | + |
| 228 | + # Trace weight (args[0]), scales (args[2]) and zero_point (args[3]) to |
| 229 | + # their placeholders. The symmetric (zero_point == 0) requirement is |
| 230 | + # verified on the real tensor in the replacement function, where the |
| 231 | + # ExportedProgram is available; checking the fake meta tensor here would |
| 232 | + # trigger a data-dependent guard error. |
| 233 | + weight_node, arg_chain = utils.trace_args_until_placeholder( |
| 234 | + dequant_node.args[0] |
| 235 | + ) |
| 236 | + if weight_node is None: |
| 237 | + return |
| 238 | + self.weight_node = weight_node |
| 239 | + self.all_nodes.extend(arg_chain) |
| 240 | + |
| 241 | + scales_node, arg_chain = utils.trace_args_until_placeholder( |
| 242 | + dequant_node.args[2] |
| 243 | + ) |
| 244 | + if scales_node is None: |
| 245 | + return |
| 246 | + self.scales_node = scales_node |
| 247 | + self.all_nodes.extend(arg_chain) |
| 248 | + |
| 249 | + self.zero_point_node, arg_chain = utils.trace_args_until_placeholder( |
| 250 | + dequant_node.args[3] |
| 251 | + ) |
| 252 | + self.all_nodes.extend(arg_chain) |
| 253 | + |
| 254 | + self.match_found = True |
| 255 | + |
| 256 | + |
| 257 | +@register_pattern_detector("torchao_quantized_embedding") |
| 258 | +def find_torchao_quantized_embedding_patterns( |
| 259 | + node: torch.fx.Node, |
| 260 | +) -> Optional[TorchAOQuantizedEmbeddingMatch]: |
| 261 | + if node.target != embedding_target: |
| 262 | + return None |
| 263 | + |
| 264 | + matched_pattern = TorchAOQuantizedEmbeddingMatch(node) |
| 265 | + if matched_pattern.match_found: |
| 266 | + return matched_pattern |
| 267 | + return None |
| 268 | + |
| 269 | + |
| 270 | +@register_pattern_replacement("torchao_quantized_embedding") |
| 271 | +def replace_torchao_quantized_embedding_patterns( |
| 272 | + ep: ExportedProgram, |
| 273 | + graph_module: torch.fx.GraphModule, |
| 274 | + match: TorchAOQuantizedEmbeddingMatch, |
| 275 | +): |
| 276 | + weight_tensor = get_param_tensor(ep, match.weight_node) |
| 277 | + assert weight_tensor is not None |
| 278 | + |
| 279 | + # The weight repack mutates the state dict entry in place, so it must run |
| 280 | + # exactly once per backing storage; a second repack of the already-packed |
| 281 | + # weight would corrupt it. The repack |
| 282 | + # (align_width_and_update_state_dict -> update_program_state_dict) locates |
| 283 | + # the entry to overwrite by the param/buffer FQN that backs the placeholder, |
| 284 | + # so the idempotency guard keys on that same FQN (via |
| 285 | + # utils.register_param_mutation). This dedups not only one placeholder |
| 286 | + # shared by multiple call sites, but also distinct placeholder nodes that |
| 287 | + # resolve to the same state dict storage (whose per-node meta would otherwise |
| 288 | + # diverge). Distinct weights (distinct FQNs) still each pack once. The guard |
| 289 | + # also raises if the same weight is later re-mutated with a different tag |
| 290 | + # (i.e. an incompatible packing format), surfacing corruption loudly. |
| 291 | + if utils.register_param_mutation(ep, match.weight_node, "embedding_q4gsw"): |
| 292 | + # The shader applies a fixed signed-4-bit offset (subtract 8), which |
| 293 | + # assumes symmetric quantization (zero_point == 0). Verify on the real |
| 294 | + # tensor. |
| 295 | + if match.zero_point_node is not None: |
| 296 | + zero_point_tensor = get_param_tensor(ep, match.zero_point_node) |
| 297 | + if zero_point_tensor is not None: |
| 298 | + assert torch.all( |
| 299 | + zero_point_tensor == 0 |
| 300 | + ), "embedding_q4gsw requires symmetric quantization (zero_point == 0)" |
| 301 | + |
| 302 | + # Repack the unpacked int8 weight [vocab, embed_dim] (values in [-8, 7]) |
| 303 | + # into the flat 4-bit packed format [vocab, embed_dim / 2] that the |
| 304 | + # non-linear embedding_q4gsw path expects. Packing convention (must |
| 305 | + # match the runtime shader and embedding_q4gsw_impl): |
| 306 | + # packed_byte = (even_val + 8) << 4 | (odd_val + 8) |
| 307 | + # i.e. the even-index value goes in the high nibble, odd-index in the |
| 308 | + # low. |
| 309 | + unpacked_u8 = weight_tensor.to(torch.uint8) + 8 |
| 310 | + packed_weight = (unpacked_u8[:, ::2] << 4 | unpacked_u8[:, 1::2]).to( |
| 311 | + torch.uint8 |
| 312 | + ) |
| 313 | + |
| 314 | + # Update the weight placeholder's state dict entry and fake-tensor meta |
| 315 | + # to the repacked tensor. align_to=1 with force_update just forces the |
| 316 | + # update; the packed width (embed_dim / 2) is already a multiple of 4. |
| 317 | + utils.align_width_and_update_state_dict( |
| 318 | + ep, match.weight_node, packed_weight, align_to=1, force_update=True |
| 319 | + ) |
| 320 | + |
| 321 | + # Scales are symmetric per-group with layout [vocab, num_groups], matching |
| 322 | + # the scale layout embedding_q4gsw expects (no transpose). |
| 323 | + group_size = match.group_size |
| 324 | + |
| 325 | + with graph_module.graph.inserting_before(match.anchor_node): |
| 326 | + embedding_q4gsw_node = graph_module.graph.create_node( |
| 327 | + "call_function", |
| 328 | + exir_ops.edge.et_vk.embedding_q4gsw.default, |
| 329 | + args=( |
| 330 | + match.weight_node, |
| 331 | + match.scales_node, |
| 332 | + group_size, |
| 333 | + match.indices_node, |
| 334 | + False, |
| 335 | + ), |
| 336 | + ) |
| 337 | + |
| 338 | + embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"] |
| 339 | + match.anchor_node.replace_all_uses_with(embedding_q4gsw_node) |
0 commit comments