1818
1919import cudf
2020from cudf .api .extensions import no_default
21-
22- # TODO: The `numpy` import is needed for typing purposes during doc builds
23- # only, need to figure out why the `np` alias is insufficient then remove.
2421from cudf .api .types import is_dtype_equal , is_scalar , is_string_dtype
2522from cudf .core ._compat import PANDAS_LT_300
2623from cudf .core ._internals import copying , sorting
3633from cudf .core .dtype .validators import is_dtype_obj_numeric
3734from cudf .core .mixins import BinaryOperand , Scannable
3835from cudf .utils .dtypes import (
36+ dtype_from_pylibcudf_column ,
3937 find_common_type ,
4038 is_pandas_nullable_extension_dtype ,
4139)
@@ -1517,16 +1515,17 @@ def searchsorted(
15171515 )
15181516 ]
15191517
1520- outcol = ColumnBase .from_pylibcudf (
1521- sorting .search_sorted (
1522- sources ,
1523- values ,
1524- side ,
1525- ascending = itertools .repeat (ascending , times = len (sources )),
1526- na_position = itertools .repeat (na_position , times = len (sources )),
1527- )
1518+ plc_outcol = sorting .search_sorted (
1519+ sources ,
1520+ values ,
1521+ side ,
1522+ ascending = itertools .repeat (ascending , times = len (sources )),
1523+ na_position = itertools .repeat (na_position , times = len (sources )),
1524+ )
1525+ outcol = ColumnBase .create (
1526+ plc .unary .cast (plc_outcol , plc .DataType (plc .TypeId .INT64 )),
1527+ np .dtype ("int64" ),
15281528 )
1529- outcol = outcol .astype (np .dtype ("int64" ))
15301529
15311530 # Return result as cupy array if the values is non-scalar
15321531 # If values is scalar, result is expected to be scalar.
@@ -1645,13 +1644,14 @@ def _get_sorted_inds(
16451644 )
16461645 else :
16471646 ascending_iter = ascending
1648- return ColumnBase .from_pylibcudf (
1649- sorting .order_by (
1650- to_sort ,
1651- ascending_iter ,
1652- itertools .repeat (na_position , times = len (to_sort )),
1653- stable = True ,
1654- )
1647+ plc_result = sorting .order_by (
1648+ to_sort ,
1649+ ascending_iter ,
1650+ itertools .repeat (na_position , times = len (to_sort )),
1651+ stable = True ,
1652+ )
1653+ return ColumnBase .create (
1654+ plc_result , dtype_from_pylibcudf_column (plc_result )
16551655 )
16561656
16571657 @_performance_tracking
@@ -1668,7 +1668,12 @@ def _split(self, splits: list[int]) -> list[Self]:
16681668 return []
16691669 return [
16701670 self ._from_columns_like_self (
1671- [ColumnBase .from_pylibcudf (col ) for col in split ],
1671+ [
1672+ ColumnBase .create (col , dtype )
1673+ for col , (_ , dtype ) in zip (
1674+ split , self ._dtypes , strict = True
1675+ )
1676+ ],
16721677 self ._column_names ,
16731678 )
16741679 for split in copying .columns_split (self ._columns , splits )
@@ -1680,9 +1685,14 @@ def _encode(self) -> tuple[Self, ColumnBase]:
16801685 plc .Table ([col .plc_column for col in self ._columns ])
16811686 )
16821687 columns = [
1683- ColumnBase .from_pylibcudf (col ) for col in plc_table .columns ()
1688+ ColumnBase .create (col , dtype )
1689+ for col , (_ , dtype ) in zip (
1690+ plc_table .columns (), self ._dtypes , strict = True
1691+ )
16841692 ]
1685- indices = ColumnBase .from_pylibcudf (plc_column )
1693+ indices = ColumnBase .create (
1694+ plc_column , dtype_from_pylibcudf_column (plc_column )
1695+ )
16861696 keys = self ._from_columns_like_self (columns )
16871697 return keys , indices
16881698
@@ -2142,8 +2152,12 @@ def _repeat(
21422152 else :
21432153 repeats_plc = repeats
21442154 return [
2145- ColumnBase .from_pylibcudf (col )
2146- for col in plc .filling .repeat (plc_table , repeats_plc ).columns ()
2155+ ColumnBase .create (col , reference_col .dtype )
2156+ for col , reference_col in zip (
2157+ plc .filling .repeat (plc_table , repeats_plc ).columns (),
2158+ columns ,
2159+ strict = True ,
2160+ )
21472161 ]
21482162
21492163 @_performance_tracking
0 commit comments