Skip to content

Commit 83d3f41

Browse files
authored
rename api.py into typing.py (#393)
* rename api.py into typing.py * fix * fix * fix * fix
1 parent 158f774 commit 83d3f41

13 files changed

Lines changed: 124 additions & 162 deletions

File tree

.github/workflows/pyrefly.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
run: |
2222
pip install pyrefly
2323
pip install -r requirements.txt
24-
pip install transformers pandas matplotlib openpyxl
24+
pip install transformers pandas matplotlib openpyxl onnx-array-api
2525
2626
- name: Run pyrefly
2727
run: pyrefly check

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ def _get_cdist_implementation(
180180
opsets: Dict[str, int],
181181
**kwargs: Any,
182182
) -> FunctionProto:
183-
"""
184-
Returns the CDist implementation as a function.
185-
"""
183+
"""Returns the CDist implementation as a function."""
186184
assert len(node_inputs) == 2
187185
assert len(node_outputs) == 1
188186
assert opsets
@@ -191,39 +189,39 @@ def _get_cdist_implementation(
191189
metric = kwargs["metric"]
192190
assert metric in ("euclidean", "sqeuclidean")
193191
# subgraph
194-
nodes = [
195-
oh.make_node("Sub", ["next", "next_in"], ["diff"]),
196-
oh.make_node("Constant", [], ["axis"], value_ints=[1]),
197-
oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0),
198-
oh.make_node("Identity", ["next_in"], ["next_out"]),
199-
]
200192

201193
def make_value(name):
202194
value = ValueInfoProto()
203195
value.name = name
204196
return value
205197

206198
graph = oh.make_graph(
207-
nodes,
199+
[
200+
oh.make_node("Sub", ["next", "next_in"], ["diff"]),
201+
oh.make_node("Constant", [], ["axis"], value_ints=[1]),
202+
oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0),
203+
oh.make_node("Identity", ["next_in"], ["next_out"]),
204+
],
208205
"loop",
209206
[make_value("next_in"), make_value("next")],
210207
[make_value("next_out"), make_value("scan_out")],
211208
)
212209

213-
scan = oh.make_node(
214-
"Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph
215-
)
216-
final = (
217-
oh.make_node("Sqrt", ["zout"], ["z"])
218-
if metric == "euclidean"
219-
else oh.make_node("Identity", ["zout"], ["z"])
220-
)
221210
return oh.make_function(
222211
"npx",
223212
f"CDist_{metric}",
224213
["xa", "xb"],
225214
["z"],
226-
[scan, final],
215+
[
216+
oh.make_node(
217+
"Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph
218+
),
219+
(
220+
oh.make_node("Sqrt", ["zout"], ["z"])
221+
if metric == "euclidean"
222+
else oh.make_node("Identity", ["zout"], ["z"])
223+
),
224+
],
227225
[oh.make_opsetid("", opsets[""])],
228226
)
229227

@@ -234,9 +232,7 @@ def test_iterate_function(self):
234232
)
235233
model = oh.make_model(
236234
oh.make_graph(
237-
[
238-
oh.make_node(proto.name, ["X", "Y"], ["Z"]),
239-
],
235+
[oh.make_node(proto.name, ["X", "Y"], ["Z"])],
240236
"dummy",
241237
[
242238
oh.make_tensor_value_info("X", itype, [None, None]),

_unittests/ut_xrun_doc/test_unit_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
has_cuda,
1717
has_onnxscript,
1818
)
19-
from onnx_diagnostic.api import TensorLike
2019

2120

2221
class TestUnitTest(ExtTestCase):
@@ -111,10 +110,6 @@ def test_measure_time_max(self):
111110
},
112111
)
113112

114-
def test_exc(self):
115-
self.assertRaise(lambda: TensorLike().dtype, NotImplementedError)
116-
self.assertRaise(lambda: TensorLike().shape, NotImplementedError)
117-
118113

119114
if __name__ == "__main__":
120115
unittest.main(verbosity=2)

onnx_diagnostic/api.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)