1919#include < string>
2020#include < unordered_map>
2121
22- // This file is compiled into a separate library that is dynamically loaded with LD_PRELOAD at
23- // runtime to libcudf to override some stream-related symbols in libcudf. The goal of such a library
24- // is to verify if the stream/stream pool is being correctly forwarded between API calls.
22+ // This file is compiled into a separate library that is statically linked to tests to
23+ // override some stream-related symbols in libcudf. The goal of such a library is to
24+ // verify if the stream/stream pool is being correctly forwarded between API calls.
2525//
2626// We control whether to override cudf::test::get_default_stream or
2727// cudf::get_default_stream with a compile-time flag. The behaviour of tests
4444
4545namespace cudf {
4646
47- #ifdef STREAM_MODE_TESTING
48- namespace test {
47+ namespace detail {
48+
49+ #if defined(CUDF_USE_PER_THREAD_DEFAULT_STREAM)
50+ rmm::cuda_stream_view const default_stream_value{rmm::cuda_stream_per_thread};
51+ #else
52+ rmm::cuda_stream_view const default_stream_value{};
4953#endif
5054
55+ } // namespace detail
56+
5157rmm::cuda_stream_view const get_default_stream ()
5258{
59+ #ifdef STREAM_MODE_TESTING
60+ static auto const default_stream = []() {
61+ if (std::getenv (" CUDF_PER_THREAD_STREAM" ) != nullptr ) {
62+ return rmm::cuda_stream_per_thread;
63+ } else {
64+ return detail::default_stream_value;
65+ }
66+ }();
67+ return default_stream;
68+ #else
5369 static rmm::cuda_stream stream{};
5470 return stream;
71+ #endif
5572}
5673
74+ namespace test {
75+
76+ rmm::cuda_stream_view const get_default_stream ()
77+ {
5778#ifdef STREAM_MODE_TESTING
58- } // namespace test
79+ static rmm::cuda_stream stream{};
80+ return stream;
81+ #else
82+ return cudf::get_default_stream ();
5983#endif
84+ }
85+
86+ } // namespace test
6087
6188#ifdef STREAM_MODE_TESTING
6289namespace detail {
@@ -119,15 +146,6 @@ void check_stream_and_error(cudaStream_t stream)
119146 }
120147}
121148
122- /* *
123- * @brief Container for CUDA APIs that have been overloaded using DEFINE_OVERLOAD.
124- *
125- * This variable must be initialized before everything else.
126- *
127- * @see find_originals for a description of the priorities
128- */
129- __attribute__ ((init_priority(1001 ))) std::unordered_map<std::string, void*> originals;
130-
131149/* *
132150 * @brief Macro for generating functions to override existing CUDA functions.
133151 *
@@ -145,15 +163,12 @@ __attribute__((init_priority(1001))) std::unordered_map<std::string, void*> orig
145163 * @param signature The function signature (must include names, not just types).
146164 * @parameter arguments The function arguments (names only, no types).
147165 */
148- #define DEFINE_OVERLOAD (function, signature, arguments ) \
149- using function##_t = cudaError_t (*)(signature); \
150- \
151- cudaError_t function (signature) \
152- { \
153- check_stream_and_error (stream); \
154- return ((function##_t)originals[#function])(arguments); \
155- } \
156- __attribute__ ((constructor(1002 ))) void queue_##function() { originals[#function] = nullptr ; }
166+ #define DEFINE_OVERLOAD (function, signature, arguments ) \
167+ extern " C" cudaError_t cudf_##function(signature) \
168+ { \
169+ check_stream_and_error (stream); \
170+ return function (arguments); \
171+ }
157172
158173/* *
159174 * @brief Helper macro to define macro arguments that contain a comma.
@@ -177,6 +192,8 @@ __attribute__((init_priority(1001))) std::unordered_map<std::string, void*> orig
177192 - https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__INTEROP.html#group__CUDART__INTEROP
178193 - https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH
179194 - https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__HIGHLEVEL.html#group__CUDART__HIGHLEVEL
195+
196+ This list must be kept in sync with cpp/cmake/fixup_target_for_stream_testing.py
180197 */
181198// clang-format on
182199
@@ -335,18 +352,3 @@ DEFINE_OVERLOAD(cudaMallocAsync,
335352DEFINE_OVERLOAD (cudaMallocFromPoolAsync,
336353 ARG (void ** ptr, size_t size, cudaMemPool_t memPool, cudaStream_t stream),
337354 ARG(ptr, size, memPool, stream));
338-
339- /* *
340- * @brief Function to collect all the original CUDA symbols corresponding to overloaded functions.
341- *
342- * Note on priorities:
343- * - `originals` must be initialized first, so it is 1001.
344- * - The function names must be added to originals next in the macro, so those are 1002.
345- * - Finally, this function actually finds the original symbols so it is 1003.
346- */
347- __attribute__ ((constructor(1003 ))) void find_originals()
348- {
349- for (auto it : originals) {
350- originals[it.first ] = dlsym (RTLD_NEXT, it.first .data ());
351- }
352- }
0 commit comments