|
53 | 53 | PQLayerNorm, |
54 | 54 | PQMultiheadAttention, |
55 | 55 | ) |
| 56 | +from pquant.core.torch.quantizer import Quantizer # noqa: E402 |
56 | 57 |
|
57 | 58 | # --------------------------------------------------------------------------- |
58 | 59 | # QONNX Quant node |
@@ -1200,6 +1201,15 @@ def _emit_module( |
1200 | 1201 | module, prefix, current, current, current, nodes, initializers, quant_fn, use_qonnx, store_integer_weights |
1201 | 1202 | ) |
1202 | 1203 | return out |
| 1204 | + if isinstance(module, Quantizer): |
| 1205 | + # Standalone quantizer (e.g. an auto-inserted missing quantizer or a |
| 1206 | + # constant-matrix quantizer): emit a single QDQ node from its k/i/f. |
| 1207 | + k, i, f = module.get_quantization_bits() |
| 1208 | + new_nodes, out = quant_fn( |
| 1209 | + prefix, current, module.round_mode, k, i, f, initializers, overflow_mode=getattr(module, "overflow", "SAT") |
| 1210 | + ) |
| 1211 | + nodes.extend(new_nodes) |
| 1212 | + return out |
1203 | 1213 | raise TypeError(f"Unsupported module type for ONNX export: {type(module).__name__}") |
1204 | 1214 |
|
1205 | 1215 |
|
@@ -1481,6 +1491,8 @@ class _PQTracer(_fx.Tracer): |
1481 | 1491 | PQAvgPool1d, |
1482 | 1492 | PQAvgPool2d, |
1483 | 1493 | PQMultiheadAttention, |
| 1494 | + PQActivation, |
| 1495 | + Quantizer, |
1484 | 1496 | ) |
1485 | 1497 |
|
1486 | 1498 | def is_leaf_module(self, m: nn.Module, qualname: str) -> bool: |
@@ -1516,13 +1528,16 @@ def convert_to_onnx_fx( |
1516 | 1528 | # need to expand torch's two-arg .transpose(d0, d1) into a full ONNX perm. |
1517 | 1529 | from torch.fx.passes.shape_prop import ShapeProp |
1518 | 1530 |
|
| 1531 | + # Build the probe tensor on the model's own device so ShapeProp doesn't hit a |
| 1532 | + # device mismatch when a default device (e.g. CUDA) is set via torch.set_default_device. |
| 1533 | + device = next((p.device for p in model.parameters()), None) |
1519 | 1534 | with torch.no_grad(): |
1520 | | - ShapeProp(gm).propagate(torch.zeros(1, *input_shape)) |
| 1535 | + ShapeProp(gm).propagate(torch.zeros(1, *input_shape, device=device)) |
1521 | 1536 |
|
1522 | 1537 | onnx_nodes: list[onnx.NodeProto] = [] |
1523 | 1538 | initializers: list[onnx.TensorProto] = [] |
1524 | 1539 | node_to_name: dict[_fx.Node, str] = {} |
1525 | | - output_name: str = "" |
| 1540 | + output_names: list[str] = [] |
1526 | 1541 |
|
1527 | 1542 | def _res(arg) -> str: |
1528 | 1543 | if isinstance(arg, _fx.Node): |
@@ -1680,6 +1695,11 @@ def _resolve_perm_dims(args, rank: int) -> list[int]: |
1680 | 1695 | onnx_nodes.append(oh.make_node("Relu", inputs=[_res(node.args[0])], outputs=[out])) |
1681 | 1696 | node_to_name[node] = out |
1682 | 1697 |
|
| 1698 | + elif fn in (_F.sigmoid, torch.sigmoid): |
| 1699 | + out = f"{node.name}_sigmoid" |
| 1700 | + onnx_nodes.append(oh.make_node("Sigmoid", inputs=[_res(node.args[0])], outputs=[out])) |
| 1701 | + node_to_name[node] = out |
| 1702 | + |
1683 | 1703 | elif fn is torch.flatten: |
1684 | 1704 | start_dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("start_dim", 0) |
1685 | 1705 | out = f"{node.name}_flatten" |
@@ -1739,28 +1759,29 @@ def _resolve_perm_dims(args, rank: int) -> list[int]: |
1739 | 1759 |
|
1740 | 1760 | elif node.op == "output": |
1741 | 1761 | ret = node.args[0] |
1742 | | - if isinstance(ret, _fx.Node): |
1743 | | - val = node_to_name[ret] |
| 1762 | + rets = list(ret) if isinstance(ret, (tuple, list)) else [ret] |
| 1763 | + for r in rets: |
| 1764 | + if not isinstance(r, _fx.Node): |
| 1765 | + raise TypeError("FX ONNX export: unsupported (non-tensor) model output") |
| 1766 | + val = node_to_name[r] |
1744 | 1767 | # MHA nodes store a tuple (out, avg_attn); expose the attention output. |
1745 | | - output_name = val[0] if isinstance(val, tuple) else val |
1746 | | - elif isinstance(ret, (tuple, list)) and len(ret) == 1: |
1747 | | - val = node_to_name[ret[0]] |
1748 | | - output_name = val[0] if isinstance(val, tuple) else val |
1749 | | - else: |
1750 | | - raise TypeError("Only single-output models are supported for FX ONNX export") |
| 1768 | + output_names.append(val[0] if isinstance(val, tuple) else val) |
1751 | 1769 |
|
1752 | 1770 | with torch.no_grad(): |
1753 | | - dummy_out = model(torch.zeros(1, *input_shape)) |
1754 | | - output_shape = [None] + list(dummy_out.shape[1:]) |
| 1771 | + dummy_out = model(torch.zeros(1, *input_shape, device=device)) |
| 1772 | + dummy_outs = list(dummy_out) if isinstance(dummy_out, (tuple, list)) else [dummy_out] |
1755 | 1773 |
|
1756 | 1774 | batch_dim = oh.make_tensor_value_info("input", TensorProto.FLOAT, [None, *input_shape]) |
1757 | | - output_vi = oh.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) |
| 1775 | + output_vis = [ |
| 1776 | + oh.make_tensor_value_info(name, TensorProto.FLOAT, [None] + list(t.shape[1:])) |
| 1777 | + for name, t in zip(output_names, dummy_outs) |
| 1778 | + ] |
1758 | 1779 |
|
1759 | 1780 | onnx_graph = oh.make_graph( |
1760 | 1781 | nodes=onnx_nodes, |
1761 | 1782 | name="pquant_onnx_fx", |
1762 | 1783 | inputs=[batch_dim], |
1763 | | - outputs=[output_vi], |
| 1784 | + outputs=output_vis, |
1764 | 1785 | initializer=initializers, |
1765 | 1786 | ) |
1766 | 1787 |
|
|
0 commit comments