@@ -60,24 +60,24 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
6060
6161 if (order == GGML_SORT_ORDER_ASC) {
6262 if (nrows == 1 ) {
63- DeviceRadixSort::SortPairs (nullptr , temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
63+ CUDA_CHECK ( DeviceRadixSort::SortPairs (nullptr , temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
6464 temp_indices, dst, // values (indices)
65- ncols, 0 , sizeof (float ) * 8 , stream);
65+ ncols, 0 , sizeof (float ) * 8 , stream)) ;
6666 } else {
67- DeviceSegmentedSort::SortPairs (nullptr , temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
67+ CUDA_CHECK ( DeviceSegmentedSort::SortPairs (nullptr , temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
6868 temp_indices, dst, // values (indices)
6969 ncols * nrows, nrows, // num items, num segments
70- offset_iterator, offset_iterator + 1 , stream);
70+ offset_iterator, offset_iterator + 1 , stream)) ;
7171 }
7272 } else {
7373 if (nrows == 1 ) {
74- DeviceRadixSort::SortPairsDescending (nullptr , temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
74+ CUDA_CHECK ( DeviceRadixSort::SortPairsDescending (nullptr , temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
7575 temp_indices, dst, // values (indices)
76- ncols, 0 , sizeof (float ) * 8 , stream);
76+ ncols, 0 , sizeof (float ) * 8 , stream)) ;
7777 } else {
78- DeviceSegmentedSort::SortPairsDescending (nullptr , temp_storage_bytes, temp_keys, temp_keys, temp_indices,
78+ CUDA_CHECK ( DeviceSegmentedSort::SortPairsDescending (nullptr , temp_storage_bytes, temp_keys, temp_keys, temp_indices,
7979 dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1 ,
80- stream);
80+ stream)) ;
8181 }
8282 }
8383
@@ -86,22 +86,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
8686
8787 if (order == GGML_SORT_ORDER_ASC) {
8888 if (nrows == 1 ) {
89- DeviceRadixSort::SortPairs (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
89+ CUDA_CHECK ( DeviceRadixSort::SortPairs (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
9090 temp_indices, dst, // values (indices)
91- ncols, 0 , sizeof (float ) * 8 , stream);
91+ ncols, 0 , sizeof (float ) * 8 , stream)) ;
9292 } else {
93- DeviceSegmentedSort::SortPairs (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
94- ncols * nrows, nrows, offset_iterator, offset_iterator + 1 , stream);
93+ CUDA_CHECK ( DeviceSegmentedSort::SortPairs (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
94+ ncols * nrows, nrows, offset_iterator, offset_iterator + 1 , stream)) ;
9595 }
9696 } else {
9797 if (nrows == 1 ) {
98- DeviceRadixSort::SortPairsDescending (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
98+ CUDA_CHECK ( DeviceRadixSort::SortPairsDescending (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
9999 temp_indices, dst, // values (indices)
100- ncols, 0 , sizeof (float ) * 8 , stream);
100+ ncols, 0 , sizeof (float ) * 8 , stream)) ;
101101 } else {
102- DeviceSegmentedSort::SortPairsDescending (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
102+ CUDA_CHECK ( DeviceSegmentedSort::SortPairsDescending (d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
103103 temp_indices, dst, ncols * nrows, nrows, offset_iterator,
104- offset_iterator + 1 , stream);
104+ offset_iterator + 1 , stream)) ;
105105 }
106106 }
107107}
0 commit comments