Skip to content

Commit 440da10

Browse files
committed
to_frame lib agnostic
1 parent ae0204c commit 440da10

2 files changed

Lines changed: 64 additions & 99 deletions

File tree

great_tables/_formats_vals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
43
from pathlib import Path
4+
from typing import TYPE_CHECKING, Any
55

66
from typing_extensions import TypeAlias
77

8-
from .gt import GT, _get_column_of_values
98
from ._gt_data import GTData
109
from ._tbl_data import SeriesLike, to_frame
10+
from .gt import GT, _get_column_of_values
1111

1212
if TYPE_CHECKING:
1313
from ._formats import DateStyle, TimeStyle

great_tables/_tbl_data.py

Lines changed: 62 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
import re
4-
import warnings
4+
import sys
55
from functools import singledispatch
6-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
6+
from typing import TYPE_CHECKING, Any, Callable, Union
77

88
from typing_extensions import TypeAlias
99

@@ -761,7 +761,6 @@ def _(df: PyArrowTable, x: Any) -> bool:
761761
return arr.is_null().to_pylist()[0] or arr.is_nan().to_pylist()[0]
762762

763763

764-
@singledispatch
765764
def validate_frame(df: DataFrameLike) -> DataFrameLike:
766765
"""Raises an error if a DataFrame is not supported by Great Tables.
767766
@@ -771,101 +770,67 @@ def validate_frame(df: DataFrameLike) -> DataFrameLike:
771770
raise NotImplementedError(f"Unsupported type: {type(df)}")
772771

773772

774-
@validate_frame.register
775-
def _(df: PdDataFrame) -> PdDataFrame:
776-
import pandas as pd
777-
778-
# case 1: multi-index columns ----
779-
if isinstance(df.columns, pd.MultiIndex):
780-
raise ValueError(
781-
"pandas DataFrames with MultiIndex columns are not supported."
782-
" Please use .columns.droplevel() to remove extra column levels,"
783-
" or combine the levels into a single name per column."
784-
)
785-
786-
# case 2: duplicate column names ----
787-
dupes = df.columns[df.columns.duplicated()]
788-
if len(dupes):
789-
raise ValueError(
790-
f"Column names must be unique. Detected duplicate columns:\n\n {list(dupes)}"
791-
)
792-
793-
non_str_cols = [(ii, el) for ii, el in enumerate(df.columns) if not isinstance(el, str)]
794-
795-
if non_str_cols:
796-
_col_msg = "\n".join(f" * Position {ii}: {col}" for ii, col in non_str_cols[:3])
797-
warnings.warn(
798-
"pandas DataFrame contains non-string column names. Coercing to strings. "
799-
"Here are the first few non-string columns:\n\n"
800-
f"{_col_msg}",
801-
category=UserWarning,
802-
)
803-
new_df = df.copy()
804-
new_df.columns = [str(el) for el in df.columns]
805-
return new_df
806-
807-
return df
808-
809-
810-
@validate_frame.register
811-
def _(df: PlDataFrame) -> PlDataFrame:
812-
return df
813-
814-
815-
@validate_frame.register
816-
def _(df: PyArrowTable) -> PyArrowTable:
817-
warnings.warn("PyArrow Table support is currently experimental.")
818-
819-
if len(set(df.column_names)) != len(df.column_names):
820-
raise ValueError("Column names must be unique.")
821-
822-
return df
823-
824-
825-
# to_frame ----
826-
827-
828-
@singledispatch
829-
def to_frame(ser: "list[Any] | SeriesLike", name: Optional[str] = None) -> DataFrameLike:
830-
# TODO: remove pandas. currently, we support converting a list to a pd.DataFrame
831-
# in order to support backwards compatibility in the vals.fmt_* functions.
832-
833-
try:
834-
import pandas as pd
835-
except ImportError:
836-
_raise_pandas_required(
837-
"Passing a plain list of values currently requires the library pandas. "
838-
"You can avoid this error by passing a polars Series."
839-
)
840-
841-
if not isinstance(ser, list):
842-
raise NotImplementedError(f"Unsupported type: {type(ser)}")
843-
844-
if not name:
845-
raise ValueError("name must be specified, when converting a list to a DataFrame.")
846-
847-
return pd.DataFrame({name: ser})
848-
849-
850-
@to_frame.register
851-
def _(ser: PdSeries, name: Optional[str] = None) -> PdDataFrame:
852-
return ser.to_frame(name)
853-
854-
855-
@to_frame.register
856-
def _(ser: PlSeries, name: Optional[str] = None) -> PlDataFrame:
857-
return ser.to_frame(name)
858-
859-
860-
@to_frame.register
861-
def _(ser: PyArrowArray, name: Optional[str] = None) -> PyArrowTable:
862-
import pyarrow as pa
773+
# @validate_frame.register
774+
# def _(df: PdDataFrame) -> PdDataFrame:
775+
# import pandas as pd
776+
777+
# # case 1: multi-index columns ----
778+
# if isinstance(df.columns, pd.MultiIndex):
779+
# raise ValueError(
780+
# "pandas DataFrames with MultiIndex columns are not supported."
781+
# " Please use .columns.droplevel() to remove extra column levels,"
782+
# " or combine the levels into a single name per column."
783+
# )
784+
785+
# # case 2: duplicate column names ----
786+
# dupes = df.columns[df.columns.duplicated()]
787+
# if len(dupes):
788+
# raise ValueError(
789+
# f"Column names must be unique. Detected duplicate columns:\n\n {list(dupes)}"
790+
# )
791+
792+
# non_str_cols = [(ii, el) for ii, el in enumerate(df.columns) if not isinstance(el, str)]
793+
794+
# if non_str_cols:
795+
# _col_msg = "\n".join(f" * Position {ii}: {col}" for ii, col in non_str_cols[:3])
796+
# warnings.warn(
797+
# "pandas DataFrame contains non-string column names. Coercing to strings. "
798+
# "Here are the first few non-string columns:\n\n"
799+
# f"{_col_msg}",
800+
# category=UserWarning,
801+
# )
802+
# new_df = df.copy()
803+
# new_df.columns = [str(el) for el in df.columns]
804+
# return new_df
805+
806+
# return df
807+
808+
809+
def to_frame(ser: list[Any], name: str) -> DataFrameLike:
810+
# TODO: ser can probably be more broad than a list
811+
812+
frame_instantiating_libs: tuple[str, ...] = ("pandas", "polars", "pyarrow")
813+
814+
for lib_name in frame_instantiating_libs:
815+
if lib_name in sys.modules:
816+
module = sys.modules[lib_name]
817+
try:
818+
module = __import__(lib_name)
819+
break
820+
except ImportError:
821+
continue
822+
else:
823+
raise ImportError("None of pandas, polars, or pyarrow could be imported")
863824

864-
return pa.table({name: ser})
825+
# Now implement the conversion based on which library was imported
826+
if lib_name == "pandas":
827+
return module.DataFrame({name: ser})
865828

829+
if lib_name == "polars":
830+
return module.DataFrame({name: ser})
866831

867-
@to_frame.register
868-
def _(ser: PyArrowChunkedArray, name: Optional[str] = None) -> PyArrowTable:
869-
import pyarrow as pa
832+
if lib_name == "pyarrow":
833+
return module.Table.from_arrays([ser], names=[name])
870834

871-
return pa.table({name: ser})
835+
msg = f"Library {lib_name} not supported. Allowed libraries: {frame_instantiating_libs!s}."
836+
raise NotImplementedError(msg)

0 commit comments

Comments
 (0)