1- from typing import NamedTuple , Callable , Optional , List , Dict , Union
1+ import logging
2+ from typing import NamedTuple , Optional , List , Dict , Union
23
34import onnx
4- from ebm2onnx import __version__
5- from .utils import get_latest_opset_version
5+ from onnx .helper import make_opsetid
66
77from . import context as _context
88
@@ -14,6 +14,7 @@ class Graph(NamedTuple):
1414 transients : List [onnx .ValueInfoProto ] = []
1515 nodes : List [onnx .NodeProto ] = []
1616 initializers : List [onnx .TensorProto ] = []
17+ opsets : Dict [str , int ] = {}
1718
1819
1920def extend (i , val ):
@@ -52,12 +53,18 @@ def from_onnx(model) -> Graph:
5253 Returns:
5354 A Graph object.
5455 """
56+ opsets = {
57+ op .domain : op .version
58+ for op in model .opset_import
59+ }
60+
5561 return Graph (
5662 context = _context .create (),
5763 inputs = [n for n in model .graph .input ],
5864 outputs = [n for n in model .graph .output ],
5965 nodes = [n for n in model .graph .node ],
6066 initializers = [n for n in model .graph .initializer ],
67+ opsets = opsets ,
6168 )
6269
6370
@@ -74,13 +81,19 @@ def to_onnx(
7481
7582 Args:
7683 graph: The graph object
77- target_opset: the target opset to use when converting ot onnx, can be an int or a dict
84+ target_opset: [Optional][Deprecated] the target opset to use when converting ot onnx, can be an int or a dict
7885 name: [Optional] An existing ONNX model
7986
8087 Returns:
8188 A Graph object.
8289 """
83- #outputs = graph.transients
90+ if target_opset :
91+ logging .warning ("to_onnx: target_opset argument is deprecated" )
92+
93+ opset_imports = [
94+ make_opsetid (domain = domain , version = version )
95+ for domain ,version in graph .opsets .items ()
96+ ]
8497
8598 graph = onnx .helper .make_graph (
8699 nodes = graph .nodes ,
@@ -89,32 +102,15 @@ def to_onnx(
89102 outputs = graph .outputs ,
90103 initializer = graph .initializers ,
91104 )
92- model = onnx .helper .make_model (graph , producer_name = 'ebm2onnx' )
93-
94- #producer_name = "interpretml/ebm2onnx"
95- #producer_version = __version__
96-
97- #domain
98- #model_version
99- #doc_string
100105
101- #metadata_props
102-
103- # set opset versions
104- if target_opset is not None :
105- if type (target_opset ) is int :
106- model .opset_import [0 ].version = target_opset
107- elif type (target_opset ) is dict :
108- del model .opset_import [:]
109-
110- for k , v in target_opset .items ():
111- opset = model .opset_import .add ()
112- opset .domain = k
113- opset .version = v
114- else :
115- raise ValueError (f"ebm2onnx.graph.to_onnx: invalid type for target_opset: { type (target_opset )} ." )
116- else :
117- model .opset_import [0 ].version = get_latest_opset_version ()
106+ # create the onnx model from the graph.
107+ # The onnx library will set the ir version to the minimal required ir that
108+ # is compatible with the opset_imports provided.
109+ model = onnx .helper .make_model_gen_version (
110+ graph ,
111+ producer_name = 'ebm2onnx' ,
112+ opset_imports = opset_imports ,
113+ )
118114
119115 return model
120116
@@ -210,4 +206,10 @@ def merge(*args):
210206 nodes = extend (g .nodes , graph .nodes ),
211207 )
212208
209+ # merge opsets, keep higher version for each domain
210+ for domain ,version in graph .opsets .items ():
211+ cur_version = g .opsets .get (domain , - 1 )
212+ if version > cur_version :
213+ g .opsets [domain ] = version
214+
213215 return g
0 commit comments