22# see `LiteralExpr.value`
33from __future__ import annotations
44
5+ import datetime as dt
6+ from decimal import Decimal
7+ from functools import cache
58from typing import TYPE_CHECKING , Any , Generic , final
69
710from narwhals ._plan import common
811from narwhals ._plan ._dispatch import DispatcherOptions
912from narwhals ._plan ._dtype import ResolveDType
1013from narwhals ._plan ._expr_ir import ExprIR
14+ from narwhals ._plan ._guards import is_python_literal_type
15+ from narwhals ._plan .exceptions import literal_type_error
1116from narwhals ._plan .typing import (
1217 LiteralT_co ,
1318 NativeSeriesT ,
1419 NativeSeriesT_co ,
15- NonNestedLiteralT ,
16- NonNestedLiteralT_co ,
20+ PythonLiteralT ,
21+ PythonLiteralT_co ,
1722)
23+ from narwhals ._utils import Version
24+ from narwhals .dtypes import Field , List , Struct , Unknown
1825
1926if TYPE_CHECKING :
2027 from collections .abc import Iterator
2128
2229 from narwhals ._plan .series import Series
23- from narwhals ._utils import Version
2430 from narwhals .dtypes import DType
25- from narwhals .typing import IntoDType
31+ from narwhals .typing import IntoDType , NonNestedDType , PythonLiteral
2632
2733__all__ = ["Lit" , "LitSeries" , "lit" , "lit_series" ]
2834
@@ -58,7 +64,7 @@ def iter_output_name(self) -> Iterator[ExprIR]:
5864
5965
6066@final
61- class Lit (LiteralExpr [NonNestedLiteralT_co ], dispatch = namespaced ()):
67+ class Lit (LiteralExpr [PythonLiteralT_co ], dispatch = namespaced ()):
6268 """An expression representing a scalar literal value.
6369
6470 >>> import narwhals._plan as nw
@@ -78,6 +84,12 @@ def __repr__(self) -> str:
7884 v = self .value
7985 return f"lit({ 'null' if v is None else f'{ type (v ).__name__ } : { v !s} ' } )"
8086
87+ @property
88+ def __immutable_values__ (self ) -> Iterator [Any ]:
89+ dtype = self .dtype
90+ value : Any = self .value
91+ yield from (id (value ) if dtype .is_nested () else value , dtype )
92+
8193
8294@final
8395class LitSeries (LiteralExpr ["Series[NativeSeriesT_co]" ], dispatch = namespaced ()):
@@ -128,14 +140,73 @@ def __immutable_values__(self) -> Iterator[Any]:
128140 yield from (self .name , self .dtype , id (self .value ))
129141
130142
131- def lit (
132- value : NonNestedLiteralT , dtype : IntoDType | None = None
133- ) -> Lit [NonNestedLiteralT ]:
134- if dtype is None :
135- dtype = common .py_to_narwhals_dtype (value )
136- else :
137- dtype = common .into_dtype (dtype )
143+ lit_series = LitSeries .from_series
144+
145+
146+ def lit (value : PythonLiteralT , dtype : IntoDType | None = None ) -> Lit [PythonLiteralT ]:
147+ dtype = (
148+ _py_value_to_dtype (value , Version .MAIN , allow_null = True )
149+ if dtype is None
150+ else common .into_dtype (dtype )
151+ )
138152 return Lit (value = value , dtype = dtype )
139153
140154
141- lit_series = LitSeries .from_series
155+ def _py_value_to_dtype (
156+ obj : PythonLiteral , version : Version = Version .MAIN , * , allow_null : bool
157+ ) -> DType :
158+ # NOTE: Surely mypy must have fixed `_lru_cache_wrapper` hashable in a new version?
159+ if dtype := _py_type_to_dtype (type (obj ), version ): # type: ignore[arg-type]
160+ if allow_null or not isinstance (dtype , Unknown ):
161+ return dtype
162+ msg = "Nested dtypes containing nulls are not yet supported"
163+ raise TypeError (msg )
164+ if not isinstance (obj , (list , dict , tuple )):
165+ # Just a type narrowing issue
166+ msg = f"Expected unreachable, got { obj !r} "
167+ raise NotImplementedError (msg )
168+
169+ if not obj :
170+ msg = "Cannot infer dtype for empty nested structure. Please provide an explicit dtype parameter."
171+ raise TypeError (msg )
172+ if not isinstance (obj , dict ):
173+ first_value = next ((el for el in obj if el is not None ), None )
174+ return List (_py_value_to_dtype (first_value , version , allow_null = False ))
175+ return Struct (
176+ [
177+ Field (k , _py_value_to_dtype (v , version , allow_null = False ))
178+ for k , v in obj .items ()
179+ ]
180+ )
181+
182+
183+ @cache
184+ def _py_type_to_dtype (
185+ py_type : type [PythonLiteral ], version : Version = Version .MAIN , /
186+ ) -> NonNestedDType | None :
187+ """SAFETY.
188+
189+ Cache size is bound by these dimensions:
190+
191+ n_valid_py_types = len(non_nested) + len([list, dict, tuple])
192+ maxsize = n_valid_py_types * len(Version)
193+ """
194+ dtypes = version .dtypes
195+ non_nested : dict [type [PythonLiteral ], type [NonNestedDType ]] = {
196+ int : dtypes .Int64 ,
197+ float : dtypes .Float64 ,
198+ str : dtypes .String ,
199+ bool : dtypes .Boolean ,
200+ dt .datetime : dtypes .Datetime ,
201+ dt .date : dtypes .Date ,
202+ dt .time : dtypes .Time ,
203+ dt .timedelta : dtypes .Duration ,
204+ bytes : dtypes .Binary ,
205+ Decimal : dtypes .Decimal ,
206+ type (None ): dtypes .Unknown ,
207+ }
208+ if dtype := non_nested .get (py_type ):
209+ return dtype ()
210+ if not is_python_literal_type (py_type ):
211+ raise literal_type_error (py_type )
212+ return None
0 commit comments