Skip to content

Commit 5c8669d

Browse files
authored
Read-only (covariant) list parameter annotations (#1745)
* Add failing test * Add covariant list for use in parameter annotations * Fix hashability and pyright comment moved by black * type-check only * Address code review
1 parent 406b31e commit 5c8669d

4 files changed

Lines changed: 51 additions & 3 deletions

File tree

pandas-stubs/_typing.pyi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import sys
1717
from typing import (
1818
TYPE_CHECKING,
1919
Any,
20+
ClassVar,
2021
Generic,
2122
Literal,
2223
Protocol,
@@ -1230,6 +1231,21 @@ class Just(Protocol, Generic[T]):
12301231
@override
12311232
def __class__(self, t: type[T], /) -> None: ...
12321233

1234+
# Read-only (covariant) list for use in parameter annotations (See GH #1745)
1235+
class CovariantList(Protocol[_T_co]):
1236+
__hash__: ClassVar[None] # type: ignore[assignment] # pyright: ignore[reportIncompatibleMethodOverride]
1237+
@property # type: ignore[override]
1238+
def __class__(self) -> type[list[Any]]: ... # pyrefly: ignore[bad-override]
1239+
@__class__.setter
1240+
def __class__( # pyright: ignore[reportIncompatibleMethodOverride]
1241+
self, value: type[list[Any]], /
1242+
) -> None: ...
1243+
def __iter__(self) -> Iterator[_T_co]: ...
1244+
# copy() is only TEMPORARILY needed because `__class__` is a property
1245+
# and ty doesn't support property protocol members. Remove when
1246+
# https://github.com/astral-sh/ty/issues/1379 is resolved
1247+
def copy(self) -> list[Any]: ...
1248+
12331249
class SupportsTrueDiv(Protocol[_T_contra, _T_co]):
12341250
def __truediv__(self, x: _T_contra, /) -> _T_co: ...
12351251

pandas-stubs/core/frame.pyi

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ from pandas._typing import (
9999
CalculationMethod,
100100
ColspaceArgType,
101101
CompressionOptions,
102+
CovariantList,
102103
DropKeep,
103104
Dtype,
104105
FilePath,
@@ -2092,9 +2093,11 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
20922093
@overload
20932094
def get(self, key: Hashable, default: _T) -> Series | _T: ...
20942095
@overload
2095-
def get(self, key: list[Hashable], default: None = None) -> Self | None: ...
2096+
def get(
2097+
self, key: CovariantList[Hashable], default: None = None
2098+
) -> Self | None: ...
20962099
@overload
2097-
def get(self, key: list[Hashable], default: _T) -> Self | _T: ...
2100+
def get(self, key: CovariantList[Hashable], default: _T) -> Self | _T: ...
20982101
def gt(
20992102
self,
21002103
other: complex | ListLike | DataFrame,

tests/frame/test_frame.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3921,6 +3921,9 @@ def test_get() -> None:
39213921
)
39223922
check(assert_type(df.get(["z"], default=1), pd.DataFrame | int), int)
39233923

3924+
key = ["a", "b"]
3925+
check(assert_type(df.get(key), pd.DataFrame | None), pd.DataFrame)
3926+
39243927

39253928
def test_info() -> None:
39263929
df = pd.DataFrame()

tests/test_typing.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Mapping
22
from typing import (
3+
TYPE_CHECKING,
34
Any,
45
Literal,
56
)
@@ -10,9 +11,15 @@
1011

1112
from pandas.core.dtypes.base import ExtensionDtype
1213

13-
from tests import get_dtype
14+
from tests import (
15+
TYPE_CHECKING_INVALID_USAGE,
16+
get_dtype,
17+
)
1418
from tests.dtypes import DTYPE_ARG_ALIAS_MAPS
1519

20+
if TYPE_CHECKING:
21+
from pandas._typing import CovariantList
22+
1623

1724
def test_get_dtype() -> None:
1825
alias_union = (
@@ -36,3 +43,22 @@ def test_dtype_arg_aliases(dtype_arg: Any, alias_map: Mapping[Any, Any]) -> None
3643
assert set(get_dtype(dtype_arg)) == {
3744
type(t) if isinstance(t, ExtensionDtype) else t for t in alias_map
3845
}
46+
47+
48+
def test_covariant_list() -> None:
49+
def f(_: "CovariantList[float]") -> None: ...
50+
51+
good1: list[float] = [1.0, 2.0, 3.0] # OK, trivial case
52+
good2: list[int] = [1, 2, 3] # OK, list[int] < list[float] due to covariance
53+
bad1: tuple[float, ...] = (1.0, 2.0, 3.0) # Error, tuple is not a subtype of list
54+
bad2: list[str] = ["a", "b", "c"] # Error, list[str] !< list[float]
55+
bad3: list[object] = [1, "a", 3.0] # Error, list[object] !< list[float]
56+
bad4: float = 1.0 # Error, float !< list[float]
57+
58+
f(good1)
59+
f(good2)
60+
if TYPE_CHECKING_INVALID_USAGE:
61+
f(bad1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type]
62+
f(bad2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type]
63+
f(bad3) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type]
64+
f(bad4) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] # pyrefly: ignore[bad-argument-type] # ty: ignore[invalid-argument-type]

0 commit comments

Comments
 (0)