3535 _core .RefAttr ,
3636 _protocols .GraphProtocol ,
3737 Sequence [_protocols .GraphProtocol ],
38+ onnx .GraphProto ,
3839 _protocols .TypeProtocol ,
3940 Sequence [_protocols .TypeProtocol ],
4041 None ,
@@ -60,10 +61,15 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
6061 if isinstance (attr , (_core .TensorBase , onnx .TensorProto , _protocols .TensorProtocol )):
6162 # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
6263 return _enums .AttributeType .TENSOR
63- if isinstance (attr , (_core .Graph , _protocols .GraphProtocol )):
64+ if isinstance (attr , Sequence ) and all (
65+ isinstance (x , (_core .TensorBase , onnx .TensorProto , _protocols .TensorProtocol ))
66+ for x in attr
67+ ):
68+ return _enums .AttributeType .TENSORS
69+ if isinstance (attr , (_core .Graph , onnx .GraphProto , _protocols .GraphProtocol )):
6470 return _enums .AttributeType .GRAPH
6571 if isinstance (attr , Sequence ) and all (
66- isinstance (x , (_core .Graph , _protocols .GraphProtocol )) for x in attr
72+ isinstance (x , (_core .Graph , onnx . GraphProto , _protocols .GraphProtocol )) for x in attr
6773 ):
6874 return _enums .AttributeType .GRAPHS
6975 if isinstance (
@@ -145,11 +151,27 @@ def convert_attribute(
145151 if isinstance (attr , (_core .TensorBase , _protocols .TensorProtocol )):
146152 return _core .AttrTensor (name , attr )
147153 if isinstance (attr , onnx .TensorProto ):
148- return _core .AttrTensor (name , serde .TensorProtoTensor (attr ))
154+ return _core .AttrTensor (name , serde .deserialize_tensor (attr ))
155+ if attr_type == _enums .AttributeType .TENSORS :
156+ tensors = []
157+ for t in attr : # type: ignore[union-attr]
158+ if isinstance (t , onnx .TensorProto ):
159+ tensors .append (_core .AttrTensor (name , serde .deserialize_tensor (t )))
160+ else :
161+ tensors .append (t ) # type: ignore[arg-type]
162+ return _core .AttrTensors (name , tensors ) # type: ignore[arg-type]
149163 if attr_type == _enums .AttributeType .GRAPH :
164+ if isinstance (attr , onnx .GraphProto ):
165+ attr = serde .deserialize_graph (attr )
150166 return _core .AttrGraph (name , attr ) # type: ignore[arg-type]
151167 if attr_type == _enums .AttributeType .GRAPHS :
152- return _core .AttrGraphs (name , attr ) # type: ignore[arg-type]
168+ graphs = []
169+ for graph in attr : # type: ignore[union-attr]
170+ if isinstance (graph , onnx .GraphProto ):
171+ graphs .append (serde .deserialize_graph (graph ))
172+ else :
173+ graphs .append (graph ) # type: ignore[arg-type]
174+ return _core .AttrGraphs (name , graphs ) # type: ignore[arg-type]
153175 if attr_type == _enums .AttributeType .TYPE_PROTO :
154176 return _core .AttrTypeProto (name , attr ) # type: ignore[arg-type]
155177 if attr_type == _enums .AttributeType .TYPE_PROTOS :
0 commit comments