Skip to content

Commit df090ec

Browse files
committed
Fix TRT strongly-typed parsing for torch_onnx quantized exports
Three related fixes so all torch_onnx tests build a TensorRT engine under --stronglyTyped: - onnx/utils.py: Add fold_qdq_scale_fp16_to_fp32_casts and extend fold_dq_fp32_to_fp16_casts to propagate FP16 through nested DQ-scale chains (NVFP4 double-DQ). Without this, the outer DQ output was retyped to FP16 while its FP32 scale stayed in the graph, leaving the downstream Gemm in FP32 and mismatching the FP16 bias. - torch/_deploy/utils/torch_onnx.py: Run fold_qdq_scale_fp16_to_fp32_casts after fold_dq_fp32_to_fp16_casts so Cast(FP16->FP32) nodes injected in front of Q/DQ scale inputs by onnxconverter_common are removed. - examples/torch_onnx/torch_quant_to_onnx.py: Skip downsample.reduction (Swin/SwinV2 4D Linear incompatible with TRT DynamicQuantize) and pass strict=False to load_calib_amax so quantizers that never saw a tensor with calibration_data_size=1 do not crash calibration. Signed-off-by: ajrasane <arasane@nvidia.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 76e6be1 commit df090ec

3 files changed

Lines changed: 143 additions & 28 deletions

File tree

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,15 @@ def get_quant_config(quantize_mode):
122122

123123

124124
def filter_func(name):
125-
"""Filter function to exclude certain layers from quantization."""
125+
"""Filter function to exclude certain layers from quantization.
126+
127+
``downsample.reduction`` (Swin/SwinV2) is excluded because it operates on 4D tensors
128+
and TRT's DynamicQuantize layer (used for MXFP8/NVFP4) requires 2D/3D input.
129+
"""
126130
pattern = re.compile(
127131
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
128132
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|"
129-
r"maxpool|global_pool).*"
133+
r"maxpool|global_pool|downsample\.reduction).*"
130134
)
131135
return pattern.match(name) is not None
132136

@@ -192,7 +196,7 @@ def _calibrate_uncalibrated_quantizers(model, data_loader):
192196

193197
for quantizer in uncalibrated:
194198
quantizer.disable_calib()
195-
quantizer.load_calib_amax()
199+
quantizer.load_calib_amax(strict=False)
196200

197201

198202
def quantize_model(model, config, data_loader=None):

modelopt/onnx/utils.py

Lines changed: 132 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,12 +1513,18 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15131513
2. Updating the DQ output type to FP16 in value_info
15141514
3. Bypassing and removing the Cast node
15151515
1516+
NVFP4 uses a nested DQ chain (scale is itself a DQ output). When the outer DQ's scale
1517+
is produced by another DQ, recursively retype the inner DQ's chain so the whole
1518+
chain produces FP16 tensors under strongly-typed TRT parsing.
1519+
15161520
Args:
15171521
onnx_model: The ONNX model with DQ -> Cast(FP32->FP16) patterns.
15181522
15191523
Returns:
15201524
The ONNX model with Cast nodes removed and DQ outputs set to FP16.
15211525
"""
1526+
import numpy as np
1527+
15221528
dq_ops = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
15231529

15241530
# Build a map of tensor name -> producer node
@@ -1532,51 +1538,66 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15321538
init.name: init for init in onnx_model.graph.initializer
15331539
}
15341540

1541+
value_info_map: dict[str, onnx.ValueInfoProto] = {
1542+
vi.name: vi for vi in onnx_model.graph.value_info
1543+
}
1544+
1545+
retyped_dq_outputs: set[str] = set()
1546+
1547+
def _convert_fp32_init_to_fp16(init: onnx.TensorProto) -> None:
1548+
scale_data = np.frombuffer(init.raw_data, dtype=np.float32)
1549+
if not scale_data.size:
1550+
scale_data = np.array(init.float_data, dtype=np.float32)
1551+
init.data_type = onnx.TensorProto.FLOAT16
1552+
init.raw_data = scale_data.astype(np.float16).tobytes()
1553+
del init.float_data[:]
1554+
1555+
def _retype_dq_chain(dq_node: onnx.NodeProto, depth: int = 0) -> None:
1556+
"""Propagate FP16 output type down through a DQ's scale chain."""
1557+
if depth > 4 or len(dq_node.input) < 2:
1558+
return
1559+
scale_name = dq_node.input[1]
1560+
scale_init = initializer_map.get(scale_name)
1561+
if scale_init is not None:
1562+
if scale_init.data_type == onnx.TensorProto.FLOAT:
1563+
_convert_fp32_init_to_fp16(scale_init)
1564+
return
1565+
scale_producer = producer_map.get(scale_name)
1566+
if scale_producer is None or scale_producer.op_type not in dq_ops:
1567+
return
1568+
_retype_dq_chain(scale_producer, depth + 1)
1569+
if scale_name in value_info_map:
1570+
value_info_map[scale_name].type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
1571+
retyped_dq_outputs.add(scale_name)
1572+
15351573
nodes_to_remove = []
15361574
for node in onnx_model.graph.node:
15371575
if node.op_type != "Cast":
15381576
continue
15391577

1540-
# Check: Cast target is FP16
15411578
cast_to = None
15421579
for attr in node.attribute:
15431580
if attr.name == "to":
15441581
cast_to = attr.i
15451582
if cast_to != onnx.TensorProto.FLOAT16:
15461583
continue
15471584

1548-
# Check: producer is a DQ node
15491585
producer = producer_map.get(node.input[0])
15501586
if producer is None or producer.op_type not in dq_ops:
15511587
continue
15521588

1553-
# Convert the DQ scale initializer from FP32 to FP16
1554-
# DQ inputs: [input, scale, (zero_point)]
1555-
if len(producer.input) >= 2:
1556-
scale_name = producer.input[1]
1557-
if scale_name in initializer_map:
1558-
scale_init = initializer_map[scale_name]
1559-
if scale_init.data_type == onnx.TensorProto.FLOAT:
1560-
import numpy as np
1561-
1562-
scale_data = np.frombuffer(scale_init.raw_data, dtype=np.float32)
1563-
if not scale_data.size:
1564-
scale_data = np.array(scale_init.float_data, dtype=np.float32)
1565-
scale_fp16 = scale_data.astype(np.float16)
1566-
scale_init.data_type = onnx.TensorProto.FLOAT16
1567-
scale_init.raw_data = scale_fp16.tobytes()
1568-
del scale_init.float_data[:]
1569-
1570-
# Bypass the Cast node
1589+
_retype_dq_chain(producer)
1590+
15711591
_bypass_cast_node(onnx_model, node)
15721592
nodes_to_remove.append(node)
15731593

1574-
# Update the DQ output type in value_info
15751594
dq_output_name = producer.output[0]
1576-
for vi in onnx_model.graph.value_info:
1577-
if vi.name == dq_output_name:
1578-
vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
1579-
break
1595+
retyped_dq_outputs.add(dq_output_name)
1596+
1597+
for name in retyped_dq_outputs:
1598+
vi = value_info_map.get(name)
1599+
if vi is not None:
1600+
vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
15801601

15811602
logger.debug(f"Folded {len(nodes_to_remove)} DQ -> Cast(FP32->FP16) patterns")
15821603
for node in nodes_to_remove:
@@ -1585,6 +1606,92 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15851606
return onnx_model
15861607

15871608

1609+
def fold_qdq_scale_fp16_to_fp32_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1610+
"""Remove Cast(FP16->FP32) nodes feeding into Q/DQ scale inputs.
1611+
1612+
When convert_float_to_float16 blocks QuantizeLinear/DequantizeLinear, it inserts
1613+
Cast(FP16->FP32) nodes before every scale input. In opset >=20 Q/DQ natively accept
1614+
FP16 scales, and leaving the cast in place forces DQ outputs to FP32, breaking
1615+
downstream FP16 matmul/add operations under strongly-typed TRT parsing.
1616+
1617+
This function bypasses each such Cast and, when the upstream Constant is FP16,
1618+
wires the DQ output to FP16 in value_info so shape inference stays consistent.
1619+
1620+
Args:
1621+
onnx_model: The ONNX model with Cast(FP16->FP32) -> Q/DQ.scale patterns.
1622+
1623+
Returns:
1624+
The ONNX model with redundant scale-path casts removed.
1625+
"""
1626+
qdq_ops = {
1627+
"QuantizeLinear",
1628+
"DequantizeLinear",
1629+
"TRT_FP8QuantizeLinear",
1630+
"TRT_FP8DequantizeLinear",
1631+
}
1632+
1633+
producer_map: dict[str, onnx.NodeProto] = {}
1634+
consumer_map: dict[str, list[tuple[onnx.NodeProto, int]]] = {}
1635+
for node in onnx_model.graph.node:
1636+
for out in node.output:
1637+
producer_map[out] = node
1638+
for idx, inp in enumerate(node.input):
1639+
if inp:
1640+
consumer_map.setdefault(inp, []).append((node, idx))
1641+
1642+
type_map = _build_tensor_type_map(onnx_model)
1643+
1644+
nodes_to_remove: list[onnx.NodeProto] = []
1645+
dq_outputs_retyped: set[str] = set()
1646+
visited_casts: set[int] = set()
1647+
for node in onnx_model.graph.node:
1648+
if node.op_type not in qdq_ops or len(node.input) < 2:
1649+
continue
1650+
1651+
scale_name = node.input[1]
1652+
cast_node = producer_map.get(scale_name)
1653+
if cast_node is None or cast_node.op_type != "Cast":
1654+
continue
1655+
if id(cast_node) in visited_casts:
1656+
# Already handled (e.g. shared scale Cast across paired Q/DQ).
1657+
if node.op_type.endswith("DequantizeLinear"):
1658+
dq_outputs_retyped.add(node.output[0])
1659+
continue
1660+
if get_cast_to_type(cast_node) != onnx.TensorProto.FLOAT:
1661+
continue
1662+
if type_map.get(cast_node.input[0]) != onnx.TensorProto.FLOAT16:
1663+
continue
1664+
1665+
# Only bypass when every consumer of this Cast is a Q/DQ scale input; otherwise
1666+
# other ops would silently receive FP16 instead of the FP32 they requested.
1667+
cast_output = cast_node.output[0]
1668+
consumers = consumer_map.get(cast_output, [])
1669+
if not consumers or not all(
1670+
c.op_type in qdq_ops and i == 1 for c, i in consumers
1671+
):
1672+
continue
1673+
1674+
# Bypass the cast so the scale stays FP16
1675+
_bypass_cast_node(onnx_model, cast_node)
1676+
nodes_to_remove.append(cast_node)
1677+
visited_casts.add(id(cast_node))
1678+
1679+
# For DQ nodes, the output type follows the scale type — update value_info.
1680+
if node.op_type.endswith("DequantizeLinear"):
1681+
dq_outputs_retyped.add(node.output[0])
1682+
1683+
for vi in onnx_model.graph.value_info:
1684+
if vi.name in dq_outputs_retyped:
1685+
vi.type.tensor_type.elem_type = onnx.TensorProto.FLOAT16
1686+
1687+
logger.debug(f"Folded {len(nodes_to_remove)} Cast(FP16->FP32) -> Q/DQ.scale patterns")
1688+
for cast_node in nodes_to_remove:
1689+
if cast_node in onnx_model.graph.node:
1690+
onnx_model.graph.node.remove(cast_node)
1691+
1692+
return onnx_model
1693+
1694+
15881695
def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
15891696
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
15901697

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
change_casts_to_fp16,
4949
check_model_uses_external_data,
5050
fold_dq_fp32_to_fp16_casts,
51+
fold_qdq_scale_fp16_to_fp32_casts,
5152
get_input_names,
5253
get_input_shapes,
5354
get_node_names,
@@ -652,6 +653,9 @@ def get_onnx_bytes_and_metadata(
652653
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, op_list)
653654
# Remove Cast(FP32->FP16) nodes after DQ by setting DQ output to FP16 directly
654655
onnx_opt_graph = fold_dq_fp32_to_fp16_casts(onnx_opt_graph)
656+
# Remove Cast(FP16->FP32) feeding Q/DQ scales so DQ stays FP16 for downstream
657+
# MatMul/Add layers under strongly-typed TRT parsing.
658+
onnx_opt_graph = fold_qdq_scale_fp16_to_fp32_casts(onnx_opt_graph)
655659
else:
656660
onnx_opt_graph = convert_to_f16(
657661
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False

0 commit comments

Comments
 (0)