11from __future__ import annotations
22
3+ from typing import TYPE_CHECKING
4+
5+ if TYPE_CHECKING :
6+ from collections .abc import Sequence
7+
8+ from narwhals .typing import IntoDType
9+
310import random
11+ import warnings
12+ from typing import Any
413
514import narwhals as nw
615import numpy as np
@@ -60,7 +69,7 @@ def check_data_emptiness(data: nw.DataFrame) -> None:
6069
6170def is_string_dtype (series : nw .Series ) -> bool :
6271 """Check if a series has a string dtype."""
63- return series .dtype == nw .String or series . dtype == nw .Object
72+ return series .dtype in { nw .String , nw .Object }
6473
6574
6675def is_numeric_dtype (series : nw .Series ) -> bool :
@@ -78,22 +87,22 @@ def is_datetime_dtype(series: nw.Series) -> bool:
7887 return series .dtype == nw .Datetime
7988
8089
81- def select_string_columns (data : nw .DataFrame ) -> list [str ]:
90+ def select_string_columns (data : nw .DataFrame ) -> list [str | int ]:
8291 """Select columns with string dtype."""
8392 return [col for col in data .columns if is_string_dtype (data [col ])]
8493
8594
86- def select_numeric_columns (data : nw .DataFrame ) -> list [str ]:
95+ def select_numeric_columns (data : nw .DataFrame ) -> list [str | int ]:
8796 """Select columns with numeric dtype."""
8897 return [col for col in data .columns if is_numeric_dtype (data [col ])]
8998
9099
91- def select_datetime_columns (data : nw .DataFrame ) -> list [str ]:
100+ def select_datetime_columns (data : nw .DataFrame ) -> list [str | int ]:
92101 """Select columns with datetime dtype."""
93102 return [col for col in data .columns if is_datetime_dtype (data [col ])]
94103
95104
96- def select_numeric_or_datetime_columns (data : nw .DataFrame ) -> list [str ]:
105+ def select_numeric_or_datetime_columns (data : nw .DataFrame ) -> list [str | int ]:
97106 """Select columns with numeric or datetime dtype."""
98107 return [col for col in data .columns if is_numeric_dtype (data [col ]) or is_datetime_dtype (data [col ])]
99108
@@ -106,3 +115,32 @@ def create_empty_boolean_mask(data: nw.DataFrame) -> nw.DataFrame:
106115 dict .fromkeys (data .columns , mask_values ),
107116 backend = nw .get_native_namespace (data ),
108117 )
118+
119+
120+ def cast_series_like (series : nw .Series , like : nw .Series , column : int | str ) -> nw .Series :
121+ """Cast series to the dtype of 'like' when possible, otherwise keep original."""
122+ if series .dtype == like .dtype :
123+ return series
124+ dtype : IntoDType = like .dtype
125+
126+ try :
127+ return series .cast (dtype )
128+ except Exception as exc : # noqa: BLE001
129+ msg = f"Failed to cast column { column } to dtype { like .dtype } : { exc } . Keeping inferred dtype."
130+ warnings .warn (msg , stacklevel = 2 )
131+ return series
132+
133+
134+ def _values_to_list (values : Sequence [Any ] | np .ndarray ) -> list [Any ]:
135+ """Normalize values into a list for nw.new_series."""
136+ if isinstance (values , np .ndarray ):
137+ return values .tolist ()
138+ return list (values )
139+
140+
141+ def new_series_like (data : nw .DataFrame , column : int | str , values : Sequence [Any ] | np .ndarray ) -> nw .Series :
142+ """Create a new series for 'column' and cast it back to the original dtype."""
143+ col_name = get_column_str (data , column )
144+ original = get_column (data , column )
145+ series = nw .new_series (col_name , _values_to_list (values ), backend = nw .get_native_namespace (data ))
146+ return cast_series_like (series , original , column )
0 commit comments