5353 CUDF_STRING_DTYPE ,
5454 SIZE_TYPE_DTYPE ,
5555 cudf_dtype_to_pa_type ,
56+ dtype_from_pylibcudf_column ,
5657 get_dtype_of_same_kind ,
5758)
5859from cudf .utils .performance_tracking import _performance_tracking
@@ -850,24 +851,21 @@ def _groups(
850851 def _aggregate (
851852 self , values : tuple [ColumnBase , ...], aggregations
852853 ) -> tuple [
853- list [list [ColumnBase ]],
854+ list [list [plc . Column ]],
854855 list [ColumnBase ],
855856 list [list [tuple [str , str ]]],
856857 ]:
857858 included_aggregations = []
858859 column_included = []
859860 requests = []
860- # For any post-processing needed after pylibcudf aggregations
861- adjustments = []
862- result_columns : list [list [ColumnBase ]] = []
861+ result_columns : list [list [plc .Column ]] = []
863862
864863 for i , (col , aggs ) in enumerate (
865864 zip (values , aggregations , strict = True )
866865 ):
867866 valid_aggregations = get_valid_aggregation (col .dtype )
868867 included_aggregations_i = []
869868 col_aggregations = []
870- adjustments_i = []
871869 for agg in aggs :
872870 str_agg = str (agg )
873871 if _is_unsupported_agg_for_type (col .dtype , str_agg ):
@@ -881,12 +879,6 @@ def _aggregate(
881879 ):
882880 included_aggregations_i .append ((agg , agg_obj .kind ))
883881 col_aggregations .append (agg_obj .plc_obj )
884- if str_agg == "cumcount" :
885- # pandas 0-indexes cumulative count, see
886- # https://github.com/rapidsai/cudf/issues/10237
887- adjustments_i .append (lambda col : (col - 1 ))
888- else :
889- adjustments_i .append (lambda col : col )
890882 included_aggregations .append (included_aggregations_i )
891883 result_columns .append ([])
892884 if col_aggregations :
@@ -896,7 +888,6 @@ def _aggregate(
896888 )
897889 )
898890 column_included .append (i )
899- adjustments .append (adjustments_i )
900891
901892 if not requests and any (len (v ) > 0 for v in aggregations ):
902893 raise pd .errors .DataError (
@@ -911,19 +902,15 @@ def _aggregate(
911902 else plc_groupby .aggregate (requests )
912903 )
913904
914- for i , result , adjustments_i in zip (
915- column_included , results , adjustments , strict = True
916- ):
917- result_columns [i ] = [
918- adj (ColumnBase .from_pylibcudf (col ))
919- for col , adj in zip (
920- result .columns (), adjustments_i , strict = True
921- )
922- ]
905+ for i , result in zip (column_included , results , strict = True ):
906+ result_columns [i ] = result .columns ()
923907
924908 return (
925909 result_columns ,
926- [ColumnBase .from_pylibcudf (key ) for key in keys .columns ()],
910+ [
911+ ColumnBase .create (key , dtype_from_pylibcudf_column (key ))
912+ for key in keys .columns ()
913+ ],
927914 included_aggregations ,
928915 )
929916
@@ -1096,52 +1083,52 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
10961083 orig_dtypes ,
10971084 strict = True ,
10981085 ):
1099- for agg_tuple , col in zip (aggs , cols , strict = True ):
1086+ for agg_tuple , plc_result in zip (aggs , cols , strict = True ):
11001087 agg , agg_kind = agg_tuple
11011088 agg_name = agg .__name__ if callable (agg ) else agg
11021089 if multilevel :
11031090 key = (col_name , agg_name )
11041091 else :
11051092 key = col_name
1093+
1094+ create_dtype = dtype_from_pylibcudf_column (plc_result )
1095+ cast_dtype = None
11061096 if agg in {list , "collect" }:
11071097 # Collect wraps the original dtype in ListDtype (e.g., int -> list<int>)
1108- new_dtype = get_dtype_of_same_kind (
1098+ create_dtype = get_dtype_of_same_kind (
11091099 orig_dtype , ListDtype (orig_dtype )
11101100 )
1111- col = ColumnBase .create (col .plc_column , new_dtype )
1112-
1113- # Default: use column as-is
1114- data [key ] = col
1115-
11161101 # Override for specific aggregation types that need dtype adjustments
11171102 if agg_kind in {"COUNT" , "SIZE" , "ARGMIN" , "ARGMAX" }:
1118- data [ key ] = col . astype (
1119- get_dtype_of_same_kind ( orig_dtype , np .dtype (np .int64 ) )
1103+ cast_dtype = get_dtype_of_same_kind (
1104+ orig_dtype , np .dtype (np .int64 )
11201105 )
11211106 elif (
11221107 self .obj .empty
11231108 and (
11241109 isinstance (agg_name , str )
11251110 and agg_name in Reducible ._SUPPORTED_REDUCTIONS
11261111 )
1127- and len ( col ) == 0
1112+ and plc_result . size ( ) == 0
11281113 and not isinstance (
1129- col . dtype ,
1114+ create_dtype ,
11301115 (ListDtype , StructDtype , DecimalDtype ),
11311116 )
11321117 ):
1133- data [ key ] = col . astype ( orig_dtype )
1118+ cast_dtype = orig_dtype
11341119 elif agg not in {list , "collect" }:
1135- # For non-collect aggregations, apply original dtype metadata
1136- if isinstance (orig_dtype , DecimalDtype ):
1137- # `col` has a different precision than `orig_dtype`
1138- # hence we only preserve the kind of the dtype
1139- # and not the precision.
1140- data [key ] = col ._with_type_metadata (
1141- get_dtype_of_same_kind (orig_dtype , col .dtype )
1142- )
1143- else :
1144- data [key ] = col ._with_type_metadata (orig_dtype )
1120+ create_dtype = get_dtype_of_same_kind (
1121+ orig_dtype , create_dtype
1122+ )
1123+
1124+ result_col = ColumnBase .create (plc_result , create_dtype )
1125+ if agg == "cumcount" :
1126+ # pandas 0-indexes cumulative count, see
1127+ # https://github.com/rapidsai/cudf/issues/10237
1128+ result_col = result_col - 1
1129+ if cast_dtype is not None :
1130+ result_col = result_col .astype (cast_dtype )
1131+ data [key ] = result_col
11451132 data = ColumnAccessor (data , multiindex = multilevel )
11461133 if not multilevel :
11471134 data = data .rename_levels ({np .nan : None }, level = 0 )
0 commit comments