diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 812f66b882956..d99af64ce90c0 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -765,6 +765,18 @@ def axes(self) -> List: """ return [self.index, self.columns] + @property + def _constructor(self) -> Any: + # Return the class of the current instance to support subclassing + return self.__class__ + + @property + def _constructor_sliced(self) -> Any: + # Return the Series class for operations that reduce dimensionality + from pyspark.pandas.series import Series + + return Series + @with_ansi_mode_context def _reduce_for_stat_function( self, @@ -913,7 +925,7 @@ def _apply_series_op( internal = self._internal.with_new_columns(applied) if should_resolve: internal = internal.resolved_copy - return DataFrame(internal) + return self._constructor(internal) # Arithmetic Operators @with_ansi_mode_context @@ -970,7 +982,7 @@ def apply_op( ) column_labels.append(label) internal = self._internal.with_new_columns(applied, column_labels=column_labels) - return DataFrame(internal) + return self._constructor(internal) else: return self._apply_series_op(lambda psser: getattr(psser, op)(other)) @@ -1748,7 +1760,7 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D sdf = sdf.select(*index_col_names, *numeric_col_names) - return DataFrame( + return self._constructor( InternalFrame( spark_frame=sdf, index_spark_columns=[ @@ -2925,7 +2937,7 @@ def transpose(self) -> "DataFrame": max_compute_count ) ) - return DataFrame(pdf.transpose()) + return self._constructor(pdf.transpose()) # Explode the data to be pairs. # @@ -3008,7 +3020,7 @@ def transpose(self) -> "DataFrame": column_label_names=self._internal.index_names, ) - return DataFrame(internal) + return self._constructor(internal) T = property(transpose) @@ -3500,7 +3512,7 @@ def transform( ) internal = self._internal.with_new_columns(applied, data_fields=data_fields) - return DataFrame(internal) + return self._constructor(internal) else: return self._apply_series_op( lambda psser: psser.pandas_on_spark.transform_batch(func, *args, **kwargs) @@ -3711,7 +3723,7 @@ class locomotion index_names=index_names, index_fields=index_fields, ).resolved_copy - return DataFrame(internal) + return self._constructor(internal) def between_time( self, @@ -3815,7 +3827,7 @@ def pandas_between_time( # type: ignore[no-untyped-def] with option_context("compute.default_index_type", "distributed"): psdf = psdf.pandas_on_spark.apply_batch(pandas_between_time) - return DataFrame( + return self._constructor( self._internal.copy( spark_frame=psdf._internal.spark_frame, index_spark_columns=psdf._internal.data_spark_columns[:1], @@ -3894,7 +3906,7 @@ def pandas_at_time( # type: ignore[no-untyped-def] with option_context("compute.default_index_type", "distributed"): psdf = psdf.pandas_on_spark.apply_batch(pandas_at_time) - return DataFrame( + return self._constructor( self._internal.copy( spark_frame=psdf._internal.spark_frame, index_spark_columns=psdf._internal.data_spark_columns[:1], @@ -4098,7 +4110,7 @@ def where( .alias(psdf._internal.spark_column_name_for(label)) for label in self._internal.column_labels ] - return DataFrame( + return self._constructor( psdf._internal.with_new_columns( data_spark_columns, column_labels=self._internal.column_labels, # TODO: dtypes? @@ -4358,7 +4370,7 @@ def set_index( self._update_internal_frame(internal) return None else: - return DataFrame(internal) + return self._constructor(internal) def reset_index( self, @@ -4636,7 +4648,7 @@ def rename(index: int) -> Label: self._update_internal_frame(internal) return None else: - return DataFrame(internal) + return self._constructor(internal) def isnull(self) -> "DataFrame": """ @@ -5778,7 +5790,7 @@ def _assign(self, kwargs: Any) -> "DataFrame": column_labels=column_labels, data_fields=data_fields, ) - return DataFrame(internal) + return self._constructor(internal) @staticmethod def from_records( @@ -5957,7 +5969,7 @@ def copy(self, deep: bool = True) -> "DataFrame": 0 1 3 5 7 1 2 4 6 8 """ - return DataFrame(self._internal) + return self._constructor(self._internal) def dropna( self, @@ -6105,7 +6117,7 @@ def dropna( self._update_internal_frame(internal) return None else: - return DataFrame(internal) + return self._constructor(internal) else: assert axis == 1 @@ -6649,12 +6661,12 @@ def head(self, n: int = 5) -> "DataFrame": if n < 0: n = len(self) + n if n <= 0: - return DataFrame(self._internal.with_filter(F.lit(False))) + return self._constructor(self._internal.with_filter(F.lit(False))) else: sdf = self._internal.resolved_copy.spark_frame if get_option("compute.ordered_head"): sdf = sdf.orderBy(NATURAL_ORDER_COLUMN_NAME) - return DataFrame(self._internal.with_new_sdf(sdf.limit(n))) + return self._constructor(self._internal.with_new_sdf(sdf.limit(n))) def last(self, offset: Union[str, DateOffset]) -> "DataFrame": """ @@ -7228,7 +7240,7 @@ def pivot( return df else: internal = df._internal.copy(index_names=self._internal.index_names) - return DataFrame(internal) + return self._constructor(internal) @property def columns(self) -> pd.Index: @@ -7492,7 +7504,7 @@ def select_dtypes( if should_include: column_labels.append(label) - return DataFrame( + return self._constructor( self._internal.with_new_columns([self._psser_for(label) for label in column_labels]) ) @@ -7603,7 +7615,7 @@ def droplevel( index_names=list(index_names), index_fields=list(index_fields), ) - return DataFrame(internal) + return self._constructor(internal) else: psdf = self.copy() psdf.columns = psdf.columns.droplevel(level) @@ -7780,7 +7792,7 @@ def drop( raise KeyError(missing) if len(drop_column_labels) == 0: - return DataFrame(internal) + return self._constructor(internal) keep_columns_and_labels = [ (column, label) @@ -7798,7 +7810,7 @@ def drop( internal = internal.with_new_columns( [self._psser_for(label) for label in labels] ) - return DataFrame(internal) + return self._constructor(internal) def _prepare_sort_by_scols(self, by: Union[Name, List[Name]]) -> List[PySparkColumn]: if is_name_like_value(by): @@ -7851,7 +7863,7 @@ def _sort( elif keep != "first": raise ValueError('keep must be either "first", "last" or "all".') sdf = self._internal.resolved_copy.spark_frame.sort(*by, natural_order_scol) - return DataFrame(self._internal.with_new_sdf(sdf)) + return self._constructor(self._internal.with_new_sdf(sdf)) def sort_values( self, @@ -8190,7 +8202,7 @@ def swaplevel( assert axis == 1 internal = self._swaplevel_columns(i, j) - return DataFrame(internal) + return self._constructor(internal) def swapaxes(self, i: Axis, j: Axis, copy: bool = True) -> "DataFrame": """ @@ -8644,7 +8656,7 @@ def isin(self, values: Union[List, Dict]) -> "DataFrame": else: raise TypeError("Values should be iterable, Series, DataFrame or dict.") - return DataFrame( + return self._constructor( self._internal.with_new_columns( data_spark_columns, data_fields=[ @@ -9017,7 +9029,7 @@ def right_scol_for(label: Label) -> PySparkColumn: column_labels=column_labels, data_spark_columns=[scol_for(selected_columns, col) for col in data_columns], ) - return DataFrame(internal) + return self._constructor(internal) def join( self, @@ -9231,7 +9243,7 @@ def combine_first(self, other: "DataFrame") -> "DataFrame": data_fields=None, # TODO: dtype? column_label_names=self._internal.column_label_names, ) - return DataFrame(internal) + return self._constructor(internal) def update( self, @@ -9522,11 +9534,11 @@ def cov(self, min_periods: Optional[int] = None, ddof: int = 1) -> "DataFrame": cov = np.zeros([num_cols, num_cols]) if num_cols == 0: - return DataFrame() + return self._constructor() if len(psdf) < min_periods: cov.fill(np.nan) - return DataFrame(cov, columns=psdf.columns, index=psdf.columns) + return self._constructor(cov, columns=psdf.columns, index=psdf.columns) data_cols = psdf._internal.data_spark_column_names cov_scols = [] @@ -9614,7 +9626,7 @@ def cov(self, min_periods: Optional[int] = None, ddof: int = 1) -> "DataFrame": # b None cov(b, b) cov(b, c) # c cov(a, c) cov(b, c) cov(c, c) cov = cov + cov.T - np.diag(np.diag(cov)) - return DataFrame(cov, columns=psdf.columns, index=psdf.columns) + return self._constructor(cov, columns=psdf.columns, index=psdf.columns) def sample( self, @@ -9717,9 +9729,9 @@ def sample( withReplacement=replace, fraction=frac, seed=random_state ) if ignore_index: - return DataFrame(sdf.drop(*self._internal.index_spark_column_names)) + return self._constructor(sdf.drop(*self._internal.index_spark_column_names)) else: - return DataFrame(self._internal.with_new_sdf(sdf)) + return self._constructor(self._internal.with_new_sdf(sdf)) def astype(self, dtype: Union[str, Dtype, Dict[Name, Union[str, Dtype]]]) -> "DataFrame": """ @@ -9791,7 +9803,7 @@ def astype(self, dtype: Union[str, Dtype, Dict[Name, Union[str, Dtype]]]) -> "Da else: for col_name, col in self.items(): applied.append(col.astype(dtype=cast(Union[str, Dtype], dtype))) - return DataFrame(self._internal.with_new_columns(applied)) + return self._constructor(self._internal.with_new_columns(applied)) def add_prefix(self, prefix: str) -> "DataFrame": """ @@ -10081,7 +10093,7 @@ def describe(self, percentiles: Optional[List[float]] = None) -> "DataFrame": data = dict() for psser in psser_string: data[psser.name] = [0, 0, np.nan, np.nan] - return DataFrame(data, index=["count", "unique", "top", "freq"]) + return self._constructor(data, index=["count", "unique", "top", "freq"]) if len(exprs_string) == 1: # Fast path for single column (e.g. Series.describe): avoid unpivot overhead. @@ -10517,7 +10529,7 @@ def _reindex_index( if isinstance(index, ps.Index): if nlevels != index.nlevels: - return DataFrame(index._internal.with_new_columns([])).reindex( + return self._constructor(index._internal.with_new_columns([])).reindex( columns=self.columns, fill_value=fill_value ) @@ -10596,7 +10608,7 @@ def _reindex_index( ], data_fields=data_fields, ) - return DataFrame(internal) + return self._constructor(internal) def _reindex_columns( self, columns: Optional[Union[pd.Index, Sequence[Any]]], fill_value: Optional[Any] @@ -10635,7 +10647,7 @@ def _reindex_columns( else: internal = self._internal.with_new_columns(scols_or_pssers, column_labels=labels) - return DataFrame(internal) + return self._constructor(internal) def reindex_like(self, other: "DataFrame", copy: bool = True) -> "DataFrame": """ @@ -10946,7 +10958,7 @@ def melt( ) exploded_df = sdf.withColumn("pairs", pairs).select(columns) - return DataFrame( + return self._constructor( InternalFrame( spark_frame=exploded_df, index_spark_columns=None, @@ -11067,7 +11079,7 @@ def stack(self) -> DataFrameOrSeries: from pyspark.pandas.series import first_series if len(self._internal.column_labels) == 0: - return DataFrame( + return self._constructor( self._internal.copy( column_label_names=self._internal.column_label_names[:-1] ).with_filter(F.lit(False)) @@ -11243,7 +11255,7 @@ def unstack(self) -> DataFrameOrSeries: ] ), ) - return DataFrame(internal) + return self._constructor(internal) # TODO: Codes here are similar with melt. Should we deduplicate? column_labels = self._internal.column_labels @@ -11765,7 +11777,7 @@ def rank( pdf = psdf.head(limit + 1)._to_internal_pandas() if len(pdf) <= limit: pdf_rank = pdf.rank(method=method, ascending=ascending, axis=1, numeric_only=False) - return DataFrame(InternalFrame.from_pandas(pdf_rank)) + return self._constructor(InternalFrame.from_pandas(pdf_rank)) column_label_strings = [ name_like_string(label) for label in psdf._internal.column_labels @@ -11797,7 +11809,7 @@ def rank_axis_1(*cols: pd.Series) -> pd.DataFrame: for label in psdf._internal.column_labels ] internal = psdf._internal.with_new_columns(new_data_columns, data_fields=data_fields) - return DataFrame(internal) + return self._constructor(internal) def filter( self, @@ -11948,7 +11960,7 @@ def filter( col = index_scol.contains(like) else: col = col | index_scol.contains(like) - return DataFrame(self._internal.with_filter(col)) + return self._constructor(self._internal.with_filter(col)) else: column_labels = self._internal.column_labels output_labels = [label for label in column_labels if any(like in i for i in label)] @@ -11961,7 +11973,7 @@ def filter( col = index_scol.rlike(regex) else: col = col | index_scol.rlike(regex) - return DataFrame(self._internal.with_filter(col)) + return self._constructor(self._internal.with_filter(col)) else: column_labels = self._internal.column_labels matcher = re.compile(regex) @@ -12382,7 +12394,7 @@ def gen_names( self._update_internal_frame(internal) return None else: - return DataFrame(internal) + return self._constructor(internal) def keys(self) -> pd.Index: """ @@ -12989,7 +13001,7 @@ def quantile(psser: "Series") -> PySparkColumn: column_labels.append(label) if len(percentile_cols) == 0: - return DataFrame(index=qq) + return self._constructor(index=qq) sdf = self._internal.spark_frame.select(percentile_cols) # Here, after select percentile cols, a spark_frame looks like below: @@ -13028,7 +13040,7 @@ def quantile(psser: "Series") -> PySparkColumn: column_labels=column_labels, data_spark_columns=[scol_for(sdf, col) for col in percentile_col_names], ) - return DataFrame(internal) + return self._constructor(internal) else: return self._reduce_for_stat_function( quantile, name="quantile", numeric_only=numeric_only @@ -13136,7 +13148,7 @@ def query(self, expr: str, inplace: bool = False) -> Optional["DataFrame"]: self._update_internal_frame(internal) return None else: - return DataFrame(internal) + return self._constructor(internal) def take(self, indices: List[int], axis: Axis = 0, **kwargs: Any) -> "DataFrame": """ @@ -13549,7 +13561,7 @@ def mode(self, axis: Axis = 0, numeric_only: bool = False, dropna: bool = True) column_labels=mode_labels, data_spark_columns=[scol_for(sdf, col) for col in mode_col_names], ) - return DataFrame(internal) + return self._constructor(internal) def tail(self, n: int = 5) -> "DataFrame": """ @@ -13632,7 +13644,7 @@ def tail(self, n: int = 5) -> "DataFrame": rows = sdf.tail(n) new_sdf = default_session().createDataFrame(rows, sdf.schema) - return DataFrame(self._internal.with_new_sdf(new_sdf)) + return self._constructor(self._internal.with_new_sdf(new_sdf)) def align( self, @@ -14296,7 +14308,7 @@ def apply_op( for label in this._internal.column_labels ] internal = this._internal.with_new_columns(applied) - return DataFrame(internal) + return self._constructor(internal) def __class_getitem__(cls, params: Any) -> object: # See https://github.com/python/typing/issues/193 diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 6496137824de2..913f3f44d2c6a 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -485,6 +485,16 @@ def __init__( # type: ignore[no-untyped-def] self._col_label = anchor._internal.column_labels[0] object.__setattr__(anchor, "_psseries", {self._column_label: self}) + @property + def _constructor(self) -> Any: + return self.__class__ + + @property + def _constructor_expanddim(self) -> Any: + from pyspark.pandas.frame import DataFrame + + return DataFrame + @property def _psdf(self) -> DataFrame: return self._anchor @@ -1585,7 +1595,7 @@ def to_frame(self, name: Optional[Name] = None) -> DataFrame: renamed = self.rename(DEFAULT_SERIES_NAME) else: renamed = self - return DataFrame(renamed._internal) + return self._constructor_expanddim(renamed._internal) to_dataframe = to_frame @@ -2301,7 +2311,7 @@ def _fillna( ) scol = F.when(cond, func(scol, True).over(window)).otherwise(scol) - return DataFrame( + return self._constructor_expanddim( self._psdf._internal.with_new_spark_column( self._column_label, scol.alias(name_like_string(self.name)), # TODO: dtype? @@ -2422,7 +2432,7 @@ def _interpolate( .otherwise(scol) ) - return DataFrame( + return self._constructor_expanddim( self._psdf._internal.with_new_spark_column(self._column_label, scol) # TODO: dtype? )._psser_for(self._column_label) @@ -2767,7 +2777,7 @@ def _drop( self._update_anchor(DataFrame(dropped_internal)) return None else: - return DataFrame(dropped_internal) + return self._constructor_expanddim(dropped_internal) elif columns is not None: return self._psdf else: @@ -2796,7 +2806,8 @@ def head(self, n: int = 5) -> "Series": 1 bee Name: animal, dtype: object """ - return first_series(self.to_frame().head(n)).rename(self.name) + res = first_series(self.to_frame().head(n)).rename(self.name) + return self._constructor(data=res) def last(self, offset: Union[str, DateOffset]) -> "Series": """ @@ -4058,7 +4069,7 @@ def transform( applied.append(self.apply(f, args=args, **kwargs).rename(f.__name__)) internal = self._internal.with_new_columns(applied) - return DataFrame(internal) + return self._constructor_expanddim(internal) else: return self.apply(func, args=args, **kwargs) @@ -6178,7 +6189,7 @@ def unstack(self, level: int = -1) -> DataFrame: for field in internal.data_fields ] ) - return DataFrame(internal) + return self._constructor_expanddim(internal) def item(self) -> Scalar: """ @@ -6811,7 +6822,7 @@ def compare( data_fields=[this_field, that_field], column_label_names=[None], ) - return DataFrame(internal) + return self._constructor_expanddim(internal) # TODO(SPARK-40553): 1, support array-like 'value'; 2, add parameter 'sorter' def searchsorted(self, value: Any, side: str = "left") -> int: diff --git a/python/pyspark/pandas/tests/test_extension.py b/python/pyspark/pandas/tests/test_extension.py index af34540815516..d6f7ad50d351c 100644 --- a/python/pyspark/pandas/tests/test_extension.py +++ b/python/pyspark/pandas/tests/test_extension.py @@ -138,6 +138,41 @@ def __init__(self, data): with self.assertRaises(AttributeError): ps.Series([1, 2], dtype=object).bad + def test_extension_properties(self): + # Define subclasses mimicking the original pandas subclass implementation + class SubclassedSeries(ps.Series): + @property + def _constructor(self): + return SubclassedSeries + + @property + def _constructor_expanddim(self): + return SubclassedDataFrame + + class SubclassedDataFrame(ps.DataFrame): + @property + def _constructor(self): + return SubclassedDataFrame + + @property + def _constructor_sliced(self): + return SubclassedSeries + + # Test DataFrame extension properties + sub_psdf = SubclassedDataFrame(self.psdf._internal) + result_df = sub_psdf.head(2) + + self.assertIsInstance(result_df, SubclassedDataFrame) + self.assertEqual(result_df.shape, (2, 2)) + + # Test Series extension properties + # Pass the PySpark Series directly instead of its _internal frame + sub_psser = SubclassedSeries(self.psdf["a"]) + result_ser = sub_psser.head(2) + + self.assertIsInstance(result_ser, SubclassedSeries) + self.assertEqual(len(result_ser), 2) + class ExtensionTests( ExtensionTestsMixin,