Skip to content

Commit 158f774

Browse files
authored
minor modifications (#392)
* minor modifications * fix * fix * types * fix * miss * fix * fix * fix doc
1 parent 1032e24 commit 158f774

24 files changed

Lines changed: 118 additions & 67 deletions

.github/workflows/pyrefly.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
run: |
2222
pip install pyrefly
2323
pip install -r requirements.txt
24+
pip install transformers pandas matplotlib openpyxl
2425
2526
- name: Run pyrefly
2627
run: pyrefly check

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def linkcode_resolve(domain, info):
132132
("py:class", "Module"),
133133
("py:class", "np.ndarray"),
134134
("py:class", "onnx_ir.Tuple"),
135+
("py:class", "pandas.api.typing.DataFrameGroupBy"),
135136
("py:class", "pandas.core.groupby.generic.DataFrameGroupBy"),
136137
("py:class", "pipeline.Pipeline"),
137138
("py:class", "torch._guards.Source"),

_doc/recipes/plot_dynamic_shapes_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def flatten_unflatten_like_dynamic_shapes(obj):
8181
subtrees.append(value)
8282
start = end
8383
if spec.type is dict:
84-
# This a dictionary.
84+
# This is a dictionary.
8585
return dict(zip(spec.context, subtrees))
8686
if spec.type is tuple:
8787
return tuple(subtrees)

onnx_diagnostic/export/validate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def _get(a):
8080
)
8181
got = modep(*_get(args), **_get(kwargs))
8282
if verbose:
83+
# pyrefly: ignore[unbound-name]
8384
d = time.perf_counter() - begin
8485
print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
8586
if mod:
@@ -89,6 +90,7 @@ def _get(a):
8990
expected = mod(*_get(args), **_get(kwargs))
9091
diff = max_diff(expected, got)
9192
if verbose:
93+
# pyrefly: ignore[unbound-name]
9294
d = time.perf_counter() - begin
9395
print(
9496
f"[compare_modules] done in {d} with "

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def unit_test_going(self) -> bool:
780780

781781
@property
782782
def verbose(self) -> int:
783-
"Returns the the value of environment variable ``VERBOSE``."
783+
"Returns the value of environment variable ``VERBOSE``."
784784
return int(os.environ.get("VERBOSE", "0"))
785785

786786
@classmethod

onnx_diagnostic/helpers/args_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def get_parsed_args(
105105
default=tries,
106106
)
107107
for k, v in kwargs.items():
108+
assert isinstance(v, tuple) # type
108109
parser.add_argument(
109110
f"--{k}",
110111
help=f"{v[1]}, default is {v[0]}",

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def flatten_unflatten_for_dynamic_shapes(
9090
the context gives the dictionary keys but it is not expressed
9191
in the dynamic shapes, these specifications seems to be different
9292
for the strict and non strict mode. It also preserves tuple.
93-
:param change_function: to modifies the tensor in the structure itself,
93+
:param change_function: to modify the tensor in the structure itself,
9494
like replace them by a shape
9595
:return: the serialized object
9696
"""
@@ -110,7 +110,7 @@ def flatten_unflatten_for_dynamic_shapes(
110110
start = end
111111
if use_dict:
112112
if spec.type is dict:
113-
# This a dictionary.
113+
# This is a dictionary.
114114
return dict(zip(spec.context, subtrees))
115115
if spec.type is tuple:
116116
return tuple(subtrees)

onnx_diagnostic/helpers/doc_helper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Dict, List, Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple
33
import onnx
44
import onnx.helper as oh
55
import torch
@@ -46,10 +46,10 @@ def __init__(
4646
f"This kernel implementation only work when only one output "
4747
f"is required but {node.output} were."
4848
)
49-
self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
49+
self._cache: Dict[Tuple[int, int], Any] = {}
5050
self.is_cpu = torch.device("cpu") == self.device
5151

52-
def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
52+
def _make_model(self, itype: int, rank: int, has_bias: bool) -> Any:
5353
shape = [*["d{i}" for i in range(rank - 1)], "last"]
5454
layer_model = oh.make_model(
5555
oh.make_graph(
@@ -88,6 +88,7 @@ def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
8888
providers=[provider],
8989
)
9090

91+
# pyrefly: ignore[bad-override]
9192
def run(self, x, scale, bias=None):
9293
itype = torch_dtype_to_onnx_dtype(x.dtype)
9394
rank = len(x.shape)
@@ -124,7 +125,7 @@ def __init__(
124125
self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
125126
self.is_cpu = torch.device("cpu") == self.device
126127

127-
def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
128+
def _make_model(self, itype: int, ranka: int, rankb: int) -> Any:
128129
shapea = ["a{i}" for i in range(ranka)]
129130
shapeb = ["b{i}" for i in range(rankb)]
130131
shapec = ["c{i}" for i in range(max(ranka, rankb))]
@@ -149,6 +150,7 @@ def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
149150
providers=[provider],
150151
)
151152

153+
# pyrefly: ignore[bad-override]
152154
def run(self, a, b):
153155
itype = torch_dtype_to_onnx_dtype(a.dtype)
154156
ranka, rankb = len(a.shape), len(b.shape)
@@ -159,5 +161,6 @@ def run(self, a, b):
159161
if self.verbose:
160162
print(f"[MatMulOrt] running on {self._provider!r}")
161163
feeds = dict(A=a.tensor, B=b.tensor)
164+
# pyrefly: ignore[missing-attribute]
162165
got = sess.run(None, feeds)[0]
163166
return OpRunTensor(got)

onnx_diagnostic/helpers/graph_helper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def computation_order(
3636
:return: computation order
3737
"""
3838
assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
39-
f"This algorithme is not yet implemented if the sequence contains "
39+
f"This algorithm is not yet implemented if the sequence contains "
4040
f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
4141
)
4242
number = {e: start - 1 for e in (existing or [])} # noqa: C420
@@ -131,14 +131,14 @@ def text_positions(
131131
@property
132132
def nodes(self) -> List[onnx.NodeProto]:
133133
"Returns the list of nodes"
134-
return (
134+
return list(
135135
self.proto.graph.node
136136
if isinstance(self.proto, onnx.ModelProto)
137137
else self.proto.node
138138
)
139139

140140
@property
141-
def start_names(self) -> List[onnx.NodeProto]:
141+
def start_names(self) -> List[str]:
142142
"Returns the list of known names, inputs and initializer"
143143
graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
144144
input_names = (
@@ -151,15 +151,15 @@ def start_names(self) -> List[onnx.NodeProto]:
151151
if isinstance(graph, onnx.FunctionProto)
152152
else [
153153
*[i.name for i in graph.initializer],
154-
*[i.name for i in graph.sparse_initializer],
154+
*[i.values.name for i in graph.sparse_initializer],
155155
]
156156
)
157157
return [*input_names, *init_names]
158158

159159
@property
160160
def input_names(self) -> List[str]:
161161
"Returns the list of input names."
162-
return (
162+
return list(
163163
self.proto.input
164164
if isinstance(self.proto, onnx.FunctionProto)
165165
else [
@@ -173,7 +173,7 @@ def input_names(self) -> List[str]:
173173
@property
174174
def output_names(self) -> List[str]:
175175
"Returns the list of output names."
176-
return (
176+
return list(
177177
self.proto.output
178178
if isinstance(self.proto, onnx.FunctionProto)
179179
else [

0 commit comments

Comments
 (0)