Skip to content

Commit c7e10f4

Browse files
committed
[6106576] Address PR review feedback on edgellm shim restoration
- fp4qdq_to_2dq: look up block_size by attribute name instead of position so the shim does not silently use the wrong attribute if TRT_FP4QDQ attribute ordering changes. - _get_precision_dtype: use onnx.TensorProto.BFLOAT16 instead of the literal 16 for readability. - nvfp4_exporter: note in the docstrings of _cast_fp4 and _replace_fp4qdq_with_2dq that they are reused by the deprecated qdq_utils.fp4qdq_to_2dq shim, so a future refactor does not silently drop them. - Add direct smoke tests for quantize_weights_to_int4, quantize_weights_to_mxfp8, and fp4qdq_to_2dq that assert each shim emits a DeprecationWarning and produces the expected end-state graph (the existing tests only exercise the staged exporters). Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent c61ac69 commit c7e10f4

3 files changed

Lines changed: 105 additions & 2 deletions

File tree

modelopt/onnx/export/nvfp4_exporter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray:
3939
4040
Note: The first dimension of the array must be divisible by 2
4141
as two FP4 values are packed into a single byte.
42+
43+
Also reused by the deprecated ``modelopt.onnx.quantization.qdq_utils.fp4qdq_to_2dq``
44+
compatibility shim. Do not rename or change the signature without updating that
45+
shim (it is a load-bearing re-export for TensorRT-Edge-LLM 0.6.1).
4246
"""
4347
array_f32_t = torch.from_numpy(array)
4448
array_f32_t_shape = array_f32_t.shape
@@ -76,6 +80,10 @@ def _replace_fp4qdq_with_2dq(
7680
):
7781
"""Replaces the given node in the ONNX graph with a subgraph consisting of two DequantizeLinear nodes.
7882
83+
Also reused by the deprecated ``modelopt.onnx.quantization.qdq_utils.fp4qdq_to_2dq``
84+
compatibility shim. Do not rename or change the signature without updating that
85+
shim (it is a load-bearing re-export for TensorRT-Edge-LLM 0.6.1).
86+
7987
Args:
8088
graph: The ONNX graph containing the node to replace.
8189
node: The node to be replaced.

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str):
15551555
def _get_precision_dtype() -> str:
15561556
precision_dtype = "Half"
15571557
for initializer in graph.initializer:
1558-
if initializer.data_type == 16:
1558+
if initializer.data_type == onnx.TensorProto.BFLOAT16:
15591559
precision_dtype = "BFloat16"
15601560
break
15611561
return precision_dtype
@@ -1570,7 +1570,9 @@ def _get_precision_dtype() -> str:
15701570
for node in fp4_qdq_nodes:
15711571
idx1 = initializer_indices.get(node.input[0], None)
15721572
assert idx1 is not None, f"Initializer for weight '{node.input[0]}' not found."
1573-
block_size = node.attribute[0].i
1573+
block_size_attr = next((attr for attr in node.attribute if attr.name == "block_size"), None)
1574+
assert block_size_attr is not None, f"block_size attribute not found for {node.name}"
1575+
block_size = block_size_attr.i
15741576
initializers_to_delete.append(initializers[idx1].name)
15751577
logger.debug(
15761578
f"Processing FP4QDQ node for weight {node.input[0]} with block size {block_size}"

tests/unit/onnx/quantization/test_qdq_utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,3 +1108,96 @@ def test_constant_node_scale_path_still_patched(self):
11081108
scale_arr = numpy_helper.to_array(value_attr.t)
11091109
assert not (scale_arr == 0).any()
11101110
assert (scale_arr > 0).all()
1111+
1112+
1113+
class TestLegacyEdgeLLMShims:
1114+
"""Smoke tests for the deprecated top-level shims kept for TensorRT-Edge-LLM 0.6.1.
1115+
1116+
These are the functions edgellm 0.6.1 imports from
1117+
``modelopt.onnx.quantization.qdq_utils`` directly (not via the staged exporters).
1118+
Tests verify each shim runs end-to-end on the same fixtures used for the staged
1119+
exporters and emits a ``DeprecationWarning``.
1120+
"""
1121+
1122+
def test_quantize_weights_to_int4_shim(self):
1123+
import warnings
1124+
1125+
from modelopt.onnx.quantization.qdq_utils import quantize_weights_to_int4
1126+
1127+
model = create_test_model_with_int4_dq_reshape_transpose_matmul()
1128+
1129+
with warnings.catch_warnings(record=True) as caught:
1130+
warnings.simplefilter("always")
1131+
quantized_model = quantize_weights_to_int4(model)
1132+
1133+
assert any(
1134+
issubclass(w.category, DeprecationWarning)
1135+
and "quantize_weights_to_int4" in str(w.message)
1136+
for w in caught
1137+
)
1138+
1139+
weight_tensor = next(
1140+
init for init in quantized_model.graph.initializer if init.name == "weight"
1141+
)
1142+
assert weight_tensor.data_type == TensorProto.INT4
1143+
1144+
node_types = [node.op_type for node in quantized_model.graph.node]
1145+
assert "Reshape" not in node_types
1146+
assert "Transpose" not in node_types
1147+
1148+
def test_quantize_weights_to_mxfp8_shim(self):
1149+
import warnings
1150+
1151+
from modelopt.onnx.quantization.qdq_utils import quantize_weights_to_mxfp8
1152+
1153+
model = create_test_model_with_mxfp8_dq()
1154+
1155+
with warnings.catch_warnings(record=True) as caught:
1156+
warnings.simplefilter("always")
1157+
quantized_model = quantize_weights_to_mxfp8(model)
1158+
1159+
assert any(
1160+
issubclass(w.category, DeprecationWarning)
1161+
and "quantize_weights_to_mxfp8" in str(w.message)
1162+
for w in caught
1163+
)
1164+
1165+
weight_tensor = next(
1166+
init for init in quantized_model.graph.initializer if init.name == "linear.weight"
1167+
)
1168+
assert weight_tensor.data_type == TensorProto.FLOAT8E4M3FN
1169+
1170+
gelu_node = next(node for node in quantized_model.graph.node if node.op_type == "Gelu")
1171+
approximate_attr = next(attr for attr in gelu_node.attribute if attr.name == "approximate")
1172+
assert approximate_attr.s == b"tanh"
1173+
1174+
@pytest.mark.parametrize("with_transpose", [False, True])
1175+
def test_fp4qdq_to_2dq_shim(self, with_transpose):
1176+
import warnings
1177+
1178+
from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq
1179+
1180+
model = create_test_model_with_nvfp4_qdq(with_transpose=with_transpose)
1181+
1182+
with warnings.catch_warnings(record=True) as caught:
1183+
warnings.simplefilter("always")
1184+
converted_model = fp4qdq_to_2dq(model)
1185+
1186+
assert any(
1187+
issubclass(w.category, DeprecationWarning) and "fp4qdq_to_2dq" in str(w.message)
1188+
for w in caught
1189+
)
1190+
1191+
fp4qdq_nodes = [node for node in converted_model.graph.node if node.op_type == "TRT_FP4QDQ"]
1192+
assert len(fp4qdq_nodes) == 0
1193+
1194+
dq_nodes = [
1195+
node for node in converted_model.graph.node if node.op_type == "DequantizeLinear"
1196+
]
1197+
assert len(dq_nodes) == 2
1198+
1199+
initializer_names = {init.name for init in converted_model.graph.initializer}
1200+
assert "linear.weight_f4" in initializer_names
1201+
assert "linear.weight_f8_scale" in initializer_names
1202+
assert "linear.weight_f8_scale_f32_scale" in initializer_names
1203+
assert "linear.weight" not in initializer_names

0 commit comments

Comments
 (0)