Skip to content

Commit 232fffd

Browse files
authored
Merge pull request rapidsai#21776 from rapidsai/release/26.04
Forward-merge release/26.04 into main
2 parents 95652fc + d1e8731 commit 232fffd

4 files changed

Lines changed: 162 additions & 24 deletions

File tree

.github/workflows/pr.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,6 @@ jobs:
371371
with:
372372
build_type: pull-request
373373
script: "ci/test_python_cudf.sh"
374-
# Skip failing tests on RTX PRO 6000 (Blackwell). xref: https://github.com/rapidsai/cudf/issues/21357
375-
matrix_filter: map(select(.GPU != "rtxpro6000"))
376374
conda-python-other-tests:
377375
# Tests for dask_cudf, custreamz, cudf_kafka are separated for CI parallelism
378376
needs: [conda-python-build-noarch, changed-files]
@@ -459,8 +457,6 @@ jobs:
459457
with:
460458
build_type: pull-request
461459
script: ci/test_wheel_cudf.sh
462-
# Skip failing tests on RTX PRO 6000 (Blackwell). xref: https://github.com/rapidsai/cudf/issues/21357
463-
matrix_filter: map(select(.GPU != "rtxpro6000"))
464460
wheel-build-cudf-polars:
465461
needs: wheel-build-pylibcudf
466462
secrets: inherit

.github/workflows/test.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ jobs:
7070
date: ${{ inputs.date }}
7171
sha: ${{ inputs.sha }}
7272
script: "ci/test_python_cudf.sh"
73-
# Skip failing tests on RTX PRO 6000 (Blackwell). xref: https://github.com/rapidsai/cudf/issues/21357
74-
matrix_filter: map(select(.GPU != "rtxpro6000"))
7573
conda-python-other-tests:
7674
# Tests for dask_cudf, custreamz, cudf_kafka are separated for CI parallelism
7775
secrets: inherit
@@ -115,8 +113,6 @@ jobs:
115113
date: ${{ inputs.date }}
116114
sha: ${{ inputs.sha }}
117115
script: ci/test_wheel_cudf.sh
118-
# Skip failing tests on RTX PRO 6000 (Blackwell). xref: https://github.com/rapidsai/cudf/issues/21357
119-
matrix_filter: map(select(.GPU != "rtxpro6000"))
120116
wheel-tests-dask-cudf:
121117
secrets: inherit
122118
uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@main

cpp/src/jit/parser.cpp

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -10,6 +10,7 @@
1010
#include <thrust/iterator/counting_iterator.h>
1111

1212
#include <algorithm>
13+
#include <format>
1314
#include <numeric>
1415
#include <set>
1516
#include <string>
@@ -101,7 +102,7 @@ std::string ptx_parser::register_type_to_cpp_type(std::string const& register_ty
101102
else if (register_type == ".f16x2")
102103
return "half2";
103104
else if (register_type == ".b64" || register_type == ".s64" || register_type == ".u64")
104-
return "long int";
105+
return "long long int";
105106
else if (register_type == ".f32")
106107
return "float";
107108
else if (register_type == ".f64")
@@ -110,6 +111,24 @@ std::string ptx_parser::register_type_to_cpp_type(std::string const& register_ty
110111
return "x_cpptype";
111112
}
112113

114+
int32_t get_register_size(std::string_view register_type)
115+
{
116+
if (register_type == ".b8" || register_type == ".s8" || register_type == ".u8") {
117+
return 8;
118+
} else if (register_type == ".b16" || register_type == ".s16" || register_type == ".u16" ||
119+
register_type == ".f16") {
120+
return 16;
121+
} else if (register_type == ".b32" || register_type == ".s32" || register_type == ".u32" ||
122+
register_type == ".f32" || register_type == ".f16x2") {
123+
return 32;
124+
} else if (register_type == ".b64" || register_type == ".s64" || register_type == ".u64" ||
125+
register_type == ".f64") {
126+
return 64;
127+
} else {
128+
CUDF_FAIL("Unknown register type: " + std::string(register_type));
129+
}
130+
}
131+
113132
std::string ptx_parser::parse_instruction(std::string const& src)
114133
{
115134
// I am assuming for an instruction statement the starting phrase is an
@@ -127,10 +146,8 @@ std::string ptx_parser::parse_instruction(std::string const& src)
127146
bool is_instruction = true;
128147
bool is_pragma_instruction = false;
129148
bool is_param_loading_instruction = false;
130-
std::string constraint;
131149
std::string register_type;
132150
bool blank = true;
133-
std::string cpp_typename;
134151
while (stop < length) {
135152
while (start < length && (is_white(src[start]) || src[start] == ',' || src[start] == '{' ||
136153
src[start] == '}')) { // running to the first non-white character.
@@ -169,8 +186,7 @@ std::string ptx_parser::parse_instruction(std::string const& src)
169186
is_param_loading_instruction = true;
170187
register_type = std::string(piece, 8, stop - 8);
171188
// This is the ld.param sentence
172-
cpp_typename = register_type_to_cpp_type(register_type);
173-
if (cpp_typename == "int" || cpp_typename == "short int" || cpp_typename == "char") {
189+
if (get_register_size(register_type) < 64) {
174190
// The trick to support `ld` statement whose destination reg. wider than
175191
// the instruction width, e.g.
176192
//
@@ -186,7 +202,6 @@ std::string ptx_parser::parse_instruction(std::string const& src)
186202
} else {
187203
output += " mov" + register_type;
188204
}
189-
constraint = register_type_to_contraint(register_type);
190205
} else if (piece.find("st.param") != std::string::npos) {
191206
return "asm volatile (\"" + output +
192207
"/** *** The way we parse the CUDA PTX assumes the function returns the return "
@@ -207,11 +222,20 @@ std::string ptx_parser::parse_instruction(std::string const& src)
207222
if (piece_count == 2 && is_param_loading_instruction) {
208223
// This is the source of the parameter loading instruction
209224
output += " %0";
210-
if (cpp_typename == "char") {
211-
suffix = ": : \"" + constraint + "\"( static_cast<short>(" +
212-
remove_nonalphanumeric(piece) + "))";
225+
226+
auto constraint = register_type_to_contraint(register_type);
227+
auto cpp_typename = register_type_to_cpp_type(register_type);
228+
229+
// there's no 8-bit register size constraint in PTX, so we widen the argument to 16-bits
230+
if (get_register_size(register_type) == 8) {
231+
suffix = std::format(
232+
": : \"{}\"( static_cast<short>( {} ) )", constraint, remove_nonalphanumeric(piece));
213233
} else {
214-
suffix = ": : \"" + constraint + "\"(" + remove_nonalphanumeric(piece) + ")";
234+
// normal case
235+
suffix = std::format(": : \"{}\"( *reinterpret_cast<{} *>(&{}) )",
236+
constraint,
237+
cpp_typename,
238+
remove_nonalphanumeric(piece));
215239
}
216240
} else if (is_pragma_instruction) {
217241
// quote any string

cpp/tests/jit/parse_ptx_function.cpp

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -165,7 +165,7 @@ TEST_F(JitParseTest, PTXWithPragmaWithSpaces)
165165

166166
std::string expected = R"(
167167
__device__ __inline__ void LongKernel(){
168-
asm volatile ("{"); asm volatile (" $L__BB0_58: cvt.u32.u32 _ %0, [_ rd419 + 80];": : "r"(r1419));
168+
asm volatile ("{"); asm volatile (" $L__BB0_58: cvt.u32.u32 _ %0, [_ rd419 + 80];": : "r"( *reinterpret_cast<int *>(&r1419) ));
169169
/** $L__BB0_58:
170170
ld.param.u32 % r1419, [% rd419 + 80] */
171171
@@ -181,7 +181,7 @@ __device__ __inline__ void LongKernel(){
181181
asm volatile (" @ _ p394 bra $L__BB0_380;");
182182
/** @ % p394 bra $L__BB0_380 */
183183
184-
asm volatile (" cvt.u8.u8 _ %0, [_ rd419 + 208];": : "h"( static_cast<short>(rs1369)));
184+
asm volatile (" cvt.u8.u8 _ %0, [_ rd419 + 208];": : "h"( static_cast<short>( rs1369 ) ));
185185
/** ld.param.u8 % rs1369, [% rd419 + 208] */
186186
187187
asm volatile (" setp.eq.s16 _ p395, _ rs1369, 0;");
@@ -190,13 +190,13 @@ __device__ __inline__ void LongKernel(){
190190
asm volatile (" selp.b32 _ r1422, _ r1925, 0, _ p395;");
191191
/** selp.b32 % r1422, % r1925, 0, % p395 */
192192
193-
asm volatile (" cvt.u32.u32 _ %0, [_ rd419 + 112];": : "r"(r1423));
193+
asm volatile (" cvt.u32.u32 _ %0, [_ rd419 + 112];": : "r"( *reinterpret_cast<int *>(&r1423) ));
194194
/** ld.param.u32 % r1423, [% rd419 + 112] */
195195
196196
asm volatile (" add.s32 _ r427, _ r1422, _ r1423;");
197197
/** add.s32 % r427, % r1422, % r1423 */
198198
199-
asm volatile (" mov.u64 _ %0, [_ rd419 + 120];": : "l"(rd1249));
199+
asm volatile (" mov.u64 _ %0, [_ rd419 + 120];": : "l"( *reinterpret_cast<long long int *>(&rd1249) ));
200200
/** ld.param.u64 % rd1249, [% rd419 + 120] */
201201
202202
asm volatile (" cvta.to.global.u64 _ rd1250, _ rd1249;");
@@ -227,4 +227,126 @@ __device__ __inline__ void LongKernel(){
227227
EXPECT_TRUE(ptx_equal(cuda_source, expected));
228228
}
229229

230+
// test that an ld.param instruction that doesn't contain the exact semantic type
231+
// is still parsed correctly. This is important because NVVM IR doesn't always use the exact
232+
// semantic type in the ld.param instruction. For example, it may use `ld.param.u8` to load a `char`
233+
// parameter, which is semantically correct but doesn't match the expected `ld.param.s8`.
234+
TEST_F(JitParseTest, PTXWithUntypedLdParam)
235+
{
236+
//
237+
// Generated from NUMBA using:
238+
//
239+
// """py
240+
//
241+
// from numba import cuda, float64
242+
// from numba.cuda import compile_ptx_for_current_device
243+
//
244+
// @cuda.jit(device=True)
245+
// def op(a, b, c):
246+
// return (a + b) *c
247+
//
248+
// ptx, _ = cuda.compile_ptx_for_current_device(op, (float64, float64, float64), device=True,
249+
// abi="c")
250+
//
251+
// print(ptx)
252+
//
253+
// """
254+
//
255+
auto raw_ptx = R"***(
256+
//
257+
// Generated by NVIDIA NVVM Compiler
258+
//
259+
// Compiler Build ID: CL-37061995
260+
// Cuda compilation tools, release 13.1, V13.1.115
261+
// Based on NVVM 7.0.1
262+
//
263+
264+
265+
.visible .func (.param .b32 func_retval0) mad(
266+
.param .b64 param_0,
267+
.param .b64 param_1,
268+
.param .b64 param_2,
269+
.param .b64 param_3
270+
)
271+
{
272+
.reg .b32 %r<2>;
273+
.reg .f64 %fd<6>;
274+
.reg .b64 %rd<2>;
275+
276+
277+
ld.param.u64 %rd1, [param_0];
278+
ld.param.f64 %fd1, [param_1];
279+
ld.param.f64 %fd2, [param_2];
280+
ld.param.f64 %fd3, [param_3];
281+
add.f64 %fd4, %fd1, %fd2;
282+
mul.f64 %fd5, %fd4, %fd3;
283+
st.f64 [%rd1], %fd5;
284+
mov.u32 %r1, 0;
285+
st.param.b32 [func_retval0+0], %r1;
286+
ret;
287+
288+
}
289+
)***";
290+
291+
std::string expected = R"***(
292+
__device__ __inline__ void GENERIC_OP(double * param_0,
293+
double param_1,
294+
double param_2,
295+
double param_3){
296+
297+
asm volatile ("{"); asm volatile (" .reg .b32 _r<2>;");
298+
/** .reg .b32 %r<2> */
299+
300+
asm volatile (" .reg .f64 _fd<6>;");
301+
/** .reg .f64 %fd<6> */
302+
303+
asm volatile (" .reg .b64 _rd<2>;");
304+
/** .reg .b64 %rd<2> */
305+
306+
asm volatile (" mov.u64 _rd1, %0;": : "l"( *reinterpret_cast<long long int *>(&param_0) ));
307+
/** ld.param.u64 %rd1, [param_0] */
308+
309+
asm volatile (" mov.f64 _fd1, %0;": : "d"( *reinterpret_cast<double *>(&param_1) ));
310+
/** ld.param.f64 %fd1, [param_1] */
311+
312+
asm volatile (" mov.f64 _fd2, %0;": : "d"( *reinterpret_cast<double *>(&param_2) ));
313+
/** ld.param.f64 %fd2, [param_2] */
314+
315+
asm volatile (" mov.f64 _fd3, %0;": : "d"( *reinterpret_cast<double *>(&param_3) ));
316+
/** ld.param.f64 %fd3, [param_3] */
317+
318+
asm volatile (" add.f64 _fd4, _fd1, _fd2;");
319+
/** add.f64 %fd4, %fd1, %fd2 */
320+
321+
asm volatile (" mul.f64 _fd5, _fd4, _fd3;");
322+
/** mul.f64 %fd5, %fd4, %fd3 */
323+
324+
asm volatile (" st.f64 [_rd1], _fd5;");
325+
/** st.f64 [%rd1], %fd5 */
326+
327+
asm volatile (" mov.u32 _r1, 0;");
328+
/** mov.u32 %r1, 0 */
329+
330+
asm volatile (" /** *** The way we parse the CUDA PTX assumes the function returns the return value through the first function parameter. Thus the `st.param.***` instructions are not processed. *** */");
331+
/** st.param.b32 [func_retval0+0], %r1 */
332+
333+
asm volatile (" bra RETTGT;");
334+
/** ret */
335+
336+
337+
338+
asm volatile ("RETTGT:}");}
339+
)***";
340+
341+
std::string cuda_source = cudf::jit::parse_single_function_ptx(raw_ptx,
342+
"GENERIC_OP",
343+
{
344+
{0, "double *"},
345+
{1, "double"},
346+
{2, "double"},
347+
{3, "double"},
348+
});
349+
EXPECT_TRUE(ptx_equal(cuda_source, expected));
350+
}
351+
230352
CUDF_TEST_PROGRAM_MAIN()

0 commit comments

Comments
 (0)