Skip to content

Commit cdec887

Browse files
committed
Update scikit-learn imports for version 1.8.
And remove upper bounds on dependencies, to match main tabpfn package.
1 parent 452708b commit cdec887

3 files changed

Lines changed: 1829 additions & 1106 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ name = "tabpfn-extensions"
77
version = "0.2.2"
88
dependencies = [
99
"torch>=2.1,<3",
10-
"pandas>=1.4.0,<3",
11-
"scikit-learn>=1.6.0,<1.7",
12-
"scipy>=1.11.1,<2",
10+
"pandas>=1.4.0",
11+
"scikit-learn>=1.6.0",
12+
"scipy>=1.11.1",
1313
"tabpfn>=6.0.5,<8",
1414
"tabpfn-common-utils[telemetry-interactive]>=0.2.0",
1515
]

src/tabpfn_extensions/misc/sklearn_compat.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def _raise_for_params(params, owner, method):
245245
f" details. Extra parameters passed are: {set(params)}",
246246
)
247247

248-
def _is_pandas_df(X):
248+
def is_pandas_df(X):
249249
"""Return True if the X is a pandas dataframe."""
250250
try:
251251
pd = sys.modules["pandas"]
@@ -260,7 +260,6 @@ def _is_pandas_df(X):
260260
)
261261
from sklearn.utils.validation import (
262262
_is_fitted, # noqa: F401
263-
_is_pandas_df, # noqa: F401
264263
)
265264

266265

@@ -861,3 +860,15 @@ def parametrize_with_checks(
861860
check_X_y, # noqa: F401
862861
validate_data, # noqa: F401
863862
)
863+
864+
865+
########################################################################################
866+
# Upgrading for scikit-learn 1.8
867+
########################################################################################
868+
869+
870+
if sklearn_version < parse_version("1.8"):
871+
if sklearn_version >= parse_version("1.4"):
872+
from sklearn.utils.validation import _is_pandas_df as is_pandas_df
873+
else:
874+
from sklearn.utils._dataframe import is_pandas_df # noqa: F401

0 commit comments

Comments
 (0)