diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5c189bd28b..411e8b075c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,7 @@ Changelog - Fix Minitron pruning (``mcore_minitron``) for MoE models. Importance estimation hooks were incorrectly registered for MoE modules and NAS step was hanging before this. - Fix TRT support for remote autotuning in ONNX Autotune from 10.16+ to 10.15+ and fix TRT versioning check to the ``trtexec`` version instead of the TRT Python API when using ``trtexec`` backend. +- Exclude MatMul/Gemm nodes with K or N < 16 from ONNX INT8 and FP8 quantization. Such small-dimension GEMMs cannot efficiently use INT8/FP8 Tensor Cores and the added Q/DQ layers cause perf regressions in TensorRT. Honors Gemm ``transB`` when deriving K. **Misc** diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index e8b15e3059..164af24839 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -1089,11 +1089,14 @@ def find_nodes_from_matmul_to_exclude( calibration_eps: list[str] = ["cpu", "cuda:0", "trt"], calibration_shapes: str | dict | None = None, ) -> list[str]: - """Find MatMul nodes that meets gemv condition to exclude. + """Find MatMul nodes that meet gemv or small-gemm conditions and should be excluded. - Either of m or n in matmul is 1, this matmul cannot utilize - TensorCores. The perf of adding Q/DQ layers is not good in - TRT. Thus, in this case, do not add Q/DQ layers to this matmul. + A MatMul is excluded if either: + + - m or n in the output is 1 (GEMV): cannot utilize TensorCores; or + - K or N is smaller than ``_MIN_MATMUL_DIM`` (16): both INT8 and FP8 Tensor Core + kernels need K/N >= 16 to be efficient, and adding Q/DQ layers on such small + GEMMs causes TRT perf regressions. Args: onnx_path: Path to the onnx model. @@ -1143,6 +1146,10 @@ def find_nodes_from_matmul_to_exclude( _MIN_CHANNELS_FP8 = 16 +# Minimum K/N dim for MatMul/Gemm under INT8 or FP8 quantization. Both INT8 and FP8 +# Tensor Core kernels need K/N >= 16 to be efficient; adding Q/DQ layers on smaller +# GEMMs causes TRT perf regressions. +_MIN_MATMUL_DIM = 16 def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"): @@ -1231,10 +1238,47 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"): return unsupported_conv_nodes +def _get_inp_b_k_dim( + matmul_node, value_info_map: dict | None = None, output_map: dict | None = None +): + """Get the K dimension from the second input of a MatMul/Gemm node. + + Tries Constant shape first, then falls back to shape inference (value_info_map) + or runtime inference (output_map). For Gemm nodes, honors the ``transB`` attribute: + when ``transB=1``, B has shape ``[N, K]`` so K lives at axis -1; otherwise B is + ``[..., K, N]`` and K is at axis -2. + + Returns: + The K dimension value, or None if it cannot be determined. + """ + # For Gemm, transB=1 means B is [N, K] (K is last axis); default/MatMul is [K, N]. + trans_b = bool(matmul_node.attrs.get("transB", 0)) if matmul_node.op == "Gemm" else False + k_axis = -1 if trans_b else -2 + + inp_b = matmul_node.inputs[1] + if hasattr(inp_b, "values") and inp_b.values is not None: + inp_b_shape = inp_b.values.shape + if len(inp_b_shape) >= 2: + return inp_b_shape[k_axis] + if value_info_map is not None: + inp_b_info = value_info_map.get(inp_b.name) + if inp_b_info: + inp_b_dims = inp_b_info.type.tensor_type.shape.dim + if len(inp_b_dims) >= 2: + return inp_b_dims[k_axis].dim_value + if output_map is not None and inp_b.name in output_map: + inp_b_out = output_map[inp_b.name] + if len(inp_b_out.shape) >= 2: + return inp_b_out.shape[k_axis] + return None + + def _exclude_matmuls_by_shape_inference( - model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None + model: onnx.ModelProto, + matmul_nodes: list, + calibration_shapes: str | dict | None = None, ) -> list[str]: - """Use shape inference to find MatMuls with dimension 1.""" + """Use shape inference to find MatMuls with dimension 1 or small K/N.""" # Prepare model for symbolic inference for graph_input in model.graph.input: for dim in graph_input.type.tensor_type.shape.dim: @@ -1263,7 +1307,10 @@ def _exclude_matmuls_by_shape_inference( dim.dim_value = new_dim_value model = infer_shapes(model) - value_info_map = {vi.name: vi for vi in model.graph.value_info} + # Include graph inputs, value_info, and outputs so B that comes from a graph input + # is visible when deriving K. + value_info_map = {vi.name: vi for vi in model.graph.input} + value_info_map.update({vi.name: vi for vi in model.graph.value_info}) value_info_map.update({vi.name: vi for vi in model.graph.output}) nodes_to_exclude = [] @@ -1280,8 +1327,23 @@ def _exclude_matmuls_by_shape_inference( if dims[-1].dim_value == 1 or dims[-2].dim_value == 1: nodes_to_exclude.append(matmul_node.name) + continue elif len(dims) < 3 and any(out.dim_value == 1 for out in dims): nodes_to_exclude.append(matmul_node.name) + continue + + # Small-gemm check: applies to both INT8 and FP8 quantization. + n_dim = dims[-1].dim_value if len(dims) >= 2 else 0 + k_dim = _get_inp_b_k_dim(matmul_node, value_info_map=value_info_map) + small_n = 0 < n_dim < _MIN_MATMUL_DIM + small_k = k_dim is not None and 0 < k_dim < _MIN_MATMUL_DIM + + if small_n or small_k: + logger.debug( + f"Excluding small-dim MatMul from quantization: {matmul_node.name} " + f"(N={n_dim}, K={k_dim}, threshold={_MIN_MATMUL_DIM})" + ) + nodes_to_exclude.append(matmul_node.name) return nodes_to_exclude @@ -1295,10 +1357,20 @@ def _exclude_matmuls_by_inference( calibration_data_reader: CalibrationDataReader, calibration_eps: list[str], ) -> list[str]: - """Use actual inference to find MatMuls with dimension 1.""" - # Add matmul outputs to model outputs + """Use actual inference to find MatMuls with dimension 1 or small K/N.""" + # Add matmul outputs and second-input outputs to model outputs + existing_output_names = {out.name for out in model.graph.output} for matmul_node in matmul_nodes: - model.graph.output.extend([onnx.ValueInfoProto(name=matmul_node.outputs[0].name)]) + out_name = matmul_node.outputs[0].name + if out_name not in existing_output_names: + model.graph.output.extend([onnx.ValueInfoProto(name=out_name)]) + existing_output_names.add(out_name) + # Also add second input for K-dimension check (only if it's a Variable, not a Constant) + if isinstance(matmul_node.inputs[1], Variable): + inp_b_name = matmul_node.inputs[1].name + if inp_b_name not in existing_output_names: + model.graph.output.extend([onnx.ValueInfoProto(name=inp_b_name)]) + existing_output_names.add(inp_b_name) output_map = get_extended_model_outputs( onnx_path, @@ -1319,8 +1391,23 @@ def _exclude_matmuls_by_inference( or matmul_output.shape[-2] == 1 ): nodes_to_exclude.append(matmul_node.name) + continue elif len(matmul_output.shape) < 3 and any(out == 1 for out in matmul_output.shape): nodes_to_exclude.append(matmul_node.name) + continue + + # Small-gemm check: applies to both INT8 and FP8 quantization. + n_dim = matmul_output.shape[-1] if len(matmul_output.shape) >= 2 else 0 + k_dim = _get_inp_b_k_dim(matmul_node, output_map=output_map) + small_n = 0 < n_dim < _MIN_MATMUL_DIM + small_k = k_dim is not None and 0 < k_dim < _MIN_MATMUL_DIM + + if small_n or small_k: + logger.debug( + f"Excluding small-dim MatMul from quantization: {matmul_node.name} " + f"(N={n_dim}, K={k_dim}, threshold={_MIN_MATMUL_DIM})" + ) + nodes_to_exclude.append(matmul_node.name) return nodes_to_exclude diff --git a/tests/unit/onnx/quantization/test_graph_utils.py b/tests/unit/onnx/quantization/test_graph_utils.py index 1deaa1b8d3..60f67dba7a 100644 --- a/tests/unit/onnx/quantization/test_graph_utils.py +++ b/tests/unit/onnx/quantization/test_graph_utils.py @@ -13,11 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import numpy as np import onnx_graphsurgeon as gs import pytest +from onnx import TensorProto, helper -from modelopt.onnx.quantization.graph_utils import find_nodes_from_convs_to_exclude +from modelopt.onnx.quantization.graph_utils import ( + _exclude_matmuls_by_inference, + _exclude_matmuls_by_shape_inference, + _get_inp_b_k_dim, + find_nodes_from_convs_to_exclude, +) def _make_conv_graph(output_channels, input_channels, kernel_shape=(3, 3), name="Conv_0"): @@ -85,3 +93,260 @@ def test_fp8_channels_below_16_excluded_by_general_check(oc, ic): graph = _make_conv_graph(output_channels=oc, input_channels=ic, kernel_shape=(3, 3)) excluded = find_nodes_from_convs_to_exclude(graph, quantize_mode="fp8") assert "Conv_0" in excluded + + +def _make_matmul_model(m, k, n, name="MatMul_0", inp_b_constant=True): + """Build a minimal ONNX model with a single MatMul: [M, K] x [K, N] -> [M, N].""" + inp_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [m, k]) + out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [m, n]) + + if inp_b_constant: + b_init = helper.make_tensor("B", TensorProto.FLOAT, [k, n], np.ones(k * n).tolist()) + matmul = helper.make_node("MatMul", ["A", "B"], ["Y"], name=name) + graph = helper.make_graph([matmul], "test", [inp_a], [out], initializer=[b_init]) + else: + inp_b = helper.make_tensor_value_info("B", TensorProto.FLOAT, [k, n]) + matmul = helper.make_node("MatMul", ["A", "B"], ["Y"], name=name) + graph = helper.make_graph([matmul], "test", [inp_a, inp_b], [out]) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + return model + + +def _get_nodes_by_op(model, op): + """Import an ONNX model and return its gs.Nodes whose op matches ``op``.""" + graph = gs.import_onnx(model) + return [n for n in graph.nodes if n.op == op] + + +def test_get_inp_b_k_dim_constant(): + """K dimension should be read from the Constant weight shape.""" + model = _make_matmul_model(m=32, k=8, n=64) + nodes = _get_nodes_by_op(model, "MatMul") + assert _get_inp_b_k_dim(nodes[0]) == 8 + + +def test_get_inp_b_k_dim_variable_with_output_map(): + """K dimension should be read from output_map for Variable inputs.""" + model = _make_matmul_model(m=32, k=10, n=64, inp_b_constant=False) + nodes = _get_nodes_by_op(model, "MatMul") + output_map = {"B": np.zeros((10, 64))} + assert _get_inp_b_k_dim(nodes[0], output_map=output_map) == 10 + + +def test_get_inp_b_k_dim_returns_none_when_unknown(): + """Should return None if K cannot be determined.""" + model = _make_matmul_model(m=32, k=8, n=64, inp_b_constant=False) + nodes = _get_nodes_by_op(model, "MatMul") + assert _get_inp_b_k_dim(nodes[0]) is None + + +@pytest.mark.parametrize( + ("m", "k", "n", "expected_excluded"), + [ + (32, 64, 8, True), + (32, 64, 15, True), + (32, 8, 64, True), + (32, 15, 64, True), + (32, 8, 8, True), + (32, 64, 16, False), + (32, 16, 64, False), + (32, 64, 64, False), + (32, 32, 32, False), + ], +) +def test_matmul_small_gemm_exclusion(m, k, n, expected_excluded): + """MatMuls with N or K < 16 should be excluded by shape inference.""" + model = _make_matmul_model(m=m, k=k, n=n) + nodes = _get_nodes_by_op(model, "MatMul") + calibration_shapes = {"A": [m, k]} + excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes) + if expected_excluded: + assert "MatMul_0" in excluded + else: + assert "MatMul_0" not in excluded + + +def test_matmul_gemv_excluded(): + """MatMul with N=1 (GEMV) should be excluded regardless of other dims.""" + model = _make_matmul_model(m=32, k=64, n=1) + nodes = _get_nodes_by_op(model, "MatMul") + calibration_shapes = {"A": [32, 64]} + excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes) + assert "MatMul_0" in excluded + + +def test_matmul_large_dims_not_excluded(): + """MatMul with all large dims should not be excluded.""" + model = _make_matmul_model(m=128, k=256, n=64) + nodes = _get_nodes_by_op(model, "MatMul") + calibration_shapes = {"A": [128, 256]} + excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes) + assert "MatMul_0" not in excluded + + +def _make_gemm_model(m, k, n, trans_b, name="Gemm_0"): + """Build a minimal ONNX model with a single Gemm node and a constant B. + + If trans_b is 1, B has shape [N, K] (K is last axis). + Otherwise B has shape [K, N]. + """ + inp_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [m, k]) + out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [m, n]) + + b_shape = [n, k] if trans_b else [k, n] + b_init = helper.make_tensor( + "B", TensorProto.FLOAT, b_shape, np.ones(b_shape[0] * b_shape[1]).tolist() + ) + gemm = helper.make_node("Gemm", ["A", "B"], ["Y"], name=name, transB=trans_b) + graph = helper.make_graph([gemm], "test", [inp_a], [out], initializer=[b_init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + return model + + +@pytest.mark.parametrize("trans_b", [0, 1]) +def test_get_inp_b_k_dim_gemm_transb_constant(trans_b): + """Gemm should honor transB when deriving K from a Constant B.""" + model = _make_gemm_model(m=32, k=10, n=64, trans_b=trans_b) + nodes = _get_nodes_by_op(model, "Gemm") + assert _get_inp_b_k_dim(nodes[0]) == 10 + + +@pytest.mark.parametrize("trans_b", [0, 1]) +def test_get_inp_b_k_dim_gemm_transb_output_map(trans_b): + """Gemm should honor transB when deriving K from an output_map.""" + # Build with a Variable B so the node's input is not a Constant. + inp_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [32, 10]) + inp_b = helper.make_tensor_value_info("B", TensorProto.FLOAT, [64, 10] if trans_b else [10, 64]) + out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 64]) + gemm = helper.make_node("Gemm", ["A", "B"], ["Y"], name="Gemm_0", transB=trans_b) + graph = helper.make_graph([gemm], "test", [inp_a, inp_b], [out]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + nodes = _get_nodes_by_op(model, "Gemm") + + b_runtime_shape = (64, 10) if trans_b else (10, 64) + output_map = {"B": np.zeros(b_runtime_shape)} + assert _get_inp_b_k_dim(nodes[0], output_map=output_map) == 10 + + +def test_gemm_small_k_excluded_with_transb(): + """Gemm with transB=1 and small K should be excluded (regression: prior code read N).""" + # N=64 is large; K=8 is small. With transB=1, B=[N,K]=[64,8], K axis is -1. + # If _get_inp_b_k_dim ignored transB it would read 64 (N) and not exclude. + model = _make_gemm_model(m=32, k=8, n=64, trans_b=1) + nodes = _get_nodes_by_op(model, "Gemm") + calibration_shapes = {"A": [32, 8]} + excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes) + assert "Gemm_0" in excluded + + +def test_gemm_large_dims_not_excluded_with_transb(): + """Gemm with transB=1 and all large dims should NOT be excluded.""" + model = _make_gemm_model(m=32, k=64, n=64, trans_b=1) + nodes = _get_nodes_by_op(model, "Gemm") + calibration_shapes = {"A": [32, 64]} + excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes) + assert "Gemm_0" not in excluded + + +def _make_matmul_model_graph_input_b(m, k, n, name="MatMul_0"): + """MatMul where B is a graph input (its shape lives in model.graph.input only).""" + inp_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [m, k]) + inp_b = helper.make_tensor_value_info("B", TensorProto.FLOAT, [k, n]) + out = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [m, n]) + matmul = helper.make_node("MatMul", ["A", "B"], ["Y"], name=name) + graph = helper.make_graph([matmul], "test", [inp_a, inp_b], [out]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + +def test_matmul_small_k_graph_input_b_excluded(): + """Small-K MatMul whose B is a graph input should still be excluded. + + Regression: previous value_info_map only covered model.graph.value_info/output, + missing graph inputs, so K was undetectable and the MatMul wasn't excluded. + """ + model = _make_matmul_model_graph_input_b(m=32, k=8, n=64) + nodes = _get_nodes_by_op(model, "MatMul") + calibration_shapes = {"A": [32, 8], "B": [8, 64]} + excluded = _exclude_matmuls_by_shape_inference(model, nodes, calibration_shapes) + assert "MatMul_0" in excluded + + +@pytest.mark.parametrize( + ("k", "n", "expected_excluded"), + [ + (8, 64, True), + (64, 8, True), + (64, 64, False), + ], +) +def test_exclude_matmuls_by_inference_runtime_path(k, n, expected_excluded): + """Exercise the runtime-inference path with B as a graph input (read from output_map).""" + m = 32 + model = _make_matmul_model_graph_input_b(m=m, k=k, n=n) + nodes = _get_nodes_by_op(model, "MatMul") + + # Mock get_extended_model_outputs to return a synthetic output_map so we don't + # need an actual ORT session. + fake_output_map = { + "Y": np.zeros((m, n), dtype=np.float32), + "B": np.zeros((k, n), dtype=np.float32), + } + with mock.patch( + "modelopt.onnx.quantization.graph_utils.get_extended_model_outputs", + return_value=fake_output_map, + ): + excluded = _exclude_matmuls_by_inference( + onnx_path="unused.onnx", + model=model, + matmul_nodes=nodes, + use_external_data_format=False, + intermediate_generated_files=[], + calibration_data_reader=None, + calibration_eps=["cpu"], + ) + if expected_excluded: + assert "MatMul_0" in excluded + else: + assert "MatMul_0" not in excluded + + +def test_exclude_matmuls_by_inference_dedupes_added_outputs(): + """Two MatMuls sharing the same Variable B must not create duplicate graph outputs.""" + # Build two MatMuls sharing B as a graph input. + m, k, n = 32, 8, 64 + inp_a1 = helper.make_tensor_value_info("A1", TensorProto.FLOAT, [m, k]) + inp_a2 = helper.make_tensor_value_info("A2", TensorProto.FLOAT, [m, k]) + inp_b = helper.make_tensor_value_info("B", TensorProto.FLOAT, [k, n]) + out1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [m, n]) + out2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [m, n]) + mm1 = helper.make_node("MatMul", ["A1", "B"], ["Y1"], name="MatMul_0") + mm2 = helper.make_node("MatMul", ["A2", "B"], ["Y2"], name="MatMul_1") + graph = helper.make_graph([mm1, mm2], "test", [inp_a1, inp_a2, inp_b], [out1, out2]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + nodes = _get_nodes_by_op(model, "MatMul") + + fake_output_map = { + "Y1": np.zeros((m, n), dtype=np.float32), + "Y2": np.zeros((m, n), dtype=np.float32), + "B": np.zeros((k, n), dtype=np.float32), + } + with mock.patch( + "modelopt.onnx.quantization.graph_utils.get_extended_model_outputs", + return_value=fake_output_map, + ): + excluded = _exclude_matmuls_by_inference( + onnx_path="unused.onnx", + model=model, + matmul_nodes=nodes, + use_external_data_format=False, + intermediate_generated_files=[], + calibration_data_reader=None, + calibration_eps=["cpu"], + ) + output_names = [o.name for o in model.graph.output] + # B should appear only once in the graph outputs. + assert output_names.count("B") == 1 + # Both MatMuls should be excluded (small K). + assert "MatMul_0" in excluded + assert "MatMul_1" in excluded