Skip to content

Commit d763425

Browse files
committed
feat: Support nested data in lit
Playing catch up with #3424 @FBruzzesi The `list` stuff was the main wart left in `_parse.py`. Was easier implement the `list` support in `lit`, than try to make sense of whatever logic I landed on before 😅 Up next, is to fix the leaky `DataFrame.filter` list parsing! https://github.com/narwhals-dev/narwhals/blob/6e43487472a13000d99361e7ccc3fcb8cc3e656f/narwhals/_plan/dataframe.py#L571-L579
1 parent 6e43487 commit d763425

14 files changed

Lines changed: 261 additions & 126 deletions

File tree

narwhals/_plan/_guards.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
bytes,
3535
Decimal,
3636
)
37-
_PYTHON_LITERAL_TPS = (*_NON_NESTED_LITERAL_TPS, list, tuple, type(None))
37+
_PYTHON_LITERAL_TPS = (*_NON_NESTED_LITERAL_TPS, list, tuple, dict, type(None))
3838

3939

4040
def _ir(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202
@@ -57,6 +57,10 @@ def is_python_literal(obj: Any) -> TypeIs[PythonLiteral]:
5757
return isinstance(obj, _PYTHON_LITERAL_TPS)
5858

5959

60+
def is_python_literal_type(tp: type[Any]) -> TypeIs[type[PythonLiteral]]:
61+
return tp in _PYTHON_LITERAL_TPS
62+
63+
6064
def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]:
6165
return isinstance(obj, _series().Series)
6266

narwhals/_plan/_parse.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -63,61 +63,9 @@
6363
"sort_by_into_seq_of_expr_ir",
6464
]
6565

66-
_RaisesInvalidIntoExprError: TypeAlias = "Any"
67-
"""
68-
Placeholder for multiple `Iterable[IntoExpr]`.
69-
70-
We only support cases `a`, `b`, but the typing for most contexts is more permissive:
71-
72-
>>> import polars as pl
73-
>>> df = pl.DataFrame({"one": ["A", "B", "A"], "two": [1, 2, 3], "three": [4, 5, 6]})
74-
>>> a = ("one", "two")
75-
>>> b = (["one", "two"],)
76-
>>>
77-
>>> c = ("one", ["two"])
78-
>>> d = (["one"], "two")
79-
>>> [df.select(*into) for into in (a, b, c, d)]
80-
[shape: (3, 2)
81-
┌─────┬─────┐
82-
│ one ┆ two │
83-
│ --- ┆ --- │
84-
│ str ┆ i64 │
85-
╞═════╪═════╡
86-
│ A ┆ 1 │
87-
│ B ┆ 2 │
88-
│ A ┆ 3 │
89-
└─────┴─────┘,
90-
shape: (3, 2)
91-
┌─────┬─────┐
92-
│ one ┆ two │
93-
│ --- ┆ --- │
94-
│ str ┆ i64 │
95-
╞═════╪═════╡
96-
│ A ┆ 1 │
97-
│ B ┆ 2 │
98-
│ A ┆ 3 │
99-
└─────┴─────┘,
100-
shape: (3, 2)
101-
┌─────┬───────────┐
102-
│ one ┆ literal │
103-
│ --- ┆ --- │
104-
│ str ┆ list[str] │
105-
╞═════╪═══════════╡
106-
│ A ┆ ["two"] │
107-
│ B ┆ ["two"] │
108-
│ A ┆ ["two"] │
109-
└─────┴───────────┘,
110-
shape: (3, 2)
111-
┌───────────┬─────┐
112-
│ literal ┆ two │
113-
│ --- ┆ --- │
114-
│ list[str] ┆ i64 │
115-
╞═══════════╪═════╡
116-
│ ["one"] ┆ 1 │
117-
│ ["one"] ┆ 2 │
118-
│ ["one"] ┆ 3 │
119-
└───────────┴─────┘]
120-
"""
66+
# TODO @dangotbanned: Simplify the `list` special-casing, now that it is supported
67+
Incomplete: TypeAlias = "Any"
68+
"""Artifact from previous `lit(list)` rejection"""
12169

12270

12371
def into_expr_ir(
@@ -146,7 +94,7 @@ def into_expr_ir(
14694

14795
def into_seq_of_expr_ir(
14896
first_input: OneOrIterable[IntoExpr] = (),
149-
*more_inputs: IntoExpr | _RaisesInvalidIntoExprError,
97+
*more_inputs: IntoExpr | Incomplete,
15098
**named_inputs: IntoExpr,
15199
) -> Seq[ExprIR]:
152100
"""Parse variadic inputs into a flat sequence of expressions."""
@@ -157,7 +105,7 @@ def into_seq_of_expr_ir(
157105

158106
def predicates_constraints_into_expr_ir(
159107
first_predicate: OneOrIterable[IntoExprColumn] | list[bool] = (),
160-
*more_predicates: IntoExprColumn | list[bool] | _RaisesInvalidIntoExprError,
108+
*more_predicates: IntoExprColumn | list[bool] | Incomplete,
161109
_list_as_series: PartialSeries | None = None,
162110
**constraints: IntoExpr,
163111
) -> ExprIR:

narwhals/_plan/arrow/namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def col(self, node: ir.Column, frame: Frame, name: str) -> Expr:
117117
frame.native.column(node.name), name, version=frame.version
118118
)
119119

120-
def lit(self, node: ir.Lit[NonNestedLiteral], frame: Frame, name: str) -> Scalar:
120+
def lit(self, node: ir.Lit[PythonLiteral], frame: Frame, name: str) -> Scalar:
121121
return self._scalar.from_python(
122122
node.value, name, dtype=node.dtype, version=frame.version
123123
)

narwhals/_plan/common.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from __future__ import annotations
22

3-
import datetime as dt
43
import sys
54
from collections.abc import Iterable
65
from copy import deepcopy
7-
from decimal import Decimal
86
from io import BytesIO
97
from secrets import token_hex
108
from types import MappingProxyType
@@ -14,7 +12,6 @@
1412
from narwhals._utils import _hasattr_static, qualified_type_name
1513
from narwhals.dtypes import DType
1614
from narwhals.exceptions import NarwhalsError
17-
from narwhals.utils import Version
1815

1916
if TYPE_CHECKING:
2017
import reprlib
@@ -34,7 +31,7 @@
3431
Seq,
3532
)
3633
from narwhals._utils import _StoresColumns
37-
from narwhals.typing import FileSource, NonNestedDType, NonNestedLiteral
34+
from narwhals.typing import FileSource
3835

3936
T = TypeVar("T")
4037

@@ -52,24 +49,6 @@ def replace(obj: T, /, **changes: Any) -> T:
5249
return func(obj, **changes) # type: ignore[no-any-return]
5350

5451

55-
def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType:
56-
dtypes = version.dtypes
57-
mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = {
58-
int: dtypes.Int64,
59-
float: dtypes.Float64,
60-
str: dtypes.String,
61-
bool: dtypes.Boolean,
62-
dt.datetime: dtypes.Datetime,
63-
dt.date: dtypes.Date,
64-
dt.time: dtypes.Time,
65-
dt.timedelta: dtypes.Duration,
66-
bytes: dtypes.Binary,
67-
Decimal: dtypes.Decimal,
68-
type(None): dtypes.Unknown,
69-
}
70-
return mapping.get(type(obj), dtypes.Unknown)()
71-
72-
7352
@overload
7453
def into_dtype(dtype: type[NonNestedDTypeT], /) -> NonNestedDTypeT: ...
7554
@overload

narwhals/_plan/compliant/namespace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from narwhals._plan.expressions import FunctionExpr, boolean, functions as F
2424
from narwhals._plan.expressions.strings import ConcatStr
2525
from narwhals._utils import Implementation
26-
from narwhals.typing import NonNestedLiteral
26+
from narwhals.typing import PythonLiteral
2727

2828
Incomplete: TypeAlias = Any
2929

@@ -61,7 +61,7 @@ def coalesce(
6161
) -> ExprT_co | ScalarT_co: ...
6262
def len(self, node: ir.Len, frame: FrameT, name: str) -> ScalarT_co: ...
6363
def lit(
64-
self, node: ir.Lit[NonNestedLiteral], frame: FrameT, name: str
64+
self, node: ir.Lit[PythonLiteral], frame: FrameT, name: str
6565
) -> ScalarT_co: ...
6666
def max_horizontal(
6767
self, node: FunctionExpr[F.MaxHorizontal], frame: FrameT, name: str

narwhals/_plan/exceptions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def function_arg_non_scalar_error(
6969
return ShapeError(msg)
7070

7171

72-
def list_literal_error(value: Any) -> TypeError:
73-
msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}."
72+
def literal_type_error(value: Any) -> TypeError:
73+
msg = f"{qualified_type_name(value)!r} is not supported in `nw.lit`"
74+
if not isinstance(value, type):
75+
msg = f"{msg}, got: {value!r}."
7476
return TypeError(msg)
7577

7678

narwhals/_plan/expressions/literal.py

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,33 @@
22
# see `LiteralExpr.value`
33
from __future__ import annotations
44

5+
import datetime as dt
6+
from decimal import Decimal
7+
from functools import cache
58
from typing import TYPE_CHECKING, Any, Generic, final
69

710
from narwhals._plan import common
811
from narwhals._plan._dispatch import DispatcherOptions
912
from narwhals._plan._dtype import ResolveDType
1013
from narwhals._plan._expr_ir import ExprIR
14+
from narwhals._plan._guards import is_python_literal_type
15+
from narwhals._plan.exceptions import literal_type_error
1116
from narwhals._plan.typing import (
1217
LiteralT_co,
1318
NativeSeriesT,
1419
NativeSeriesT_co,
15-
NonNestedLiteralT,
16-
NonNestedLiteralT_co,
20+
PythonLiteralT,
21+
PythonLiteralT_co,
1722
)
23+
from narwhals._utils import Version
24+
from narwhals.dtypes import Field, List, Struct, Unknown
1825

1926
if TYPE_CHECKING:
2027
from collections.abc import Iterator
2128

2229
from narwhals._plan.series import Series
23-
from narwhals._utils import Version
2430
from narwhals.dtypes import DType
25-
from narwhals.typing import IntoDType
31+
from narwhals.typing import IntoDType, NonNestedDType, PythonLiteral
2632

2733
__all__ = ["Lit", "LitSeries", "lit", "lit_series"]
2834

@@ -58,7 +64,7 @@ def iter_output_name(self) -> Iterator[ExprIR]:
5864

5965

6066
@final
61-
class Lit(LiteralExpr[NonNestedLiteralT_co], dispatch=namespaced()):
67+
class Lit(LiteralExpr[PythonLiteralT_co], dispatch=namespaced()):
6268
"""An expression representing a scalar literal value.
6369
6470
>>> import narwhals._plan as nw
@@ -78,6 +84,12 @@ def __repr__(self) -> str:
7884
v = self.value
7985
return f"lit({'null' if v is None else f'{type(v).__name__}: {v!s}'})"
8086

87+
@property
88+
def __immutable_values__(self) -> Iterator[Any]:
89+
dtype = self.dtype
90+
value: Any = self.value
91+
yield from (id(value) if dtype.is_nested() else value, dtype)
92+
8193

8294
@final
8395
class LitSeries(LiteralExpr["Series[NativeSeriesT_co]"], dispatch=namespaced()):
@@ -128,14 +140,73 @@ def __immutable_values__(self) -> Iterator[Any]:
128140
yield from (self.name, self.dtype, id(self.value))
129141

130142

131-
def lit(
132-
value: NonNestedLiteralT, dtype: IntoDType | None = None
133-
) -> Lit[NonNestedLiteralT]:
134-
if dtype is None:
135-
dtype = common.py_to_narwhals_dtype(value)
136-
else:
137-
dtype = common.into_dtype(dtype)
143+
lit_series = LitSeries.from_series
144+
145+
146+
def lit(value: PythonLiteralT, dtype: IntoDType | None = None) -> Lit[PythonLiteralT]:
147+
dtype = (
148+
_py_value_to_dtype(value, Version.MAIN, allow_null=True)
149+
if dtype is None
150+
else common.into_dtype(dtype)
151+
)
138152
return Lit(value=value, dtype=dtype)
139153

140154

141-
lit_series = LitSeries.from_series
155+
def _py_value_to_dtype(
156+
obj: PythonLiteral, version: Version = Version.MAIN, *, allow_null: bool
157+
) -> DType:
158+
# NOTE: Surely mypy must have fixed `_lru_cache_wrapper` hashable in a new version?
159+
if dtype := _py_type_to_dtype(type(obj), version): # type: ignore[arg-type]
160+
if allow_null or not isinstance(dtype, Unknown):
161+
return dtype
162+
msg = "Nested dtypes containing nulls are not yet supported"
163+
raise TypeError(msg)
164+
if not isinstance(obj, (list, dict, tuple)):
165+
# Just a type narrowing issue
166+
msg = f"Expected unreachable, got {obj!r}"
167+
raise NotImplementedError(msg)
168+
169+
if not obj:
170+
msg = "Cannot infer dtype for empty nested structure. Please provide an explicit dtype parameter."
171+
raise TypeError(msg)
172+
if not isinstance(obj, dict):
173+
first_value = next((el for el in obj if el is not None), None)
174+
return List(_py_value_to_dtype(first_value, version, allow_null=False))
175+
return Struct(
176+
[
177+
Field(k, _py_value_to_dtype(v, version, allow_null=False))
178+
for k, v in obj.items()
179+
]
180+
)
181+
182+
183+
@cache
184+
def _py_type_to_dtype(
185+
py_type: type[PythonLiteral], version: Version = Version.MAIN, /
186+
) -> NonNestedDType | None:
187+
"""SAFETY.
188+
189+
Cache size is bound by these dimensions:
190+
191+
n_valid_py_types = len(non_nested) + len([list, dict, tuple])
192+
maxsize = n_valid_py_types * len(Version)
193+
"""
194+
dtypes = version.dtypes
195+
non_nested: dict[type[PythonLiteral], type[NonNestedDType]] = {
196+
int: dtypes.Int64,
197+
float: dtypes.Float64,
198+
str: dtypes.String,
199+
bool: dtypes.Boolean,
200+
dt.datetime: dtypes.Datetime,
201+
dt.date: dtypes.Date,
202+
dt.time: dtypes.Time,
203+
dt.timedelta: dtypes.Duration,
204+
bytes: dtypes.Binary,
205+
Decimal: dtypes.Decimal,
206+
type(None): dtypes.Unknown,
207+
}
208+
if dtype := non_nested.get(py_type):
209+
return dtype()
210+
if not is_python_literal_type(py_type):
211+
raise literal_type_error(py_type)
212+
return None

narwhals/_plan/functions/literal.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33
from typing import TYPE_CHECKING
44

55
from narwhals._plan import _guards, expressions as ir
6-
from narwhals._plan.exceptions import list_literal_error
6+
from narwhals._plan.exceptions import literal_type_error
77

88
if TYPE_CHECKING:
99
from narwhals._plan.expr import Expr
1010
from narwhals._plan.series import Series
1111
from narwhals._plan.typing import NativeSeriesT
12-
from narwhals.typing import IntoDType, NonNestedLiteral
12+
from narwhals.typing import IntoDType, PythonLiteral
1313

1414

1515
def lit(
16-
value: NonNestedLiteral | Series[NativeSeriesT], dtype: IntoDType | None = None
16+
value: PythonLiteral | Series[NativeSeriesT], dtype: IntoDType | None = None
1717
) -> Expr:
1818
if _guards.is_series(value):
1919
return ir.lit_series(value).to_narwhals()
20-
if not _guards.is_non_nested_literal(value):
21-
raise list_literal_error(value)
20+
if not _guards.is_python_literal(value):
21+
raise literal_type_error(value)
2222
return ir.lit(value, dtype).to_narwhals()

narwhals/_plan/polars/namespace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
ConcatMethod,
3636
IntoDType,
3737
IntoSchema,
38-
NonNestedLiteral,
38+
PythonLiteral,
3939
)
4040

4141
Incomplete: TypeAlias = Any
@@ -169,7 +169,7 @@ def col(self, node: ir.Column, frame: Incomplete, name: str) -> Expr:
169169
def len(self, node: ir.Len, frame: Incomplete, name: str) -> Expr:
170170
return self._expr.from_native(pl.len(), name, self.version)
171171

172-
def lit(self, node: ir.Lit[NonNestedLiteral], frame: Incomplete, name: str) -> Expr:
172+
def lit(self, node: ir.Lit[PythonLiteral], frame: Incomplete, name: str) -> Expr:
173173
return self._expr.from_python(
174174
node.value, name, dtype=node.dtype, version=self.version
175175
)

0 commit comments

Comments
 (0)