Skip to content

Commit 505cb4d

Browse files
cyyevermeta-codesync[bot]
authored andcommitted
Fix pyre type annotations in test_utils.py (pytorch#5660)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2602 Pull Request resolved: pytorch#5660 Reviewed By: henrylhtsang Differential Revision: D101473074 Pulled By: q10 fbshipit-source-id: 72ba6dd29a47bb8e4c042cf43aab30d461a2a417
1 parent 9374e3f commit 505cb4d

1 file changed

Lines changed: 22 additions & 31 deletions

File tree

fbgemm_gpu/test/test_utils.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import os
1111
import subprocess
1212
import unittest
13+
from collections.abc import Callable, Generator
1314
from contextlib import contextmanager
1415
from functools import wraps
15-
from typing import Any, Callable, Optional, Union
16+
from typing import Any
1617

1718
import fbgemm_gpu
1819
import hypothesis.strategies as st
@@ -27,8 +28,7 @@
2728

2829
# Skip pt2 compliant tag test for certain operators
2930
# TODO: remove this once the operators are pt2 compliant
30-
# pyre-ignore
31-
additional_decorators: dict[str, list[Callable]] = {
31+
additional_decorators: dict[str, list[Callable[..., Any]]] = {
3232
# vbe_generate_metadata_cpu return different values from vbe_generate_metadata_meta
3333
# this fails fake_tensor test as the test expects them to be the same
3434
# fake_tensor test is added in failures_dict but failing fake_tensor test still cause pt2_compliant tag test to fail
@@ -115,14 +115,12 @@ class optests:
115115
# ...
116116
#
117117
@staticmethod
118-
# pyre-ignore[3]
119118
def generate_opcheck_tests(
120-
test_class: Optional[unittest.TestCase] = None,
119+
test_class: unittest.TestCase | None = None,
121120
*,
122121
fast: bool = False,
123-
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
124-
additional_decorators: Optional[dict[str, Callable]] = None,
125-
):
122+
additional_decorators: dict[str, Callable[..., Any]] | None = None,
123+
) -> unittest.TestCase | Callable[[unittest.TestCase], unittest.TestCase]:
126124
if additional_decorators is None:
127125
additional_decorators = {}
128126

@@ -176,8 +174,9 @@ def is_inside_opcheck_mode() -> bool:
176174
return optests.is_inside_opcheck_mode()
177175

178176
@staticmethod
179-
# pyre-ignore[3]
180-
def dontGenerateOpCheckTests(reason: str):
177+
def dontGenerateOpCheckTests(
178+
reason: str,
179+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
181180
if hasattr(fbgemm_gpu, "open_source"):
182181
return lambda fun: fun
183182
import torch.testing._internal.optests as optests
@@ -187,10 +186,10 @@ def dontGenerateOpCheckTests(reason: str):
187186

188187
class TestSuite(unittest.TestCase):
189188
@contextmanager
190-
# pyre-ignore[2]
191-
def assertNotRaised(self, exc_type) -> None:
189+
def assertNotRaised(
190+
self, exc_type: type[BaseException]
191+
) -> Generator[None, None, None]:
192192
try:
193-
# pyre-ignore[7]
194193
yield None
195194
except exc_type as e:
196195
raise self.failureException(e)
@@ -200,10 +199,8 @@ def assertNotRaised(self, exc_type) -> None:
200199
# The problem with just torch.autograd.gradcheck is that it results in
201200
# very slow tests when composed with generate_opcheck_tests.
202201
def gradcheck(
203-
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
204-
f: Callable,
205-
# pyre-ignore[2]
206-
inputs: Union[torch.Tensor, tuple[Any, ...]],
202+
f: Callable[..., Any],
203+
inputs: torch.Tensor | tuple[Any, ...],
207204
*args: Any,
208205
**kwargs: Any,
209206
) -> None:
@@ -241,14 +238,12 @@ def gpu_memory_lt_gb(x: int) -> tuple[bool, str]:
241238
)
242239

243240

244-
# pyre-fixme[3]: Return annotation cannot be `Any`.
245-
def skipIfRocm(reason: str = "Test currently doesn't work on the ROCm stack") -> Any:
246-
# pyre-fixme[3]: Return annotation cannot be `Any`.
247-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
248-
def decorator(fn: Callable) -> Any:
241+
def skipIfRocm(
242+
reason: str = "Test currently doesn't work on the ROCm stack",
243+
) -> Callable[[Callable[..., None]], Callable[..., None]]:
244+
def decorator(fn: Callable[..., None]) -> Callable[..., None]:
249245
@wraps(fn)
250-
# pyre-fixme[3]: Return annotation cannot be `Any`.
251-
def wrapper(*args: Any, **kwargs: Any) -> Any:
246+
def wrapper(*args: Any, **kwargs: Any) -> None:
252247
if TEST_WITH_ROCM:
253248
raise unittest.SkipTest(reason)
254249
else:
@@ -259,16 +254,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
259254
return decorator
260255

261256

262-
# pyre-fixme[3]: Return annotation cannot be `Any`.
263257
def skipIfNotRocm(
264258
reason: str = "Test currently doesn work only on the ROCm stack",
265-
) -> Any:
266-
# pyre-fixme[3]: Return annotation cannot be `Any`.
267-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
268-
def decorator(fn: Callable) -> Any:
259+
) -> Callable[[Callable[..., None]], Callable[..., None]]:
260+
def decorator(fn: Callable[..., None]) -> Callable[..., None]:
269261
@wraps(fn)
270-
# pyre-fixme[3]: Return annotation cannot be `Any`.
271-
def wrapper(*args: Any, **kwargs: Any) -> Any:
262+
def wrapper(*args: Any, **kwargs: Any) -> None:
272263
if TEST_WITH_ROCM:
273264
fn(*args, **kwargs)
274265
else:

0 commit comments

Comments
 (0)