-
Notifications
You must be signed in to change notification settings - Fork 354
Add EfficientViT support for torch_onnx quantization workflow #1254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1504,6 +1504,130 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: | |
| return onnx_model | ||
|
|
||
|
|
||
| def fix_fp16_fp32_mismatches(model: onnx.ModelProto) -> onnx.ModelProto: | ||
| """Insert Cast nodes to resolve FP32/FP16 type mismatches after blocked-op FP16 conversion. | ||
|
|
||
| After convert_float_to_float16 with an op_block_list, FP32 data from blocked ops | ||
| (e.g., QDQ paths) can flow into nodes whose other inputs are FP16. TensorRT | ||
| --stronglyTyped rejects such mismatches. This function propagates "real" types | ||
| through the graph and inserts FP32->FP16 Cast nodes where needed. | ||
|
|
||
| Note: value_info types are unreliable after convert_float_to_float16 with blocked ops | ||
| (metadata may say FP16 even when actual data is FP32), so this function re-derives | ||
| types by following op semantics. | ||
|
|
||
| Args: | ||
| model: The ONNX model to fix. | ||
|
|
||
| Returns: | ||
| The modified ONNX model with Cast nodes inserted to resolve mismatches. | ||
| """ | ||
| FLOAT = onnx.TensorProto.FLOAT | ||
| FLOAT16 = onnx.TensorProto.FLOAT16 | ||
|
|
||
| # Ops whose data inputs must all have the same type in TRT stronglyTyped mode. | ||
| _ELEMENTWISE_OPS = { | ||
| "Add", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Equal", "Less", | ||
| "Greater", "Where", "Sum", "Mean", "Concat", | ||
| } | ||
|
|
||
| # Ops that are FP32-only (QDQ) — never cast their I/O. | ||
| _BLOCKED_OPS = {"QuantizeLinear", "DequantizeLinear"} | ||
|
|
||
| # --- Step 1: Propagate real element types through the graph. --- | ||
| real_type: dict[str, int] = {} | ||
|
|
||
| # Seed from graph inputs and initializers (these are authoritative). | ||
| for inp in model.graph.input: | ||
| real_type[inp.name] = inp.type.tensor_type.elem_type | ||
| for init in model.graph.initializer: | ||
| real_type[init.name] = init.data_type | ||
|
|
||
| # Process nodes in topological order. | ||
| for node in model.graph.node: | ||
| if node.op_type == "Constant": | ||
| for attr in node.attribute: | ||
| if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: | ||
| for out in node.output: | ||
| real_type[out] = attr.t.data_type | ||
| continue | ||
|
|
||
| if node.op_type == "Cast": | ||
| cast_to = get_cast_to_type(node) | ||
| for out in node.output: | ||
| real_type[out] = cast_to | ||
| continue | ||
|
|
||
| if node.op_type in _BLOCKED_OPS: | ||
| for out in node.output: | ||
| real_type[out] = FLOAT | ||
| continue | ||
|
|
||
| # For other ops: output type matches the predominant data-input type. | ||
| data_types = [] | ||
| for inp_name in node.input: | ||
| if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16): | ||
| data_types.append(real_type[inp_name]) | ||
|
|
||
| if data_types: | ||
| out_type = FLOAT if FLOAT in data_types else FLOAT16 | ||
| else: | ||
| out_type = FLOAT16 | ||
|
|
||
| for out in node.output: | ||
| real_type[out] = out_type | ||
|
Comment on lines
+1528
to
+1578
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
python - <<'PY'
import onnx
for op in ("Equal", "Less", "Greater"):
schema = onnx.defs.get_schema(op, 20)
print(f"\n{op}")
for out in schema.outputs:
print(" output:", out.name, out.type_str)
for constraint in schema.type_constraints:
print(" constraint:", constraint.type_param_str, list(constraint.allowed_type_strs))
PYRepository: NVIDIA/Model-Optimizer Length of output: 182 🏁 Script executed: # First, examine the actual code in the file to understand the current state
cd /tmp && find . -name "utils.py" -path "*/modelopt/onnx/*" 2>/dev/null | head -1Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Let's locate and examine the file
find . -name "utils.py" -path "*/modelopt/onnx/*" | head -1 | xargs wc -lRepository: NVIDIA/Model-Optimizer Length of output: 103 🏁 Script executed: # Examine the specific lines mentioned in the review (1528-1578)
find . -name "utils.py" -path "*/modelopt/onnx/*" | head -1 | xargs sed -n '1500,1600p'Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Try installing onnx and running the verification script with more context
pip install -q onnx 2>&1 | tail -5
python - <<'PY'
try:
import onnx
print("ONNX installed successfully")
for op in ("Equal", "Less", "Greater", "Add"):
schema = onnx.defs.get_schema(op, 20)
print(f"\n{op}:")
for out in schema.outputs:
print(f" output: {out.name}, type: {out.type_str}")
except Exception as e:
print(f"Error: {e}")
PYRepository: NVIDIA/Model-Optimizer Length of output: 464 🏁 Script executed: # Get full ONNX schema with all type constraints for comparison ops
python - <<'PY'
import onnx
for op in ("Equal", "Less", "Greater"):
schema = onnx.defs.get_schema(op, 20)
print(f"\n{op}:")
print(f" Outputs:")
for out in schema.outputs:
print(f" {out.name}: {out.type_str}")
print(f" Type Constraints:")
for constraint in schema.type_constraints:
print(f" {constraint.type_param_str}: {constraint.allowed_type_strs}")
PYRepository: NVIDIA/Model-Optimizer Length of output: 982 🏁 Script executed: # Now locate and read the actual file with the function
fd utils.py modelopt/onnx | head -1 | xargs wc -lRepository: NVIDIA/Model-Optimizer Length of output: 101 🏁 Script executed: # Read the actual function around lines 1528-1578
fd utils.py modelopt/onnx | head -1 | xargs sed -n '1520,1590p'Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Find the correct file - look for utils.py in modelopt/onnx (not just autocast subdir)
find modelopt/onnx -name "utils.py" -type fRepository: NVIDIA/Model-Optimizer Length of output: 165 🏁 Script executed: # Also search for the function name to locate it
rg "fix_fp16_fp32_mismatches" modelopt/Repository: NVIDIA/Model-Optimizer Length of output: 343 🏁 Script executed: # Read the actual function from the correct file
sed -n '1500,1650p' modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 5799 🏁 Script executed: # Verify if there are other comparison operators that should also output BOOL
python - <<'PY'
import onnx
comparison_ops = ["Equal", "Less", "Greater", "LessOrEqual", "GreaterOrEqual"]
for op in comparison_ops:
try:
schema = onnx.defs.get_schema(op, 20)
print(f"{op}:")
for constraint in schema.type_constraints:
if constraint.type_param_str == "T1" or "bool" in str(constraint.allowed_type_strs):
print(f" {constraint.type_param_str}: {constraint.allowed_type_strs}")
except Exception as e:
print(f"{op}: Error - {e}")
PYRepository: NVIDIA/Model-Optimizer Length of output: 483 🏁 Script executed: # Check if LessOrEqual and GreaterOrEqual are also in the code
grep -n "LessOrEqual\|GreaterOrEqual" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Double-check the logic of the fallback case in the code
# to confirm comparison ops are being treated like regular ops
sed -n '1555,1575p' modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer Length of output: 804 Special-case comparison outputs as The type propagation in Step 1 treats 🐛 Suggested fix for node in model.graph.node:
if node.op_type == "Constant":
@@
if node.op_type in _BLOCKED_OPS:
for out in node.output:
real_type[out] = FLOAT
continue
+
+ if node.op_type in {"Equal", "Less", "Greater"}:
+ for out in node.output:
+ real_type[out] = onnx.TensorProto.BOOL
+ continue
# For other ops: output type matches the predominant data-input type.
data_types = []🤖 Prompt for AI Agents |
||
|
|
||
| # --- Step 2: Find nodes with mixed real types and insert Casts. --- | ||
| nodes_to_insert: list[tuple[int, onnx.NodeProto]] = [] | ||
|
|
||
| for node_idx, node in enumerate(model.graph.node): | ||
| if node.op_type not in _ELEMENTWISE_OPS: | ||
| continue | ||
|
|
||
| input_real_types = [] | ||
| for inp_name in node.input: | ||
| if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16): | ||
| input_real_types.append((inp_name, real_type[inp_name])) | ||
|
|
||
| if not input_real_types: | ||
| continue | ||
|
|
||
| has_fp32 = any(t == FLOAT for _, t in input_real_types) | ||
| has_fp16 = any(t == FLOAT16 for _, t in input_real_types) | ||
| if not (has_fp32 and has_fp16): | ||
| continue | ||
|
|
||
| # Insert Cast(FP32 -> FP16) for each FP32 input. | ||
| # Reuse existing Cast if the same input was already cast (avoids duplicate names). | ||
| for inp_idx, inp_name in enumerate(node.input): | ||
| if not inp_name or inp_name not in real_type: | ||
| continue | ||
| if real_type[inp_name] != FLOAT: | ||
| continue | ||
| cast_out_name = inp_name + "_cast_to_fp16" | ||
| if cast_out_name not in real_type: | ||
| cast_node = onnx.helper.make_node( | ||
| "Cast", | ||
| inputs=[inp_name], | ||
| outputs=[cast_out_name], | ||
| to=FLOAT16, | ||
| ) | ||
| real_type[cast_out_name] = FLOAT16 | ||
| nodes_to_insert.append((node_idx, cast_node)) | ||
| node.input[inp_idx] = cast_out_name | ||
|
|
||
| # Insert cast nodes in reverse order so positions stay valid. | ||
| for pos, cast_node in sorted(nodes_to_insert, key=lambda x: x[0], reverse=True): | ||
| model.graph.node.insert(pos, cast_node) | ||
|
|
||
| if nodes_to_insert: | ||
| logger.info( | ||
| f"Inserted {len(nodes_to_insert)} Cast node(s) to fix FP32/FP16 mismatches" | ||
| ) | ||
|
|
||
| return model | ||
|
|
||
|
|
||
| def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: | ||
| """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--no_pretrainedis no longer honored in the standard path.Because Line 421 now always calls
load_calibration_data(), the non-auto flow will still instantiatetimm.create_model(..., pretrained=True)from Line 135 even when--no_pretrainedis set. That turns a local/offline smoke run into a networked weights fetch.💡 Suggested fix
data_loader = load_calibration_data( args.timm_model_name, args.calibration_data_size, args.batch_size, device, with_labels=False, + pretrained=not args.no_pretrained, )🤖 Prompt for AI Agents