1616#include < cudf/detail/copy.hpp>
1717#include < cudf/detail/nvtx/ranges.hpp>
1818#include < cudf/detail/utilities/device_operators.cuh>
19+ #include < cudf/dictionary/dictionary_column_view.hpp>
1920#include < cudf/types.hpp>
2021#include < cudf/utilities/error.hpp>
2122#include < cudf/utilities/memory_resource.hpp>
@@ -55,22 +56,23 @@ struct DeviceRolling {
5556
5657 // operations we do support
5758 template <typename T = InputType, aggregation::Kind O = op>
58- explicit DeviceRolling (size_type _min_periods, std:: enable_if_t < is_supported<T, O>()>* = nullptr )
59- : min_periods(_min_periods )
59+ requires ( is_supported<T, O>())
60+ explicit DeviceRolling (size_type min_periods) : min_periods(min_periods )
6061 {
6162 }
6263
6364 // operations we don't support
6465 template <typename T = InputType, aggregation::Kind O = op>
65- explicit DeviceRolling (size_type _min_periods, std:: enable_if_t <! is_supported<T, O>()>* = nullptr )
66- : min_periods(_min_periods )
66+ requires ( not is_supported<T, O>())
67+ explicit DeviceRolling (size_type min_periods) : min_periods(min_periods )
6768 {
6869 CUDF_FAIL (" Invalid aggregation/type pair" );
6970 }
7071
7172 // perform the windowing operation
72- template <typename OutputType, bool has_nulls >
73+ template <typename OutputType>
7374 bool __device__ operator ()(column_device_view const & input,
75+ bool has_nulls,
7476 column_device_view const &,
7577 mutable_column_device_view& output,
7678 size_type start_index,
@@ -84,7 +86,7 @@ struct DeviceRolling {
8486 OutputType val = AggOp::template identity<OutputType>();
8587
8688 for (size_type j = start_index; j < end_index; j++) {
87- if (!has_nulls || input.is_valid (j)) {
89+ if (!has_nulls || input.is_valid_nocheck (j)) {
8890 OutputType element = input.element <device_storage_type_t <InputType>>(j);
8991 val = agg_op (element, val);
9092 count++;
@@ -139,8 +141,9 @@ struct DeviceRollingArgMinMaxString : DeviceRollingArgMinMaxBase<cudf::string_vi
139141 }
140142 using DeviceRollingArgMinMaxBase<cudf::string_view, op>::min_periods;
141143
142- template <typename OutputType, bool has_nulls >
144+ template <typename OutputType>
143145 bool __device__ operator ()(column_device_view const & input,
146+ bool has_nulls,
144147 column_device_view const &,
145148 mutable_column_device_view& output,
146149 size_type start_index,
@@ -158,7 +161,7 @@ struct DeviceRollingArgMinMaxString : DeviceRollingArgMinMaxBase<cudf::string_vi
158161 OutputType val_index = default_output;
159162
160163 for (size_type j = start_index; j < end_index; j++) {
161- if (!has_nulls || input.is_valid (j)) {
164+ if (!has_nulls || input.is_valid_nocheck (j)) {
162165 InputType element = input.element <InputType>(j);
163166 val = agg_op (element, val);
164167 if (val.data () == element.data ()) { val_index = j; }
@@ -188,8 +191,9 @@ struct DeviceRollingArgMinMaxStruct : DeviceRollingArgMinMaxBase<cudf::struct_vi
188191 using DeviceRollingArgMinMaxBase<cudf::struct_view, op>::min_periods;
189192 Comparator comp;
190193
191- template <typename OutputType, bool has_nulls >
194+ template <typename OutputType>
192195 bool __device__ operator ()(column_device_view const & input,
196+ bool has_nulls,
193197 column_device_view const &,
194198 mutable_column_device_view& output,
195199 size_type start_index,
@@ -232,11 +236,11 @@ struct DeviceRollingArgMinMaxDictionary : DeviceRollingArgMinMaxBase<cudf::dicti
232236 }
233237 using DeviceRollingArgMinMaxBase<cudf::dictionary32, op>::min_periods;
234238
235- template <bool has_nulls>
236239 struct keys_dispatch_fn {
237240 template <typename T>
238241 requires (cudf::is_relationally_comparable<T, T>() and not cudf::is_dictionary<T>())
239242 size_type __device__ operator ()(column_device_view const & dict,
243+ bool has_nulls,
240244 size_type start_index,
241245 size_type end_index,
242246 size_type current_index)
@@ -249,7 +253,7 @@ struct DeviceRollingArgMinMaxDictionary : DeviceRollingArgMinMaxBase<cudf::dicti
249253 auto val = AggOp::template identity<T>();
250254 auto index = size_type{-1 };
251255 for (size_type j = start_index; j < end_index; j++) {
252- if (!has_nulls || dict.is_valid (j)) {
256+ if (!has_nulls || dict.is_valid_nocheck (j)) {
253257 auto element = keys.element <T>(dict.element <dictionary32>(j).value ());
254258 val = agg_op (element, val);
255259 if (val == element) { index = j; }
@@ -260,26 +264,29 @@ struct DeviceRollingArgMinMaxDictionary : DeviceRollingArgMinMaxBase<cudf::dicti
260264 }
261265 template <typename T>
262266 requires (!cudf::is_relationally_comparable<T, T>() or cudf::is_dictionary<T>())
263- size_type __device__ operator ()(column_device_view const &, size_type, size_type, size_type)
267+ size_type __device__
268+ operator ()(column_device_view const &, bool , size_type, size_type, size_type)
264269 {
265270 CUDF_UNREACHABLE (" invalid dictionary key" );
266271 }
267272
268273 size_type min_periods;
269274 };
270275
271- template <typename OutputType, bool has_nulls >
276+ template <typename OutputType>
272277 bool __device__ operator ()(column_device_view const & input,
278+ bool has_nulls,
273279 column_device_view const &,
274280 mutable_column_device_view& output,
275281 size_type start_index,
276282 size_type end_index,
277283 size_type current_index)
278284 {
279- auto keys_type = input.child (1 ).type ();
285+ auto keys_type = input.child (cudf::dictionary_column_view::keys_column_index ).type ();
280286 auto index = type_dispatcher<dispatch_storage_type>(keys_type,
281- keys_dispatch_fn<has_nulls> {min_periods},
287+ keys_dispatch_fn{min_periods},
282288 input,
289+ has_nulls,
283290 start_index,
284291 end_index,
285292 current_index);
@@ -306,8 +313,9 @@ struct DeviceRollingCountValid {
306313
307314 DeviceRollingCountValid (size_type _min_periods) : min_periods(_min_periods) {}
308315
309- template <typename OutputType, bool has_nulls >
316+ template <typename OutputType>
310317 bool __device__ operator ()(column_device_view const & input,
318+ bool has_nulls,
311319 column_device_view const &,
312320 mutable_column_device_view& output,
313321 size_type start_index,
@@ -350,8 +358,9 @@ struct DeviceRollingCountAll {
350358
351359 DeviceRollingCountAll (size_type _min_periods) : min_periods(_min_periods) {}
352360
353- template <typename OutputType, bool has_nulls >
361+ template <typename OutputType>
354362 bool __device__ operator ()(column_device_view const &,
363+ bool ,
355364 column_device_view const &,
356365 mutable_column_device_view& output,
357366 size_type start_index,
@@ -387,8 +396,9 @@ struct DeviceRollingVariance {
387396 {
388397 }
389398
390- template <typename OutputType, bool has_nulls >
399+ template <typename OutputType>
391400 bool __device__ operator ()(column_device_view const & input,
401+ bool has_nulls,
392402 column_device_view const &,
393403 mutable_column_device_view& output,
394404 size_type start_index,
@@ -463,8 +473,9 @@ struct DeviceRollingRowNumber {
463473
464474 DeviceRollingRowNumber (size_type _min_periods) : min_periods(_min_periods) {}
465475
466- template <typename OutputType, bool has_nulls >
476+ template <typename OutputType>
467477 bool __device__ operator ()(column_device_view const &,
478+ bool ,
468479 column_device_view const &,
469480 mutable_column_device_view& output,
470481 size_type start_index,
@@ -530,8 +541,9 @@ struct DeviceRollingLead {
530541 CUDF_FAIL (" Invalid aggregation/type pair" );
531542 }
532543
533- template <typename OutputType, bool has_nulls >
544+ template <typename OutputType>
534545 bool __device__ operator ()(column_device_view const & input,
546+ bool has_nulls,
535547 column_device_view const & default_outputs,
536548 mutable_column_device_view& output,
537549 size_type,
@@ -552,7 +564,7 @@ struct DeviceRollingLead {
552564
553565 // Not an invalid row.
554566 auto index = current_index + row_offset;
555- auto is_null = input.is_null (index);
567+ auto is_null = has_nulls && input.is_null_nocheck (index);
556568 if (!is_null) {
557569 if constexpr (cudf::is_dictionary<InputType>()) {
558570 output.element <OutputType>(current_index) = input.element <dictionary32>(index).value ();
@@ -594,8 +606,9 @@ struct DeviceRollingLag {
594606 CUDF_FAIL (" Invalid aggregation/type pair" );
595607 }
596608
597- template <typename OutputType, bool has_nulls >
609+ template <typename OutputType>
598610 bool __device__ operator ()(column_device_view const & input,
611+ bool has_nulls,
599612 column_device_view const & default_outputs,
600613 mutable_column_device_view& output,
601614 size_type start_index,
@@ -616,7 +629,7 @@ struct DeviceRollingLag {
616629
617630 // Not an invalid row.
618631 auto index = current_index - row_offset;
619- auto is_null = input.is_null (index);
632+ auto is_null = has_nulls && input.is_null_nocheck (index);
620633 if (!is_null) {
621634 if constexpr (cudf::is_dictionary<InputType>()) {
622635 output.element <OutputType>(current_index) = input.element <dictionary32>(index).value ();
@@ -735,11 +748,9 @@ struct create_rolling_operator<InputType, aggregation::Kind::LAG> {
735748};
736749
737750template <typename InputType, aggregation::Kind k>
738- struct create_rolling_operator <
739- InputType,
740- k,
741- typename std::enable_if_t <std::is_same_v<InputType, cudf::string_view> &&
742- (k == aggregation::Kind::ARGMIN || k == aggregation::Kind::ARGMAX )>> {
751+ requires (std::is_same_v<InputType, cudf::string_view> &&
752+ (k == aggregation::Kind::ARGMIN || k == aggregation::Kind::ARGMAX ))
753+ struct create_rolling_operator <InputType, k> {
743754 auto operator ()(size_type min_periods, rolling_aggregation const &)
744755 {
745756 return DeviceRollingArgMinMaxString<k>{min_periods};
@@ -757,11 +768,9 @@ struct create_rolling_operator<InputType, k> {
757768};
758769
759770template <typename InputType, aggregation::Kind k>
760- struct create_rolling_operator <
761- InputType,
762- k,
763- typename std::enable_if_t <std::is_same_v<InputType, cudf::struct_view> &&
764- (k == aggregation::Kind::ARGMIN || k == aggregation::Kind::ARGMAX )>> {
771+ requires (std::is_same_v<InputType, cudf::struct_view> &&
772+ (k == aggregation::Kind::ARGMIN || k == aggregation::Kind::ARGMAX ))
773+ struct create_rolling_operator <InputType, k> {
765774 template <typename Comparator>
766775 auto operator ()(size_type min_periods, Comparator const & comp)
767776 {
0 commit comments