Skip to content

Commit 3baff2c

Browse files
committed
add conditional code for compatibility sklearn < 1.2
1 parent a404e40 commit 3baff2c

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

khiops/sklearn/dataset.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
import sklearn
1718
from scipy import sparse as sp
1819
from sklearn.utils import check_array
1920
from sklearn.utils.validation import column_or_1d
@@ -430,6 +431,16 @@ def write_internal_data_table(dataframe, file_path_or_stream):
430431
)
431432

432433

434+
def _column_or_1d_with_dtype(y, dtype=None):
435+
# 'dtype' has been introduced on `column_or_1d' since Scikit-learn 1.2;
436+
if sklearn.__version__ < "1.2":
437+
if pd.api.types.is_string_dtype(dtype) and y.isin(["True", "False"]).all():
438+
warnings.warn("'y' contains strings of 'True' and 'False'")
439+
return column_or_1d(y, warn=True)
440+
else:
441+
return column_or_1d(y, warn=True, dtype=dtype)
442+
443+
433444
class Dataset:
434445
"""A representation of a dataset
435446
@@ -740,20 +751,20 @@ def _init_target_column(self, y):
740751
else:
741752
if hasattr(y, "dtype"):
742753
if isinstance(y.dtype, pd.CategoricalDtype):
743-
y_checked = column_or_1d(
744-
y, warn=True, dtype=y.dtype.categories.dtype
754+
y_checked = _column_or_1d_with_dtype(
755+
y, dtype=y.dtype.categories.dtype
745756
)
746757
else:
747-
y_checked = column_or_1d(y, warn=True, dtype=y.dtype)
758+
y_checked = _column_or_1d_with_dtype(y, dtype=y.dtype)
748759
elif hasattr(y, "dtypes"):
749760
if isinstance(y.dtypes[0], pd.CategoricalDtype):
750-
y_checked = column_or_1d(
751-
y, warn=True, dtype=y.dtypes[0].categories.dtype
761+
y_checked = _column_or_1d_with_dtype(
762+
y, dtype=y.dtypes[0].categories.dtype
752763
)
753764
else:
754-
y_checked = column_or_1d(y, warn=True)
765+
y_checked = _column_or_1d_with_dtype(y)
755766
else:
756-
y_checked = column_or_1d(y, warn=True)
767+
y_checked = _column_or_1d_with_dtype(y)
757768
# Check the target type coherence with those of X's tables
758769
if isinstance(
759770
self.main_table, (PandasTable, SparseTable, NumpyTable)

0 commit comments

Comments
 (0)