@@ -431,13 +431,7 @@ def set_mask(self, mask: Buffer | None) -> Self:
431431 new_plc_column = self .plc_column .with_mask (new_mask , new_null_count )
432432 return cast (
433433 "Self" ,
434- (
435- type (self )
436- .from_pylibcudf (
437- new_plc_column ,
438- )
439- ._with_type_metadata (self .dtype )
440- ),
434+ ColumnBase .create (new_plc_column , self .dtype ),
441435 )
442436
443437 @property
@@ -617,12 +611,129 @@ def _wrap_buffers(col: plc.Column) -> plc.Column:
617611 validate = False ,
618612 )
619613
614+ @staticmethod
615+ def create (col : plc .Column , dtype : DtypeObj ) -> ColumnBase :
616+ """
617+ Create a Column from a pylibcudf.Column with an explicit cudf dtype.
618+
619+ This is the primary factory for ColumnBase construction. It always requires
620+ an explicit dtype to ensure type safety. If you need to infer the dtype from
621+ the pylibcudf Column, use dtype_from_pylibcudf_column() first:
622+
623+ dtype = dtype_from_pylibcudf_column(plc_col)
624+ col = ColumnBase.create(plc_col, dtype)
625+ """
626+ # Wrap buffers recursively
627+ wrapped = ColumnBase ._wrap_buffers (col )
628+
629+ # Dispatch to the appropriate subclass based on dtype
630+ target_cls = ColumnBase ._dispatch_subclass_from_dtype (dtype )
631+
632+ # Validate dtype compatibility with the column structure using the
633+ # target subclass's _validate_args method (includes recursive validation)
634+ wrapped , dtype = target_cls ._validate_args (wrapped , dtype )
635+
636+ # Construct the instance using the subclass's _from_preprocessed method
637+ # Skip validation since we already validated above
638+ return target_cls ._from_preprocessed (
639+ plc_column = wrapped ,
640+ dtype = dtype ,
641+ validate = False ,
642+ )
643+
644+ @staticmethod
645+ def _dispatch_subclass_from_dtype (dtype : DtypeObj ) -> type [ColumnBase ]:
646+ """
647+ Dispatch to the appropriate ColumnBase subclass based on dtype.
648+
649+ This function determines which ColumnBase subclass should be used
650+ to construct a column with the given dtype.
651+ """
652+ # Special pandas extension types
653+ if isinstance (dtype , pd .DatetimeTZDtype ):
654+ return cudf .core .column .datetime .DatetimeTZColumn
655+ if isinstance (dtype , CategoricalDtype ):
656+ return cudf .core .column .CategoricalColumn
657+
658+ # Temporal types (by kind)
659+ if dtype .kind == "M" :
660+ return cudf .core .column .DatetimeColumn
661+ if dtype .kind == "m" :
662+ return cudf .core .column .TimeDeltaColumn
663+
664+ # String types
665+ if (
666+ dtype == CUDF_STRING_DTYPE
667+ or (hasattr (dtype , "kind" ) and dtype .kind == "U" )
668+ or isinstance (dtype , pd .StringDtype )
669+ or (isinstance (dtype , pd .ArrowDtype ) and dtype .kind == "U" )
670+ ):
671+ return cudf .core .column .StringColumn
672+
673+ # cuDF custom types
674+ if isinstance (dtype , ListDtype ):
675+ return cudf .core .column .ListColumn
676+ if isinstance (dtype , IntervalDtype ):
677+ return cudf .core .column .IntervalColumn
678+ if isinstance (dtype , StructDtype ):
679+ return cudf .core .column .StructColumn
680+
681+ # Decimal types
682+ if isinstance (dtype , cudf .Decimal128Dtype ):
683+ return cudf .core .column .Decimal128Column
684+ if isinstance (dtype , cudf .Decimal64Dtype ):
685+ return cudf .core .column .Decimal64Column
686+ if isinstance (dtype , cudf .Decimal32Dtype ):
687+ return cudf .core .column .Decimal32Column
688+
689+ # Numerical types
690+ if dtype .kind in "iufb" :
691+ return cudf .core .column .NumericalColumn
692+
693+ raise TypeError (f"Unrecognized dtype: { dtype } " )
694+
695+ @staticmethod
696+ def _validate_dtype_recursively (col : plc .Column , dtype : DtypeObj ) -> None :
697+ """
698+ Validate dtype compatibility by dispatching to the appropriate ColumnBase
699+ subclass's _validate_args method.
700+
701+ This method is used for recursive validation in nested types (List, Struct,
702+ Interval). It dispatches to the correct ColumnBase subclass based on dtype
703+ and calls its _validate_args method, which may recursively call this method
704+ for nested children.
705+
706+ Parameters
707+ ----------
708+ col : plc.Column
709+ The pylibcudf Column to validate.
710+ dtype : DtypeObj
711+ The cudf dtype to validate against.
712+
713+ Raises
714+ ------
715+ ValueError
716+ If the dtype is incompatible with the Column.
717+ """
718+ # Skip validation for empty columns (INT8 with all nulls). These are created
719+ # by _wrap_buffers() from EMPTY columns and may have inaccurate dtype metadata.
720+ # For example, an empty list [] has element_type=object but child is INT8.
721+ if (
722+ col .type ().id () == plc .TypeId .INT8
723+ and col .null_count () == col .size ()
724+ ):
725+ return
726+
727+ # Dispatch to the appropriate subclass and use its _validate_args
728+ target_cls = ColumnBase ._dispatch_subclass_from_dtype (dtype )
729+ target_cls ._validate_args (col , dtype )
730+
620731 @staticmethod
621732 def from_pylibcudf (col : plc .Column ) -> ColumnBase :
622733 """Create a Column from a pylibcudf.Column.
623734
624735 This function will generate a Column pointing to the provided pylibcudf
625- Column. It will directly access the data and mask buffers of the
736+ Column. It will directly access the data and mask buffers of the
626737 pylibcudf Column, so the newly created object is not tied to the
627738 lifetime of the original pylibcudf.Column.
628739
@@ -636,51 +747,17 @@ def from_pylibcudf(col: plc.Column) -> ColumnBase:
636747 pylibcudf.Column
637748 A new pylibcudf.Column referencing the same data.
638749 """
750+ # Wrap buffers first so that dtypes are compatible with dtype_from_pylibcudf_column
639751 wrapped = ColumnBase ._wrap_buffers (col )
640-
641752 dtype = dtype_from_pylibcudf_column (wrapped )
642-
643- cls : type [ColumnBase ]
644- if isinstance (dtype , pd .DatetimeTZDtype ):
645- cls = cudf .core .column .datetime .DatetimeTZColumn
646- elif dtype .kind == "M" :
647- cls = cudf .core .column .DatetimeColumn
648- elif dtype .kind == "m" :
649- cls = cudf .core .column .TimeDeltaColumn
650- elif (
651- dtype == CUDF_STRING_DTYPE
652- or dtype .kind == "U"
653- or isinstance (dtype , pd .StringDtype )
654- or (isinstance (dtype , pd .ArrowDtype ) and dtype .kind == "U" )
655- ):
656- cls = cudf .core .column .StringColumn
657- elif isinstance (dtype , ListDtype ):
658- cls = cudf .core .column .ListColumn
659- elif isinstance (dtype , IntervalDtype ):
660- cls = cudf .core .column .IntervalColumn
661- elif isinstance (dtype , StructDtype ):
662- cls = cudf .core .column .StructColumn
663- elif isinstance (dtype , cudf .Decimal64Dtype ):
664- cls = cudf .core .column .Decimal64Column
665- elif isinstance (dtype , cudf .Decimal32Dtype ):
666- cls = cudf .core .column .Decimal32Column
667- elif isinstance (dtype , cudf .Decimal128Dtype ):
668- cls = cudf .core .column .Decimal128Column
669- elif dtype .kind in "iufb" :
670- cls = cudf .core .column .NumericalColumn
671- else :
672- raise TypeError (f"Unrecognized dtype: { dtype } " )
673-
674- return cls ._from_preprocessed (
675- plc_column = wrapped ,
676- dtype = dtype ,
677- )
753+ return ColumnBase .create (wrapped , dtype )
678754
679755 @classmethod
680756 def _from_preprocessed (
681757 cls ,
682758 plc_column : plc .Column ,
683759 dtype : DtypeObj ,
760+ validate : bool = True ,
684761 ) -> Self :
685762 # TODO: This function bypassess some of the buffer copying/wrapping that would
686763 # be done in from_pylibcudf, so it is only ever safe to call this in situations
@@ -689,7 +766,8 @@ def _from_preprocessed(
689766 # in from_pylibcudf, but for now it is necessary for the various
690767 # _with_type_metadata calls.
691768 self = cls .__new__ (cls )
692- plc_column , dtype = self ._validate_args (plc_column , dtype )
769+ if validate :
770+ plc_column , dtype = self ._validate_args (plc_column , dtype )
693771 self .plc_column = plc_column
694772 self ._dtype = dtype
695773 self ._distinct_count = {}
@@ -921,9 +999,9 @@ def dropna(self) -> Self:
921999 if self .has_nulls ():
9221000 return cast (
9231001 "Self" ,
924- ColumnBase .from_pylibcudf (
925- stream_compaction .drop_nulls ([self ])[0 ]
926- ). _with_type_metadata ( self . dtype ) ,
1002+ ColumnBase .create (
1003+ stream_compaction .drop_nulls ([self ])[0 ], self . dtype
1004+ ),
9271005 )
9281006 else :
9291007 return self .copy ()
@@ -1120,6 +1198,11 @@ def copy(self, deep: bool = True) -> Self:
11201198 plc_col = self .plc_column
11211199 if deep :
11221200 plc_col = plc_col .copy ()
1201+ # For nested types (e.g., list<list<int>>), self.dtype may not accurately
1202+ # reflect the actual plc_column structure. Some operations (like groupby
1203+ # collect on a list column) create nested structures but don't update the
1204+ # stored dtype to reflect the new nesting level. Using _with_type_metadata()
1205+ # is more permissive and handles these cases.
11231206 return cast (
11241207 "Self" ,
11251208 (
@@ -1360,7 +1443,8 @@ def _scatter_by_column(
13601443 else :
13611444 return cast (
13621445 "Self" ,
1363- ColumnBase .from_pylibcudf (
1446+ type (self )
1447+ .from_pylibcudf (
13641448 copying .scatter (
13651449 cast ("list[plc.Scalar]" , [value ])
13661450 if isinstance (value , plc .Scalar )
@@ -1369,7 +1453,8 @@ def _scatter_by_column(
13691453 [self ],
13701454 bounds_check = bounds_check ,
13711455 )[0 ]
1372- )._with_type_metadata (self .dtype ),
1456+ )
1457+ ._with_type_metadata (self .dtype ),
13731458 )
13741459
13751460 def _check_scatter_key_length (
@@ -1468,10 +1553,9 @@ def fillna(
14681553 input_col .plc_column ,
14691554 plc_replace ,
14701555 )
1471- result = type (self ).from_pylibcudf (plc_column )
14721556 return cast (
14731557 "Self" ,
1474- result . _with_type_metadata ( self .dtype ),
1558+ ColumnBase . create ( plc_column , self .dtype ),
14751559 )
14761560
14771561 def is_valid (self ) -> ColumnBase :
@@ -1778,11 +1862,7 @@ def sort_values(
17781862 )
17791863 return cast (
17801864 "Self" ,
1781- (
1782- type (self )
1783- .from_pylibcudf (plc_table .columns ()[0 ])
1784- ._with_type_metadata (self .dtype )
1785- ),
1865+ ColumnBase .create (plc_table .columns ()[0 ], self .dtype ),
17861866 )
17871867
17881868 def distinct_count (self , dropna : bool = True ) -> int :
@@ -1902,9 +1982,9 @@ def apply_boolean_mask(self, mask: ColumnBase) -> ColumnBase:
19021982 if mask .dtype .kind != "b" :
19031983 raise ValueError ("boolean_mask is not boolean type." )
19041984
1905- return ColumnBase .from_pylibcudf (
1906- stream_compaction .apply_boolean_mask ([self ], mask )[0 ]
1907- ). _with_type_metadata ( self . dtype )
1985+ return ColumnBase .create (
1986+ stream_compaction .apply_boolean_mask ([self ], mask )[0 ], self . dtype
1987+ )
19081988
19091989 def argsort (
19101990 self ,
@@ -2029,9 +2109,10 @@ def unique(self) -> Self:
20292109 else :
20302110 return cast (
20312111 "Self" ,
2032- ColumnBase .from_pylibcudf (
2033- stream_compaction .drop_duplicates ([self ], keep = "first" )[0 ]
2034- )._with_type_metadata (self .dtype ),
2112+ ColumnBase .create (
2113+ stream_compaction .drop_duplicates ([self ], keep = "first" )[0 ],
2114+ self .dtype ,
2115+ ),
20352116 )
20362117
20372118 @staticmethod
@@ -2174,7 +2255,7 @@ def deserialize(cls, header: dict, frames: list) -> ColumnBase:
21742255 assert len (frames ) == 0 , (
21752256 f"{ len (frames )} frame(s) remaining after deserialization"
21762257 )
2177- return cls . from_pylibcudf (plc_column ). _with_type_metadata ( dtype )
2258+ return ColumnBase . create (plc_column , dtype )
21782259
21792260 def unary_operator (self , unaryop : str ) -> ColumnBase :
21802261 raise TypeError (
@@ -2369,9 +2450,7 @@ def split_by_offsets(
23692450 for col in cols :
23702451 yield cast (
23712452 "Self" ,
2372- type (self )
2373- .from_pylibcudf (col )
2374- ._with_type_metadata (self .dtype ),
2453+ ColumnBase .create (col , self .dtype ),
23752454 )
23762455
23772456 def one_hot_encode (self , categories : ColumnBase ) -> Generator [ColumnBase ]:
@@ -2386,20 +2465,23 @@ def one_hot_encode(self, categories: ColumnBase) -> Generator[ColumnBase]:
23862465 type (self ).from_pylibcudf (col ) for col in plc_table .columns ()
23872466 )
23882467
2468+ # TODO: Currently this method is only used once, in ExponentialMovingWindow. That
2469+ # suggests a potential refactoring opportunity to make EWM play better with the rest
2470+ # of our aggregation/reduction framework.
23892471 def scan (self , scan_op : str , inclusive : bool , ** kwargs : Any ) -> Self :
23902472 with self .access (mode = "read" , scope = "internal" ):
2391- return cast (
2392- "Self" ,
2393- type (self ).from_pylibcudf (
2394- plc .reduce .scan (
2395- self .plc_column ,
2396- aggregation .make_aggregation (scan_op , kwargs ).plc_obj ,
2397- plc .reduce .ScanType .INCLUSIVE
2398- if inclusive
2399- else plc .reduce .ScanType .EXCLUSIVE ,
2400- )
2401- ),
2473+ plc_result = plc .reduce .scan (
2474+ self .plc_column ,
2475+ aggregation .make_aggregation (scan_op , kwargs ).plc_obj ,
2476+ plc .reduce .ScanType .INCLUSIVE
2477+ if inclusive
2478+ else plc .reduce .ScanType .EXCLUSIVE ,
24022479 )
2480+ return cast ("Self" , ColumnBase .create (plc_result , self .dtype ))
2481+
2482+ def _scan (self , op : str ) -> ColumnBase :
2483+ """Default cumulative scan implementation for DataFrame.cum* methods."""
2484+ return self .scan (op .replace ("cum" , "" ), inclusive = True )
24032485
24042486 def reduce (self , reduction_op : str , ** kwargs : Any ) -> ScalarLike :
24052487 col_dtype = self ._reduction_result_dtype (reduction_op )
@@ -3428,8 +3510,9 @@ def concat_columns(objs: Sequence[ColumnBase]) -> ColumnBase:
34283510 with access_columns ( # type: ignore[assignment]
34293511 * objs_with_len , mode = "read" , scope = "internal"
34303512 ) as objs_with_len :
3431- return ColumnBase .from_pylibcudf (
3513+ return ColumnBase .create (
34323514 plc .concatenate .concatenate (
34333515 [col .plc_column for col in objs_with_len ]
3434- )
3435- )._with_type_metadata (objs_with_len [0 ].dtype )
3516+ ),
3517+ objs_with_len [0 ].dtype ,
3518+ )
0 commit comments