@@ -67,7 +67,6 @@ def get_saxpy_kernel_cubin(init_cuda):
6767 "cubin" ,
6868 name_expressions = ("saxpy<float>" , "saxpy<double>" ),
6969 )
70-
7170 # run in single precision
7271 return mod .get_kernel ("saxpy<float>" ), mod
7372
@@ -154,9 +153,9 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx):
154153def test_object_code_load_ptx_from_file (get_saxpy_kernel_ptx , tmp_path ):
155154 ptx , mod = get_saxpy_kernel_ptx
156155 sym_map = mod ._sym_map
157- assert isinstance (ptx , str )
156+ assert isinstance (ptx , bytes )
158157 ptx_file = tmp_path / "test.ptx"
159- ptx_file .write_text (ptx )
158+ ptx_file .write_bytes (ptx )
160159 mod_obj = ObjectCode .from_ptx (str (ptx_file ), symbol_mapping = sym_map )
161160 assert mod_obj .code == str (ptx_file )
162161 assert mod_obj ._code_type == "ptx"
@@ -187,8 +186,8 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path):
187186 mod .get_kernel ("saxpy<double>" ) # force loading
188187
189188
190- def test_object_code_handle (get_saxpy_object_code ):
191- mod = get_saxpy_object_code
189+ def test_object_code_handle (get_saxpy_kernel_cubin ):
190+ _ , mod = get_saxpy_kernel_cubin
192191 assert mod .handle is not None
193192
194193
0 commit comments