Skip to content

Commit fe3a964

Browse files
committed
feat: preserver dtypes across pandas and polars
Adopting narwhals, it's important that dtypes are preserved across backends - a value that's NaN in pandas should be NaN in polars, datetimes should remain datetimes, etc. . This is achieved by - using narwhals dtypes support as much as possible - explicitly casting dtypes after transforming values when necessary
1 parent ad03b94 commit fe3a964

14 files changed

Lines changed: 123 additions & 79 deletions

tab_err/_utils.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from collections.abc import Sequence
7+
8+
from narwhals.typing import IntoDType
9+
310
import random
11+
import warnings
12+
from typing import Any
413

514
import narwhals as nw
615
import numpy as np
@@ -60,7 +69,7 @@ def check_data_emptiness(data: nw.DataFrame) -> None:
6069

6170
def is_string_dtype(series: nw.Series) -> bool:
6271
"""Check if a series has a string dtype."""
63-
return series.dtype == nw.String or series.dtype == nw.Object
72+
return series.dtype in {nw.String, nw.Object}
6473

6574

6675
def is_numeric_dtype(series: nw.Series) -> bool:
@@ -78,22 +87,22 @@ def is_datetime_dtype(series: nw.Series) -> bool:
7887
return series.dtype == nw.Datetime
7988

8089

81-
def select_string_columns(data: nw.DataFrame) -> list[str]:
90+
def select_string_columns(data: nw.DataFrame) -> list[str | int]:
8291
"""Select columns with string dtype."""
8392
return [col for col in data.columns if is_string_dtype(data[col])]
8493

8594

86-
def select_numeric_columns(data: nw.DataFrame) -> list[str]:
95+
def select_numeric_columns(data: nw.DataFrame) -> list[str | int]:
8796
"""Select columns with numeric dtype."""
8897
return [col for col in data.columns if is_numeric_dtype(data[col])]
8998

9099

91-
def select_datetime_columns(data: nw.DataFrame) -> list[str]:
100+
def select_datetime_columns(data: nw.DataFrame) -> list[str | int]:
92101
"""Select columns with datetime dtype."""
93102
return [col for col in data.columns if is_datetime_dtype(data[col])]
94103

95104

96-
def select_numeric_or_datetime_columns(data: nw.DataFrame) -> list[str]:
105+
def select_numeric_or_datetime_columns(data: nw.DataFrame) -> list[str | int]:
97106
"""Select columns with numeric or datetime dtype."""
98107
return [col for col in data.columns if is_numeric_dtype(data[col]) or is_datetime_dtype(data[col])]
99108

@@ -106,3 +115,32 @@ def create_empty_boolean_mask(data: nw.DataFrame) -> nw.DataFrame:
106115
dict.fromkeys(data.columns, mask_values),
107116
backend=nw.get_native_namespace(data),
108117
)
118+
119+
120+
def cast_series_like(series: nw.Series, like: nw.Series, column: int | str) -> nw.Series:
121+
"""Cast series to the dtype of 'like' when possible, otherwise keep original."""
122+
if series.dtype == like.dtype:
123+
return series
124+
dtype: IntoDType = like.dtype
125+
126+
try:
127+
return series.cast(dtype)
128+
except Exception as exc: # noqa: BLE001
129+
msg = f"Failed to cast column {column} to dtype {like.dtype}: {exc}. Keeping inferred dtype."
130+
warnings.warn(msg, stacklevel=2)
131+
return series
132+
133+
134+
def _values_to_list(values: Sequence[Any] | np.ndarray) -> list[Any]:
135+
"""Normalize values into a list for nw.new_series."""
136+
if isinstance(values, np.ndarray):
137+
return values.tolist()
138+
return list(values)
139+
140+
141+
def new_series_like(data: nw.DataFrame, column: int | str, values: Sequence[Any] | np.ndarray) -> nw.Series:
142+
"""Create a new series for 'column' and cast it back to the original dtype."""
143+
col_name = get_column_str(data, column)
144+
original = get_column(data, column)
145+
series = nw.new_series(col_name, _values_to_list(values), backend=nw.get_native_namespace(data))
146+
return cast_series_like(series, original, column)

tab_err/api/high_level.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from tab_err import ErrorMechanism, ErrorType, error_mechanism, error_type
99
from tab_err._error_model import ErrorModel
10-
from tab_err._utils import check_data_emptiness, check_error_rate, create_empty_boolean_mask, seed_randomness_and_get_generator
10+
from tab_err._utils import check_data_emptiness, check_error_rate, seed_randomness_and_get_generator
1111
from tab_err.api import MidLevelConfig, mid_level
1212

1313
if TYPE_CHECKING:
@@ -124,7 +124,7 @@ def _build_column_mechanism_dictionary(
124124
msg = "Possible conflict in error mechanisms to apply. Set at least on of: error_mechanisms_to_exclude or error_mechanisms_to_include to None."
125125
raise ValueError(msg)
126126

127-
columns_mechanisms = {}
127+
columns_mechanisms: dict[int | str, list[ErrorMechanism]] = {}
128128

129129
if error_mechanisms_to_include is not None and error_mechanisms_to_exclude is None: # Include specified
130130
if not all(issubclass(type(cls), ErrorMechanism) for cls in error_mechanisms_to_include): # Check input
@@ -177,7 +177,7 @@ def _build_column_number_of_models_dictionary(
177177
Returns:
178178
dict[int | str, int]: A dictionary mapping from column names to the number of error models to apply to that column.
179179
"""
180-
column_num_models = {}
180+
column_num_models: dict[int | str, int] = {}
181181

182182
for column in data.columns:
183183
column_num_models[column] = len(column_types[column]) * len(column_mechanisms[column])
@@ -202,7 +202,7 @@ def create_errors( # noqa: PLR0913
202202
"""Creates errors in a given DataFrame, at a rate of *approximately* max_error_rate.
203203
204204
Args:
205-
data (IntoDataFrame): The DataFrame to create errors in. Supports pandas, Polars, and other narwhals-compatible backends.
205+
data (IntoDataFrame): The DataFrame to create errors in. Supports pandas, Polars, and (experimental) other narwhals-compatible backends.
206206
error_rate (float): The maximum error rate to be introduced to each column in the DataFrame.
207207
n_error_models_per_column (int, optional): The number of valid error models to apply to each column. Defaults to 1.
208208
error_types_to_include (list[ErrorType] | None, optional): A list of the error types to be included when building error models. Defaults to None.
@@ -231,7 +231,6 @@ def create_errors( # noqa: PLR0913
231231

232232
# Set Up Data
233233
data_copy = data_nw.clone()
234-
error_mask = create_empty_boolean_mask(data_nw)
235234

236235
# Build Dictionaries
237236
col_type = _build_column_type_dictionary(
@@ -272,5 +271,5 @@ def create_errors( # noqa: PLR0913
272271
raise ValueError(msg)
273272

274273
# Create Errors & Return (mid_level handles native conversion)
275-
dirty_data, error_mask = mid_level.create_errors(nw.to_native(data_copy), config)
276-
return dirty_data, error_mask
274+
dirty_data_native, error_mask_native = mid_level.create_errors(nw.to_native(data_copy), config)
275+
return dirty_data_native, error_mask_native

tab_err/error_type/_add_delta.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import warnings
4+
from typing import TYPE_CHECKING
45

5-
import narwhals as nw
6+
if TYPE_CHECKING:
7+
import narwhals as nw
68
import numpy as np
79

8-
from tab_err._utils import get_column, get_column_str, is_datetime_dtype, is_numeric_dtype, select_numeric_or_datetime_columns
10+
from tab_err._utils import get_column, is_datetime_dtype, is_numeric_dtype, new_series_like, select_numeric_or_datetime_columns
911

1012
from ._error_type import ErrorType
1113

@@ -39,7 +41,6 @@ def _apply(self: AddDelta, data: nw.DataFrame, error_mask: nw.DataFrame, column:
3941
Returns:
4042
nw.Series: The data column, 'column', after AddDelta errors at the locations specified by 'error_mask' are introduced.
4143
"""
42-
col_name = get_column_str(data, column)
4344
series = get_column(data, column)
4445
series_mask = get_column(error_mask, column)
4546
was_datetime = False
@@ -62,13 +63,14 @@ def _apply(self: AddDelta, data: nw.DataFrame, error_mask: nw.DataFrame, column:
6263
mean_val = np.nanmean(data_arr)
6364
std_val = np.nanstd(data_arr)
6465
random_choice = self._random_generator.choice(data_arr[~np.isnan(data_arr)])
65-
self.config.add_delta_value = (random_choice - mean_val) / std_val if std_val != 0 else 0
66+
delta_value = (random_choice - mean_val) / std_val if std_val != 0 else 0
67+
else:
68+
delta_value = self.config.add_delta_value
6669

67-
# Apply delta where mask is True
68-
data_arr[mask_arr] += self.config.add_delta_value
70+
data_arr[mask_arr] += delta_value
6971

7072
if was_datetime:
7173
# Convert back to datetime (from seconds)
7274
data_arr = (data_arr * 10**9).astype("int64").astype("datetime64[ns]")
7375

74-
return nw.new_series(col_name, data_arr.tolist(), backend=nw.get_native_namespace(data))
76+
return new_series_like(data, column, data_arr)

tab_err/error_type/_category_swap.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import random
44

55
import narwhals as nw
6-
import numpy as np
76

8-
from tab_err._utils import get_column, get_column_str
7+
from tab_err._utils import get_column, new_series_like
98

109
from ._error_type import ErrorType
1110

@@ -40,7 +39,7 @@ def _check_type(data: nw.DataFrame, column: int | str) -> None:
4039

4140
def _get_valid_columns(self: CategorySwap, data: nw.DataFrame) -> list[str | int]:
4241
"""Checks which columns are categorical and returns the indices of those with two or more categories."""
43-
valid_columns = []
42+
valid_columns: list[str | int] = []
4443
for col_name in data.columns:
4544
series = get_column(data, col_name)
4645

@@ -65,7 +64,6 @@ def _apply(self: CategorySwap, data: nw.DataFrame, error_mask: nw.DataFrame, col
6564
Returns:
6665
nw.Series: The data column, 'column', after CategorySwap errors at the locations specified by 'error_mask' are introduced.
6766
"""
68-
col_name = get_column_str(data, column)
6967
series = get_column(data, column)
7068
series_mask = get_column(error_mask, column)
7169

@@ -84,16 +82,16 @@ def sample_label(old_label: str) -> str:
8482

8583
elif self.config.mislabel_weighing == "frequency":
8684
# Calculate frequency weights
87-
value_counts = {}
85+
value_counts: dict[str, int] = {}
8886
for val in data_arr:
8987
if val not in value_counts:
9088
value_counts[val] = 0
9189
value_counts[val] += 1
9290

9391
def sample_label(old_label: str) -> str:
9492
choices = [x for x in categories if x != old_label]
95-
weights = [value_counts.get(x, 1) for x in choices]
96-
total = sum(weights)
93+
weights: list[float] = [float(value_counts.get(x, 1)) for x in choices]
94+
total = float(sum(weights))
9795
weights = [w / total for w in weights]
9896
return random.choices(choices, weights=weights, k=1)[0]
9997
else:
@@ -105,4 +103,4 @@ def sample_label(old_label: str) -> str:
105103
if mask_arr[i]:
106104
data_arr[i] = sample_label(data_arr[i])
107105

108-
return nw.new_series(col_name, data_arr.tolist(), backend=nw.get_native_namespace(data))
106+
return new_series_like(data, column, data_arr)

tab_err/error_type/_error_type.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
from abc import ABC, abstractmethod
44
from typing import TYPE_CHECKING, Any
55

6-
import narwhals as nw
7-
86
from tab_err._utils import seed_randomness_and_get_generator
97

108
from ._config import ErrorTypeConfig
119

1210
if TYPE_CHECKING:
11+
import narwhals as nw
1312
import numpy as np
1413

1514

tab_err/error_type/_extraneous.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
import string
44
import warnings
5+
from typing import TYPE_CHECKING
56

6-
import narwhals as nw
7-
8-
from tab_err._utils import get_column, get_column_str, select_string_columns
7+
from tab_err._utils import get_column, new_series_like, select_string_columns
98

109
from ._error_type import ErrorType
1110

11+
if TYPE_CHECKING:
12+
import narwhals as nw
1213

1314
class Extraneous(ErrorType):
1415
"""Adds Extraneous strings around the values in a column."""
@@ -45,7 +46,6 @@ def _apply(self: Extraneous, data: nw.DataFrame, error_mask: nw.DataFrame, colum
4546
Returns:
4647
nw.Series: The data column, 'column', after Extraneous errors at the locations specified by 'error_mask' are introduced.
4748
"""
48-
col_name = get_column_str(data, column)
4949
series = get_column(data, column)
5050
series_mask = get_column(error_mask, column)
5151

@@ -69,4 +69,4 @@ def _apply(self: Extraneous, data: nw.DataFrame, error_mask: nw.DataFrame, colum
6969
val = data_arr[i]
7070
data_arr[i] = self.config.extraneous_value_template.format(value=val)
7171

72-
return nw.new_series(col_name, data_arr.tolist(), backend=nw.get_native_namespace(data))
72+
return new_series_like(data, column, data_arr)

tab_err/error_type/_missing.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3-
import narwhals as nw
4-
import numpy as np
3+
from typing import TYPE_CHECKING, Union, cast
54

6-
from tab_err._utils import get_column, get_column_str, is_string_dtype, select_string_columns
5+
from tab_err._utils import get_column, is_string_dtype, new_series_like, select_string_columns
76

87
from ._error_type import ErrorType
98

9+
if TYPE_CHECKING:
10+
import narwhals as nw
1011

1112
class MissingValue(ErrorType):
1213
"""Insert missing values into a column.
@@ -22,7 +23,9 @@ def _check_type(data: nw.DataFrame, column: int | str) -> None:
2223

2324
def _get_valid_columns(self: MissingValue, data: nw.DataFrame) -> list[str | int]:
2425
"""If the config missing value is None, returns all columns. Otherwise, only the columns with string type."""
25-
return data.columns if self.config.missing_value is None else select_string_columns(data)
26+
if self.config.missing_value is None:
27+
return cast("list[Union[str, int]]", list(data.columns))
28+
return select_string_columns(data)
2629

2730
def _apply(self: MissingValue, data: nw.DataFrame, error_mask: nw.DataFrame, column: int | str) -> nw.Series:
2831
"""Applies the MissingValue ErrorType to a column of data.
@@ -35,7 +38,6 @@ def _apply(self: MissingValue, data: nw.DataFrame, error_mask: nw.DataFrame, col
3538
Returns:
3639
nw.Series: The data column, 'column', after MissingValue errors at the locations specified by 'error_mask' are introduced.
3740
"""
38-
col_name = get_column_str(data, column)
3941
series = get_column(data, column)
4042
series_mask = get_column(error_mask, column)
4143

@@ -57,4 +59,4 @@ def _apply(self: MissingValue, data: nw.DataFrame, error_mask: nw.DataFrame, col
5759
data_arr[mask_arr] = missing_val
5860

5961
# Create new series with the modified data
60-
return nw.new_series(col_name, data_arr.tolist(), backend=nw.get_native_namespace(data))
62+
return new_series_like(data, column, data_arr)

tab_err/error_type/_mistype.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import narwhals as nw
4-
import numpy as np
54

65
from tab_err._utils import get_column, get_column_str
76

@@ -60,9 +59,7 @@ def _apply(self: Mistype, data: nw.DataFrame, error_mask: nw.DataFrame, column:
6059
target_dtype = "object"
6160
elif current_dtype == nw.Int64:
6261
target_dtype = "float64"
63-
elif current_dtype == nw.Float64:
64-
target_dtype = "int64"
65-
elif current_dtype == nw.Boolean:
62+
elif current_dtype in {nw.Float64, nw.Boolean}:
6663
target_dtype = "int64"
6764
elif current_dtype.is_integer():
6865
target_dtype = "float64"

tab_err/error_type/_mojibake.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

33
import random
4+
from typing import TYPE_CHECKING
45

5-
import narwhals as nw
6-
7-
from tab_err._utils import get_column, get_column_str, is_string_dtype, select_string_columns
6+
from tab_err._utils import get_column, is_string_dtype, new_series_like, select_string_columns
87

98
from ._error_type import ErrorType
109

10+
if TYPE_CHECKING:
11+
import narwhals as nw
1112

1213
class Mojibake(ErrorType):
1314
"""Inserts mojibake into a column containing strings."""
@@ -53,7 +54,6 @@ def _apply(self: Mojibake, data: nw.DataFrame, error_mask: nw.DataFrame, column:
5354
"iso-8859-2": top10 - {"iso-8859-2", "windows-1250", "iso-8859-1", "windows-1252"},
5455
}
5556

56-
col_name = get_column_str(data, column)
5757
series = get_column(data, column)
5858
series_mask = get_column(error_mask, column)
5959

@@ -77,4 +77,4 @@ def _apply(self: Mojibake, data: nw.DataFrame, error_mask: nw.DataFrame, column:
7777
if val is not None and isinstance(val, str):
7878
data_arr[i] = val.encode(encoding_sender, errors="ignore").decode(encoding_receiver, errors="ignore")
7979

80-
return nw.new_series(col_name, data_arr.tolist(), backend=nw.get_native_namespace(data))
80+
return new_series_like(data, column, data_arr)

0 commit comments

Comments
 (0)