Skip to content

Commit 9910215

Browse files
authored
Use ir methods to replace onnx helper (#2091)
Ban onnx.helper and onnx.numpy_helper because they can be slow. Selectively enable usages of some with `noqa: TID251` and updated usages of the rest. Fix `ir.tensor` to generate float32 tensors when a plain python float is provided.
1 parent c60c090 commit 9910215

21 files changed

Lines changed: 193 additions & 188 deletions

File tree

onnxscript/_internal/autocast.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
import numpy as np
99
import onnx
10-
from onnx import helper, numpy_helper
1110
from onnx.defs import OpSchema
1211

13-
from onnxscript import tensor
12+
from onnxscript import ir, tensor
1413

1514
if TYPE_CHECKING:
1615
from onnxscript import converter
@@ -24,42 +23,8 @@
2423
# Utilities to convert a python value to TensorProto (for use by the script converter)
2524

2625

27-
def _py_type_to_onnx_type(pytype: type):
28-
if pytype is bool:
29-
return onnx.TensorProto.BOOL
30-
if pytype is int:
31-
return onnx.TensorProto.INT64
32-
if pytype is float:
33-
return onnx.TensorProto.FLOAT
34-
if pytype is str:
35-
return onnx.TensorProto.STRING
36-
raise ValueError(f"Tensor element of type {pytype} not supported")
37-
38-
3926
def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue):
40-
if isinstance(pyvalue, np.ndarray):
41-
return numpy_helper.from_array(pyvalue, tensor_name)
42-
if isinstance(pyvalue, list):
43-
if len(pyvalue) == 0:
44-
raise ValueError("Cannot convert an empty list to tensor")
45-
pytype = type(pyvalue[0])
46-
if not all(isinstance(e, pytype) for e in pyvalue):
47-
raise ValueError(
48-
"Cannot convert an list with elements of different types to tensor"
49-
)
50-
return helper.make_tensor(
51-
tensor_name,
52-
_py_type_to_onnx_type(pytype),
53-
[len(pyvalue)],
54-
pyvalue,
55-
)
56-
onnx_type = _py_type_to_onnx_type(type(pyvalue))
57-
if onnx_type is onnx.TensorProto.BOOL:
58-
return helper.make_tensor(tensor_name, onnx_type, [], [int(pyvalue)])
59-
if onnx_type is onnx.TensorProto.STRING:
60-
return helper.make_tensor(tensor_name, onnx_type, [], vals=[pyvalue.encode("utf-8")])
61-
62-
return helper.make_tensor(tensor_name, onnx_type, [], [pyvalue])
27+
return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name))
6328

6429

6530
_REPEATED_ATTRIBUTE_TYPES = frozenset(
@@ -103,7 +68,13 @@ def pyvalue_to_onnx_attribute(
10368
name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value)
10469
)
10570
else:
106-
return onnx.helper.make_attribute(key, value)
71+
attr = ir.convenience.convert_attribute(
72+
key,
73+
value,
74+
attr_type=ir.AttributeType(attr_type) if attr_type is not None else None,
75+
)
76+
assert isinstance(attr, ir.Attr)
77+
return ir.serde.serialize_attribute(attr)
10778

10879

10980
# Utilities to convert python values into onnxscript tensors.

onnxscript/_internal/utils.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import numpy as np
99
import onnx
10-
import onnx.helper
1110

1211
from onnxscript import tensor
1312

@@ -65,26 +64,26 @@ def add(k, v):
6564
def value_to_type_proto(val):
6665
"""Return the ONNX type of a python-value."""
6766
if isinstance(val, (np.ndarray, tensor.Tensor)):
68-
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype)
67+
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251
6968
shape = val.shape
70-
return onnx.helper.make_tensor_type_proto(elem_type, shape)
69+
return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251
7170
if isinstance(val, int):
72-
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, [])
71+
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251
7372
if isinstance(val, (float, np.float32)):
74-
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [])
73+
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251
7574
if isinstance(val, list):
7675
if len(val) > 0:
77-
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))
76+
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251
7877
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
7978
# Should be using a typed-value instead.
8079
# Treated as a sequence of tensors of float-type.
81-
return onnx.helper.make_sequence_type_proto(
82-
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None)
80+
return onnx.helper.make_sequence_type_proto( # noqa: TID251
81+
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251
8382
)
8483
if isinstance(val, numbers.Number):
8584
nparray = np.array(val)
86-
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype)
87-
return onnx.helper.make_tensor_type_proto(elem_type, [])
85+
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251
86+
return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251
8887
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")
8988

9089

@@ -93,7 +92,7 @@ def values_to_value_infos(name_values):
9392
skipping any None values.
9493
"""
9594
return [
96-
onnx.helper.make_value_info(name, value_to_type_proto(val))
95+
onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251
9796
for (name, val) in name_values
9897
if val is not None
9998
]

onnxscript/_legacy_ir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def value_as_np_array(self) -> np.ndarray | None:
142142
if isinstance(self.value, np.ndarray):
143143
return self.value
144144
if isinstance(self.value, onnx.TensorProto):
145-
return onnx.numpy_helper.to_array(self.value)
145+
return onnx.numpy_helper.to_array(self.value) # noqa: TID251
146146
return None
147147

148148
def def_node(self) -> Node | None:

onnxscript/_legacy_ir/visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3+
# ruff: noqa: TID251
34
from __future__ import annotations
45

56
import dataclasses

onnxscript/backend/onnx_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
3+
# ruff: noqa: TID251
44

55
import os
66
import textwrap

onnxscript/backend/onnx_export.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy
88
import onnx
99
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
10-
from onnx.helper import make_node
1110

1211
import onnxscript.onnx_types
1312
import onnxscript.type_annotation
@@ -68,10 +67,10 @@ def _get_const_repr(const_node):
6867
if tensor_proto.data_type in {TensorProto.FLOAT, TensorProto.INT64}:
6968
rank = len(tensor_proto.dims)
7069
if rank == 0:
71-
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1)
70+
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251
7271
return repr(array[0])
7372
if rank == 1 and tensor_proto.dims[0] < 5:
74-
return repr(list(onnx.numpy_helper.to_array(tensor_proto)))
73+
return repr(list(onnx.numpy_helper.to_array(tensor_proto))) # noqa: TID251
7574
return None
7675

7776

@@ -161,7 +160,7 @@ def _attribute_value(attr: onnx.AttributeProto):
161160
if onnx.external_data_helper.uses_external_data(tensor_proto):
162161
return tensor_proto
163162
else:
164-
return onnx.numpy_helper.to_array(tensor_proto)
163+
return onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251
165164
# TODO:
166165
# - onnx.AttributeProto.GRAPH
167166
# - onnx.AttributeProto.SPARSE_TENSOR
@@ -348,7 +347,7 @@ def _translate_graph_body(self, graph, opsets, indent=0):
348347
)
349348
self.skipped_initializers[init_py_name] = init
350349
continue
351-
node = make_node(
350+
node = onnx.helper.make_node( # noqa: TID251
352351
"Constant",
353352
[],
354353
[self._translate_onnx_var(init.name)], # type: ignore[list-item]

onnxscript/evaluator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import numpy as np
2121
import onnx
2222
import onnx.defs
23-
import onnx.helper
2423
import onnx.reference
2524
from typing_extensions import TypeAlias
2625

@@ -430,21 +429,22 @@ def make_tensor_name() -> str:
430429
num_outputs = compute_num_outputs(schema, args, kwargs)
431430
outputs = [f"output{i}" for i in range(num_outputs)]
432431

433-
node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain)
432+
node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251
434433
node.attribute.extend(
435434
make_attr(key, value) for key, value in kwargs.items() if value is not None
436435
)
437436
input_value_infos = utils.values_to_value_infos(zip(inputs, args))
438437
implicit_value_infos = utils.values_to_value_infos(implicit_args.items())
439438
output_value_infos = [
440-
onnx.helper.make_value_info(name, onnx.TypeProto()) for name in outputs
439+
onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251
440+
for name in outputs
441441
]
442442

443-
graph = onnx.helper.make_graph(
443+
graph = onnx.helper.make_graph( # noqa: TID251
444444
[node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos
445445
)
446-
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version)
447-
model = onnx.helper.make_model(
446+
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251
447+
model = onnx.helper.make_model( # noqa: TID251
448448
graph,
449449
opset_imports=[opset_id],
450450
ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain),

onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3+
# ruff: noqa: TID251
34
"""Graph building functions for torchscript graph backend."""
45

56
from __future__ import annotations

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import math
1818
from typing import Optional, Sequence, Tuple, TypeVar, Union
1919

20-
import onnx
21-
2220
from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
2321
from onnxscript.function_libs.torch_lib.ops import common as common_ops
2422
from onnxscript.function_libs.torch_lib.registration import torch_op
@@ -1798,15 +1796,11 @@ def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(
17981796
op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3])
17991797
)
18001798
logsumexp = op.Expand(0.0, query_first_three_dims)
1801-
# TODO: Eliminate `make_tensor` usage when ORT supports empty tensor.
1802-
empty_tensor_int = op.Cast(
1803-
op.ConstantOfShape(
1804-
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
1805-
),
1806-
to=INT64.dtype,
1799+
empty_tensor_int = op.ConstantOfShape(
1800+
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
18071801
)
18081802
empty_tensor_float = op.ConstantOfShape(
1809-
op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], []))
1803+
op.Constant(value=ir.tensor([], dtype=ir.DataType.FLOAT))
18101804
)
18111805
empty_int = op.Constant(value_int=0)
18121806

@@ -1881,11 +1875,8 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
18811875
logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0))
18821876

18831877
# See Note [Seed and Offset]:
1884-
empty_tensor_int = op.Cast(
1885-
op.ConstantOfShape(
1886-
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
1887-
),
1888-
to=INT64.dtype,
1878+
empty_tensor_int = op.ConstantOfShape(
1879+
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
18891880
)
18901881

18911882
return logsum_exp, empty_tensor_int

onnxscript/ir/_convenience/__init__.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_core.RefAttr,
3636
_protocols.GraphProtocol,
3737
Sequence[_protocols.GraphProtocol],
38+
onnx.GraphProto,
3839
_protocols.TypeProtocol,
3940
Sequence[_protocols.TypeProtocol],
4041
None,
@@ -60,10 +61,15 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
6061
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
6162
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
6263
return _enums.AttributeType.TENSOR
63-
if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)):
64+
if isinstance(attr, Sequence) and all(
65+
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
66+
for x in attr
67+
):
68+
return _enums.AttributeType.TENSORS
69+
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
6470
return _enums.AttributeType.GRAPH
6571
if isinstance(attr, Sequence) and all(
66-
isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
72+
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
6773
):
6874
return _enums.AttributeType.GRAPHS
6975
if isinstance(
@@ -145,11 +151,27 @@ def convert_attribute(
145151
if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
146152
return _core.AttrTensor(name, attr)
147153
if isinstance(attr, onnx.TensorProto):
148-
return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
154+
return _core.AttrTensor(name, serde.deserialize_tensor(attr))
155+
if attr_type == _enums.AttributeType.TENSORS:
156+
tensors = []
157+
for t in attr: # type: ignore[union-attr]
158+
if isinstance(t, onnx.TensorProto):
159+
tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t)))
160+
else:
161+
tensors.append(t) # type: ignore[arg-type]
162+
return _core.AttrTensors(name, tensors) # type: ignore[arg-type]
149163
if attr_type == _enums.AttributeType.GRAPH:
164+
if isinstance(attr, onnx.GraphProto):
165+
attr = serde.deserialize_graph(attr)
150166
return _core.AttrGraph(name, attr) # type: ignore[arg-type]
151167
if attr_type == _enums.AttributeType.GRAPHS:
152-
return _core.AttrGraphs(name, attr) # type: ignore[arg-type]
168+
graphs = []
169+
for graph in attr: # type: ignore[union-attr]
170+
if isinstance(graph, onnx.GraphProto):
171+
graphs.append(serde.deserialize_graph(graph))
172+
else:
173+
graphs.append(graph) # type: ignore[arg-type]
174+
return _core.AttrGraphs(name, graphs) # type: ignore[arg-type]
153175
if attr_type == _enums.AttributeType.TYPE_PROTO:
154176
return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
155177
if attr_type == _enums.AttributeType.TYPE_PROTOS:

0 commit comments

Comments
 (0)