Skip to content

Commit 9ef4821

Browse files
committed
nits
1 parent 459cbdb commit 9ef4821

3 files changed

Lines changed: 9 additions & 12 deletions

File tree

cuda_core/cuda/core/experimental/_program.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

@@ -403,6 +403,8 @@ def close(self):
403403
def __init__(self, code, code_type, options: ProgramOptions = None):
404404
self._mnff = Program._MembersNeededForFinalize(self, None)
405405

406+
self._options = options = check_or_create_options(ProgramOptions, options, "Program options")
407+
406408
if code_type not in self._supported_code_type:
407409
raise NotImplementedError
408410

@@ -416,8 +418,6 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
416418
else:
417419
raise NotImplementedError
418420

419-
self._options = options = check_or_create_options(ProgramOptions, options, "Program options")
420-
421421
def close(self):
422422
"""Destroy this program."""
423423
self._mnff.close()

cuda_core/examples/vector_add.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
s = dev.create_stream()
2727

2828
# prepare program
29-
program_options = ProgramOptions(std="c++17", arch="sm_" + "".join(f"{i}" for i in dev.compute_capability))
29+
arch = "".join(f"{i}" for i in dev.compute_capability)
30+
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
3031
prog = Program(code, code_type="c++", options=program_options)
3132
mod = prog.compile("cubin", name_expressions=("vector_add<float>",))
3233

cuda_core/tests/test_linker.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,10 @@ def test_linker_init(compile_ptx_functions, options):
8686

8787

8888
def test_linker_init_invalid_arch(compile_ptx_functions):
89-
if culink_backend:
90-
with pytest.raises(AttributeError):
91-
options = LinkerOptions(arch="99", ptx=True)
92-
Linker(*compile_ptx_functions, options=options)
93-
else:
94-
with pytest.raises(nvjitlink.nvJitLinkError):
95-
options = LinkerOptions(arch="99", ptx=True)
96-
Linker(*compile_ptx_functions, options=options)
89+
err = AttributeError if culink_backend else nvjitlink.nvJitLinkError
90+
with pytest.raises(err):
91+
options = LinkerOptions(arch="99", ptx=True)
92+
Linker(*compile_ptx_functions, options=options)
9793

9894

9995
@pytest.mark.skipif(culink_backend, reason="culink does not support ptx option")

0 commit comments

Comments
 (0)