Skip to content

Commit 8c288de

Browse files
authored
Remove ._with_dtype_metadata in cudf classic groupby (rapidsai#21389)
Towards rapidsai#21229 The groupby results from pylibcudf needed to be preserved as `pylibcudf.Column`s before passed to another function that resolves the correct resulting type. The added test to `conftest-patch.py` _appears_ to a separate cudf bug arising. `test_groupby_agg_extension` is using an indexing API to compute the expected result, but it doesn't seem to preserving the original string variant correctly. Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: rapidsai#21389
1 parent 0fde383 commit 8c288de

2 files changed

Lines changed: 37 additions & 44 deletions

File tree

python/cudf/cudf/core/groupby/groupby.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
CUDF_STRING_DTYPE,
5454
SIZE_TYPE_DTYPE,
5555
cudf_dtype_to_pa_type,
56+
dtype_from_pylibcudf_column,
5657
get_dtype_of_same_kind,
5758
)
5859
from cudf.utils.performance_tracking import _performance_tracking
@@ -850,24 +851,21 @@ def _groups(
850851
def _aggregate(
851852
self, values: tuple[ColumnBase, ...], aggregations
852853
) -> tuple[
853-
list[list[ColumnBase]],
854+
list[list[plc.Column]],
854855
list[ColumnBase],
855856
list[list[tuple[str, str]]],
856857
]:
857858
included_aggregations = []
858859
column_included = []
859860
requests = []
860-
# For any post-processing needed after pylibcudf aggregations
861-
adjustments = []
862-
result_columns: list[list[ColumnBase]] = []
861+
result_columns: list[list[plc.Column]] = []
863862

864863
for i, (col, aggs) in enumerate(
865864
zip(values, aggregations, strict=True)
866865
):
867866
valid_aggregations = get_valid_aggregation(col.dtype)
868867
included_aggregations_i = []
869868
col_aggregations = []
870-
adjustments_i = []
871869
for agg in aggs:
872870
str_agg = str(agg)
873871
if _is_unsupported_agg_for_type(col.dtype, str_agg):
@@ -881,12 +879,6 @@ def _aggregate(
881879
):
882880
included_aggregations_i.append((agg, agg_obj.kind))
883881
col_aggregations.append(agg_obj.plc_obj)
884-
if str_agg == "cumcount":
885-
# pandas 0-indexes cumulative count, see
886-
# https://github.com/rapidsai/cudf/issues/10237
887-
adjustments_i.append(lambda col: (col - 1))
888-
else:
889-
adjustments_i.append(lambda col: col)
890882
included_aggregations.append(included_aggregations_i)
891883
result_columns.append([])
892884
if col_aggregations:
@@ -896,7 +888,6 @@ def _aggregate(
896888
)
897889
)
898890
column_included.append(i)
899-
adjustments.append(adjustments_i)
900891

901892
if not requests and any(len(v) > 0 for v in aggregations):
902893
raise pd.errors.DataError(
@@ -911,19 +902,15 @@ def _aggregate(
911902
else plc_groupby.aggregate(requests)
912903
)
913904

914-
for i, result, adjustments_i in zip(
915-
column_included, results, adjustments, strict=True
916-
):
917-
result_columns[i] = [
918-
adj(ColumnBase.from_pylibcudf(col))
919-
for col, adj in zip(
920-
result.columns(), adjustments_i, strict=True
921-
)
922-
]
905+
for i, result in zip(column_included, results, strict=True):
906+
result_columns[i] = result.columns()
923907

924908
return (
925909
result_columns,
926-
[ColumnBase.from_pylibcudf(key) for key in keys.columns()],
910+
[
911+
ColumnBase.create(key, dtype_from_pylibcudf_column(key))
912+
for key in keys.columns()
913+
],
927914
included_aggregations,
928915
)
929916

@@ -1096,52 +1083,52 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
10961083
orig_dtypes,
10971084
strict=True,
10981085
):
1099-
for agg_tuple, col in zip(aggs, cols, strict=True):
1086+
for agg_tuple, plc_result in zip(aggs, cols, strict=True):
11001087
agg, agg_kind = agg_tuple
11011088
agg_name = agg.__name__ if callable(agg) else agg
11021089
if multilevel:
11031090
key = (col_name, agg_name)
11041091
else:
11051092
key = col_name
1093+
1094+
create_dtype = dtype_from_pylibcudf_column(plc_result)
1095+
cast_dtype = None
11061096
if agg in {list, "collect"}:
11071097
# Collect wraps the original dtype in ListDtype (e.g., int -> list<int>)
1108-
new_dtype = get_dtype_of_same_kind(
1098+
create_dtype = get_dtype_of_same_kind(
11091099
orig_dtype, ListDtype(orig_dtype)
11101100
)
1111-
col = ColumnBase.create(col.plc_column, new_dtype)
1112-
1113-
# Default: use column as-is
1114-
data[key] = col
1115-
11161101
# Override for specific aggregation types that need dtype adjustments
11171102
if agg_kind in {"COUNT", "SIZE", "ARGMIN", "ARGMAX"}:
1118-
data[key] = col.astype(
1119-
get_dtype_of_same_kind(orig_dtype, np.dtype(np.int64))
1103+
cast_dtype = get_dtype_of_same_kind(
1104+
orig_dtype, np.dtype(np.int64)
11201105
)
11211106
elif (
11221107
self.obj.empty
11231108
and (
11241109
isinstance(agg_name, str)
11251110
and agg_name in Reducible._SUPPORTED_REDUCTIONS
11261111
)
1127-
and len(col) == 0
1112+
and plc_result.size() == 0
11281113
and not isinstance(
1129-
col.dtype,
1114+
create_dtype,
11301115
(ListDtype, StructDtype, DecimalDtype),
11311116
)
11321117
):
1133-
data[key] = col.astype(orig_dtype)
1118+
cast_dtype = orig_dtype
11341119
elif agg not in {list, "collect"}:
1135-
# For non-collect aggregations, apply original dtype metadata
1136-
if isinstance(orig_dtype, DecimalDtype):
1137-
# `col` has a different precision than `orig_dtype`
1138-
# hence we only preserve the kind of the dtype
1139-
# and not the precision.
1140-
data[key] = col._with_type_metadata(
1141-
get_dtype_of_same_kind(orig_dtype, col.dtype)
1142-
)
1143-
else:
1144-
data[key] = col._with_type_metadata(orig_dtype)
1120+
create_dtype = get_dtype_of_same_kind(
1121+
orig_dtype, create_dtype
1122+
)
1123+
1124+
result_col = ColumnBase.create(plc_result, create_dtype)
1125+
if agg == "cumcount":
1126+
# pandas 0-indexes cumulative count, see
1127+
# https://github.com/rapidsai/cudf/issues/10237
1128+
result_col = result_col - 1
1129+
if cast_dtype is not None:
1130+
result_col = result_col.astype(cast_dtype)
1131+
data[key] = result_col
11451132
data = ColumnAccessor(data, multiindex=multilevel)
11461133
if not multilevel:
11471134
data = data.rename_levels({np.nan: None}, level=0)

python/cudf/cudf/pandas/scripts/conftest-patch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,12 @@ def set_copy_on_write_option():
20392039
"tests/extension/test_string.py::TestStringArray::test_container_shift[python-True-0-indices1-True]",
20402040
"tests/extension/test_string.py::TestStringArray::test_container_shift[python-True-2-indices2-False]",
20412041
"tests/extension/test_string.py::TestStringArray::test_container_shift[python-True-2-indices2-True]",
2042+
"tests/extension/test_string.py::TestStringArray::test_groupby_agg_extension[string=string[pyarrow]-True]",
2043+
"tests/extension/test_string.py::TestStringArray::test_groupby_agg_extension[string=string[pyarrow]-False]",
2044+
"tests/extension/test_string.py::TestStringArray::test_groupby_agg_extension[string=str[pyarrow]-True]",
2045+
"tests/extension/test_string.py::TestStringArray::test_groupby_agg_extension[string=str[pyarrow]-False]",
2046+
"tests/extension/test_string.py::TestStringArray::test_groupby_agg_extension[string=str[python]-True]",
2047+
"tests/extension/test_string.py::TestStringArray::test_groupby_agg_extension[string=str[python]-False]",
20422048
"tests/extension/test_string.py::TestStringArray::test_grouping_grouper[pyarrow-False]",
20432049
"tests/extension/test_string.py::TestStringArray::test_grouping_grouper[pyarrow-True]",
20442050
"tests/extension/test_string.py::TestStringArray::test_grouping_grouper[pyarrow_numpy-False]",

0 commit comments

Comments
 (0)