Skip to content

Commit 3665e41

Browse files
committed
minor fixes
1 parent 45d0b24 commit 3665e41

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

cuda_core/tests/test_module.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def get_saxpy_kernel_cubin(init_cuda):
6767
"cubin",
6868
name_expressions=("saxpy<float>", "saxpy<double>"),
6969
)
70-
7170
# run in single precision
7271
return mod.get_kernel("saxpy<float>"), mod
7372

@@ -154,9 +153,9 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx):
154153
def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path):
155154
ptx, mod = get_saxpy_kernel_ptx
156155
sym_map = mod._sym_map
157-
assert isinstance(ptx, str)
156+
assert isinstance(ptx, bytes)
158157
ptx_file = tmp_path / "test.ptx"
159-
ptx_file.write_text(ptx)
158+
ptx_file.write_bytes(ptx)
160159
mod_obj = ObjectCode.from_ptx(str(ptx_file), symbol_mapping=sym_map)
161160
assert mod_obj.code == str(ptx_file)
162161
assert mod_obj._code_type == "ptx"
@@ -187,8 +186,8 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path):
187186
mod.get_kernel("saxpy<double>") # force loading
188187

189188

190-
def test_object_code_handle(get_saxpy_object_code):
191-
mod = get_saxpy_object_code
189+
def test_object_code_handle(get_saxpy_kernel_cubin):
190+
_, mod = get_saxpy_kernel_cubin
192191
assert mod.handle is not None
193192

194193

0 commit comments

Comments
 (0)