@@ -47,6 +47,7 @@ _CCCL_DIAG_POP
4747# include < cuda/std/__iterator/distance.h>
4848# include < cuda/std/__iterator/iterator_traits.h>
4949# include < cuda/std/__numeric/inclusive_scan.h>
50+ # include < cuda/std/__pstl/cuda/ensure_current_context.h>
5051# include < cuda/std/__pstl/cuda/temporary_storage.h>
5152# include < cuda/std/__pstl/dispatch.h>
5253# include < cuda/std/__type_traits/always_false.h>
@@ -70,6 +71,9 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
7071 _BinaryOp __binary_op,
7172 _Tp __init)
7273 {
74+ const auto __stream = ::cuda::__call_or (::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
75+ const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for (__policy);
76+
7377 // We pass the policy as an environment to DeviceScan
7478 _CCCL_TRY_CUDA_API (
7579 CUB_NS_QUALIFIER::DeviceScan::InclusiveScanInit,
@@ -81,8 +85,6 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
8185 __count,
8286 __policy);
8387
84- // Get the stream for synchronization after the algorithm is run
85- auto __stream = ::cuda::__call_or (::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
8688 __stream.sync ();
8789
8890 return __result + iter_difference_t <_OutputIterator>(__count);
@@ -96,6 +98,9 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
9698 _OutputIterator __result,
9799 _BinaryOp __binary_op)
98100 {
101+ const auto __stream = ::cuda::__call_or (::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
102+ const auto __ctx = ::cuda::std::execution::__pstl_ensure_current_ctx_for (__policy);
103+
99104 _OutputIterator __ret = __result + iter_difference_t <_OutputIterator>(__count);
100105
101106 // We pass the policy as an environment to DeviceScan
@@ -108,8 +113,6 @@ struct __pstl_dispatch<__pstl_algorithm::__inclusive_scan, __execution_backend::
108113 __count,
109114 __policy);
110115
111- // Get the stream for synchronization after the algorithm is run
112- auto __stream = ::cuda::__call_or (::cuda::get_stream, ::cuda::stream_ref{cudaStream_t{}}, __policy);
113116 __stream.sync ();
114117 return __ret;
115118 }
0 commit comments