Skip to content

Commit 691cb57

Browse files
authored
Add support to sink input dequants through transparent ops (#18504)
Differential Revision: D97877439 Pull Request resolved: #18504
1 parent 3eba197 commit 691cb57

2 files changed

Lines changed: 227 additions & 1 deletion

File tree

backends/cadence/aot/compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ def quantize_pt2(
239239

240240
# Apply quant fusion to the exported program
241241
program = torch.export.export(converted_gm, inputs, strict=True)
242+
243+
# Sink dequant nodes through transparent ops so they fuse per-branch.
244+
if quant_input_args is not None:
245+
QuantizedInputWrapper.sink_dequants(program)
246+
242247
fused_program = apply_pre_edge_transform_passes(program, quantizer)
243248

244249
if dump_graphs:

backends/cadence/aot/compiler_funcs.py

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import logging
10+
import operator
1011
from typing import Any, Optional, Union
1112

1213
import torch
@@ -303,7 +304,26 @@ def __init__(
303304
if isinstance(input_args, list):
304305
self.quant_args = extract_input_quant_params_from_graph(module, input_args)
305306
elif isinstance(input_args, dict):
306-
self.quant_args = input_args
307+
# dict[int, QuantArgs] — use directly
308+
# dict[int, list[str]] — extract quant params from graph, keyed by input index
309+
first_value = next(iter(input_args.values()), None)
310+
if (
311+
isinstance(first_value, (list, tuple))
312+
and first_value
313+
and isinstance(first_value[0], str)
314+
):
315+
# Values are lists of node names: extract quant params and map
316+
# to the caller-specified input indices.
317+
for input_idx, node_names in input_args.items():
318+
assert isinstance(node_names, list)
319+
extracted = extract_input_quant_params_from_graph(
320+
module, node_names
321+
)
322+
# Use the first extracted quant params for this input index.
323+
if extracted:
324+
self.quant_args[int(input_idx)] = next(iter(extracted.values()))
325+
else:
326+
self.quant_args = {int(k): v for k, v in input_args.items()}
307327

308328
def forward(self, *args: torch.Tensor) -> Any:
309329
"""Run inference, dequantizing configured inputs."""
@@ -349,6 +369,27 @@ def forward(self, *args: torch.Tensor) -> Any:
349369

350370
return self.module(*dequantized_args)
351371

372+
@staticmethod
373+
def sink_dequants(program: torch.export.ExportedProgram) -> None:
374+
"""Sink dequant nodes through transparent ops in an exported program.
375+
376+
If the graph branches through transparent ops (view, split, getitem, etc.)
377+
into paths with different quantization parameters, sink the dequants to be
378+
adjacent to each downstream quant node, enabling per-branch fusion.
379+
380+
Must be called after export() on a QuantizedInputWrapper-wrapped model.
381+
"""
382+
from torch.export.graph_signature import InputKind
383+
384+
user_input_names = {
385+
spec.arg.name
386+
for spec in program.graph_signature.input_specs
387+
if spec.kind == InputKind.USER_INPUT
388+
}
389+
sink_input_dequant_through_transparent_ops(
390+
program.graph_module, user_input_names
391+
)
392+
352393

353394
class QuantizedOutputWrapper(torch.nn.Module):
354395
"""
@@ -379,3 +420,183 @@ def forward(self, *args: torch.Tensor) -> Any:
379420
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
380421
result, scale, zp, qmin, qmax, dtype
381422
)
423+
424+
425+
def _get_transparent_ops() -> set[Any]:
426+
"""Ops that only reshape/index data without changing values.
427+
Safe to pass uint8 data through these."""
428+
return {
429+
torch.ops.aten.view_copy.default,
430+
torch.ops.aten.view.default,
431+
torch.ops.aten.reshape.default,
432+
torch.ops.aten.split.Tensor,
433+
torch.ops.aten.slice_copy.Tensor,
434+
torch.ops.aten.permute_copy.default,
435+
torch.ops.aten.permute.default,
436+
torch.ops.aten.expand_copy.default,
437+
torch.ops.aten.unsqueeze_copy.default,
438+
torch.ops.aten.squeeze_copy.dim,
439+
torch.ops.aten.transpose_copy.int,
440+
torch.ops.aten.clone.default,
441+
operator.getitem,
442+
}
443+
444+
445+
def _get_quantize_ops() -> set[Any]:
446+
ops = {torch.ops.quantized_decomposed.quantize_per_tensor.default}
447+
try:
448+
ops.add(torch.ops.cadence.quantize_per_tensor.default)
449+
except AttributeError:
450+
pass
451+
return ops
452+
453+
454+
def _get_dequantize_ops() -> set[Any]:
455+
ops = {torch.ops.quantized_decomposed.dequantize_per_tensor.default}
456+
try:
457+
ops.add(torch.ops.cadence.dequantize_per_tensor.default)
458+
except AttributeError:
459+
pass
460+
return ops
461+
462+
463+
def _walk_to_downstream_quants(
464+
node: torch.fx.Node,
465+
quantize_ops: set[Any],
466+
transparent_ops: set[Any],
467+
downstream_quants: list[torch.fx.Node],
468+
) -> bool:
469+
"""Walk forward through transparent ops collecting downstream quant nodes.
470+
471+
Returns True if all paths end at a quant node.
472+
"""
473+
all_valid = True
474+
for user in node.users:
475+
if user.op == "call_function" and user.target in quantize_ops:
476+
downstream_quants.append(user)
477+
elif user.op == "call_function" and user.target in transparent_ops:
478+
if not _walk_to_downstream_quants(
479+
user, quantize_ops, transparent_ops, downstream_quants
480+
):
481+
all_valid = False
482+
else:
483+
all_valid = False
484+
return all_valid
485+
486+
487+
def _get_dequant_node_for_placeholder(
488+
placeholder: torch.fx.Node,
489+
input_placeholder_names: set[str] | None,
490+
dequantize_ops: set[Any],
491+
) -> torch.fx.Node | None:
492+
"""Return the single dequant user of a uint8 placeholder, or None."""
493+
if placeholder.op != "placeholder":
494+
return None
495+
if (
496+
input_placeholder_names is not None
497+
and placeholder.name not in input_placeholder_names
498+
):
499+
return None
500+
val = placeholder.meta.get("val")
501+
if val is None or not isinstance(val, torch.Tensor):
502+
return None
503+
if val.dtype != torch.uint8:
504+
return None
505+
if len(placeholder.users) != 1:
506+
return None
507+
dequant_node = next(iter(placeholder.users))
508+
if dequant_node.op == "call_function" and dequant_node.target in dequantize_ops:
509+
return dequant_node
510+
return None
511+
512+
513+
def _sink_dequant_to_quant_nodes(
514+
graph: torch.fx.Graph,
515+
dequant_node: torch.fx.Node,
516+
placeholder: torch.fx.Node,
517+
downstream_quants: list[torch.fx.Node],
518+
) -> None:
519+
"""Insert per-branch dequants before each downstream quant and rewire."""
520+
dequant_op = dequant_node.target
521+
assert callable(dequant_op)
522+
523+
for quant_node in downstream_quants:
524+
quant_input = quant_node.args[0]
525+
assert isinstance(quant_input, torch.fx.Node)
526+
quant_params = quant_node.args[1:]
527+
528+
with graph.inserting_before(quant_node):
529+
new_dequant = graph.call_function(
530+
dequant_op,
531+
args=(quant_input, *quant_params),
532+
)
533+
new_dequant.meta = {**dequant_node.meta}
534+
if "val" in quant_node.meta and isinstance(
535+
quant_node.meta["val"], torch.Tensor
536+
):
537+
quant_val = quant_node.meta["val"]
538+
new_dequant.meta["val"] = torch.empty(quant_val.shape, dtype=torch.float32)
539+
540+
quant_node.replace_input_with(quant_input, new_dequant)
541+
542+
dequant_node.replace_all_uses_with(placeholder)
543+
graph.erase_node(dequant_node)
544+
545+
546+
def sink_input_dequant_through_transparent_ops(
547+
graph_module: GraphModule,
548+
input_placeholder_names: set[str] | None = None,
549+
) -> bool:
550+
"""
551+
Sinks dequantize nodes from quantized input placeholders through transparent ops
552+
to be adjacent to downstream quantize nodes, enabling dequant-quant fusion.
553+
This creates per-branch dequants with matching params.
554+
555+
Args:
556+
graph_module: The graph module to transform.
557+
input_placeholder_names: Optional set of placeholder names to consider.
558+
If provided, only these placeholders are processed (use this to
559+
restrict to user inputs and avoid touching weight/buffer placeholders).
560+
If None, all uint8 placeholders are considered.
561+
562+
Returns True if the graph was modified.
563+
"""
564+
graph = graph_module.graph
565+
modified = False
566+
567+
transparent_ops: set[Any] = _get_transparent_ops()
568+
quantize_ops: set[Any] = _get_quantize_ops()
569+
dequantize_ops: set[Any] = _get_dequantize_ops()
570+
571+
for placeholder in list(graph.nodes):
572+
dequant_node = _get_dequant_node_for_placeholder(
573+
placeholder, input_placeholder_names, dequantize_ops
574+
)
575+
if dequant_node is None:
576+
continue
577+
578+
downstream_quants: list[torch.fx.Node] = []
579+
all_paths_end_at_quant = _walk_to_downstream_quants(
580+
dequant_node, quantize_ops, transparent_ops, downstream_quants
581+
)
582+
583+
if not downstream_quants or not all_paths_end_at_quant:
584+
continue
585+
586+
_sink_dequant_to_quant_nodes(
587+
graph, dequant_node, placeholder, downstream_quants
588+
)
589+
590+
modified = True
591+
logger.info(
592+
"Sunk dequant for input '%s' through transparent ops to %d "
593+
"downstream quant nodes",
594+
placeholder.name,
595+
len(downstream_quants),
596+
)
597+
598+
if modified:
599+
graph.lint()
600+
graph_module.recompile()
601+
602+
return modified

0 commit comments

Comments
 (0)