Skip to content

Commit c4e0f1f

Browse files
author
ssjia
committed
[ET-VK][patterns] Fuse torchao 4-bit quantized embedding to embedding_q4gsw
Pull Request resolved: #20381 TISO and other torchao-quantized models emit a `torchao.dequantize_affine -> aten.embedding` subgraph for their weight-only int4 quantized embedding. The existing `QuantizedEmbeddingMatch` only matches the `quantized_decomposed.embedding_4bit.dtype` fused op, so the torchao embedding never fused: its `dequantize_affine` const-folded to an fp32 weight, the resulting `aten.embedding` exceeded the buffer-element limit and fell back to CPU, and the fp32 constant bloated the serialized model. This adds a separate `TorchAOQuantizedEmbeddingMatch` matcher that recognizes the torchao int4 `dequantize_affine -> aten.embedding` shape (qmin=-8/qmax=7, per-row group block_size `[1, G]`) and rewrites it to the existing `et_vk.embedding_q4gsw.default` op, repacking the unpacked int8 weight into the packed 4-bit layout. It asserts symmetric quantization (zero_point == 0, which the shader assumes) and guards against repacking a shared/tied weight more than once via an `et_vk_embedding_q4gsw_packed` meta flag. It is kept as a separate class from `QuantizedEmbeddingMatch` because the two dialects produce different graph shapes (one fused op vs a split dequant+gather), so a single class would only co-locate two disjoint parse paths. On the en_US TISO backbone the embedding now delegates to Vulkan instead of falling back to CPU, and the serialized `.pte` drops from 418 MiB to 348 MiB. This change was authored with Claude. ghstack-source-id: 396618161 @exported-using-ghexport Differential Revision: [D108457797](https://our.internmc.facebook.com/intern/diff/D108457797/)
1 parent ddb5aee commit c4e0f1f

3 files changed

Lines changed: 434 additions & 3 deletions

File tree

backends/vulkan/patterns/quantized_embedding.py

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919

2020

21+
embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype
22+
embedding_target = exir_ops.edge.aten.embedding.default
23+
torchao_dequantize_affine_target = exir_ops.edge.torchao.dequantize_affine.default
24+
25+
2126
class QuantizedEmbeddingMatch(PatternMatch):
2227
def __init__(self, node: torch.fx.Node) -> None:
2328
self.anchor_node = node
@@ -68,9 +73,6 @@ def __init__(self, node: torch.fx.Node) -> None:
6873
self.match_found = True
6974

7075

71-
embedding_4bit_target = exir_ops.edge.quantized_decomposed.embedding_4bit.dtype
72-
73-
7476
def _detect_tied_linear_weight(
7577
ep: ExportedProgram,
7678
weight_node: torch.fx.Node,
@@ -175,3 +177,163 @@ def replace_quantized_embedding_patterns(
175177

176178
embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"]
177179
match.anchor_node.replace_all_uses_with(embedding_q4gsw_node)
180+
181+
182+
class TorchAOQuantizedEmbeddingMatch(PatternMatch):
183+
"""Matches a torchao 4-bit weight-only quantized embedding and rewrites it
184+
as a single et_vk.embedding_q4gsw.default node.
185+
186+
The recognized graph shape is a split torchao.dequantize_affine ->
187+
aten.embedding, whose weight is unpacked int8 [vocab, embed_dim] with values
188+
in [-8, 7]. This requires symmetric 4-bit signed quantization (quant_min=-8,
189+
quant_max=7, zero_point=0) and per-row groupwise blocks (block_size=[1, G]),
190+
which the runtime shader assumes via a fixed subtract-8 offset.
191+
"""
192+
193+
def __init__(self, node: torch.fx.Node) -> None:
194+
self.anchor_node = node
195+
self.match_found = False
196+
self.all_nodes = [node]
197+
198+
# aten.embedding.default args: (weight, indices, *)
199+
dequant_node = node.args[0]
200+
self.indices_node = node.args[1]
201+
202+
if not isinstance(dequant_node, torch.fx.Node):
203+
return
204+
if dequant_node.target != torchao_dequantize_affine_target:
205+
return
206+
207+
self.all_nodes.append(dequant_node)
208+
209+
# torchao.dequantize_affine args:
210+
# (input, block_size, scale, zero_point, input_dtype, quant_min,
211+
# quant_max, ...)
212+
block_size = dequant_node.args[1]
213+
quant_min = dequant_node.args[5] if len(dequant_node.args) > 5 else None
214+
quant_max = dequant_node.args[6] if len(dequant_node.args) > 6 else None
215+
216+
# The shader hardcodes the 4-bit signed offset (subtract 8), which
217+
# corresponds to quant_min=-8, quant_max=7, zero_point=0.
218+
if quant_min != -8 or quant_max != 7:
219+
return
220+
221+
# block_size must be per-row groupwise: [1, group_size]
222+
if not isinstance(block_size, (list, tuple)) or len(block_size) != 2:
223+
return
224+
if block_size[0] != 1:
225+
return
226+
self.group_size = int(block_size[1])
227+
228+
# Trace weight (args[0]), scales (args[2]) and zero_point (args[3]) to
229+
# their placeholders. The symmetric (zero_point == 0) requirement is
230+
# verified on the real tensor in the replacement function, where the
231+
# ExportedProgram is available; checking the fake meta tensor here would
232+
# trigger a data-dependent guard error.
233+
weight_node, arg_chain = utils.trace_args_until_placeholder(
234+
dequant_node.args[0]
235+
)
236+
if weight_node is None:
237+
return
238+
self.weight_node = weight_node
239+
self.all_nodes.extend(arg_chain)
240+
241+
scales_node, arg_chain = utils.trace_args_until_placeholder(
242+
dequant_node.args[2]
243+
)
244+
if scales_node is None:
245+
return
246+
self.scales_node = scales_node
247+
self.all_nodes.extend(arg_chain)
248+
249+
self.zero_point_node, arg_chain = utils.trace_args_until_placeholder(
250+
dequant_node.args[3]
251+
)
252+
self.all_nodes.extend(arg_chain)
253+
254+
self.match_found = True
255+
256+
257+
@register_pattern_detector("torchao_quantized_embedding")
258+
def find_torchao_quantized_embedding_patterns(
259+
node: torch.fx.Node,
260+
) -> Optional[TorchAOQuantizedEmbeddingMatch]:
261+
if node.target != embedding_target:
262+
return None
263+
264+
matched_pattern = TorchAOQuantizedEmbeddingMatch(node)
265+
if matched_pattern.match_found:
266+
return matched_pattern
267+
return None
268+
269+
270+
@register_pattern_replacement("torchao_quantized_embedding")
271+
def replace_torchao_quantized_embedding_patterns(
272+
ep: ExportedProgram,
273+
graph_module: torch.fx.GraphModule,
274+
match: TorchAOQuantizedEmbeddingMatch,
275+
):
276+
weight_tensor = get_param_tensor(ep, match.weight_node)
277+
assert weight_tensor is not None
278+
279+
# The weight repack mutates the state dict entry in place, so it must run
280+
# exactly once per backing storage; a second repack of the already-packed
281+
# weight would corrupt it. The repack
282+
# (align_width_and_update_state_dict -> update_program_state_dict) locates
283+
# the entry to overwrite by the param/buffer FQN that backs the placeholder,
284+
# so the idempotency guard keys on that same FQN (via
285+
# utils.register_param_mutation). This dedups not only one placeholder
286+
# shared by multiple call sites, but also distinct placeholder nodes that
287+
# resolve to the same state dict storage (whose per-node meta would otherwise
288+
# diverge). Distinct weights (distinct FQNs) still each pack once. The guard
289+
# also raises if the same weight is later re-mutated with a different tag
290+
# (i.e. an incompatible packing format), surfacing corruption loudly.
291+
if utils.register_param_mutation(ep, match.weight_node, "embedding_q4gsw"):
292+
# The shader applies a fixed signed-4-bit offset (subtract 8), which
293+
# assumes symmetric quantization (zero_point == 0). Verify on the real
294+
# tensor.
295+
if match.zero_point_node is not None:
296+
zero_point_tensor = get_param_tensor(ep, match.zero_point_node)
297+
if zero_point_tensor is not None:
298+
assert torch.all(
299+
zero_point_tensor == 0
300+
), "embedding_q4gsw requires symmetric quantization (zero_point == 0)"
301+
302+
# Repack the unpacked int8 weight [vocab, embed_dim] (values in [-8, 7])
303+
# into the flat 4-bit packed format [vocab, embed_dim / 2] that the
304+
# non-linear embedding_q4gsw path expects. Packing convention (must
305+
# match the runtime shader and embedding_q4gsw_impl):
306+
# packed_byte = (even_val + 8) << 4 | (odd_val + 8)
307+
# i.e. the even-index value goes in the high nibble, odd-index in the
308+
# low.
309+
unpacked_u8 = weight_tensor.to(torch.uint8) + 8
310+
packed_weight = (unpacked_u8[:, ::2] << 4 | unpacked_u8[:, 1::2]).to(
311+
torch.uint8
312+
)
313+
314+
# Update the weight placeholder's state dict entry and fake-tensor meta
315+
# to the repacked tensor. align_to=1 with force_update just forces the
316+
# update; the packed width (embed_dim / 2) is already a multiple of 4.
317+
utils.align_width_and_update_state_dict(
318+
ep, match.weight_node, packed_weight, align_to=1, force_update=True
319+
)
320+
321+
# Scales are symmetric per-group with layout [vocab, num_groups], matching
322+
# the scale layout embedding_q4gsw expects (no transpose).
323+
group_size = match.group_size
324+
325+
with graph_module.graph.inserting_before(match.anchor_node):
326+
embedding_q4gsw_node = graph_module.graph.create_node(
327+
"call_function",
328+
exir_ops.edge.et_vk.embedding_q4gsw.default,
329+
args=(
330+
match.weight_node,
331+
match.scales_node,
332+
group_size,
333+
match.indices_node,
334+
False,
335+
),
336+
)
337+
338+
embedding_q4gsw_node.meta["val"] = match.anchor_node.meta["val"]
339+
match.anchor_node.replace_all_uses_with(embedding_q4gsw_node)

0 commit comments

Comments
 (0)