Skip to content

Commit a6532fd

Browse files
committed
fix exceptions
1 parent 821fa76 commit a6532fd

2 files changed

Lines changed: 6 additions & 3 deletions

File tree

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,10 @@ def test_make_model_with_local_functions_bug(self):
793793
meta = node.metadata_props.add()
794794
meta.key = "namespace"
795795
meta.value = "LLL"
796-
self.assertRaise(lambda: make_model_with_local_functions(model, "^LLL$"), ValueError)
796+
self.assertRaise(
797+
lambda: make_model_with_local_functions(model, "^LLL$", allow_extensions=False),
798+
ValueError,
799+
)
797800
check_model(model)
798801

799802
@hide_stdout()

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,7 @@ def make_model_with_local_functions(
19111911
print(" ...")
19121912
functions = []
19131913
new_nodes: List[Optional[NodeProto]] = list(model.graph.node)
1914-
processed = {}
1914+
processed: Dict[str, FunctionProto] = {}
19151915
unique_as_set = {k: set(v) for k, v in unique.items()}
19161916
while len(processed) < len(unique):
19171917
for key, node_indices in unique.items():
@@ -1932,7 +1932,7 @@ def make_model_with_local_functions(
19321932

19331933
function_inputs = unknown_names_within_nodes(function_nodes)
19341934
additional_nodes = check_for_non_recursivity(
1935-
node_indices, new_nodes, function_inputs, outputs, exc=False
1935+
node_indices, new_nodes, function_inputs, outputs, exc=not allow_extensions
19361936
)
19371937
if additional_nodes:
19381938
if not allow_extensions:

0 commit comments

Comments
 (0)