diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py index c7ce3b6ff..4a29759ff 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py @@ -703,6 +703,7 @@ def fold_constants( size_threshold=None, should_exclude_node=None, recurse_functions=True, + ort_session_options=None, ): """ Folds constants in-place in the graph. The graph's nodes and functions must be topologically @@ -758,7 +759,8 @@ def fold_constants( recurse_functions (bool): Whether to fold constants in this graph's Functions. Defaults to True. - + ort_session_options (Optional[onnxruntime.SessionOptions]): + SessionOptions object to be used for ONNX Runtime sessions. Returns: self """ @@ -1176,6 +1178,7 @@ def get_out_node_ids(): sess = onnxrt.InferenceSession( export_onnx(part, do_type_check=False).SerializeToString(), + sess_options = ort_session_options, providers=ORT_PROVIDERS, ) values = sess.run(names, {}) @@ -1258,6 +1261,7 @@ def should_eval_foldable(tensor): export_onnx( graph_clone, do_type_check=False ).SerializeToString(), + sess_options = ort_session_options, providers=ORT_PROVIDERS, ) values = sess.run(names, {})