Skip to content

Commit 0fde383

Browse files
authored
Use ColumnBase.create more in frame.py (rapidsai#21394)
Towards rapidsai#21229 Additionally preserves the original types more than just converting `from_pylibcudf` Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: rapidsai#21394
1 parent fe6797e commit 0fde383

1 file changed

Lines changed: 38 additions & 24 deletions

File tree

python/cudf/cudf/core/frame.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818

1919
import cudf
2020
from cudf.api.extensions import no_default
21-
22-
# TODO: The `numpy` import is needed for typing purposes during doc builds
23-
# only, need to figure out why the `np` alias is insufficient then remove.
2421
from cudf.api.types import is_dtype_equal, is_scalar, is_string_dtype
2522
from cudf.core._compat import PANDAS_LT_300
2623
from cudf.core._internals import copying, sorting
@@ -36,6 +33,7 @@
3633
from cudf.core.dtype.validators import is_dtype_obj_numeric
3734
from cudf.core.mixins import BinaryOperand, Scannable
3835
from cudf.utils.dtypes import (
36+
dtype_from_pylibcudf_column,
3937
find_common_type,
4038
is_pandas_nullable_extension_dtype,
4139
)
@@ -1517,16 +1515,17 @@ def searchsorted(
15171515
)
15181516
]
15191517

1520-
outcol = ColumnBase.from_pylibcudf(
1521-
sorting.search_sorted(
1522-
sources,
1523-
values,
1524-
side,
1525-
ascending=itertools.repeat(ascending, times=len(sources)),
1526-
na_position=itertools.repeat(na_position, times=len(sources)),
1527-
)
1518+
plc_outcol = sorting.search_sorted(
1519+
sources,
1520+
values,
1521+
side,
1522+
ascending=itertools.repeat(ascending, times=len(sources)),
1523+
na_position=itertools.repeat(na_position, times=len(sources)),
1524+
)
1525+
outcol = ColumnBase.create(
1526+
plc.unary.cast(plc_outcol, plc.DataType(plc.TypeId.INT64)),
1527+
np.dtype("int64"),
15281528
)
1529-
outcol = outcol.astype(np.dtype("int64"))
15301529

15311530
# Return result as cupy array if the values is non-scalar
15321531
# If values is scalar, result is expected to be scalar.
@@ -1645,13 +1644,14 @@ def _get_sorted_inds(
16451644
)
16461645
else:
16471646
ascending_iter = ascending
1648-
return ColumnBase.from_pylibcudf(
1649-
sorting.order_by(
1650-
to_sort,
1651-
ascending_iter,
1652-
itertools.repeat(na_position, times=len(to_sort)),
1653-
stable=True,
1654-
)
1647+
plc_result = sorting.order_by(
1648+
to_sort,
1649+
ascending_iter,
1650+
itertools.repeat(na_position, times=len(to_sort)),
1651+
stable=True,
1652+
)
1653+
return ColumnBase.create(
1654+
plc_result, dtype_from_pylibcudf_column(plc_result)
16551655
)
16561656

16571657
@_performance_tracking
@@ -1668,7 +1668,12 @@ def _split(self, splits: list[int]) -> list[Self]:
16681668
return []
16691669
return [
16701670
self._from_columns_like_self(
1671-
[ColumnBase.from_pylibcudf(col) for col in split],
1671+
[
1672+
ColumnBase.create(col, dtype)
1673+
for col, (_, dtype) in zip(
1674+
split, self._dtypes, strict=True
1675+
)
1676+
],
16721677
self._column_names,
16731678
)
16741679
for split in copying.columns_split(self._columns, splits)
@@ -1680,9 +1685,14 @@ def _encode(self) -> tuple[Self, ColumnBase]:
16801685
plc.Table([col.plc_column for col in self._columns])
16811686
)
16821687
columns = [
1683-
ColumnBase.from_pylibcudf(col) for col in plc_table.columns()
1688+
ColumnBase.create(col, dtype)
1689+
for col, (_, dtype) in zip(
1690+
plc_table.columns(), self._dtypes, strict=True
1691+
)
16841692
]
1685-
indices = ColumnBase.from_pylibcudf(plc_column)
1693+
indices = ColumnBase.create(
1694+
plc_column, dtype_from_pylibcudf_column(plc_column)
1695+
)
16861696
keys = self._from_columns_like_self(columns)
16871697
return keys, indices
16881698

@@ -2142,8 +2152,12 @@ def _repeat(
21422152
else:
21432153
repeats_plc = repeats
21442154
return [
2145-
ColumnBase.from_pylibcudf(col)
2146-
for col in plc.filling.repeat(plc_table, repeats_plc).columns()
2155+
ColumnBase.create(col, reference_col.dtype)
2156+
for col, reference_col in zip(
2157+
plc.filling.repeat(plc_table, repeats_plc).columns(),
2158+
columns,
2159+
strict=True,
2160+
)
21472161
]
21482162

21492163
@_performance_tracking

0 commit comments

Comments
 (0)