Skip to content

Commit 2f7cb8b

Browse files
davebayermiscco
andauthored
[libcu++] Use stream's context in PSTL (#9219)
* [libcu++] Use stream's context in PSTL * Address review comments * Actually use the right name * Morning coffee * fixes * fix --------- Co-authored-by: Michael Schellenberger Costa <miscco@nvidia.com>
1 parent fbb9e24 commit 2f7cb8b

26 files changed

Lines changed: 164 additions & 54 deletions

libcudacxx/include/cuda/std/__pstl/cuda/adjacent_difference.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ _CCCL_DIAG_POP
4444
# include <cuda/std/__execution/policy.h>
4545
# include <cuda/std/__iterator/iterator_traits.h>
4646
# include <cuda/std/__numeric/adjacent_difference.h>
47+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
4748
# include <cuda/std/__pstl/cuda/temporary_storage.h>
4849
# include <cuda/std/__pstl/dispatch.h>
4950
# include <cuda/std/__type_traits/always_false.h>
@@ -66,7 +67,10 @@ struct __pstl_dispatch<__pstl_algorithm::__adjacent_difference, __execution_back
6667
_OutputIterator __result,
6768
_BinaryOp __binary_op)
6869
{
69-
auto __count = ::cuda::std::distance(__first, __last);
70+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
71+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
72+
73+
const auto __count = ::cuda::std::distance(__first, __last);
7074

7175
// We pass the policy as an environment to DeviceAdjacentDifference
7276
_CCCL_TRY_CUDA_API(
@@ -78,8 +82,6 @@ struct __pstl_dispatch<__pstl_algorithm::__adjacent_difference, __execution_back
7882
::cuda::std::move(__binary_op),
7983
__policy);
8084

81-
// Get the stream for synchronization after the algorithm is run
82-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
8385
__stream.sync();
8486

8587
return __result + __count;

libcudacxx/include/cuda/std/__pstl/cuda/copy_if.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ _CCCL_DIAG_POP
4747
# include <cuda/std/__iterator/incrementable_traits.h>
4848
# include <cuda/std/__iterator/iterator_traits.h>
4949
# include <cuda/std/__iterator/next.h>
50+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5051
# include <cuda/std/__pstl/cuda/temporary_storage.h>
5152
# include <cuda/std/__pstl/dispatch.h>
5253
# include <cuda/std/__type_traits/always_false.h>
@@ -69,11 +70,12 @@ struct __pstl_dispatch<__pstl_algorithm::__copy_if, __execution_backend::__cuda>
6970
_OutputIterator __result,
7071
_UnaryPredicate __pred)
7172
{
73+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
74+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
75+
7276
using _OffsetType = iter_difference_t<_InputIterator>;
7377
_OffsetType __ret;
7478

75-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
76-
7779
// Determine temporary device storage requirements
7880
void* __temp_storage = nullptr;
7981
size_t __num_bytes = 0;

libcudacxx/include/cuda/std/__pstl/cuda/copy_n.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ _CCCL_DIAG_POP
5151
# include <cuda/std/__iterator/incrementable_traits.h>
5252
# include <cuda/std/__iterator/iterator_traits.h>
5353
# include <cuda/std/__memory/pointer_traits.h>
54+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5455
# include <cuda/std/__pstl/dispatch.h>
5556
# include <cuda/std/__type_traits/always_false.h>
5657
# include <cuda/std/__type_traits/is_same.h>
@@ -72,6 +73,9 @@ struct __pstl_dispatch<__pstl_algorithm::__copy_n, __execution_backend::__cuda>
7273
[[nodiscard]] _CCCL_HOST_API static _OutputIterator __par_impl(
7374
const _Policy& __policy, _InputIterator __first, _Size __count, _OutputIterator __result, _UnaryPred __pred)
7475
{
76+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
77+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
78+
7579
// We pass the policy as an environment to DeviceTransform
7680
_CCCL_TRY_CUDA_API(
7781
CUB_NS_QUALIFIER::DeviceTransform::TransformIf,
@@ -83,8 +87,6 @@ struct __pstl_dispatch<__pstl_algorithm::__copy_n, __execution_backend::__cuda>
8387
identity{},
8488
__policy);
8589

86-
// Get the stream for synchronization after the algorithm is run
87-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
8890
__stream.sync();
8991

9092
return __result + __count;
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#ifndef _CUDA_STD___PSTL_CUDA_ENSURE_CURRENT_CONTEXT_H
12+
#define _CUDA_STD___PSTL_CUDA_ENSURE_CURRENT_CONTEXT_H
13+
14+
#include <cuda/std/detail/__config>
15+
16+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
17+
# pragma GCC system_header
18+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
19+
# pragma clang system_header
20+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
21+
# pragma system_header
22+
#endif // no system header
23+
24+
#if _CCCL_HAS_BACKEND_CUDA()
25+
26+
# include <cuda/__device/device_ref.h>
27+
# include <cuda/__runtime/api_wrapper.h>
28+
# include <cuda/__runtime/ensure_current_context.h>
29+
# include <cuda/__stream/get_stream.h>
30+
# include <cuda/std/__type_traits/is_callable.h>
31+
32+
# include <cuda/std/__cccl/prologue.h>
33+
34+
_CCCL_BEGIN_NAMESPACE_CUDA_STD_EXECUTION
35+
36+
template <class _Policy>
37+
[[nodiscard]] _CCCL_HOST_API __ensure_current_context __pstl_ensure_current_ctx_for(const _Policy& __policy)
38+
{
39+
if constexpr (__is_callable_v<get_stream_t, const _Policy&>)
40+
{
41+
return __ensure_current_context{get_stream(__policy)};
42+
}
43+
else
44+
{
45+
int __curr_device{};
46+
_CCCL_TRY_CUDA_API(::cudaGetDevice, "Failed to get current device", &__curr_device);
47+
return __ensure_current_context{device_ref{__curr_device}};
48+
}
49+
}
50+
51+
_CCCL_END_NAMESPACE_CUDA_STD_EXECUTION
52+
53+
# include <cuda/std/__cccl/epilogue.h>
54+
55+
#endif /// _CCCL_HAS_BACKEND_CUDA()
56+
57+
#endif // _CUDA_STD___PSTL_CUDA_ENSURE_CURRENT_CONTEXT_H

libcudacxx/include/cuda/std/__pstl/cuda/exclusive_scan.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ _CCCL_DIAG_POP
4747
# include <cuda/std/__iterator/distance.h>
4848
# include <cuda/std/__iterator/iterator_traits.h>
4949
# include <cuda/std/__numeric/exclusive_scan.h>
50+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5051
# include <cuda/std/__pstl/cuda/temporary_storage.h>
5152
# include <cuda/std/__pstl/dispatch.h>
5253
# include <cuda/std/__type_traits/always_false.h>
@@ -70,6 +71,9 @@ struct __pstl_dispatch<__pstl_algorithm::__exclusive_scan, __execution_backend::
7071
_BinaryOp __binary_op,
7172
_Tp __init)
7273
{
74+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
75+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
76+
7377
// We pass the policy as an environment to DeviceScan
7478
_CCCL_TRY_CUDA_API(
7579
CUB_NS_QUALIFIER::DeviceScan::ExclusiveScan,
@@ -81,8 +85,6 @@ struct __pstl_dispatch<__pstl_algorithm::__exclusive_scan, __execution_backend::
8185
__count,
8286
__policy);
8387

84-
// Get the stream for synchronization after the algorithm is run
85-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
8688
__stream.sync();
8789

8890
return __result + iter_difference_t<_OutputIterator>(__count);

libcudacxx/include/cuda/std/__pstl/cuda/find_if.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ _CCCL_DIAG_POP
4848
# include <cuda/std/__iterator/distance.h>
4949
# include <cuda/std/__iterator/iterator_traits.h>
5050
# include <cuda/std/__memory/addressof.h>
51+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5152
# include <cuda/std/__pstl/cuda/temporary_storage.h>
5253
# include <cuda/std/__pstl/dispatch.h>
5354
# include <cuda/std/__type_traits/always_false.h>
@@ -67,6 +68,9 @@ struct __pstl_dispatch<__pstl_algorithm::__find_if, __execution_backend::__cuda>
6768
[[nodiscard]] _CCCL_HOST_API static _Iter
6869
__par_impl([[maybe_unused]] const _Policy& __policy, _Iter __first, _Iter __last, _UnaryOp __pred)
6970
{
71+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
72+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
73+
7074
const auto __num_items = ::cuda::std::distance(__first, __last);
7175
using _OffsetType = remove_cvref_t<decltype(__num_items)>;
7276
_OffsetType __ret;
@@ -84,9 +88,6 @@ struct __pstl_dispatch<__pstl_algorithm::__find_if, __execution_backend::__cuda>
8488
__pred,
8589
__num_items);
8690

87-
// Allocate memory for result
88-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
89-
9091
{
9192
__temporary_storage<_OffsetType> __storage{__policy, __num_bytes, 1};
9293

libcudacxx/include/cuda/std/__pstl/cuda/for_each_n.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ _CCCL_DIAG_POP
4848
# include <cuda/std/__host_stdlib/new>
4949
# include <cuda/std/__iterator/incrementable_traits.h>
5050
# include <cuda/std/__iterator/iterator_traits.h>
51+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5152
# include <cuda/std/__pstl/dispatch.h>
5253
# include <cuda/std/__type_traits/always_false.h>
5354
# include <cuda/std/__utility/convert_to_integral.h>
@@ -68,6 +69,9 @@ struct __pstl_dispatch<__pstl_algorithm::__for_each_n, __execution_backend::__cu
6869
[[nodiscard]] _CCCL_HOST_API static _Iter
6970
__par_impl([[maybe_unused]] const _Policy& __policy, _Iter __first, _Size __orig_n, _Fn __func)
7071
{
72+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
73+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
74+
7175
const auto __count = ::cuda::std::__convert_to_integral(__orig_n);
7276

7377
// We pass the policy as an environment to DeviceFor
@@ -79,8 +83,6 @@ struct __pstl_dispatch<__pstl_algorithm::__for_each_n, __execution_backend::__cu
7983
::cuda::std::move(__func),
8084
__policy);
8185

82-
// Get the stream for synchronization after the algorithm is run
83-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
8486
__stream.sync();
8587

8688
return __first + static_cast<iter_difference_t<_Iter>>(__count);

libcudacxx/include/cuda/std/__pstl/cuda/generate_n.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ _CCCL_DIAG_POP
4747
# include <cuda/std/__host_stdlib/stdexcept>
4848
# include <cuda/std/__iterator/distance.h>
4949
# include <cuda/std/__iterator/iterator_traits.h>
50+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5051
# include <cuda/std/__pstl/dispatch.h>
5152
# include <cuda/std/__type_traits/always_false.h>
5253
# include <cuda/std/__utility/move.h>
@@ -67,6 +68,9 @@ struct __pstl_dispatch<__pstl_algorithm::__generate_n, __execution_backend::__cu
6768
[[nodiscard]] _CCCL_HOST_API static _OutputIterator
6869
__par_impl(const _Policy& __policy, _OutputIterator __result, const int64_t __count, _UnaryOp __func)
6970
{
71+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
72+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
73+
7074
// We pass the policy as an environment to DeviceTransform
7175
_CCCL_TRY_CUDA_API(
7276
CUB_NS_QUALIFIER::DeviceTransform::Generate,
@@ -76,8 +80,6 @@ struct __pstl_dispatch<__pstl_algorithm::__generate_n, __execution_backend::__cu
7680
::cuda::std::move(__func),
7781
__policy);
7882

79-
// Get the stream for synchronization after the algorithm is run
80-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
8183
__stream.sync();
8284

8385
return __result + __count;

libcudacxx/include/cuda/std/__pstl/cuda/inclusive_scan.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ _CCCL_DIAG_POP
4747
# include <cuda/std/__iterator/distance.h>
4848
# include <cuda/std/__iterator/iterator_traits.h>
4949
# include <cuda/std/__numeric/inclusive_scan.h>
50+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5051
# include <cuda/std/__pstl/cuda/temporary_storage.h>
5152
# include <cuda/std/__pstl/dispatch.h>
5253
# include <cuda/std/__type_traits/always_false.h>
@@ -70,6 +71,9 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
7071
_BinaryOp __binary_op,
7172
_Tp __init)
7273
{
74+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
75+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
76+
7377
// We pass the policy as an environment to DeviceScan
7478
_CCCL_TRY_CUDA_API(
7579
CUB_NS_QUALIFIER::DeviceScan::InclusiveScanInit,
@@ -81,8 +85,6 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
8185
__count,
8286
__policy);
8387

84-
// Get the stream for synchronization after the algorithm is run
85-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
8688
__stream.sync();
8789

8890
return __result + iter_difference_t<_OutputIterator>(__count);
@@ -96,6 +98,9 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
9698
_OutputIterator __result,
9799
_BinaryOp __binary_op)
98100
{
101+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
102+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
103+
99104
_OutputIterator __ret = __result + iter_difference_t<_OutputIterator>(__count);
100105

101106
// We pass the policy as an environment to DeviceScan
@@ -108,8 +113,6 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
108113
__count,
109114
__policy);
110115

111-
// Get the stream for synchronization after the algorithm is run
112-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
113116
__stream.sync();
114117
return __ret;
115118
}

libcudacxx/include/cuda/std/__pstl/cuda/max_element.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ _CCCL_DIAG_POP
4747
# include <cuda/std/__iterator/distance.h>
4848
# include <cuda/std/__iterator/iterator_traits.h>
4949
# include <cuda/std/__memory/addressof.h>
50+
# include <cuda/std/__pstl/cuda/ensure_current_context.h>
5051
# include <cuda/std/__pstl/cuda/temporary_storage.h>
5152
# include <cuda/std/__pstl/dispatch.h>
5253
# include <cuda/std/__type_traits/always_false.h>
@@ -66,9 +67,11 @@ struct __pstl_dispatch<__pstl_algorithm::__max_element, __execution_backend::__c
6667
[[nodiscard]] _CCCL_HOST_API static _InputIterator __par_impl(
6768
[[maybe_unused]] const _Policy& __policy, _InputIterator __first, _InputIterator __last, _BinaryPred __pred)
6869
{
70+
const auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
71+
const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for(__policy);
72+
6973
size_t __ret = 0ull;
7074
const auto __count = static_cast<int64_t>(::cuda::std::distance(__first, __last));
71-
auto __stream = ::cuda::__call_or(::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
7275

7376
// Determine temporary device storage requirements for max_element
7477
size_t __num_bytes = 0;

0 commit comments

Comments
 (0)