4949#include " matx/core/operator_utils.h"
5050#include " matx/core/type_utils_both.h"
5151#include " matx/transforms/cccl_iterators.h"
52+ #include " matx/transforms/host_algorithms.h"
5253
5354
5455namespace matx {
@@ -2260,7 +2261,7 @@ void argsort_impl(OutputTensor &idx_out, const InputOperator &a,
22602261template <typename OutputTensor, typename InputOperator, ThreadsMode MODE>
22612262void argsort_impl (OutputTensor &idx_out, const InputOperator &a,
22622263 const SortDirection_t dir,
2263- [[maybe_unused]] const HostExecutor<MODE> &exec)
2264+ const HostExecutor<MODE> &exec)
22642265{
22652266 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
22662267
@@ -2271,25 +2272,23 @@ void argsort_impl(OutputTensor &idx_out, const InputOperator &a,
22712272
22722273 if constexpr (RANK == 1 ) {
22732274 if (dir == SORT_DIR_ASC) {
2274- std::sort (
2275- lout, lout + idx_out.Size (0 ),
2276- [&a](index_t i, index_t j) { return a (i) < a (j); });
2275+ detail::host_sort (exec, lout, lout + idx_out.Size (0 ),
2276+ [&a](index_t i, index_t j) { return a (i) < a (j); });
22772277 }
22782278 else {
2279- std::sort (
2280- lout, lout + idx_out.Size (0 ),
2281- [&a](index_t i, index_t j) { return a (i) > a (j); });
2279+ detail::host_sort (exec, lout, lout + idx_out.Size (0 ),
2280+ [&a](index_t i, index_t j) { return a (i) > a (j); });
22822281 }
22832282 }
22842283 else if constexpr (RANK == 2 ) {
22852284 for (index_t b = 0 ; b < lout.Size (0 ); b++) {
22862285 if (dir == SORT_DIR_ASC) {
2287- std::sort ( lout + b*a.Size (1 ), lout + (b+1 )*a.Size (1 ),
2288- [&a, b](index_t i, index_t j) { return a (b,i) < a (b,j); });
2286+ detail::host_sort (exec, lout + b*a.Size (1 ), lout + (b+1 )*a.Size (1 ),
2287+ [&a, b](index_t i, index_t j) { return a (b,i) < a (b,j); });
22892288 }
22902289 else {
2291- std::sort ( lout + b*a.Size (1 ), lout + (b+1 )*a.Size (1 ),
2292- [&a, b](index_t i, index_t j) { return a (b,i) > a (b,j); });
2290+ detail::host_sort (exec, lout + b*a.Size (1 ), lout + (b+1 )*a.Size (1 ),
2291+ [&a, b](index_t i, index_t j) { return a (b,i) > a (b,j); });
22932292 }
22942293 }
22952294 }
@@ -2301,7 +2300,7 @@ void argsort_impl(OutputTensor &idx_out, const InputOperator &a,
23012300template <typename OutputTensor, typename InputOperator, ThreadsMode MODE>
23022301void sort_impl (OutputTensor &a_out, const InputOperator &a,
23032302 const SortDirection_t dir,
2304- [[maybe_unused]] const HostExecutor<MODE> &exec)
2303+ const HostExecutor<MODE> &exec)
23052304{
23062305 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
23072306
@@ -2312,33 +2311,37 @@ void sort_impl(OutputTensor &a_out, const InputOperator &a,
23122311
23132312 if constexpr (InputOperator::Rank () == 1 ) {
23142313 if (dir == SORT_DIR_ASC) {
2315- std::partial_sort_copy ( lin,
2316- lin + a.Size (0 ),
2317- lout,
2318- lout + a_out.Size (0 ));
2314+ detail::host_sort_copy (exec,
2315+ lin,
2316+ lin + a.Size (0 ),
2317+ lout,
2318+ lout + a_out.Size (0 ));
23192319 }
23202320 else {
2321- std::partial_sort_copy ( lin,
2322- lin + a.Size (0 ),
2323- lout,
2324- lout + a_out.Size (0 ),
2325- std::greater<typename InputOperator::value_type>());
2321+ detail::host_sort_copy (exec,
2322+ lin,
2323+ lin + a.Size (0 ),
2324+ lout,
2325+ lout + a_out.Size (0 ),
2326+ std::greater<typename InputOperator::value_type>());
23262327 }
23272328 }
23282329 else {
23292330 for (index_t b = 0 ; b < lout.Size (0 ); b++) {
23302331 if (dir == SORT_DIR_ASC) {
2331- std::partial_sort_copy ( lin + b*a.Size (1 ),
2332- lin + (b+1 )*a.Size (1 ),
2333- lout + b*a.Size (1 ),
2334- lout + (b+1 )*a.Size (1 ));
2332+ detail::host_sort_copy (exec,
2333+ lin + b*a.Size (1 ),
2334+ lin + (b+1 )*a.Size (1 ),
2335+ lout + b*a.Size (1 ),
2336+ lout + (b+1 )*a.Size (1 ));
23352337 }
23362338 else {
2337- std::partial_sort_copy ( lin + b*a.Size (1 ),
2338- lin + (b+1 )*a.Size (1 ),
2339- lout + b*a.Size (1 ),
2340- lout + (b+1 )*a.Size (1 ),
2341- std::greater<typename InputOperator::value_type>());
2339+ detail::host_sort_copy (exec,
2340+ lin + b*a.Size (1 ),
2341+ lin + (b+1 )*a.Size (1 ),
2342+ lout + b*a.Size (1 ),
2343+ lout + (b+1 )*a.Size (1 ),
2344+ std::greater<typename InputOperator::value_type>());
23422345 }
23432346 }
23442347 }
@@ -2399,7 +2402,7 @@ void cumsum_impl(OutputTensor &a_out, const InputOperator &a,
23992402
24002403template <typename OutputTensor, typename InputOperator, ThreadsMode MODE>
24012404void cumsum_impl (OutputTensor &a_out, const InputOperator &a,
2402- [[maybe_unused]] const HostExecutor<MODE> &exec)
2405+ const HostExecutor<MODE> &exec)
24032406{
24042407#ifdef __CUDACC__
24052408 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
@@ -2409,15 +2412,17 @@ void cumsum_impl(OutputTensor &a_out, const InputOperator &a,
24092412 auto lout = matx::RandomOperatorOutputIterator{out_base};
24102413
24112414 if constexpr (OutputTensor::Rank () == 1 ) {
2412- std::partial_sum ( lin,
2413- lin + a.Size (0 ),
2414- lout);
2415+ detail::host_inclusive_scan (exec,
2416+ lin,
2417+ lin + a.Size (0 ),
2418+ lout);
24152419 }
24162420 else if constexpr (InputOperator::Rank () == 2 ) {
24172421 for (index_t b = 0 ; b < a.Size (0 ); b++) {
2418- std::partial_sum ( lin + b * a.Size (1 ),
2419- lin + (b+1 ) * a.Size (1 ),
2420- lout + b * a.Size (1 ));
2422+ detail::host_inclusive_scan (exec,
2423+ lin + b * a.Size (1 ),
2424+ lin + (b+1 ) * a.Size (1 ),
2425+ lout + b * a.Size (1 ));
24212426 }
24222427 }
24232428 else {
@@ -2834,7 +2839,7 @@ void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperato
28342839 * Single thread executor
28352840 */
28362841template <typename CountTensor, typename OutputTensor, typename InputOperator, ThreadsMode MODE>
2837- void unique_impl (OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, [[maybe_unused]] const HostExecutor<MODE> &exec)
2842+ void unique_impl (OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, const HostExecutor<MODE> &exec)
28382843{
28392844#ifdef __CUDACC__
28402845 static_assert (CountTensor::Rank () == 0 , " Num found output tensor rank must be 0" );
@@ -2845,8 +2850,8 @@ void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperato
28452850 return ;
28462851 }
28472852
2848- std::partial_sort_copy ( cbegin (a), cend (a), begin (a_out), end (a_out));
2849- auto last = std::unique ( begin (a_out), end (a_out));
2853+ detail::host_sort_copy (exec, cbegin (a), cend (a), begin (a_out), end (a_out));
2854+ auto last = detail::host_unique (exec, begin (a_out), end (a_out));
28502855 num_found () = static_cast <int >(last - begin (a_out));
28512856#endif
28522857}
0 commit comments