Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ name = "tabpfn-extensions"
version = "0.2.2"
dependencies = [
"torch>=2.1,<3",
"pandas>=1.4.0,<3",
"scikit-learn>=1.6.0,<1.7",
"scipy>=1.11.1,<2",
"pandas>=1.4.0",
"scikit-learn>=1.6.0",
"scipy>=1.11.1",
"tabpfn>=6.0.5,<8",
"tabpfn-common-utils[telemetry-interactive]>=0.2.0",
]
Expand Down
15 changes: 13 additions & 2 deletions src/tabpfn_extensions/misc/sklearn_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _raise_for_params(params, owner, method):
f" details. Extra parameters passed are: {set(params)}",
)

def _is_pandas_df(X):
def is_pandas_df(X):
"""Return True if the X is a pandas dataframe."""
try:
pd = sys.modules["pandas"]
Expand All @@ -260,7 +260,6 @@ def _is_pandas_df(X):
)
from sklearn.utils.validation import (
_is_fitted, # noqa: F401
_is_pandas_df, # noqa: F401
)


Expand Down Expand Up @@ -861,3 +860,15 @@ def parametrize_with_checks(
check_X_y, # noqa: F401
validate_data, # noqa: F401
)


########################################################################################
# Upgrading for scikit-learn 1.8
########################################################################################


if sklearn_version < parse_version("1.8"):
if sklearn_version >= parse_version("1.4"):
from sklearn.utils.validation import _is_pandas_df as is_pandas_df
else:
from sklearn.utils._dataframe import is_pandas_df # noqa: F401
Comment on lines +870 to +874
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested if structure can be simplified to a more readable if/elif block. Additionally, the import on line 872 is missing a # noqa: F401 comment, which is required by the project's linting rules (Ruff) for unused imports intended for export, as seen elsewhere in this file.

Suggested change
if sklearn_version < parse_version("1.8"):
if sklearn_version >= parse_version("1.4"):
from sklearn.utils.validation import _is_pandas_df as is_pandas_df
else:
from sklearn.utils._dataframe import is_pandas_df # noqa: F401
if sklearn_version >= parse_version("1.8"):
from sklearn.utils._dataframe import is_pandas_df # noqa: F401
elif sklearn_version >= parse_version("1.4"):
from sklearn.utils.validation import _is_pandas_df as is_pandas_df # noqa: F401

Loading
Loading