1010import os
1111import subprocess
1212import unittest
13+ from collections .abc import Callable , Generator
1314from contextlib import contextmanager
1415from functools import wraps
15- from typing import Any , Callable , Optional , Union
16+ from typing import Any
1617
1718import fbgemm_gpu
1819import hypothesis .strategies as st
2728
2829# Skip pt2 compliant tag test for certain operators
2930# TODO: remove this once the operators are pt2 compliant
30- # pyre-ignore
31- additional_decorators : dict [str , list [Callable ]] = {
31+ additional_decorators : dict [str , list [Callable [..., Any ]]] = {
3232 # vbe_generate_metadata_cpu return different values from vbe_generate_metadata_meta
3333 # this fails fake_tensor test as the test expects them to be the same
3434 # fake_tensor test is added in failures_dict but failing fake_tensor test still cause pt2_compliant tag test to fail
@@ -115,14 +115,12 @@ class optests:
115115 # ...
116116 #
117117 @staticmethod
118- # pyre-ignore[3]
119118 def generate_opcheck_tests (
120- test_class : Optional [ unittest .TestCase ] = None ,
119+ test_class : unittest .TestCase | None = None ,
121120 * ,
122121 fast : bool = False ,
123- # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
124- additional_decorators : Optional [dict [str , Callable ]] = None ,
125- ):
122+ additional_decorators : dict [str , Callable [..., Any ]] | None = None ,
123+ ) -> unittest .TestCase | Callable [[unittest .TestCase ], unittest .TestCase ]:
126124 if additional_decorators is None :
127125 additional_decorators = {}
128126
@@ -176,8 +174,9 @@ def is_inside_opcheck_mode() -> bool:
176174 return optests .is_inside_opcheck_mode ()
177175
178176 @staticmethod
179- # pyre-ignore[3]
180- def dontGenerateOpCheckTests (reason : str ):
177+ def dontGenerateOpCheckTests (
178+ reason : str ,
179+ ) -> Callable [[Callable [..., Any ]], Callable [..., Any ]]:
181180 if hasattr (fbgemm_gpu , "open_source" ):
182181 return lambda fun : fun
183182 import torch .testing ._internal .optests as optests
@@ -187,10 +186,10 @@ def dontGenerateOpCheckTests(reason: str):
187186
188187class TestSuite (unittest .TestCase ):
189188 @contextmanager
190- # pyre-ignore[2]
191- def assertNotRaised (self , exc_type ) -> None :
189+ def assertNotRaised (
190+ self , exc_type : type [BaseException ]
191+ ) -> Generator [None , None , None ]:
192192 try :
193- # pyre-ignore[7]
194193 yield None
195194 except exc_type as e :
196195 raise self .failureException (e )
@@ -200,10 +199,8 @@ def assertNotRaised(self, exc_type) -> None:
200199# The problem with just torch.autograd.gradcheck is that it results in
201200# very slow tests when composed with generate_opcheck_tests.
202201def gradcheck (
203- # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
204- f : Callable ,
205- # pyre-ignore[2]
206- inputs : Union [torch .Tensor , tuple [Any , ...]],
202+ f : Callable [..., Any ],
203+ inputs : torch .Tensor | tuple [Any , ...],
207204 * args : Any ,
208205 ** kwargs : Any ,
209206) -> None :
@@ -241,14 +238,12 @@ def gpu_memory_lt_gb(x: int) -> tuple[bool, str]:
241238 )
242239
243240
244- # pyre-fixme[3]: Return annotation cannot be `Any`.
245- def skipIfRocm (reason : str = "Test currently doesn't work on the ROCm stack" ) -> Any :
246- # pyre-fixme[3]: Return annotation cannot be `Any`.
247- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
248- def decorator (fn : Callable ) -> Any :
241+ def skipIfRocm (
242+ reason : str = "Test currently doesn't work on the ROCm stack" ,
243+ ) -> Callable [[Callable [..., None ]], Callable [..., None ]]:
244+ def decorator (fn : Callable [..., None ]) -> Callable [..., None ]:
249245 @wraps (fn )
250- # pyre-fixme[3]: Return annotation cannot be `Any`.
251- def wrapper (* args : Any , ** kwargs : Any ) -> Any :
246+ def wrapper (* args : Any , ** kwargs : Any ) -> None :
252247 if TEST_WITH_ROCM :
253248 raise unittest .SkipTest (reason )
254249 else :
@@ -259,16 +254,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
259254 return decorator
260255
261256
262- # pyre-fixme[3]: Return annotation cannot be `Any`.
263257def skipIfNotRocm (
264258 reason : str = "Test currently doesn work only on the ROCm stack" ,
265- ) -> Any :
266- # pyre-fixme[3]: Return annotation cannot be `Any`.
267- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
268- def decorator (fn : Callable ) -> Any :
259+ ) -> Callable [[Callable [..., None ]], Callable [..., None ]]:
260+ def decorator (fn : Callable [..., None ]) -> Callable [..., None ]:
269261 @wraps (fn )
270- # pyre-fixme[3]: Return annotation cannot be `Any`.
271- def wrapper (* args : Any , ** kwargs : Any ) -> Any :
262+ def wrapper (* args : Any , ** kwargs : Any ) -> None :
272263 if TEST_WITH_ROCM :
273264 fn (* args , ** kwargs )
274265 else :
0 commit comments