Skip to content

Commit f99d83e

Browse files
gcunhaseclaude
andauthored
[Auto-23/24][ONNX][Autocast] Clear stale Cast-output type metadata before ORT InferenceSession load (#1565)
### What does this PR do? Type of change: Bug fix **Error:** ``` onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (tensor(float16)) of output arg (node_5bc985fa) of node (node_5bc985fa) does not match expected type (tensor(float)). ``` **Root cause:** - Some ONNX exporters emit `graph.output` / `value_info` entries whose dtype disagrees with the upstream `Cast` node's `to` attribute. - ORT's type checker rejects such models on session load. **Fix:** - New helper `modelopt.onnx.utils.clear_stale_value_info()` reconciles each `graph.output` elem_type to its producing Cast's `to`, then clears `value_info` so ORT recomputes intermediate types. - Called from `autocast/referencerunner.py` and `quantization/quantize.py::_preprocess_onnx`. ### Usage ```python # Internal fix; no new flag introduced. Generic CLI to exercise the affected Autocast path: $ python -m modelopt.onnx.autocast --onnx=model.onnx ``` ### Testing ``` pytest tests/unit/onnx/test_onnx_utils.py::test_clear_stale_value_info ``` ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ❌ <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Automatic cleaning and reconciliation of stale ONNX type metadata before runtime and quantization; reconciled model files are produced and used when inconsistencies are found. * **Tests** * New unit tests covering metadata-cleaning behavior across cast/type scenarios to ensure correctness and prevent regressions. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1565?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> Co-authored-by: modelopt-fix-agent-bot (Claude Opus 4.7) <noreply@anthropic.com>
1 parent dd8314b commit f99d83e

4 files changed

Lines changed: 79 additions & 0 deletions

File tree

modelopt/onnx/autocast/referencerunner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def run(self, inputs=None):
295295
ort.set_default_logger_severity(3)
296296

297297
model_copy = copy.deepcopy(self.model)
298+
# Clear stale type metadata to prevent type check failures in ORT
299+
onnx_utils.clear_stale_value_info(model_copy)
298300
modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL)
299301

300302
# Load the modified model and create an inference session

modelopt/onnx/quantization/quantize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from modelopt.onnx.utils import (
7474
BASE_MIN_OPSET,
7575
QDQ_PRECISION_MIN_OPSET,
76+
clear_stale_value_info,
7677
duplicate_shared_constants,
7778
get_opset_version,
7879
name_onnx_nodes,
@@ -118,6 +119,13 @@ def _preprocess_onnx(
118119
use_external_data_format,
119120
intermediate_generated_files,
120121
)
122+
123+
# Clear stale type metadata to prevent type check failures in ORT
124+
if clear_stale_value_info(onnx_model):
125+
onnx_path = os.path.join(output_dir, f"{model_name}_reconciled.onnx")
126+
save_onnx(onnx_model, onnx_path, use_external_data_format)
127+
intermediate_generated_files.append(onnx_path)
128+
121129
if has_custom_op:
122130
onnx_path = os.path.join(output_dir, f"{model_name}_ort_support.onnx")
123131
save_onnx(onnx_model, onnx_path, use_external_data_format)

modelopt/onnx/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,3 +1860,37 @@ def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) ->
18601860
break
18611861

18621862
return model
1863+
1864+
1865+
def clear_stale_value_info(model: onnx.ModelProto) -> int:
1866+
"""Clear stale type metadata that would otherwise trip ORT's type checker.
1867+
1868+
Walks every ``Cast`` node and forces the ``elem_type`` of any
1869+
``graph.output`` entry produced by that Cast to match the Cast's ``to``
1870+
attribute (the spec-defined contract for a Cast's output dtype). Then
1871+
clears ``value_info`` wholesale so ORT/shape-inference re-derives
1872+
intermediate-tensor types from the operator graph during session setup.
1873+
1874+
Args:
1875+
model: Loaded in-memory onnx ModelProto.
1876+
1877+
Returns:
1878+
Total number of entries reconciled or cleared.
1879+
"""
1880+
cast_to_by_output = {
1881+
node.output[0]: get_cast_to_type(node)
1882+
for node in model.graph.node
1883+
if node.op_type == "Cast" and node.output
1884+
}
1885+
1886+
fixed_outputs = 0
1887+
for o in model.graph.output:
1888+
to_attr = cast_to_by_output.get(o.name)
1889+
if to_attr is not None and o.type.tensor_type.elem_type != to_attr:
1890+
o.type.tensor_type.elem_type = to_attr
1891+
fixed_outputs += 1
1892+
1893+
n_vi = len(model.graph.value_info)
1894+
if n_vi:
1895+
del model.graph.value_info[:]
1896+
return fixed_outputs + n_vi

tests/unit/onnx/test_onnx_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from modelopt.onnx.trt_utils import load_onnx_model
3232
from modelopt.onnx.utils import (
33+
clear_stale_value_info,
3334
get_input_names_from_bytes,
3435
get_output_names_from_bytes,
3536
randomize_weights_onnx_bytes,
@@ -329,3 +330,37 @@ def test_ir_version_support(tmp_path):
329330
assert model_reload.ir_version == 10, (
330331
f"The maximum supported IR version is 10, but version {model_reload.ir_version} was detected."
331332
)
333+
334+
335+
def _make_cast_model(cast_to, output_elem_type, with_value_info=False):
336+
"""Build a tiny X -> Cast(to=cast_to) -> Y model."""
337+
nodes = [make_node("Cast", ["X"], ["Y"], to=cast_to, name="cast")]
338+
inputs = [make_tensor_value_info("X", onnx.TensorProto.FLOAT16, [1, 4])]
339+
outputs = [make_tensor_value_info("Y", output_elem_type, [1, 4])]
340+
value_info = (
341+
[make_tensor_value_info("Y", onnx.TensorProto.FLOAT16, [1, 4])] if with_value_info else []
342+
)
343+
graph = make_graph(nodes, "cast_graph", inputs, outputs, value_info=value_info)
344+
return make_model(graph, producer_name="modelopt test", opset_imports=[make_opsetid("", 17)])
345+
346+
347+
@pytest.mark.parametrize(
348+
("output_elem_type", "with_value_info", "expected_count"),
349+
[
350+
(onnx.TensorProto.FLOAT16, True, 2), # stale output + value_info: reconcile + clear
351+
(onnx.TensorProto.FLOAT, False, 0), # output already matches Cast.to: no-op
352+
],
353+
ids=["stale_output_and_value_info", "no_op_when_matching"],
354+
)
355+
def test_clear_stale_value_info(output_elem_type, with_value_info, expected_count):
356+
model = _make_cast_model(
357+
cast_to=onnx.TensorProto.FLOAT,
358+
output_elem_type=output_elem_type,
359+
with_value_info=with_value_info,
360+
)
361+
362+
count = clear_stale_value_info(model)
363+
364+
assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.FLOAT
365+
assert len(model.graph.value_info) == 0
366+
assert count == expected_count

0 commit comments

Comments
 (0)