Skip to content

Commit 8947b8c

Browse files
committed
improve typing of _apply_over_batch
1 parent a90a0e3 commit 8947b8c

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

tests/test_funcs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,23 +1704,24 @@ def _apply_over_batch(*argdefs: tuple[str, int]) -> Any:
17041704

17051705
def decorator(f: Any) -> Any:
17061706
def wrapper(
1707-
*args_tuple: tuple[Any] | None,
1708-
**kwargs: dict[str, Any] | None,
1707+
*args_tuple: Any,
1708+
**kwargs: Any,
17091709
) -> Any:
17101710
args = list(args_tuple)
17111711

17121712
# Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs`
17131713
arrays, other_args = args[:n_arrays], args[n_arrays:]
1714+
arrays = cast(list[Array | None], arrays)
17141715
for i, name in enumerate(names):
17151716
if name in kwargs:
17161717
if i + 1 <= len(args):
17171718
message = (
17181719
f"{f.__name__}() got multiple values for argument `{name}`."
17191720
)
17201721
raise ValueError(message)
1721-
arrays.append(kwargs.pop(name)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
1722+
arrays.append(kwargs.pop(name))
17221723

1723-
xp = array_namespace(*arrays) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
1724+
xp = array_namespace(*arrays)
17241725

17251726
# Determine core and batch shapes
17261727
batch_shapes = []
@@ -1781,7 +1782,7 @@ def wrapper(
17811782
return decorator
17821783

17831784

1784-
@_apply_over_batch(("a", 1), ("v", 1)) # type: ignore[misc]
1785+
@_apply_over_batch(("a", 1), ("v", 1)) # type: ignore[untyped-decorator]
17851786
def xp_searchsorted(
17861787
a: Array,
17871788
v: Array,

0 commit comments

Comments
 (0)