Skip to content

Commit d5cce9d

Browse files
authored
initial tracing of model (#42)
Add torch.fx tracing that can be used when replacing layers with compressed variants, or by calling the function directly
1 parent b1c3ec8 commit d5cce9d

5 files changed

Lines changed: 907 additions & 17 deletions

File tree

src/pquant/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
pdp_config,
2020
wanda_config,
2121
)
22-
from .core.torch import activations, layers, optimizers, pruning_methods, quantizer
22+
from .core.torch import (
23+
activations,
24+
layers,
25+
optimizers,
26+
pruning_methods,
27+
quantizer,
28+
tracing,
29+
)
2330
from .core.torch.layers import (
2431
add_compression_layers,
2532
apply_final_compression,
@@ -29,9 +36,10 @@
2936
load_torch_hgq_model,
3037
post_training_prune,
3138
)
39+
from .core.torch.tracing import check_quantization, print_quantization_check
3240
from .core.torch.train import train_model
3341

34-
_forwards = ["activations", "layers", "quantizer", "optimizers"]
42+
_forwards = ["activations", "layers", "quantizer", "optimizers", "tracing"]
3543

3644
for name in _forwards:
3745
mod = importlib.import_module(f".core.torch.{name}", package="pquant")
@@ -57,6 +65,8 @@
5765
_forwards.append("load_from_dictionary")
5866
_forwards.append("get_ebops")
5967
_forwards.append("load_torch_hgq_model")
68+
_forwards.append("check_quantization")
69+
_forwards.append("print_quantization_check")
6070
_forwards.append("PQConfig")
6171
__all__ = _forwards
6272

src/pquant/core/torch/convert_to_onnx.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
PQLayerNorm,
5454
PQMultiheadAttention,
5555
)
56+
from pquant.core.torch.quantizer import Quantizer # noqa: E402
5657

5758
# ---------------------------------------------------------------------------
5859
# QONNX Quant node
@@ -1200,6 +1201,15 @@ def _emit_module(
12001201
module, prefix, current, current, current, nodes, initializers, quant_fn, use_qonnx, store_integer_weights
12011202
)
12021203
return out
1204+
if isinstance(module, Quantizer):
1205+
# Standalone quantizer (e.g. an auto-inserted missing quantizer or a
1206+
# constant-matrix quantizer): emit a single QDQ node from its k/i/f.
1207+
k, i, f = module.get_quantization_bits()
1208+
new_nodes, out = quant_fn(
1209+
prefix, current, module.round_mode, k, i, f, initializers, overflow_mode=getattr(module, "overflow", "SAT")
1210+
)
1211+
nodes.extend(new_nodes)
1212+
return out
12031213
raise TypeError(f"Unsupported module type for ONNX export: {type(module).__name__}")
12041214

12051215

@@ -1481,6 +1491,8 @@ class _PQTracer(_fx.Tracer):
14811491
PQAvgPool1d,
14821492
PQAvgPool2d,
14831493
PQMultiheadAttention,
1494+
PQActivation,
1495+
Quantizer,
14841496
)
14851497

14861498
def is_leaf_module(self, m: nn.Module, qualname: str) -> bool:
@@ -1516,13 +1528,16 @@ def convert_to_onnx_fx(
15161528
# need to expand torch's two-arg .transpose(d0, d1) into a full ONNX perm.
15171529
from torch.fx.passes.shape_prop import ShapeProp
15181530

1531+
# Build the probe tensor on the model's own device so ShapeProp doesn't hit a
1532+
# device mismatch when a default device (e.g. CUDA) is set via torch.set_default_device.
1533+
device = next((p.device for p in model.parameters()), None)
15191534
with torch.no_grad():
1520-
ShapeProp(gm).propagate(torch.zeros(1, *input_shape))
1535+
ShapeProp(gm).propagate(torch.zeros(1, *input_shape, device=device))
15211536

15221537
onnx_nodes: list[onnx.NodeProto] = []
15231538
initializers: list[onnx.TensorProto] = []
15241539
node_to_name: dict[_fx.Node, str] = {}
1525-
output_name: str = ""
1540+
output_names: list[str] = []
15261541

15271542
def _res(arg) -> str:
15281543
if isinstance(arg, _fx.Node):
@@ -1680,6 +1695,11 @@ def _resolve_perm_dims(args, rank: int) -> list[int]:
16801695
onnx_nodes.append(oh.make_node("Relu", inputs=[_res(node.args[0])], outputs=[out]))
16811696
node_to_name[node] = out
16821697

1698+
elif fn in (_F.sigmoid, torch.sigmoid):
1699+
out = f"{node.name}_sigmoid"
1700+
onnx_nodes.append(oh.make_node("Sigmoid", inputs=[_res(node.args[0])], outputs=[out]))
1701+
node_to_name[node] = out
1702+
16831703
elif fn is torch.flatten:
16841704
start_dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("start_dim", 0)
16851705
out = f"{node.name}_flatten"
@@ -1739,28 +1759,29 @@ def _resolve_perm_dims(args, rank: int) -> list[int]:
17391759

17401760
elif node.op == "output":
17411761
ret = node.args[0]
1742-
if isinstance(ret, _fx.Node):
1743-
val = node_to_name[ret]
1762+
rets = list(ret) if isinstance(ret, (tuple, list)) else [ret]
1763+
for r in rets:
1764+
if not isinstance(r, _fx.Node):
1765+
raise TypeError("FX ONNX export: unsupported (non-tensor) model output")
1766+
val = node_to_name[r]
17441767
# MHA nodes store a tuple (out, avg_attn); expose the attention output.
1745-
output_name = val[0] if isinstance(val, tuple) else val
1746-
elif isinstance(ret, (tuple, list)) and len(ret) == 1:
1747-
val = node_to_name[ret[0]]
1748-
output_name = val[0] if isinstance(val, tuple) else val
1749-
else:
1750-
raise TypeError("Only single-output models are supported for FX ONNX export")
1768+
output_names.append(val[0] if isinstance(val, tuple) else val)
17511769

17521770
with torch.no_grad():
1753-
dummy_out = model(torch.zeros(1, *input_shape))
1754-
output_shape = [None] + list(dummy_out.shape[1:])
1771+
dummy_out = model(torch.zeros(1, *input_shape, device=device))
1772+
dummy_outs = list(dummy_out) if isinstance(dummy_out, (tuple, list)) else [dummy_out]
17551773

17561774
batch_dim = oh.make_tensor_value_info("input", TensorProto.FLOAT, [None, *input_shape])
1757-
output_vi = oh.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape)
1775+
output_vis = [
1776+
oh.make_tensor_value_info(name, TensorProto.FLOAT, [None] + list(t.shape[1:]))
1777+
for name, t in zip(output_names, dummy_outs)
1778+
]
17581779

17591780
onnx_graph = oh.make_graph(
17601781
nodes=onnx_nodes,
17611782
name="pquant_onnx_fx",
17621783
inputs=[batch_dim],
1763-
outputs=[output_vi],
1784+
outputs=output_vis,
17641785
initializer=initializers,
17651786
)
17661787

src/pquant/core/torch/layers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,15 @@ def extra_repr(self):
644644
return s.format(**self.__dict__)
645645

646646

647-
def add_compression_layers(model, config, input_shape=None):
647+
def add_compression_layers(model, config, input_shape=None, add_missing_quantizers=False):
648648
model = add_quantized_activations_to_model_layer(model, config)
649649
model = add_pruning_to_model(model, config)
650+
if add_missing_quantizers:
651+
# Imported here (not at module top) to avoid a circular import: tracing.py
652+
# imports the layer classes defined in this module.
653+
from pquant.core.torch.tracing import check_quantization
654+
655+
model = check_quantization(model, add_missing_quantizers=True, config=config)
650656
model.to("cuda")
651657
if input_shape is not None:
652658
model(torch.rand(input_shape).to("cuda"))

0 commit comments

Comments
 (0)