|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 | import logging |
| 10 | +import operator |
10 | 11 | from typing import Any, Optional, Union |
11 | 12 |
|
12 | 13 | import torch |
@@ -303,7 +304,26 @@ def __init__( |
303 | 304 | if isinstance(input_args, list): |
304 | 305 | self.quant_args = extract_input_quant_params_from_graph(module, input_args) |
305 | 306 | elif isinstance(input_args, dict): |
306 | | - self.quant_args = input_args |
| 307 | + # dict[int, QuantArgs] — use directly |
| 308 | + # dict[int, list[str]] — extract quant params from graph, keyed by input index |
| 309 | + first_value = next(iter(input_args.values()), None) |
| 310 | + if ( |
| 311 | + isinstance(first_value, (list, tuple)) |
| 312 | + and first_value |
| 313 | + and isinstance(first_value[0], str) |
| 314 | + ): |
| 315 | + # Values are lists of node names: extract quant params and map |
| 316 | + # to the caller-specified input indices. |
| 317 | + for input_idx, node_names in input_args.items(): |
| 318 | + assert isinstance(node_names, list) |
| 319 | + extracted = extract_input_quant_params_from_graph( |
| 320 | + module, node_names |
| 321 | + ) |
| 322 | + # Use the first extracted quant params for this input index. |
| 323 | + if extracted: |
| 324 | + self.quant_args[int(input_idx)] = next(iter(extracted.values())) |
| 325 | + else: |
| 326 | + self.quant_args = {int(k): v for k, v in input_args.items()} |
307 | 327 |
|
308 | 328 | def forward(self, *args: torch.Tensor) -> Any: |
309 | 329 | """Run inference, dequantizing configured inputs.""" |
@@ -349,6 +369,27 @@ def forward(self, *args: torch.Tensor) -> Any: |
349 | 369 |
|
350 | 370 | return self.module(*dequantized_args) |
351 | 371 |
|
| 372 | + @staticmethod |
| 373 | + def sink_dequants(program: torch.export.ExportedProgram) -> None: |
| 374 | + """Sink dequant nodes through transparent ops in an exported program. |
| 375 | +
|
| 376 | + If the graph branches through transparent ops (view, split, getitem, etc.) |
| 377 | + into paths with different quantization parameters, sink the dequants to be |
| 378 | + adjacent to each downstream quant node, enabling per-branch fusion. |
| 379 | +
|
| 380 | + Must be called after export() on a QuantizedInputWrapper-wrapped model. |
| 381 | + """ |
| 382 | + from torch.export.graph_signature import InputKind |
| 383 | + |
| 384 | + user_input_names = { |
| 385 | + spec.arg.name |
| 386 | + for spec in program.graph_signature.input_specs |
| 387 | + if spec.kind == InputKind.USER_INPUT |
| 388 | + } |
| 389 | + sink_input_dequant_through_transparent_ops( |
| 390 | + program.graph_module, user_input_names |
| 391 | + ) |
| 392 | + |
352 | 393 |
|
353 | 394 | class QuantizedOutputWrapper(torch.nn.Module): |
354 | 395 | """ |
@@ -379,3 +420,183 @@ def forward(self, *args: torch.Tensor) -> Any: |
379 | 420 | return torch.ops.quantized_decomposed.quantize_per_tensor.default( |
380 | 421 | result, scale, zp, qmin, qmax, dtype |
381 | 422 | ) |
| 423 | + |
| 424 | + |
| 425 | +def _get_transparent_ops() -> set[Any]: |
| 426 | + """Ops that only reshape/index data without changing values. |
| 427 | + Safe to pass uint8 data through these.""" |
| 428 | + return { |
| 429 | + torch.ops.aten.view_copy.default, |
| 430 | + torch.ops.aten.view.default, |
| 431 | + torch.ops.aten.reshape.default, |
| 432 | + torch.ops.aten.split.Tensor, |
| 433 | + torch.ops.aten.slice_copy.Tensor, |
| 434 | + torch.ops.aten.permute_copy.default, |
| 435 | + torch.ops.aten.permute.default, |
| 436 | + torch.ops.aten.expand_copy.default, |
| 437 | + torch.ops.aten.unsqueeze_copy.default, |
| 438 | + torch.ops.aten.squeeze_copy.dim, |
| 439 | + torch.ops.aten.transpose_copy.int, |
| 440 | + torch.ops.aten.clone.default, |
| 441 | + operator.getitem, |
| 442 | + } |
| 443 | + |
| 444 | + |
| 445 | +def _get_quantize_ops() -> set[Any]: |
| 446 | + ops = {torch.ops.quantized_decomposed.quantize_per_tensor.default} |
| 447 | + try: |
| 448 | + ops.add(torch.ops.cadence.quantize_per_tensor.default) |
| 449 | + except AttributeError: |
| 450 | + pass |
| 451 | + return ops |
| 452 | + |
| 453 | + |
| 454 | +def _get_dequantize_ops() -> set[Any]: |
| 455 | + ops = {torch.ops.quantized_decomposed.dequantize_per_tensor.default} |
| 456 | + try: |
| 457 | + ops.add(torch.ops.cadence.dequantize_per_tensor.default) |
| 458 | + except AttributeError: |
| 459 | + pass |
| 460 | + return ops |
| 461 | + |
| 462 | + |
| 463 | +def _walk_to_downstream_quants( |
| 464 | + node: torch.fx.Node, |
| 465 | + quantize_ops: set[Any], |
| 466 | + transparent_ops: set[Any], |
| 467 | + downstream_quants: list[torch.fx.Node], |
| 468 | +) -> bool: |
| 469 | + """Walk forward through transparent ops collecting downstream quant nodes. |
| 470 | +
|
| 471 | + Returns True if all paths end at a quant node. |
| 472 | + """ |
| 473 | + all_valid = True |
| 474 | + for user in node.users: |
| 475 | + if user.op == "call_function" and user.target in quantize_ops: |
| 476 | + downstream_quants.append(user) |
| 477 | + elif user.op == "call_function" and user.target in transparent_ops: |
| 478 | + if not _walk_to_downstream_quants( |
| 479 | + user, quantize_ops, transparent_ops, downstream_quants |
| 480 | + ): |
| 481 | + all_valid = False |
| 482 | + else: |
| 483 | + all_valid = False |
| 484 | + return all_valid |
| 485 | + |
| 486 | + |
| 487 | +def _get_dequant_node_for_placeholder( |
| 488 | + placeholder: torch.fx.Node, |
| 489 | + input_placeholder_names: set[str] | None, |
| 490 | + dequantize_ops: set[Any], |
| 491 | +) -> torch.fx.Node | None: |
| 492 | + """Return the single dequant user of a uint8 placeholder, or None.""" |
| 493 | + if placeholder.op != "placeholder": |
| 494 | + return None |
| 495 | + if ( |
| 496 | + input_placeholder_names is not None |
| 497 | + and placeholder.name not in input_placeholder_names |
| 498 | + ): |
| 499 | + return None |
| 500 | + val = placeholder.meta.get("val") |
| 501 | + if val is None or not isinstance(val, torch.Tensor): |
| 502 | + return None |
| 503 | + if val.dtype != torch.uint8: |
| 504 | + return None |
| 505 | + if len(placeholder.users) != 1: |
| 506 | + return None |
| 507 | + dequant_node = next(iter(placeholder.users)) |
| 508 | + if dequant_node.op == "call_function" and dequant_node.target in dequantize_ops: |
| 509 | + return dequant_node |
| 510 | + return None |
| 511 | + |
| 512 | + |
| 513 | +def _sink_dequant_to_quant_nodes( |
| 514 | + graph: torch.fx.Graph, |
| 515 | + dequant_node: torch.fx.Node, |
| 516 | + placeholder: torch.fx.Node, |
| 517 | + downstream_quants: list[torch.fx.Node], |
| 518 | +) -> None: |
| 519 | + """Insert per-branch dequants before each downstream quant and rewire.""" |
| 520 | + dequant_op = dequant_node.target |
| 521 | + assert callable(dequant_op) |
| 522 | + |
| 523 | + for quant_node in downstream_quants: |
| 524 | + quant_input = quant_node.args[0] |
| 525 | + assert isinstance(quant_input, torch.fx.Node) |
| 526 | + quant_params = quant_node.args[1:] |
| 527 | + |
| 528 | + with graph.inserting_before(quant_node): |
| 529 | + new_dequant = graph.call_function( |
| 530 | + dequant_op, |
| 531 | + args=(quant_input, *quant_params), |
| 532 | + ) |
| 533 | + new_dequant.meta = {**dequant_node.meta} |
| 534 | + if "val" in quant_node.meta and isinstance( |
| 535 | + quant_node.meta["val"], torch.Tensor |
| 536 | + ): |
| 537 | + quant_val = quant_node.meta["val"] |
| 538 | + new_dequant.meta["val"] = torch.empty(quant_val.shape, dtype=torch.float32) |
| 539 | + |
| 540 | + quant_node.replace_input_with(quant_input, new_dequant) |
| 541 | + |
| 542 | + dequant_node.replace_all_uses_with(placeholder) |
| 543 | + graph.erase_node(dequant_node) |
| 544 | + |
| 545 | + |
| 546 | +def sink_input_dequant_through_transparent_ops( |
| 547 | + graph_module: GraphModule, |
| 548 | + input_placeholder_names: set[str] | None = None, |
| 549 | +) -> bool: |
| 550 | + """ |
| 551 | + Sinks dequantize nodes from quantized input placeholders through transparent ops |
| 552 | + to be adjacent to downstream quantize nodes, enabling dequant-quant fusion. |
| 553 | + This creates per-branch dequants with matching params. |
| 554 | +
|
| 555 | + Args: |
| 556 | + graph_module: The graph module to transform. |
| 557 | + input_placeholder_names: Optional set of placeholder names to consider. |
| 558 | + If provided, only these placeholders are processed (use this to |
| 559 | + restrict to user inputs and avoid touching weight/buffer placeholders). |
| 560 | + If None, all uint8 placeholders are considered. |
| 561 | +
|
| 562 | + Returns True if the graph was modified. |
| 563 | + """ |
| 564 | + graph = graph_module.graph |
| 565 | + modified = False |
| 566 | + |
| 567 | + transparent_ops: set[Any] = _get_transparent_ops() |
| 568 | + quantize_ops: set[Any] = _get_quantize_ops() |
| 569 | + dequantize_ops: set[Any] = _get_dequantize_ops() |
| 570 | + |
| 571 | + for placeholder in list(graph.nodes): |
| 572 | + dequant_node = _get_dequant_node_for_placeholder( |
| 573 | + placeholder, input_placeholder_names, dequantize_ops |
| 574 | + ) |
| 575 | + if dequant_node is None: |
| 576 | + continue |
| 577 | + |
| 578 | + downstream_quants: list[torch.fx.Node] = [] |
| 579 | + all_paths_end_at_quant = _walk_to_downstream_quants( |
| 580 | + dequant_node, quantize_ops, transparent_ops, downstream_quants |
| 581 | + ) |
| 582 | + |
| 583 | + if not downstream_quants or not all_paths_end_at_quant: |
| 584 | + continue |
| 585 | + |
| 586 | + _sink_dequant_to_quant_nodes( |
| 587 | + graph, dequant_node, placeholder, downstream_quants |
| 588 | + ) |
| 589 | + |
| 590 | + modified = True |
| 591 | + logger.info( |
| 592 | + "Sunk dequant for input '%s' through transparent ops to %d " |
| 593 | + "downstream quant nodes", |
| 594 | + placeholder.name, |
| 595 | + len(downstream_quants), |
| 596 | + ) |
| 597 | + |
| 598 | + if modified: |
| 599 | + graph.lint() |
| 600 | + graph_module.recompile() |
| 601 | + |
| 602 | + return modified |
0 commit comments