11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2023-2025 , NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2023-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55
@@ -165,7 +165,7 @@ TEST_F(JitParseTest, PTXWithPragmaWithSpaces)
165165
166166 std::string expected = R"(
167167__device__ __inline__ void LongKernel(){
168- asm volatile ("{"); asm volatile (" $L__BB0_58: cvt.u32.u32 _ %0, [_ rd419 + 80];": : "r"(r1419));
168+ asm volatile ("{"); asm volatile (" $L__BB0_58: cvt.u32.u32 _ %0, [_ rd419 + 80];": : "r"( *reinterpret_cast<int *>(& r1419) ));
169169 /** $L__BB0_58:
170170 ld.param.u32 % r1419, [% rd419 + 80] */
171171
@@ -181,7 +181,7 @@ __device__ __inline__ void LongKernel(){
181181 asm volatile (" @ _ p394 bra $L__BB0_380;");
182182 /** @ % p394 bra $L__BB0_380 */
183183
184- asm volatile (" cvt.u8.u8 _ %0, [_ rd419 + 208];": : "h"( static_cast<short>(rs1369) ));
184+ asm volatile (" cvt.u8.u8 _ %0, [_ rd419 + 208];": : "h"( static_cast<short>( rs1369 ) ));
185185 /** ld.param.u8 % rs1369, [% rd419 + 208] */
186186
187187 asm volatile (" setp.eq.s16 _ p395, _ rs1369, 0;");
@@ -190,13 +190,13 @@ __device__ __inline__ void LongKernel(){
190190 asm volatile (" selp.b32 _ r1422, _ r1925, 0, _ p395;");
191191 /** selp.b32 % r1422, % r1925, 0, % p395 */
192192
193- asm volatile (" cvt.u32.u32 _ %0, [_ rd419 + 112];": : "r"(r1423));
193+ asm volatile (" cvt.u32.u32 _ %0, [_ rd419 + 112];": : "r"( *reinterpret_cast<int *>(& r1423) ));
194194 /** ld.param.u32 % r1423, [% rd419 + 112] */
195195
196196 asm volatile (" add.s32 _ r427, _ r1422, _ r1423;");
197197 /** add.s32 % r427, % r1422, % r1423 */
198198
199- asm volatile (" mov.u64 _ %0, [_ rd419 + 120];": : "l"(rd1249));
199+ asm volatile (" mov.u64 _ %0, [_ rd419 + 120];": : "l"( *reinterpret_cast<long long int *>(& rd1249) ));
200200 /** ld.param.u64 % rd1249, [% rd419 + 120] */
201201
202202 asm volatile (" cvta.to.global.u64 _ rd1250, _ rd1249;");
@@ -227,4 +227,126 @@ __device__ __inline__ void LongKernel(){
227227 EXPECT_TRUE (ptx_equal (cuda_source, expected));
228228}
229229
230+ // test that an ld.param instruction that doesn't contain the exact semantic type
231+ // is still parsed correctly. This is important because NVVM IR doesn't always use the exact
232+ // semantic type in the ld.param instruction. For example, it may use `ld.param.u8` to load a `char`
233+ // parameter, which is semantically correct but doesn't match the expected `ld.param.s8`.
234+ TEST_F (JitParseTest, PTXWithUntypedLdParam)
235+ {
236+ //
237+ // Generated from NUMBA using:
238+ //
239+ // """py
240+ //
241+ // from numba import cuda, float64
242+ // from numba.cuda import compile_ptx_for_current_device
243+ //
244+ // @cuda.jit(device=True)
245+ // def op(a, b, c):
246+ // return (a + b) *c
247+ //
248+ // ptx, _ = cuda.compile_ptx_for_current_device(op, (float64, float64, float64), device=True,
249+ // abi="c")
250+ //
251+ // print(ptx)
252+ //
253+ // """
254+ //
255+ auto raw_ptx = R"***(
256+ //
257+ // Generated by NVIDIA NVVM Compiler
258+ //
259+ // Compiler Build ID: CL-37061995
260+ // Cuda compilation tools, release 13.1, V13.1.115
261+ // Based on NVVM 7.0.1
262+ //
263+
264+
265+ .visible .func (.param .b32 func_retval0) mad(
266+ .param .b64 param_0,
267+ .param .b64 param_1,
268+ .param .b64 param_2,
269+ .param .b64 param_3
270+ )
271+ {
272+ .reg .b32 %r<2>;
273+ .reg .f64 %fd<6>;
274+ .reg .b64 %rd<2>;
275+
276+
277+ ld.param.u64 %rd1, [param_0];
278+ ld.param.f64 %fd1, [param_1];
279+ ld.param.f64 %fd2, [param_2];
280+ ld.param.f64 %fd3, [param_3];
281+ add.f64 %fd4, %fd1, %fd2;
282+ mul.f64 %fd5, %fd4, %fd3;
283+ st.f64 [%rd1], %fd5;
284+ mov.u32 %r1, 0;
285+ st.param.b32 [func_retval0+0], %r1;
286+ ret;
287+
288+ }
289+ )***" ;
290+
291+ std::string expected = R"***(
292+ __device__ __inline__ void GENERIC_OP(double * param_0,
293+ double param_1,
294+ double param_2,
295+ double param_3){
296+
297+ asm volatile ("{"); asm volatile (" .reg .b32 _r<2>;");
298+ /** .reg .b32 %r<2> */
299+
300+ asm volatile (" .reg .f64 _fd<6>;");
301+ /** .reg .f64 %fd<6> */
302+
303+ asm volatile (" .reg .b64 _rd<2>;");
304+ /** .reg .b64 %rd<2> */
305+
306+ asm volatile (" mov.u64 _rd1, %0;": : "l"( *reinterpret_cast<long long int *>(¶m_0) ));
307+ /** ld.param.u64 %rd1, [param_0] */
308+
309+ asm volatile (" mov.f64 _fd1, %0;": : "d"( *reinterpret_cast<double *>(¶m_1) ));
310+ /** ld.param.f64 %fd1, [param_1] */
311+
312+ asm volatile (" mov.f64 _fd2, %0;": : "d"( *reinterpret_cast<double *>(¶m_2) ));
313+ /** ld.param.f64 %fd2, [param_2] */
314+
315+ asm volatile (" mov.f64 _fd3, %0;": : "d"( *reinterpret_cast<double *>(¶m_3) ));
316+ /** ld.param.f64 %fd3, [param_3] */
317+
318+ asm volatile (" add.f64 _fd4, _fd1, _fd2;");
319+ /** add.f64 %fd4, %fd1, %fd2 */
320+
321+ asm volatile (" mul.f64 _fd5, _fd4, _fd3;");
322+ /** mul.f64 %fd5, %fd4, %fd3 */
323+
324+ asm volatile (" st.f64 [_rd1], _fd5;");
325+ /** st.f64 [%rd1], %fd5 */
326+
327+ asm volatile (" mov.u32 _r1, 0;");
328+ /** mov.u32 %r1, 0 */
329+
330+ asm volatile (" /** *** The way we parse the CUDA PTX assumes the function returns the return value through the first function parameter. Thus the `st.param.***` instructions are not processed. *** */");
331+ /** st.param.b32 [func_retval0+0], %r1 */
332+
333+ asm volatile (" bra RETTGT;");
334+ /** ret */
335+
336+
337+
338+ asm volatile ("RETTGT:}");}
339+ )***" ;
340+
341+ std::string cuda_source = cudf::jit::parse_single_function_ptx (raw_ptx,
342+ " GENERIC_OP" ,
343+ {
344+ {0 , " double *" },
345+ {1 , " double" },
346+ {2 , " double" },
347+ {3 , " double" },
348+ });
349+ EXPECT_TRUE (ptx_equal (cuda_source, expected));
350+ }
351+
230352CUDF_TEST_PROGRAM_MAIN ()
0 commit comments