Skip to content

Commit b9b04d0

Browse files
feat: Allow nested structures in lit (#3424)
--------- Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com>
1 parent e2d00da commit b9b04d0

16 files changed

Lines changed: 321 additions & 58 deletions

File tree

narwhals/_arrow/namespace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from narwhals._arrow.typing import ChunkedArrayAny, Incomplete, ScalarAny
2727
from narwhals._utils import Version
28-
from narwhals.typing import IntoDType, NonNestedLiteral
28+
from narwhals.typing import IntoDType, PythonLiteral
2929

3030

3131
class ArrowNamespace(
@@ -64,7 +64,7 @@ def len(self) -> ArrowExpr:
6464
version=self._version,
6565
)
6666

67-
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr:
67+
def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> ArrowExpr:
6868
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
6969
arrow_series = ArrowSeries.from_iterable(
7070
data=[value], name="literal", context=self

narwhals/_dask/namespace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
combine_alias_output_names,
2323
combine_evaluate_output_names,
2424
)
25-
from narwhals._utils import Implementation, zip_strict
25+
from narwhals._utils import Implementation, is_nested_literal, zip_strict
2626

2727
if TYPE_CHECKING:
2828
from collections.abc import Iterable, Iterator
@@ -55,6 +55,10 @@ def __init__(self, *, version: Version) -> None:
5555
self._version = version
5656

5757
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
58+
if is_nested_literal(value):
59+
msg = f"Nested structures are not supported for Dask backend, found {type(value).__name__}"
60+
raise NotImplementedError(msg)
61+
5862
def func(df: DaskLazyFrame) -> list[dx.Series]:
5963
if dtype is not None:
6064
native_dtype = narwhals_to_native_dtype(dtype, self._version)

narwhals/_duckdb/namespace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from narwhals._compliant.window import WindowInputs
3737
from narwhals._utils import Version
38-
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
38+
from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral
3939

4040
VARCHAR = duckdb_dtypes.VARCHAR
4141

@@ -130,8 +130,12 @@ def func(cols: Iterable[Expression]) -> Expression:
130130

131131
return self._expr._from_elementwise_horizontal_op(func, *exprs)
132132

133-
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr:
133+
def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> DuckDBExpr:
134134
def func(df: DuckDBLazyFrame) -> list[Expression]:
135+
if isinstance(value, dict) and not value:
136+
msg = "Cannot create an empty struct type for DuckDB backend"
137+
raise NotImplementedError(msg)
138+
135139
tz = DeferredTimeZone(df.native)
136140
if dtype is not None:
137141
target = narwhals_to_native_dtype(dtype, self._version, tz)

narwhals/_ibis/namespace.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,15 @@ def func(cols: Iterable[ir.Value]) -> ir.Value:
114114

115115
return self._expr._from_elementwise_horizontal_op(func, *exprs)
116116

117-
def lit(self, value: Any, dtype: IntoDType | None) -> IbisExpr:
117+
def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> IbisExpr:
118118
def func(_df: IbisLazyFrame) -> Sequence[ir.Value]:
119119
ibis_dtype = narwhals_to_native_dtype(dtype, self._version) if dtype else None
120-
return [lit(value, ibis_dtype)]
120+
if not isinstance(value, dict):
121+
return [lit(value, ibis_dtype)]
122+
if value:
123+
return [ibis.struct(value, type=ibis_dtype)]
124+
msg = "Cannot create an empty struct type for Ibis backend"
125+
raise NotImplementedError(msg)
121126

122127
return self._expr(
123128
func,

narwhals/_pandas_like/dataframe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from narwhals._pandas_like.series import PANDAS_TO_NUMPY_DTYPE_MISSING, PandasLikeSeries
1111
from narwhals._pandas_like.utils import (
1212
align_and_extract_native,
13+
broadcast_series_to_index,
1314
get_dtype_backend,
1415
import_array_module,
1516
iter_dtype_backends,
@@ -307,8 +308,12 @@ def _with_native(self, df: Any, *, validate_column_names: bool = True) -> Self:
307308
def _extract_comparand(self, other: PandasLikeSeries) -> pd.Series[Any]:
308309
index = self.native.index
309310
if other._broadcast:
310-
s = other.native
311-
return type(s)(s.iloc[0], index=index, dtype=s.dtype, name=s.name)
311+
native = other.native
312+
is_nested = other.dtype.is_nested()
313+
return broadcast_series_to_index(
314+
native, index, is_nested=is_nested, series_class=type(native)
315+
)
316+
312317
if (len_other := len(other)) != (len_idx := len(index)):
313318
msg = f"Expected object of length {len_idx}, got: {len_other}."
314319
raise ShapeError(msg)

narwhals/_pandas_like/namespace.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing_extensions import TypeAlias
2626

2727
from narwhals._utils import Implementation, Version
28-
from narwhals.typing import IntoDType, NonNestedLiteral
28+
from narwhals.typing import IntoDType, PythonLiteral
2929

3030

3131
Incomplete: TypeAlias = Any
@@ -83,17 +83,46 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
8383
context=self,
8484
)
8585

86-
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> PandasLikeExpr:
86+
def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> PandasLikeExpr:
8787
def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries:
88-
pandas_series = self._series.from_iterable(
88+
if isinstance(value, (list, tuple, dict)):
89+
try:
90+
import pandas as pd # ignore-banned-import
91+
import pyarrow as pa # ignore-banned-import
92+
except ImportError as exc: # pragma: no cover
93+
msg = (
94+
"Nested structures require pyarrow to be installed for pandas backend. "
95+
"Please install pyarrow: pip install pyarrow"
96+
)
97+
raise ImportError(msg) from exc
98+
99+
from narwhals._arrow.utils import (
100+
narwhals_to_native_dtype as _to_arrow_dtype,
101+
)
102+
103+
array_value = list(value) if isinstance(value, tuple) else value
104+
pa_dtype = _to_arrow_dtype(dtype, self._version) if dtype else None
105+
pa_array = pa.array([array_value], type=pa_dtype) # type: ignore[arg-type, list-item]
106+
107+
# Use ArrowExtensionArray to avoid pandas unpacking the nested structure
108+
ns = self._implementation.to_native_namespace()
109+
pandas_series_native = ns.Series(
110+
pd.arrays.ArrowExtensionArray(pa_array), # type: ignore[attr-defined]
111+
name="literal",
112+
index=df._native_frame.index[0:1],
113+
)
114+
115+
return self._series.from_native(pandas_series_native, context=self)
116+
117+
pandas_like_series = self._series.from_iterable(
89118
data=[value],
90119
name="literal",
91120
index=df._native_frame.index[0:1],
92121
context=self,
93122
)
94123
if dtype:
95-
return pandas_series.cast(dtype)
96-
return pandas_series
124+
return pandas_like_series.cast(dtype)
125+
return pandas_like_series
97126

98127
return PandasLikeExpr(
99128
lambda df: [_lit_pandas_series(df)],

narwhals/_pandas_like/series.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from narwhals._pandas_like.series_struct import PandasLikeSeriesStructNamespace
1515
from narwhals._pandas_like.utils import (
1616
align_and_extract_native,
17+
broadcast_series_to_index,
1718
get_dtype_backend,
1819
import_array_module,
1920
narwhals_to_native_dtype,
@@ -211,8 +212,8 @@ def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
211212
reindexed = []
212213
for s in series:
213214
if s._broadcast:
214-
native = Series(
215-
s.native.iloc[0], index=idx, name=s.name, dtype=s.native.dtype
215+
native = broadcast_series_to_index(
216+
s.native, idx, is_nested=s.dtype.is_nested(), series_class=Series
216217
)
217218
compliant = s._with_native(native)
218219
elif s.native.index is not idx:

narwhals/_pandas_like/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,37 @@ class PandasLikeSeriesNamespace(EagerSeriesNamespace["PandasLikeSeries", Any]):
663663

664664
def make_group_by_kwargs(*, drop_null_keys: bool) -> dict[str, bool]:
665665
return {"sort": False, "as_index": True, "dropna": drop_null_keys, "observed": True}
666+
667+
668+
def broadcast_series_to_index(
669+
native: pd.Series[Any],
670+
index: Any,
671+
*,
672+
is_nested: bool,
673+
series_class: type[pd.Series[Any]],
674+
) -> pd.Series[Any]:
675+
"""Broadcast a scalar value from a (one element) Series to match a target index.
676+
677+
For nested (arrow-backed) types, we rely on
678+
[`pandas.array`](https://pandas.pydata.org/docs/reference/api/pandas.array.html).
679+
680+
Arguments:
681+
native: The native pandas-like Series containing the scalar value to broadcast.
682+
index: The target index to broadcast to.
683+
is_nested: Whether the Series has a nested (arrow-backed) dtype.
684+
series_class: Series class to use for constructing the result.
685+
686+
Returns:
687+
A new Series with the scalar value broadcast to match the target index.
688+
"""
689+
value = native.iloc[0]
690+
if is_nested:
691+
from narwhals._arrow.utils import repeat
692+
693+
# NOTE: Ignore typing because `pandas-stubs` are wrong
694+
# TODO(FBruzzesi): Should we pass the `copy=False` flag?
695+
pa_array = pd.array(repeat(value, len(index)), dtype=native.dtype) # type: ignore[arg-type]
696+
697+
return series_class(pa_array, index=index, name=native.name)
698+
699+
return series_class(value, index=index, dtype=native.dtype, name=native.name)

narwhals/_spark_like/namespace.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from narwhals._compliant.window import WindowInputs
2929
from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401
3030
from narwhals._utils import Implementation, Version
31-
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral, PythonLiteral
31+
from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral
3232

3333
# Adjust slight SQL vs PySpark differences
3434
FUNCTION_REMAPPINGS = {
@@ -91,9 +91,22 @@ def _when(
9191
def _coalesce(self, *exprs: Column) -> Column:
9292
return self._F.coalesce(*exprs)
9393

94-
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> SparkLikeExpr:
94+
def lit(self, value: PythonLiteral, dtype: IntoDType | None) -> SparkLikeExpr:
9595
def func(df: SparkLikeLazyFrame) -> list[Column]:
96-
column = df._F.lit(value)
96+
F = df._F
97+
98+
if isinstance(value, (list, tuple)):
99+
lit_values = [F.lit(v) for v in value]
100+
column = F.lit(F.array(lit_values))
101+
elif isinstance(value, dict):
102+
if (not self._implementation.is_pyspark()) and (len(value) == 0):
103+
msg = f"Cannot create an empty struct type for {self._implementation} backend"
104+
raise NotImplementedError(msg)
105+
lit_values = [F.lit(v).alias(k) for k, v in value.items()]
106+
column = F.struct(*lit_values)
107+
else:
108+
column = F.lit(value)
109+
97110
if dtype:
98111
native_dtype = narwhals_to_native_dtype(
99112
dtype, self._version, df._native_dtypes, df.native.sparkSession

narwhals/_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
FileSource,
123123
IntoSeriesT,
124124
MultiIndexSelector,
125+
NestedLiteral,
125126
SingleIndexSelector,
126127
SizedMultiBoolSelector,
127128
SizedMultiIndexSelector,
@@ -1371,6 +1372,10 @@ def is_sequence_of(obj: Any, tp: type[_T]) -> TypeIs[Sequence[_T]]:
13711372
)
13721373

13731374

1375+
def is_nested_literal(obj: Any) -> TypeIs[NestedLiteral]:
1376+
return isinstance(obj, (list, tuple, dict))
1377+
1378+
13741379
def validate_strict_and_pass_though(
13751380
strict: bool | None, # noqa: FBT001
13761381
pass_through: bool | None, # noqa: FBT001

0 commit comments

Comments
 (0)