Skip to content

Commit 76e6be1

Browse files
committed
Fix pre-commit lint issues (N806 + ruff format)
- Rename local `FP8_MAX` to `fp8_max` in `FP8QuantExporter._insert_conv_weight_dq_nodes`. - Rename local `DQ_OPS` to `dq_ops` in `fold_dq_fp32_to_fp16_casts`. - Apply ruff-format to collapse a broken-up chained expression in fp8_exporter.py. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 15a24d8 commit 76e6be1

2 files changed

Lines changed: 5 additions & 7 deletions

File tree

modelopt/onnx/export/fp8_exporter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int:
121121
Returns:
122122
Number of Conv weight DQ nodes inserted.
123123
"""
124-
FP8_MAX = 448.0
124+
fp8_max = 448.0
125125
count = 0
126126

127127
for node in list(graph.nodes):
@@ -142,12 +142,10 @@ def _quantize_conv_weights_to_fp8(graph: gs.Graph) -> int:
142142
amax = torch_weights.abs().max().float()
143143
if amax == 0:
144144
continue
145-
scale_val = (amax / FP8_MAX).item()
145+
scale_val = (amax / fp8_max).item()
146146

147147
# Quantize weights to FP8 (WAR: numpy doesn't support fp8)
148-
fp8_data = (
149-
(torch_weights / scale_val).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
150-
)
148+
fp8_data = (torch_weights / scale_val).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
151149
fp8_tensor = onnx.TensorProto()
152150
fp8_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
153151
fp8_tensor.dims.extend(fp8_data.shape)

modelopt/onnx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,7 +1519,7 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15191519
Returns:
15201520
The ONNX model with Cast nodes removed and DQ outputs set to FP16.
15211521
"""
1522-
DQ_OPS = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
1522+
dq_ops = {"DequantizeLinear", "TRT_FP8DequantizeLinear"}
15231523

15241524
# Build a map of tensor name -> producer node
15251525
producer_map: dict[str, onnx.NodeProto] = {}
@@ -1547,7 +1547,7 @@ def fold_dq_fp32_to_fp16_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15471547

15481548
# Check: producer is a DQ node
15491549
producer = producer_map.get(node.input[0])
1550-
if producer is None or producer.op_type not in DQ_OPS:
1550+
if producer is None or producer.op_type not in dq_ops:
15511551
continue
15521552

15531553
# Convert the DQ scale initializer from FP32 to FP16

0 commit comments

Comments
 (0)