Skip to content

Commit d70996b

Browse files
authored
Add new ColumnBase.create API (rapidsai#21187)
The new API avoids the design inconsistencies inherent to the current `from_pylibcudf`->`_with_type_metadata` API. The old approach is highly error-prone, and missing `_with_type_metadata` calls have been responsible for a large number of bugs over the lifetime of the project. With the various recently completed refactorings, a ColumnBase subclass is strictly defined by an underlying pylibcudf Column and a dtype, which broke the (incorrect, but convenient) assumption of 1-1 correspondence between pylibcudf data types and cudf ColumnBase types. Since the same type of pylibcudf data could be underlying multiple cudf ColumnBase subclasses, and since the dtype is the primary way to indicate this, forcing both to be provided on construction is more correct and more robust. I'll be opening subsequent PRs to roll this change out to more of the package, this PR just includes a few examples to clarify the new API's usage. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: rapidsai#21187
1 parent 34a8752 commit d70996b

12 files changed

Lines changed: 249 additions & 150 deletions

File tree

python/cudf/cudf/core/column/categorical.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,9 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
200200
)
201201
# We'll compare self's decategorized values later for non-CategoricalColumn
202202
else:
203-
codes = column.as_column(
203+
other = column.as_column(
204204
self._encode(other), length=len(self), dtype=self.codes.dtype
205-
)
206-
other = codes._with_type_metadata(self.dtype)
205+
)._with_type_metadata(self.dtype)
207206
equality_ops = {"__eq__", "__ne__", "NULL_EQUALS", "NULL_NOT_EQUALS"}
208207
if not self.ordered and op not in equality_ops:
209208
raise TypeError(

python/cudf/cudf/core/column/column.py

Lines changed: 165 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -431,13 +431,7 @@ def set_mask(self, mask: Buffer | None) -> Self:
431431
new_plc_column = self.plc_column.with_mask(new_mask, new_null_count)
432432
return cast(
433433
"Self",
434-
(
435-
type(self)
436-
.from_pylibcudf(
437-
new_plc_column,
438-
)
439-
._with_type_metadata(self.dtype)
440-
),
434+
ColumnBase.create(new_plc_column, self.dtype),
441435
)
442436

443437
@property
@@ -617,12 +611,129 @@ def _wrap_buffers(col: plc.Column) -> plc.Column:
617611
validate=False,
618612
)
619613

614+
@staticmethod
615+
def create(col: plc.Column, dtype: DtypeObj) -> ColumnBase:
616+
"""
617+
Create a Column from a pylibcudf.Column with an explicit cudf dtype.
618+
619+
This is the primary factory for ColumnBase construction. It always requires
620+
an explicit dtype to ensure type safety. If you need to infer the dtype from
621+
the pylibcudf Column, use dtype_from_pylibcudf_column() first:
622+
623+
dtype = dtype_from_pylibcudf_column(plc_col)
624+
col = ColumnBase.create(plc_col, dtype)
625+
"""
626+
# Wrap buffers recursively
627+
wrapped = ColumnBase._wrap_buffers(col)
628+
629+
# Dispatch to the appropriate subclass based on dtype
630+
target_cls = ColumnBase._dispatch_subclass_from_dtype(dtype)
631+
632+
# Validate dtype compatibility with the column structure using the
633+
# target subclass's _validate_args method (includes recursive validation)
634+
wrapped, dtype = target_cls._validate_args(wrapped, dtype)
635+
636+
# Construct the instance using the subclass's _from_preprocessed method
637+
# Skip validation since we already validated above
638+
return target_cls._from_preprocessed(
639+
plc_column=wrapped,
640+
dtype=dtype,
641+
validate=False,
642+
)
643+
644+
@staticmethod
645+
def _dispatch_subclass_from_dtype(dtype: DtypeObj) -> type[ColumnBase]:
646+
"""
647+
Dispatch to the appropriate ColumnBase subclass based on dtype.
648+
649+
This function determines which ColumnBase subclass should be used
650+
to construct a column with the given dtype.
651+
"""
652+
# Special pandas extension types
653+
if isinstance(dtype, pd.DatetimeTZDtype):
654+
return cudf.core.column.datetime.DatetimeTZColumn
655+
if isinstance(dtype, CategoricalDtype):
656+
return cudf.core.column.CategoricalColumn
657+
658+
# Temporal types (by kind)
659+
if dtype.kind == "M":
660+
return cudf.core.column.DatetimeColumn
661+
if dtype.kind == "m":
662+
return cudf.core.column.TimeDeltaColumn
663+
664+
# String types
665+
if (
666+
dtype == CUDF_STRING_DTYPE
667+
or (hasattr(dtype, "kind") and dtype.kind == "U")
668+
or isinstance(dtype, pd.StringDtype)
669+
or (isinstance(dtype, pd.ArrowDtype) and dtype.kind == "U")
670+
):
671+
return cudf.core.column.StringColumn
672+
673+
# cuDF custom types
674+
if isinstance(dtype, ListDtype):
675+
return cudf.core.column.ListColumn
676+
if isinstance(dtype, IntervalDtype):
677+
return cudf.core.column.IntervalColumn
678+
if isinstance(dtype, StructDtype):
679+
return cudf.core.column.StructColumn
680+
681+
# Decimal types
682+
if isinstance(dtype, cudf.Decimal128Dtype):
683+
return cudf.core.column.Decimal128Column
684+
if isinstance(dtype, cudf.Decimal64Dtype):
685+
return cudf.core.column.Decimal64Column
686+
if isinstance(dtype, cudf.Decimal32Dtype):
687+
return cudf.core.column.Decimal32Column
688+
689+
# Numerical types
690+
if dtype.kind in "iufb":
691+
return cudf.core.column.NumericalColumn
692+
693+
raise TypeError(f"Unrecognized dtype: {dtype}")
694+
695+
@staticmethod
696+
def _validate_dtype_recursively(col: plc.Column, dtype: DtypeObj) -> None:
697+
"""
698+
Validate dtype compatibility by dispatching to the appropriate ColumnBase
699+
subclass's _validate_args method.
700+
701+
This method is used for recursive validation in nested types (List, Struct,
702+
Interval). It dispatches to the correct ColumnBase subclass based on dtype
703+
and calls its _validate_args method, which may recursively call this method
704+
for nested children.
705+
706+
Parameters
707+
----------
708+
col : plc.Column
709+
The pylibcudf Column to validate.
710+
dtype : DtypeObj
711+
The cudf dtype to validate against.
712+
713+
Raises
714+
------
715+
ValueError
716+
If the dtype is incompatible with the Column.
717+
"""
718+
# Skip validation for empty columns (INT8 with all nulls). These are created
719+
# by _wrap_buffers() from EMPTY columns and may have inaccurate dtype metadata.
720+
# For example, an empty list [] has element_type=object but child is INT8.
721+
if (
722+
col.type().id() == plc.TypeId.INT8
723+
and col.null_count() == col.size()
724+
):
725+
return
726+
727+
# Dispatch to the appropriate subclass and use its _validate_args
728+
target_cls = ColumnBase._dispatch_subclass_from_dtype(dtype)
729+
target_cls._validate_args(col, dtype)
730+
620731
@staticmethod
621732
def from_pylibcudf(col: plc.Column) -> ColumnBase:
622733
"""Create a Column from a pylibcudf.Column.
623734
624735
This function will generate a Column pointing to the provided pylibcudf
625-
Column. It will directly access the data and mask buffers of the
736+
Column. It will directly access the data and mask buffers of the
626737
pylibcudf Column, so the newly created object is not tied to the
627738
lifetime of the original pylibcudf.Column.
628739
@@ -636,51 +747,17 @@ def from_pylibcudf(col: plc.Column) -> ColumnBase:
636747
pylibcudf.Column
637748
A new pylibcudf.Column referencing the same data.
638749
"""
750+
# Wrap buffers first so that dtypes are compatible with dtype_from_pylibcudf_column
639751
wrapped = ColumnBase._wrap_buffers(col)
640-
641752
dtype = dtype_from_pylibcudf_column(wrapped)
642-
643-
cls: type[ColumnBase]
644-
if isinstance(dtype, pd.DatetimeTZDtype):
645-
cls = cudf.core.column.datetime.DatetimeTZColumn
646-
elif dtype.kind == "M":
647-
cls = cudf.core.column.DatetimeColumn
648-
elif dtype.kind == "m":
649-
cls = cudf.core.column.TimeDeltaColumn
650-
elif (
651-
dtype == CUDF_STRING_DTYPE
652-
or dtype.kind == "U"
653-
or isinstance(dtype, pd.StringDtype)
654-
or (isinstance(dtype, pd.ArrowDtype) and dtype.kind == "U")
655-
):
656-
cls = cudf.core.column.StringColumn
657-
elif isinstance(dtype, ListDtype):
658-
cls = cudf.core.column.ListColumn
659-
elif isinstance(dtype, IntervalDtype):
660-
cls = cudf.core.column.IntervalColumn
661-
elif isinstance(dtype, StructDtype):
662-
cls = cudf.core.column.StructColumn
663-
elif isinstance(dtype, cudf.Decimal64Dtype):
664-
cls = cudf.core.column.Decimal64Column
665-
elif isinstance(dtype, cudf.Decimal32Dtype):
666-
cls = cudf.core.column.Decimal32Column
667-
elif isinstance(dtype, cudf.Decimal128Dtype):
668-
cls = cudf.core.column.Decimal128Column
669-
elif dtype.kind in "iufb":
670-
cls = cudf.core.column.NumericalColumn
671-
else:
672-
raise TypeError(f"Unrecognized dtype: {dtype}")
673-
674-
return cls._from_preprocessed(
675-
plc_column=wrapped,
676-
dtype=dtype,
677-
)
753+
return ColumnBase.create(wrapped, dtype)
678754

679755
@classmethod
680756
def _from_preprocessed(
681757
cls,
682758
plc_column: plc.Column,
683759
dtype: DtypeObj,
760+
validate: bool = True,
684761
) -> Self:
685762
# TODO: This function bypassess some of the buffer copying/wrapping that would
686763
# be done in from_pylibcudf, so it is only ever safe to call this in situations
@@ -689,7 +766,8 @@ def _from_preprocessed(
689766
# in from_pylibcudf, but for now it is necessary for the various
690767
# _with_type_metadata calls.
691768
self = cls.__new__(cls)
692-
plc_column, dtype = self._validate_args(plc_column, dtype)
769+
if validate:
770+
plc_column, dtype = self._validate_args(plc_column, dtype)
693771
self.plc_column = plc_column
694772
self._dtype = dtype
695773
self._distinct_count = {}
@@ -921,9 +999,9 @@ def dropna(self) -> Self:
921999
if self.has_nulls():
9221000
return cast(
9231001
"Self",
924-
ColumnBase.from_pylibcudf(
925-
stream_compaction.drop_nulls([self])[0]
926-
)._with_type_metadata(self.dtype),
1002+
ColumnBase.create(
1003+
stream_compaction.drop_nulls([self])[0], self.dtype
1004+
),
9271005
)
9281006
else:
9291007
return self.copy()
@@ -1120,6 +1198,11 @@ def copy(self, deep: bool = True) -> Self:
11201198
plc_col = self.plc_column
11211199
if deep:
11221200
plc_col = plc_col.copy()
1201+
# For nested types (e.g., list<list<int>>), self.dtype may not accurately
1202+
# reflect the actual plc_column structure. Some operations (like groupby
1203+
# collect on a list column) create nested structures but don't update the
1204+
# stored dtype to reflect the new nesting level. Using _with_type_metadata()
1205+
# is more permissive and handles these cases.
11231206
return cast(
11241207
"Self",
11251208
(
@@ -1360,7 +1443,8 @@ def _scatter_by_column(
13601443
else:
13611444
return cast(
13621445
"Self",
1363-
ColumnBase.from_pylibcudf(
1446+
type(self)
1447+
.from_pylibcudf(
13641448
copying.scatter(
13651449
cast("list[plc.Scalar]", [value])
13661450
if isinstance(value, plc.Scalar)
@@ -1369,7 +1453,8 @@ def _scatter_by_column(
13691453
[self],
13701454
bounds_check=bounds_check,
13711455
)[0]
1372-
)._with_type_metadata(self.dtype),
1456+
)
1457+
._with_type_metadata(self.dtype),
13731458
)
13741459

13751460
def _check_scatter_key_length(
@@ -1468,10 +1553,9 @@ def fillna(
14681553
input_col.plc_column,
14691554
plc_replace,
14701555
)
1471-
result = type(self).from_pylibcudf(plc_column)
14721556
return cast(
14731557
"Self",
1474-
result._with_type_metadata(self.dtype),
1558+
ColumnBase.create(plc_column, self.dtype),
14751559
)
14761560

14771561
def is_valid(self) -> ColumnBase:
@@ -1778,11 +1862,7 @@ def sort_values(
17781862
)
17791863
return cast(
17801864
"Self",
1781-
(
1782-
type(self)
1783-
.from_pylibcudf(plc_table.columns()[0])
1784-
._with_type_metadata(self.dtype)
1785-
),
1865+
ColumnBase.create(plc_table.columns()[0], self.dtype),
17861866
)
17871867

17881868
def distinct_count(self, dropna: bool = True) -> int:
@@ -1902,9 +1982,9 @@ def apply_boolean_mask(self, mask: ColumnBase) -> ColumnBase:
19021982
if mask.dtype.kind != "b":
19031983
raise ValueError("boolean_mask is not boolean type.")
19041984

1905-
return ColumnBase.from_pylibcudf(
1906-
stream_compaction.apply_boolean_mask([self], mask)[0]
1907-
)._with_type_metadata(self.dtype)
1985+
return ColumnBase.create(
1986+
stream_compaction.apply_boolean_mask([self], mask)[0], self.dtype
1987+
)
19081988

19091989
def argsort(
19101990
self,
@@ -2029,9 +2109,10 @@ def unique(self) -> Self:
20292109
else:
20302110
return cast(
20312111
"Self",
2032-
ColumnBase.from_pylibcudf(
2033-
stream_compaction.drop_duplicates([self], keep="first")[0]
2034-
)._with_type_metadata(self.dtype),
2112+
ColumnBase.create(
2113+
stream_compaction.drop_duplicates([self], keep="first")[0],
2114+
self.dtype,
2115+
),
20352116
)
20362117

20372118
@staticmethod
@@ -2174,7 +2255,7 @@ def deserialize(cls, header: dict, frames: list) -> ColumnBase:
21742255
assert len(frames) == 0, (
21752256
f"{len(frames)} frame(s) remaining after deserialization"
21762257
)
2177-
return cls.from_pylibcudf(plc_column)._with_type_metadata(dtype)
2258+
return ColumnBase.create(plc_column, dtype)
21782259

21792260
def unary_operator(self, unaryop: str) -> ColumnBase:
21802261
raise TypeError(
@@ -2369,9 +2450,7 @@ def split_by_offsets(
23692450
for col in cols:
23702451
yield cast(
23712452
"Self",
2372-
type(self)
2373-
.from_pylibcudf(col)
2374-
._with_type_metadata(self.dtype),
2453+
ColumnBase.create(col, self.dtype),
23752454
)
23762455

23772456
def one_hot_encode(self, categories: ColumnBase) -> Generator[ColumnBase]:
@@ -2386,20 +2465,23 @@ def one_hot_encode(self, categories: ColumnBase) -> Generator[ColumnBase]:
23862465
type(self).from_pylibcudf(col) for col in plc_table.columns()
23872466
)
23882467

2468+
# TODO: Currently this method is only used once, in ExponentialMovingWindow. That
2469+
# suggests a potential refactoring opportunity to make EWM play better with the rest
2470+
# of our aggregation/reduction framework.
23892471
def scan(self, scan_op: str, inclusive: bool, **kwargs: Any) -> Self:
23902472
with self.access(mode="read", scope="internal"):
2391-
return cast(
2392-
"Self",
2393-
type(self).from_pylibcudf(
2394-
plc.reduce.scan(
2395-
self.plc_column,
2396-
aggregation.make_aggregation(scan_op, kwargs).plc_obj,
2397-
plc.reduce.ScanType.INCLUSIVE
2398-
if inclusive
2399-
else plc.reduce.ScanType.EXCLUSIVE,
2400-
)
2401-
),
2473+
plc_result = plc.reduce.scan(
2474+
self.plc_column,
2475+
aggregation.make_aggregation(scan_op, kwargs).plc_obj,
2476+
plc.reduce.ScanType.INCLUSIVE
2477+
if inclusive
2478+
else plc.reduce.ScanType.EXCLUSIVE,
24022479
)
2480+
return cast("Self", ColumnBase.create(plc_result, self.dtype))
2481+
2482+
def _scan(self, op: str) -> ColumnBase:
2483+
"""Default cumulative scan implementation for DataFrame.cum* methods."""
2484+
return self.scan(op.replace("cum", ""), inclusive=True)
24032485

24042486
def reduce(self, reduction_op: str, **kwargs: Any) -> ScalarLike:
24052487
col_dtype = self._reduction_result_dtype(reduction_op)
@@ -3428,8 +3510,9 @@ def concat_columns(objs: Sequence[ColumnBase]) -> ColumnBase:
34283510
with access_columns( # type: ignore[assignment]
34293511
*objs_with_len, mode="read", scope="internal"
34303512
) as objs_with_len:
3431-
return ColumnBase.from_pylibcudf(
3513+
return ColumnBase.create(
34323514
plc.concatenate.concatenate(
34333515
[col.plc_column for col in objs_with_len]
3434-
)
3435-
)._with_type_metadata(objs_with_len[0].dtype)
3516+
),
3517+
objs_with_len[0].dtype,
3518+
)

0 commit comments

Comments
 (0)