Skip to content

Commit 7a0f9b6

Browse files
lijiaqi5OYCN
authored andcommitted
Support provide ORT session options
Signed-off-by: opluss <opluss@qq.com>
1 parent 94e2b9e commit 7a0f9b6

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

  • tools/onnx-graphsurgeon/onnx_graphsurgeon/ir

tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def fold_constants(
699699
size_threshold=None,
700700
should_exclude_node=None,
701701
recurse_functions=True,
702+
ort_session_options=None,
702703
):
703704
"""
704705
Folds constants in-place in the graph. The graph's nodes and functions must be topologically
@@ -754,7 +755,8 @@ def fold_constants(
754755
recurse_functions (bool):
755756
Whether to fold constants in this graph's Functions.
756757
Defaults to True.
757-
758+
ort_session_options (Optional[onnxruntime.SessionOptions]):
759+
SessionOptions object to be used for ONNX Runtime sessions.
758760
Returns:
759761
self
760762
"""
@@ -1172,6 +1174,7 @@ def get_out_node_ids():
11721174

11731175
sess = onnxrt.InferenceSession(
11741176
export_onnx(part, do_type_check=False).SerializeToString(),
1177+
sess_options = ort_session_options,
11751178
providers=ORT_PROVIDERS,
11761179
)
11771180
values = sess.run(names, {})
@@ -1254,6 +1257,7 @@ def should_eval_foldable(tensor):
12541257
export_onnx(
12551258
graph_clone, do_type_check=False
12561259
).SerializeToString(),
1260+
sess_options = ort_session_options,
12571261
providers=ORT_PROVIDERS,
12581262
)
12591263
values = sess.run(names, {})

0 commit comments

Comments
 (0)