|
1 | 1 | from collections.abc import Callable, Iterator |
2 | 2 | from types import ModuleType |
3 | | -from typing import cast |
| 3 | +from typing import Any, cast |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import pytest |
| 7 | +from typing_extensions import override |
7 | 8 |
|
8 | 9 | from array_api_extra._lib._backends import Backend |
9 | 10 | from array_api_extra._lib._testing import ( |
@@ -322,32 +323,34 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend): |
322 | 323 |
|
323 | 324 |
|
324 | 325 | class A: |
325 | | - def __init__(self, x): |
| 326 | + def __init__(self, x: Array): |
326 | 327 | xp = array_namespace(x) |
327 | | - self._xp = xp |
328 | | - self.x = np.asarray(x) |
| 328 | + self._xp: ModuleType = xp |
| 329 | + self.x: Any = np.asarray(x) |
329 | 330 |
|
330 | | - def f(self, y): |
331 | | - y = np.asarray(y) |
332 | | - return self._xp.asarray(np.matmul(self.x, y)) |
| 331 | + def f(self, y: Array) -> Array: |
| 332 | + return self._xp.asarray(np.matmul(self.x, np.asarray(y))) |
333 | 333 |
|
334 | | - def g(self, y, z): |
| 334 | + def g(self, y: Array, z: Array) -> Array: |
335 | 335 | return self.f(y) + self.f(z) |
336 | 336 |
|
337 | 337 |
|
338 | 338 | class B(A): |
339 | | - def __init__(self, x): |
| 339 | + @override |
| 340 | + def __init__(self, x: Array): # pyright: ignore[reportMissingSuperCall] |
340 | 341 | xp = array_namespace(x) |
341 | | - self._xp = xp |
342 | | - self.x = xp.asarray(x) |
| 342 | + self._xp: ModuleType = xp |
| 343 | + self.x: Any = xp.asarray(x) |
343 | 344 |
|
344 | | - def f(self, y): |
| 345 | + @override |
| 346 | + def f(self, y: Array) -> Array: |
345 | 347 | return self._xp.matmul(self.x, y) |
346 | 348 |
|
347 | 349 |
|
348 | 350 | lazy_xp_function((B, "g")) |
349 | 351 |
|
350 | | -def test_lazy_xp_function_class_inheritance(xp: ModuleType): |
| 352 | + |
| 353 | +def test_lazy_xp_function_class_inheritance(): |
351 | 354 | assert hasattr(B.g, "_lazy_xp_function") |
352 | 355 | assert not hasattr(A.g, "_lazy_xp_function") |
353 | 356 |
|
|
0 commit comments