1313
1414import numpy as np
1515import pandas as pd
16- import sklearn
16+ from pandas . core . dtypes . common import is_numeric_dtype , is_string_dtype
1717from scipy import sparse as sp
1818from sklearn .utils import check_array
1919from sklearn .utils .validation import column_or_1d
3333# pylint --disable=all --enable=invalid-names dataset.py
3434# pylint: disable=invalid-name
3535
36+ # Set a special pandas option to force the new string data type (`StringDType`)
37+ # even for version 2.0 which is still required for python 3.10.
38+ # This new string data type does not map any longer to the corresponding numpy one
39+ # and will break the code unless a special care is taken
40+ pd .options .future .infer_string = True
41+
3642
3743def check_dataset_spec (ds_spec ):
3844 """Checks that a dataset spec is valid
@@ -393,16 +399,19 @@ def write_internal_data_table(dataframe, file_path_or_stream):
393399
394400
395401def _column_or_1d_with_dtype (y , dtype = None ):
396- # 'dtype' has been introduced on `column_or_1d' since Scikit-learn 1.2;
397- if sklearn .__version__ < "1.2" :
398- if pd .api .types .is_string_dtype (dtype ) and y .isin (["True" , "False" ]).all ():
399- warnings .warn (
400- "'y' stores strings restricted to 'True'/'False' values: "
401- "The predict method may return a bool vector."
402- )
403- return column_or_1d (y , warn = True )
404- else :
405- return column_or_1d (y , warn = True , dtype = dtype )
402+ """Checks the data is of the provided `dtype`.
403+ If a problem is detected a warning is printed or an error raised,
404+ otherwise the pandas object is transformed into a numpy.array
405+ """
406+
407+ # Since pandas 3.0 (and even in 2.0 if the option is activated)
408+ # a new StringDType is used to handle strings.
409+ # It does not match any longer the one recognized by numpy.
410+ # We need to force the translation to the numpy dtype
411+ # whenever a pandas string is detected (`is_string_dtype` returns `True`).
412+ if is_string_dtype (dtype ):
413+ dtype = np .dtype (str )
414+ return column_or_1d (y , warn = True , dtype = dtype )
406415
407416
408417class Dataset :
@@ -965,21 +974,23 @@ def __init__(self, name, dataframe, key=None):
965974
966975 # Initialize feature columns and verify their types
967976 self .column_ids = self .data_source .columns .values
968- if not np .issubdtype (self .column_ids .dtype , np .integer ):
969- if np .issubdtype (self .column_ids .dtype , object ):
970- for i , column_id in enumerate (self .column_ids ):
971- if not isinstance (column_id , str ):
972- raise TypeError (
973- f"Dataframe column ids must be either all integers or "
974- f"all strings. Column id at index { i } ('{ column_id } ') is"
975- f" of type '{ type (column_id ).__name__ } '"
976- )
977- else :
978- raise TypeError (
979- f"Dataframe column ids must be either all integers or "
980- f"all strings. The column index has dtype "
981- f"'{ self .column_ids .dtype } '"
982- )
977+ # Ensure the feature columns are either all string
978+ # or all numeric but not a mix of both.
979+ # Warning : the new pandas string data type (`StringDType`)
980+ # - by default in pandas 3.0 or forced in pandas 2.0 -
981+ # cannot be evaluated by `np.issubdtype`, any attempt will raise an error.
982+ if not is_numeric_dtype (self .column_ids ) and not is_string_dtype (
983+ self .column_ids
984+ ):
985+ previous_type = None
986+ for i , column_id in enumerate (self .column_ids ):
987+ if previous_type is not None and type (column_id ) != previous_type :
988+ raise TypeError (
989+ f"Dataframe column ids must be either all integers or "
990+ f"all strings. Column id at index { i } ('{ column_id } ') is"
991+ f" of type '{ type (column_id ).__name__ } '"
992+ )
993+ previous_type = type (column_id )
983994
984995 # Initialize Khiops types
985996 self .khiops_types = {}
@@ -988,7 +999,8 @@ def __init__(self, name, dataframe, key=None):
988999 column_numpy_type = column .dtype
9891000 column_max_size = None
9901001 if isinstance (column_numpy_type , pd .StringDtype ):
991- column_max_size = column .str .len ().max ()
1002+ # Warning pandas.Series.str.len() returns a float64
1003+ column_max_size = int (column .str .len ().max ())
9921004 self .khiops_types [column_id ] = get_khiops_type (
9931005 column_numpy_type , column_max_size
9941006 )
@@ -1161,7 +1173,7 @@ def __init__(self, name, matrix, key=None):
11611173 raise TypeError (
11621174 type_error_message ("matrix" , matrix , "scipy.sparse.spmatrix" )
11631175 )
1164- if not np . issubdtype (matrix .dtype , np . number ):
1176+ if not is_numeric_dtype (matrix .dtype ):
11651177 raise TypeError (
11661178 type_error_message ("'matrix' dtype" , matrix .dtype , "numeric" )
11671179 )
0 commit comments