Skip to content

Commit e2abd9d

Browse files
committed
Fix test failures
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 40876b6 commit e2abd9d

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
import os
2222
import shutil
2323
import tempfile
24-
from contextlib import nullcontext
24+
from contextlib import nullcontext, suppress
2525
from typing import Any
2626

2727
import onnx
28+
import onnxconverter_common.float16 as _f16_module
2829
import torch
2930
import torch.nn as nn
3031
from onnx import ModelProto
@@ -58,6 +59,17 @@
5859

5960
from ..utils.onnx_optimizer import Optimizer
6061

62+
# Monkey-patch to fix onnxconverter_common bug where downstream_node is a list
63+
_original_remove_unnecessary_cast_node = _f16_module.remove_unnecessary_cast_node
64+
65+
66+
def _patched_remove_unnecessary_cast_node(graph):
67+
with suppress(AttributeError):
68+
_original_remove_unnecessary_cast_node(graph)
69+
70+
71+
_f16_module.remove_unnecessary_cast_node = _patched_remove_unnecessary_cast_node
72+
6173
ModelMetadata = dict[str, Any]
6274
ModelType = Any
6375
ValueInfoType = Any

0 commit comments

Comments
 (0)