@@ -1513,12 +1513,18 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15131513 2. Updating the DQ output type to FP16 in value_info
15141514 3. Bypassing and removing the Cast node
15151515
1516+ NVFP4 uses a nested DQ chain (scale is itself a DQ output). When the outer DQ's scale
1517+ is produced by another DQ, recursively retype the inner DQ's chain so the whole
1518+ chain produces FP16 tensors under strongly-typed TRT parsing.
1519+
15161520 Args:
15171521 onnx_model: The ONNX model with DQ -> Cast(FP32->FP16) patterns.
15181522
15191523 Returns:
15201524 The ONNX model with Cast nodes removed and DQ outputs set to FP16.
15211525 """
1526+ import numpy as np
1527+
15221528 dq_ops = {"DequantizeLinear" , "TRT_FP8DequantizeLinear" }
15231529
15241530 # Build a map of tensor name -> producer node
@@ -1532,51 +1538,66 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15321538 init .name : init for init in onnx_model .graph .initializer
15331539 }
15341540
1541+ value_info_map : dict [str , onnx .ValueInfoProto ] = {
1542+ vi .name : vi for vi in onnx_model .graph .value_info
1543+ }
1544+
1545+ retyped_dq_outputs : set [str ] = set ()
1546+
1547+ def _convert_fp32_init_to_fp16 (init : onnx .TensorProto ) -> None :
1548+ scale_data = np .frombuffer (init .raw_data , dtype = np .float32 )
1549+ if not scale_data .size :
1550+ scale_data = np .array (init .float_data , dtype = np .float32 )
1551+ init .data_type = onnx .TensorProto .FLOAT16
1552+ init .raw_data = scale_data .astype (np .float16 ).tobytes ()
1553+ del init .float_data [:]
1554+
1555+ def _retype_dq_chain (dq_node : onnx .NodeProto , depth : int = 0 ) -> None :
1556+ """Propagate FP16 output type down through a DQ's scale chain."""
1557+ if depth > 4 or len (dq_node .input ) < 2 :
1558+ return
1559+ scale_name = dq_node .input [1 ]
1560+ scale_init = initializer_map .get (scale_name )
1561+ if scale_init is not None :
1562+ if scale_init .data_type == onnx .TensorProto .FLOAT :
1563+ _convert_fp32_init_to_fp16 (scale_init )
1564+ return
1565+ scale_producer = producer_map .get (scale_name )
1566+ if scale_producer is None or scale_producer .op_type not in dq_ops :
1567+ return
1568+ _retype_dq_chain (scale_producer , depth + 1 )
1569+ if scale_name in value_info_map :
1570+ value_info_map [scale_name ].type .tensor_type .elem_type = onnx .TensorProto .FLOAT16
1571+ retyped_dq_outputs .add (scale_name )
1572+
15351573 nodes_to_remove = []
15361574 for node in onnx_model .graph .node :
15371575 if node .op_type != "Cast" :
15381576 continue
15391577
1540- # Check: Cast target is FP16
15411578 cast_to = None
15421579 for attr in node .attribute :
15431580 if attr .name == "to" :
15441581 cast_to = attr .i
15451582 if cast_to != onnx .TensorProto .FLOAT16 :
15461583 continue
15471584
1548- # Check: producer is a DQ node
15491585 producer = producer_map .get (node .input [0 ])
15501586 if producer is None or producer .op_type not in dq_ops :
15511587 continue
15521588
1553- # Convert the DQ scale initializer from FP32 to FP16
1554- # DQ inputs: [input, scale, (zero_point)]
1555- if len (producer .input ) >= 2 :
1556- scale_name = producer .input [1 ]
1557- if scale_name in initializer_map :
1558- scale_init = initializer_map [scale_name ]
1559- if scale_init .data_type == onnx .TensorProto .FLOAT :
1560- import numpy as np
1561-
1562- scale_data = np .frombuffer (scale_init .raw_data , dtype = np .float32 )
1563- if not scale_data .size :
1564- scale_data = np .array (scale_init .float_data , dtype = np .float32 )
1565- scale_fp16 = scale_data .astype (np .float16 )
1566- scale_init .data_type = onnx .TensorProto .FLOAT16
1567- scale_init .raw_data = scale_fp16 .tobytes ()
1568- del scale_init .float_data [:]
1569-
1570- # Bypass the Cast node
1589+ _retype_dq_chain (producer )
1590+
15711591 _bypass_cast_node (onnx_model , node )
15721592 nodes_to_remove .append (node )
15731593
1574- # Update the DQ output type in value_info
15751594 dq_output_name = producer .output [0 ]
1576- for vi in onnx_model .graph .value_info :
1577- if vi .name == dq_output_name :
1578- vi .type .tensor_type .elem_type = onnx .TensorProto .FLOAT16
1579- break
1595+ retyped_dq_outputs .add (dq_output_name )
1596+
1597+ for name in retyped_dq_outputs :
1598+ vi = value_info_map .get (name )
1599+ if vi is not None :
1600+ vi .type .tensor_type .elem_type = onnx .TensorProto .FLOAT16
15801601
15811602 logger .debug (f"Folded { len (nodes_to_remove )} DQ -> Cast(FP32->FP16) patterns" )
15821603 for node in nodes_to_remove :
@@ -1585,6 +1606,92 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15851606 return onnx_model
15861607
15871608
1609+ def fold_qdq_scale_fp16_to_fp32_casts (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
1610+ """Remove Cast(FP16->FP32) nodes feeding into Q/DQ scale inputs.
1611+
1612+ When convert_float_to_float16 blocks QuantizeLinear/DequantizeLinear, it inserts
1613+ Cast(FP16->FP32) nodes before every scale input. In opset >=20 Q/DQ natively accept
1614+ FP16 scales, and leaving the cast in place forces DQ outputs to FP32, breaking
1615+ downstream FP16 matmul/add operations under strongly-typed TRT parsing.
1616+
1617+ This function bypasses each such Cast and, when the upstream Constant is FP16,
1618+ wires the DQ output to FP16 in value_info so shape inference stays consistent.
1619+
1620+ Args:
1621+ onnx_model: The ONNX model with Cast(FP16->FP32) -> Q/DQ.scale patterns.
1622+
1623+ Returns:
1624+ The ONNX model with redundant scale-path casts removed.
1625+ """
1626+ qdq_ops = {
1627+ "QuantizeLinear" ,
1628+ "DequantizeLinear" ,
1629+ "TRT_FP8QuantizeLinear" ,
1630+ "TRT_FP8DequantizeLinear" ,
1631+ }
1632+
1633+ producer_map : dict [str , onnx .NodeProto ] = {}
1634+ consumer_map : dict [str , list [tuple [onnx .NodeProto , int ]]] = {}
1635+ for node in onnx_model .graph .node :
1636+ for out in node .output :
1637+ producer_map [out ] = node
1638+ for idx , inp in enumerate (node .input ):
1639+ if inp :
1640+ consumer_map .setdefault (inp , []).append ((node , idx ))
1641+
1642+ type_map = _build_tensor_type_map (onnx_model )
1643+
1644+ nodes_to_remove : list [onnx .NodeProto ] = []
1645+ dq_outputs_retyped : set [str ] = set ()
1646+ visited_casts : set [int ] = set ()
1647+ for node in onnx_model .graph .node :
1648+ if node .op_type not in qdq_ops or len (node .input ) < 2 :
1649+ continue
1650+
1651+ scale_name = node .input [1 ]
1652+ cast_node = producer_map .get (scale_name )
1653+ if cast_node is None or cast_node .op_type != "Cast" :
1654+ continue
1655+ if id (cast_node ) in visited_casts :
1656+ # Already handled (e.g. shared scale Cast across paired Q/DQ).
1657+ if node .op_type .endswith ("DequantizeLinear" ):
1658+ dq_outputs_retyped .add (node .output [0 ])
1659+ continue
1660+ if get_cast_to_type (cast_node ) != onnx .TensorProto .FLOAT :
1661+ continue
1662+ if type_map .get (cast_node .input [0 ]) != onnx .TensorProto .FLOAT16 :
1663+ continue
1664+
1665+ # Only bypass when every consumer of this Cast is a Q/DQ scale input; otherwise
1666+ # other ops would silently receive FP16 instead of the FP32 they requested.
1667+ cast_output = cast_node .output [0 ]
1668+ consumers = consumer_map .get (cast_output , [])
1669+ if not consumers or not all (
1670+ c .op_type in qdq_ops and i == 1 for c , i in consumers
1671+ ):
1672+ continue
1673+
1674+ # Bypass the cast so the scale stays FP16
1675+ _bypass_cast_node (onnx_model , cast_node )
1676+ nodes_to_remove .append (cast_node )
1677+ visited_casts .add (id (cast_node ))
1678+
1679+ # For DQ nodes, the output type follows the scale type — update value_info.
1680+ if node .op_type .endswith ("DequantizeLinear" ):
1681+ dq_outputs_retyped .add (node .output [0 ])
1682+
1683+ for vi in onnx_model .graph .value_info :
1684+ if vi .name in dq_outputs_retyped :
1685+ vi .type .tensor_type .elem_type = onnx .TensorProto .FLOAT16
1686+
1687+ logger .debug (f"Folded { len (nodes_to_remove )} Cast(FP16->FP32) -> Q/DQ.scale patterns" )
1688+ for cast_node in nodes_to_remove :
1689+ if cast_node in onnx_model .graph .node :
1690+ onnx_model .graph .node .remove (cast_node )
1691+
1692+ return onnx_model
1693+
1694+
15881695def remove_node_training_mode (onnx_model : onnx .ModelProto , node_op_type : str ) -> onnx .ModelProto :
15891696 """Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
15901697
0 commit comments