Skip to content

Commit d76bbe3

Browse files
authored
Advance quant above cat (pytorch#19926)
Differential Revision: D107179344 Pull Request resolved: pytorch#19926
1 parent ea8037c commit d76bbe3

2 files changed

Lines changed: 139 additions & 6 deletions

File tree

backends/cadence/aot/reorder_ops.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,22 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
248248
@register_cadence_pass(CadencePassAttribute(opt_level=1))
249249
class AdvanceQuantizeOpAboveDefChainPass(ExportPass):
250250
"""
251-
If the input to quantize op is linear chain of view, transpose, permute, or
252-
slice ops that are trivially quantized, we can convert the pattern
253-
view/transpose/permute/slice(fp32) -> quantize(int8/uint8) to
254-
quantize(int8/uint8) -> view/transpose/permute/slice(int8/uint8).
255-
The benefit of such reordering is that the view/transpose/permute/slice
256-
will move far less data.
251+
Advances a quantize op above data-movement ops to reduce data volume.
252+
253+
Handles two cases:
254+
255+
1. Linear chain: if the input to a quantize op is a chain of trivially
256+
quantizable ops (view, transpose, permute, slice), rewrite
257+
data_movement(fp32) -> quantize to quantize -> data_movement(quantized)
258+
so the data movement operates on smaller quantized tensors.
259+
260+
2. Cat: if the input to a quantize op is a cat with a single user (the
261+
quantize), advance the quantize above the cat by quantizing each cat
262+
input individually. A later pass can clean up any redundant
263+
dequant-quant pairs on the inputs.
264+
265+
For the cat case, SplitDequantizedCatPass should run first to ensure
266+
each cat has at most one quantize consumer.
257267
"""
258268

259269
def __init__(self):
@@ -302,6 +312,47 @@ def advancing_feasible(self, quant_node: torch.fx.Node):
302312
# All the conditions satisfied, we advance.
303313
return True
304314

315+
def _advance_above_cat(
316+
self, quant_node: torch.fx.Node, cat_node: torch.fx.Node
317+
) -> None:
318+
"""Advance a quantize op above a cat by quantizing each cat input."""
319+
graph = quant_node.graph
320+
quant_params = quant_node.args[1:]
321+
322+
cat_inputs = cat_node.args[0]
323+
assert isinstance(cat_inputs, (list, tuple))
324+
325+
new_inputs: list[torch.fx.Node] = []
326+
for inp in cat_inputs:
327+
# cat concatenates tensors, so every input must be a node.
328+
assert isinstance(inp, torch.fx.Node)
329+
330+
with graph.inserting_before(cat_node):
331+
new_quant = graph.call_function(
332+
# pyre-ignore[6]
333+
quant_node.target,
334+
args=(inp, *quant_params),
335+
)
336+
# This copies the fp32 input's meta, so meta["val"] keeps the
337+
# fp32 dtype rather than the quantized output dtype. That's fine:
338+
# nothing in this pass reads dtype from meta (only shape, which
339+
# is correct), and call() re-runs super().call() to re-propagate
340+
# fake tensors, making meta dtype-consistent before we return.
341+
new_quant.meta = inp.meta.copy()
342+
new_inputs.append(new_quant)
343+
344+
dim = get_arg(cat_node, "dim", int)
345+
with graph.inserting_before(quant_node):
346+
new_cat = graph.call_function(
347+
# pyre-ignore[6]
348+
cat_node.target,
349+
args=(new_inputs, dim),
350+
)
351+
new_cat.meta = quant_node.meta.copy()
352+
353+
quant_node.replace_all_uses_with(new_cat)
354+
graph.erase_node(quant_node)
355+
305356
def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
306357
graph = graph_module.graph
307358
modified = False
@@ -314,6 +365,17 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
314365
):
315366
continue
316367

368+
inp = node.args[0]
369+
if (
370+
isinstance(inp, torch.fx.Node)
371+
and get_overload_packet(inp.target)
372+
in (exir_ops.edge.aten.cat, torch.ops.aten.cat)
373+
and len(inp.users) == 1
374+
):
375+
self._advance_above_cat(node, inp)
376+
modified = True
377+
continue
378+
317379
if not self.advancing_feasible(node):
318380
continue
319381

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,3 +1268,74 @@ def test_two_quant_outputs_different_params_separate_cats(self) -> None:
12681268
)
12691269
quant_cat_inputs = {node.args[0] for node in quant_nodes}
12701270
self.assertEqual(len(quant_cat_inputs), 2)
1271+
1272+
1273+
class TestAdvanceQuantAboveCat(unittest.TestCase):
1274+
def test_float_inputs_get_quantized(self) -> None:
1275+
"""Float (non-dq) inputs to cat should get a quant inserted."""
1276+
builder = GraphBuilder()
1277+
a = builder.placeholder("a", torch.randn(2, 4))
1278+
b = builder.placeholder("b", torch.randn(2, 4))
1279+
cat = builder.call_operator(exir_ops.edge.aten.cat.default, args=([a, b], 0))
1280+
q = builder.call_operator(
1281+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1282+
args=(cat, 0.01, 0, -128, 127, torch.int8),
1283+
)
1284+
builder.output([q])
1285+
gm = builder.get_graph_module()
1286+
1287+
result = AdvanceQuantizeOpAboveDefChainPass().call(gm)
1288+
1289+
self.assertTrue(result.modified)
1290+
converted = result.graph_module
1291+
1292+
# Two new quants (one per input) should exist; the original post-cat quant is gone.
1293+
self.assertEqual(
1294+
count_node(
1295+
converted,
1296+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1297+
),
1298+
2,
1299+
)
1300+
1301+
# Cat should take quantized inputs.
1302+
cat_nodes = converted.graph.find_nodes(
1303+
op="call_function", target=exir_ops.edge.aten.cat.default
1304+
)
1305+
self.assertEqual(len(cat_nodes), 1)
1306+
for inp in cat_nodes[0].args[0]:
1307+
self.assertEqual(
1308+
inp.target,
1309+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1310+
)
1311+
1312+
def test_cat_with_multiple_users_not_advanced(self) -> None:
1313+
"""Cat with multiple users should not be advanced (split pass handles this first)."""
1314+
builder = GraphBuilder()
1315+
x_int8 = builder.placeholder(
1316+
"x_int8", torch.randint(-128, 127, (2, 4), dtype=torch.int8)
1317+
)
1318+
dq = builder.call_operator(
1319+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1320+
args=(x_int8, 0.02, -5, -128, 127, torch.int8),
1321+
)
1322+
b = builder.placeholder("b", torch.randn(2, 4))
1323+
cat = builder.call_operator(exir_ops.edge.aten.cat.default, args=([dq, b], 0))
1324+
sliced = builder.call_operator(
1325+
exir_ops.edge.aten.slice_copy.Tensor, args=(cat, 0, 0, 2)
1326+
)
1327+
q = builder.call_operator(
1328+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1329+
args=(cat, 0.02, -5, -128, 127, torch.int8),
1330+
)
1331+
q_dq = builder.call_operator(
1332+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1333+
args=(q, 0.02, -5, -128, 127, torch.int8),
1334+
)
1335+
builder.output([sliced, q_dq])
1336+
gm = builder.get_graph_module()
1337+
1338+
result = AdvanceQuantizeOpAboveDefChainPass().call(gm)
1339+
1340+
self.assertFalse(result.modified)
1341+
self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1)

0 commit comments

Comments
 (0)