|
14 | 14 |
|
15 | 15 | import numpy as np |
16 | 16 | import pandas as pd |
| 17 | +import sklearn |
17 | 18 | from scipy import sparse as sp |
18 | 19 | from sklearn.utils import check_array |
19 | 20 | from sklearn.utils.validation import column_or_1d |
@@ -430,6 +431,16 @@ def write_internal_data_table(dataframe, file_path_or_stream): |
430 | 431 | ) |
431 | 432 |
|
432 | 433 |
|
| 434 | +def _column_or_1d_with_dtype(y, dtype=None): |
| 435 | + # 'dtype' has been introduced on `column_or_1d' since Scikit-learn 1.2; |
| 436 | + if sklearn.__version__ < "1.2": |
| 437 | + if pd.api.types.is_string_dtype(dtype) and y.isin(["True", "False"]).all(): |
| 438 | + warnings.warn("'y' contains strings of 'True' and 'False'") |
| 439 | + return column_or_1d(y, warn=True) |
| 440 | + else: |
| 441 | + return column_or_1d(y, warn=True, dtype=dtype) |
| 442 | + |
| 443 | + |
433 | 444 | class Dataset: |
434 | 445 | """A representation of a dataset |
435 | 446 |
|
@@ -740,20 +751,20 @@ def _init_target_column(self, y): |
740 | 751 | else: |
741 | 752 | if hasattr(y, "dtype"): |
742 | 753 | if isinstance(y.dtype, pd.CategoricalDtype): |
743 | | - y_checked = column_or_1d( |
744 | | - y, warn=True, dtype=y.dtype.categories.dtype |
| 754 | + y_checked = _column_or_1d_with_dtype( |
| 755 | + y, dtype=y.dtype.categories.dtype |
745 | 756 | ) |
746 | 757 | else: |
747 | | - y_checked = column_or_1d(y, warn=True, dtype=y.dtype) |
| 758 | + y_checked = _column_or_1d_with_dtype(y, dtype=y.dtype) |
748 | 759 | elif hasattr(y, "dtypes"): |
749 | 760 | if isinstance(y.dtypes[0], pd.CategoricalDtype): |
750 | | - y_checked = column_or_1d( |
751 | | - y, warn=True, dtype=y.dtypes[0].categories.dtype |
| 761 | + y_checked = _column_or_1d_with_dtype( |
| 762 | + y, dtype=y.dtypes[0].categories.dtype |
752 | 763 | ) |
753 | 764 | else: |
754 | | - y_checked = column_or_1d(y, warn=True) |
| 765 | + y_checked = _column_or_1d_with_dtype(y) |
755 | 766 | else: |
756 | | - y_checked = column_or_1d(y, warn=True) |
| 767 | + y_checked = _column_or_1d_with_dtype(y) |
757 | 768 | # Check the target type coherence with those of X's tables |
758 | 769 | if isinstance( |
759 | 770 | self.main_table, (PandasTable, SparseTable, NumpyTable) |
|
0 commit comments