11import os
2- from typing import Dict , List , Optional , Tuple
2+ from typing import Any , Dict , List , Optional , Tuple
33import onnx
44import onnx .helper as oh
55import torch
@@ -46,10 +46,10 @@ def __init__(
4646 f"This kernel implementation only work when only one output "
4747 f"is required but { node .output } were."
4848 )
49- self ._cache : Dict [Tuple [int , int ], onnx . ModelProto ] = {}
49+ self ._cache : Dict [Tuple [int , int ], Any ] = {}
5050 self .is_cpu = torch .device ("cpu" ) == self .device
5151
52- def _make_model (self , itype : int , rank : int , has_bias : bool ) -> onnx . ModelProto :
52+ def _make_model (self , itype : int , rank : int , has_bias : bool ) -> Any :
5353 shape = [* ["d{i}" for i in range (rank - 1 )], "last" ]
5454 layer_model = oh .make_model (
5555 oh .make_graph (
@@ -88,6 +88,7 @@ def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
8888 providers = [provider ],
8989 )
9090
91+ # pyrefly: ignore[bad-override]
9192 def run (self , x , scale , bias = None ):
9293 itype = torch_dtype_to_onnx_dtype (x .dtype )
9394 rank = len (x .shape )
@@ -124,7 +125,7 @@ def __init__(
124125 self ._cache : Dict [Tuple [int , int , int ], onnx .ModelProto ] = {}
125126 self .is_cpu = torch .device ("cpu" ) == self .device
126127
127- def _make_model (self , itype : int , ranka : int , rankb : int ) -> onnx . ModelProto :
128+ def _make_model (self , itype : int , ranka : int , rankb : int ) -> Any :
128129 shapea = ["a{i}" for i in range (ranka )]
129130 shapeb = ["b{i}" for i in range (rankb )]
130131 shapec = ["c{i}" for i in range (max (ranka , rankb ))]
@@ -149,6 +150,7 @@ def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
149150 providers = [provider ],
150151 )
151152
153+ # pyrefly: ignore[bad-override]
152154 def run (self , a , b ):
153155 itype = torch_dtype_to_onnx_dtype (a .dtype )
154156 ranka , rankb = len (a .shape ), len (b .shape )
@@ -159,5 +161,6 @@ def run(self, a, b):
159161 if self .verbose :
160162 print (f"[MatMulOrt] running on { self ._provider !r} " )
161163 feeds = dict (A = a .tensor , B = b .tensor )
164+ # pyrefly: ignore[missing-attribute]
162165 got = sess .run (None , feeds )[0 ]
163166 return OpRunTensor (got )
0 commit comments