@@ -60,20 +60,20 @@ def test_object_code_init_disabled():
6060
6161
6262@pytest .fixture (scope = "function" )
63- def get_saxpy_kernel (init_cuda ):
63+ def get_saxpy_kernel_cubin (init_cuda ):
6464 # prepare program
6565 prog = Program (SAXPY_KERNEL , code_type = "c++" )
6666 mod = prog .compile (
6767 "cubin" ,
6868 name_expressions = ("saxpy<float>" , "saxpy<double>" ),
6969 )
70-
7170 # run in single precision
7271 return mod .get_kernel ("saxpy<float>" ), mod
7372
7473
7574@pytest .fixture (scope = "function" )
7675def get_saxpy_kernel_ptx (init_cuda ):
76+ # prepare program
7777 prog = Program (SAXPY_KERNEL , code_type = "c++" )
7878 mod = prog .compile (
7979 "ptx" ,
@@ -84,12 +84,10 @@ def get_saxpy_kernel_ptx(init_cuda):
8484
8585
8686@pytest .fixture (scope = "function" )
87- def get_saxpy_object_code (init_cuda ):
88- prog = Program (SAXPY_KERNEL , code_type = "c++" )
89- mod = prog .compile (
90- "cubin" ,
91- name_expressions = ("saxpy<float>" , "saxpy<double>" ),
92- )
87+ def get_saxpy_kernel_ltoir (init_cuda ):
88+ # Create LTOIR code using link-time optimization
89+ prog = Program (SAXPY_KERNEL , code_type = "c++" , options = ProgramOptions (link_time_optimization = True ))
90+ mod = prog .compile ("ltoir" , name_expressions = ("saxpy<float>" , "saxpy<double>" ))
9391 return mod
9492
9593
@@ -129,8 +127,8 @@ def test_get_kernel(init_cuda):
129127 ("cluster_scheduling_policy_preference" , int ),
130128 ],
131129)
132- def test_read_only_kernel_attributes (get_saxpy_kernel , attr , expected_type ):
133- kernel , _ = get_saxpy_kernel
130+ def test_read_only_kernel_attributes (get_saxpy_kernel_cubin , attr , expected_type ):
131+ kernel , _ = get_saxpy_kernel_cubin
134132 method = getattr (kernel .attributes , attr )
135133 # get the value without providing a device ordinal
136134 value = method ()
@@ -142,16 +140,6 @@ def test_read_only_kernel_attributes(get_saxpy_kernel, attr, expected_type):
142140 assert isinstance (value , expected_type ), f"Expected { attr } to be of type { expected_type } , but got { type (value )} "
143141
144142
145- def test_object_code_load_cubin (get_saxpy_kernel ):
146- _ , mod = get_saxpy_kernel
147- cubin = mod ._module
148- sym_map = mod ._sym_map
149- assert isinstance (cubin , bytes )
150- mod = ObjectCode .from_cubin (cubin , symbol_mapping = sym_map )
151- assert mod .code == cubin
152- mod .get_kernel ("saxpy<double>" ) # force loading
153-
154-
155143def test_object_code_load_ptx (get_saxpy_kernel_ptx ):
156144 ptx , mod = get_saxpy_kernel_ptx
157145 sym_map = mod ._sym_map
@@ -162,8 +150,32 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx):
162150 mod_obj .get_kernel ("saxpy<double>" ) # force loading
163151
164152
165- def test_object_code_load_cubin_from_file (get_saxpy_kernel , tmp_path ):
166- _ , mod = get_saxpy_kernel
153+ def test_object_code_load_ptx_from_file (get_saxpy_kernel_ptx , tmp_path ):
154+ ptx , mod = get_saxpy_kernel_ptx
155+ sym_map = mod ._sym_map
156+ assert isinstance (ptx , bytes )
157+ ptx_file = tmp_path / "test.ptx"
158+ ptx_file .write_bytes (ptx )
159+ mod_obj = ObjectCode .from_ptx (str (ptx_file ), symbol_mapping = sym_map )
160+ assert mod_obj .code == str (ptx_file )
161+ assert mod_obj .code_type == "ptx"
162+ if not Program ._can_load_generated_ptx ():
163+ pytest .skip ("PTX version too new for current driver" )
164+ mod_obj .get_kernel ("saxpy<double>" ) # force loading
165+
166+
167+ def test_object_code_load_cubin (get_saxpy_kernel_cubin ):
168+ _ , mod = get_saxpy_kernel_cubin
169+ cubin = mod ._module
170+ sym_map = mod ._sym_map
171+ assert isinstance (cubin , bytes )
172+ mod = ObjectCode .from_cubin (cubin , symbol_mapping = sym_map )
173+ assert mod .code == cubin
174+ mod .get_kernel ("saxpy<double>" ) # force loading
175+
176+
177+ def test_object_code_load_cubin_from_file (get_saxpy_kernel_cubin , tmp_path ):
178+ _ , mod = get_saxpy_kernel_cubin
167179 cubin = mod ._module
168180 sym_map = mod ._sym_map
169181 assert isinstance (cubin , bytes )
@@ -174,13 +186,42 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
174186 mod .get_kernel ("saxpy<double>" ) # force loading
175187
176188
177- def test_object_code_handle (get_saxpy_object_code ):
178- mod = get_saxpy_object_code
189+ def test_object_code_handle (get_saxpy_kernel_cubin ):
190+ _ , mod = get_saxpy_kernel_cubin
179191 assert mod .handle is not None
180192
181193
182- def test_saxpy_arguments (get_saxpy_kernel , cuda12_4_prerequisite_check ):
183- krn , _ = get_saxpy_kernel
194+ def test_object_code_load_ltoir (get_saxpy_kernel_ltoir ):
195+ mod = get_saxpy_kernel_ltoir
196+ ltoir = mod ._module
197+ sym_map = mod ._sym_map
198+ assert isinstance (ltoir , bytes )
199+ mod_obj = ObjectCode .from_ltoir (ltoir , symbol_mapping = sym_map )
200+ assert mod_obj .code == ltoir
201+ assert mod_obj .code_type == "ltoir"
202+ # ltoir doesn't support kernel retrieval directly as it's used for linking
203+ assert mod_obj ._handle is None
204+ # Test that get_kernel fails for unsupported code type
205+ with pytest .raises (RuntimeError , match = r'Unsupported code type "ltoir"' ):
206+ mod_obj .get_kernel ("saxpy<float>" )
207+
208+
209+ def test_object_code_load_ltoir_from_file (get_saxpy_kernel_ltoir , tmp_path ):
210+ mod = get_saxpy_kernel_ltoir
211+ ltoir = mod ._module
212+ sym_map = mod ._sym_map
213+ assert isinstance (ltoir , bytes )
214+ ltoir_file = tmp_path / "test.ltoir"
215+ ltoir_file .write_bytes (ltoir )
216+ mod_obj = ObjectCode .from_ltoir (str (ltoir_file ), symbol_mapping = sym_map )
217+ assert mod_obj .code == str (ltoir_file )
218+ assert mod_obj .code_type == "ltoir"
219+ # ltoir doesn't support kernel retrieval directly as it's used for linking
220+ assert mod_obj ._handle is None
221+
222+
223+ def test_saxpy_arguments (get_saxpy_kernel_cubin , cuda12_4_prerequisite_check ):
224+ krn , _ = get_saxpy_kernel_cubin
184225
185226 if cuda12_4_prerequisite_check :
186227 assert krn .num_arguments == 5
@@ -258,8 +299,8 @@ def test_num_args_error_handling(deinit_all_contexts_function, cuda12_4_prerequi
258299
259300@pytest .mark .parametrize ("block_size" , [32 , 64 , 96 , 120 , 128 , 256 ])
260301@pytest .mark .parametrize ("smem_size_per_block" , [0 , 32 , 4096 ])
261- def test_occupancy_max_active_block_per_multiprocessor (get_saxpy_kernel , block_size , smem_size_per_block ):
262- kernel , _ = get_saxpy_kernel
302+ def test_occupancy_max_active_block_per_multiprocessor (get_saxpy_kernel_cubin , block_size , smem_size_per_block ):
303+ kernel , _ = get_saxpy_kernel_cubin
263304 dev_props = Device ().properties
264305 assert block_size <= dev_props .max_threads_per_block
265306 assert smem_size_per_block <= dev_props .max_shared_memory_per_block
@@ -275,9 +316,9 @@ def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, block_s
275316
276317@pytest .mark .parametrize ("block_size_limit" , [32 , 64 , 96 , 120 , 128 , 256 , 0 ])
277318@pytest .mark .parametrize ("smem_size_per_block" , [0 , 32 , 4096 ])
278- def test_occupancy_max_potential_block_size_constant (get_saxpy_kernel , block_size_limit , smem_size_per_block ):
319+ def test_occupancy_max_potential_block_size_constant (get_saxpy_kernel_cubin , block_size_limit , smem_size_per_block ):
279320 """Tests use case when shared memory needed is independent on the block size"""
280- kernel , _ = get_saxpy_kernel
321+ kernel , _ = get_saxpy_kernel_cubin
281322 dev_props = Device ().properties
282323 assert block_size_limit <= dev_props .max_threads_per_block
283324 assert smem_size_per_block <= dev_props .max_shared_memory_per_block
@@ -302,9 +343,9 @@ def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel, block_siz
302343
303344@pytest .mark .skipif (numba is None , reason = "Test requires numba to be installed" )
304345@pytest .mark .parametrize ("block_size_limit" , [32 , 64 , 96 , 120 , 128 , 277 , 0 ])
305- def test_occupancy_max_potential_block_size_b2dsize (get_saxpy_kernel , block_size_limit ):
346+ def test_occupancy_max_potential_block_size_b2dsize (get_saxpy_kernel_cubin , block_size_limit ):
306347 """Tests use case when shared memory needed depends on the block size"""
307- kernel , _ = get_saxpy_kernel
348+ kernel , _ = get_saxpy_kernel_cubin
308349
309350 def shared_memory_needed (block_size : numba .intc ) -> numba .size_t :
310351 "Size of dynamic shared memory needed by kernel of this block size"
@@ -329,8 +370,8 @@ def shared_memory_needed(block_size: numba.intc) -> numba.size_t:
329370
330371
331372@pytest .mark .parametrize ("num_blocks_per_sm, block_size" , [(4 , 32 ), (2 , 64 ), (2 , 96 ), (3 , 120 ), (2 , 128 ), (1 , 256 )])
332- def test_occupancy_available_dynamic_shared_memory_per_block (get_saxpy_kernel , num_blocks_per_sm , block_size ):
333- kernel , _ = get_saxpy_kernel
373+ def test_occupancy_available_dynamic_shared_memory_per_block (get_saxpy_kernel_cubin , num_blocks_per_sm , block_size ):
374+ kernel , _ = get_saxpy_kernel_cubin
334375 dev_props = Device ().properties
335376 assert block_size <= dev_props .max_threads_per_block
336377 assert num_blocks_per_sm * block_size <= dev_props .max_threads_per_multiprocessor
@@ -340,8 +381,8 @@ def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel, n
340381
341382
342383@pytest .mark .parametrize ("cluster" , [None , 2 ])
343- def test_occupancy_max_active_clusters (get_saxpy_kernel , cluster ):
344- kernel , _ = get_saxpy_kernel
384+ def test_occupancy_max_active_clusters (get_saxpy_kernel_cubin , cluster ):
385+ kernel , _ = get_saxpy_kernel_cubin
345386 dev = Device ()
346387 if dev .compute_capability < (9 , 0 ):
347388 pytest .skip ("Device with compute capability 90 or higher is required for cluster support" )
@@ -355,8 +396,8 @@ def test_occupancy_max_active_clusters(get_saxpy_kernel, cluster):
355396 assert max_active_clusters >= 0
356397
357398
358- def test_occupancy_max_potential_cluster_size (get_saxpy_kernel ):
359- kernel , _ = get_saxpy_kernel
399+ def test_occupancy_max_potential_cluster_size (get_saxpy_kernel_cubin ):
400+ kernel , _ = get_saxpy_kernel_cubin
360401 dev = Device ()
361402 if dev .compute_capability < (9 , 0 ):
362403 pytest .skip ("Device with compute capability 90 or higher is required for cluster support" )
@@ -370,11 +411,11 @@ def test_occupancy_max_potential_cluster_size(get_saxpy_kernel):
370411 assert max_potential_cluster_size >= 0
371412
372413
373- def test_module_serialization_roundtrip (get_saxpy_kernel ):
374- _ , objcode = get_saxpy_kernel
414+ def test_module_serialization_roundtrip (get_saxpy_kernel_cubin ):
415+ _ , objcode = get_saxpy_kernel_cubin
375416 result = pickle .loads (pickle .dumps (objcode )) # noqa: S403, S301
376417
377418 assert isinstance (result , ObjectCode )
378419 assert objcode .code == result .code
379420 assert objcode ._sym_map == result ._sym_map
380- assert objcode ._code_type == result ._code_type
421+ assert objcode .code_type == result .code_type
0 commit comments