Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,10 +1089,11 @@ def find_nodes_from_matmul_to_exclude(
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
calibration_shapes: str | dict | None = None,
) -> list[str]:
"""Find MatMul nodes that meets gemv condition to exclude.
"""Find MatMul nodes that meets gemv or small-gemm condition to exclude.

Either of m or n in matmul is 1, this matmul cannot utilize
TensorCores. The perf of adding Q/DQ layers is not good in
Either of m or n in matmul is 1, or K or N is smaller than
_MIN_MATMUL_DIM_INT8 (16), this matmul cannot efficiently utilize
INT8 kernels. The perf of adding Q/DQ layers is not good in
TRT. Thus, in this case, do not add Q/DQ layers to this matmul.

Args:
Expand Down Expand Up @@ -1143,6 +1144,7 @@ def find_nodes_from_matmul_to_exclude(


_MIN_CHANNELS_FP8 = 16
_MIN_MATMUL_DIM_INT8 = 16


def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
Expand Down Expand Up @@ -1231,10 +1233,39 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
return unsupported_conv_nodes


def _get_inp_b_k_dim(
matmul_node, value_info_map: dict | None = None, output_map: dict | None = None
):
"""Get the K dimension from the second input of a MatMul node.

Tries Constant shape first, then falls back to shape inference (value_info_map)
or runtime inference (output_map).

Returns:
The K dimension value, or None if it cannot be determined.
"""
inp_b = matmul_node.inputs[1]
if hasattr(inp_b, "values") and inp_b.values is not None:
inp_b_shape = inp_b.values.shape
if len(inp_b_shape) >= 2:
return inp_b_shape[-2]
if value_info_map is not None:
inp_b_info = value_info_map.get(inp_b.name)
if inp_b_info:
inp_b_dims = inp_b_info.type.tensor_type.shape.dim
if len(inp_b_dims) >= 2:
return inp_b_dims[-2].dim_value
if output_map is not None and inp_b.name in output_map:
inp_b_out = output_map[inp_b.name]
if len(inp_b_out.shape) >= 2:
return inp_b_out.shape[-2]
return None


def _exclude_matmuls_by_shape_inference(
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
) -> list[str]:
"""Use shape inference to find MatMuls with dimension 1."""
"""Use shape inference to find MatMuls with dimension 1 or small K/N."""
# Prepare model for symbolic inference
for graph_input in model.graph.input:
for dim in graph_input.type.tensor_type.shape.dim:
Expand Down Expand Up @@ -1280,8 +1311,22 @@ def _exclude_matmuls_by_shape_inference(

if dims[-1].dim_value == 1 or dims[-2].dim_value == 1:
nodes_to_exclude.append(matmul_node.name)
continue
elif len(dims) < 3 and any(out.dim_value == 1 for out in dims):
nodes_to_exclude.append(matmul_node.name)
continue
# Small-gemm check: exclude if N or K < 16 (INT8 kernels need >= 16).
n_dim = dims[-1].dim_value if len(dims) >= 2 else 0
k_dim = _get_inp_b_k_dim(matmul_node, value_info_map=value_info_map)
small_n = 0 < n_dim < _MIN_MATMUL_DIM_INT8
small_k = k_dim is not None and 0 < k_dim < _MIN_MATMUL_DIM_INT8

if small_n or small_k:
logger.debug(
f"Excluding small-dim MatMul from INT8 quantization: {matmul_node.name} "
f"(N={n_dim}, K={k_dim}, threshold={_MIN_MATMUL_DIM_INT8})"
)
nodes_to_exclude.append(matmul_node.name)

return nodes_to_exclude

Expand All @@ -1295,10 +1340,20 @@ def _exclude_matmuls_by_inference(
calibration_data_reader: CalibrationDataReader,
calibration_eps: list[str],
) -> list[str]:
"""Use actual inference to find MatMuls with dimension 1."""
# Add matmul outputs to model outputs
"""Use actual inference to find MatMuls with dimension 1 or small K/N (INT8)."""
# Add matmul outputs and second-input outputs to model outputs
existing_output_names = {out.name for out in model.graph.output}
for matmul_node in matmul_nodes:
model.graph.output.extend([onnx.ValueInfoProto(name=matmul_node.outputs[0].name)])
out_name = matmul_node.outputs[0].name
if out_name not in existing_output_names:
model.graph.output.extend([onnx.ValueInfoProto(name=out_name)])
existing_output_names.add(out_name)
# Also add second input for K-dimension check (only if it's a Variable, not a Constant)
if isinstance(matmul_node.inputs[1], Variable):
inp_b_name = matmul_node.inputs[1].name
if inp_b_name not in existing_output_names:
model.graph.output.extend([onnx.ValueInfoProto(name=inp_b_name)])
existing_output_names.add(inp_b_name)

output_map = get_extended_model_outputs(
onnx_path,
Expand All @@ -1319,8 +1374,23 @@ def _exclude_matmuls_by_inference(
or matmul_output.shape[-2] == 1
):
nodes_to_exclude.append(matmul_node.name)
continue
elif len(matmul_output.shape) < 3 and any(out == 1 for out in matmul_output.shape):
nodes_to_exclude.append(matmul_node.name)
continue

# Small-gemm check: exclude if N or K < 16 (INT8 kernels need >= 16).
n_dim = matmul_output.shape[-1] if len(matmul_output.shape) >= 2 else 0
k_dim = _get_inp_b_k_dim(matmul_node, output_map=output_map)
small_n = 0 < n_dim < _MIN_MATMUL_DIM_INT8
small_k = k_dim is not None and 0 < k_dim < _MIN_MATMUL_DIM_INT8

if small_n or small_k:
logger.debug(
f"Excluding small-dim MatMul from INT8 quantization: {matmul_node.name} "
f"(N={n_dim}, K={k_dim}, threshold={_MIN_MATMUL_DIM_INT8})"
)
nodes_to_exclude.append(matmul_node.name)

return nodes_to_exclude

Expand Down
97 changes: 96 additions & 1 deletion tests/unit/onnx/quantization/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
import numpy as np
import onnx_graphsurgeon as gs
import pytest
from onnx import TensorProto, helper

from modelopt.onnx.quantization.graph_utils import find_nodes_from_convs_to_exclude
from modelopt.onnx.quantization.graph_utils import (
_exclude_matmuls_by_shape_inference,
_get_inp_b_k_dim,
find_nodes_from_convs_to_exclude,
)


def _make_conv_graph(output_channels, input_channels, kernel_shape=(3, 3), name="Conv_0"):
Expand Down Expand Up @@ -85,3 +90,93 @@ def test_fp8_channels_below_16_excluded_by_general_check(oc, ic):
graph = _make_conv_graph(output_channels=oc, input_channels=ic, kernel_shape=(3, 3))
excluded = find_nodes_from_convs_to_exclude(graph, quantize_mode="fp8")
assert "Conv_0" in excluded


def _make_matmul_model(m, k, n, name="MatMul_0", inp_b_constant=True):
"""Build a minimal ONNX model with a single MatMul: [M, K] x [K, N] -> [M, N]."""
inp_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [m, k])
out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [m, n])

if inp_b_constant:
b_init = helper.make_tensor("B", TensorProto.FLOAT, [k, n], np.ones(k * n).tolist())
matmul = helper.make_node("MatMul", ["A", "B"], ["Y"], name=name)
graph = helper.make_graph([matmul], "test", [inp_a], [out], initializer=[b_init])
else:
inp_b = helper.make_tensor_value_info("B", TensorProto.FLOAT, [k, n])
matmul = helper.make_node("MatMul", ["A", "B"], ["Y"], name=name)
graph = helper.make_graph([matmul], "test", [inp_a, inp_b], [out])

model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
return model


def _get_matmul_nodes(model):
"""Import an ONNX model and return its MatMul gs.Nodes."""
graph = gs.import_onnx(model)
return [n for n in graph.nodes if n.op == "MatMul"]


def test_get_inp_b_k_dim_constant():
"""K dimension should be read from the Constant weight shape."""
model = _make_matmul_model(m=32, k=8, n=64)
nodes = _get_matmul_nodes(model)
assert _get_inp_b_k_dim(nodes[0]) == 8


def test_get_inp_b_k_dim_variable_with_output_map():
"""K dimension should be read from output_map for Variable inputs."""
model = _make_matmul_model(m=32, k=10, n=64, inp_b_constant=False)
nodes = _get_matmul_nodes(model)
output_map = {"B": np.zeros((10, 64))}
assert _get_inp_b_k_dim(nodes[0], output_map=output_map) == 10


def test_get_inp_b_k_dim_returns_none_when_unknown():
"""Should return None if K cannot be determined."""
model = _make_matmul_model(m=32, k=8, n=64, inp_b_constant=False)
nodes = _get_matmul_nodes(model)
assert _get_inp_b_k_dim(nodes[0]) is None


@pytest.mark.parametrize(
("m", "k", "n", "expected_excluded"),
[
(32, 64, 8, True),
(32, 64, 15, True),
(32, 8, 64, True),
(32, 15, 64, True),
(32, 8, 8, True),
(32, 64, 16, False),
(32, 16, 64, False),
(32, 64, 64, False),
(32, 32, 32, False),
],
)
def test_matmul_small_gemm_exclusion(m, k, n, expected_excluded):
"""MatMuls with N or K < 16 should be excluded by shape inference."""
model = _make_matmul_model(m=m, k=k, n=n)
nodes = _get_matmul_nodes(model)
calibration_shapes = {"A": [m, k]}
excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes)
if expected_excluded:
assert "MatMul_0" in excluded
else:
assert "MatMul_0" not in excluded


def test_matmul_gemv_excluded():
"""MatMul with N=1 (GEMV) should be excluded regardless of other dims."""
model = _make_matmul_model(m=32, k=64, n=1)
nodes = _get_matmul_nodes(model)
calibration_shapes = {"A": [32, 64]}
excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes)
assert "MatMul_0" in excluded


def test_matmul_large_dims_not_excluded():
"""MatMul with all large dims should not be excluded."""
model = _make_matmul_model(m=128, k=256, n=64)
nodes = _get_matmul_nodes(model)
calibration_shapes = {"A": [128, 256]}
excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes)
assert "MatMul_0" not in excluded