Skip to content

Commit e14f0ff

Browse files
committed
Fix #1995: Use StrEnum for enum-like strings
1 parent 64e2e6a commit e14f0ff

16 files changed

Lines changed: 231 additions & 110 deletions

cuda_core/cuda/core/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,17 @@ def _import_versioned_module():
4141
DeviceMemoryResourceOptions,
4242
GraphMemoryResource,
4343
LegacyPinnedMemoryResource,
44+
ManagedMemoryLocationType,
4445
ManagedMemoryResource,
4546
ManagedMemoryResourceOptions,
4647
MemoryResource,
4748
PinnedMemoryResource,
4849
PinnedMemoryResourceOptions,
50+
VirtualMemoryAccessType,
51+
VirtualMemoryAllocationType,
52+
VirtualMemoryGranularityType,
53+
VirtualMemoryHandleType,
54+
VirtualMemoryLocationType,
4955
VirtualMemoryResource,
5056
VirtualMemoryResourceOptions,
5157
)
@@ -54,7 +60,7 @@ def _import_versioned_module():
5460
args_viewable_as_strided_memory,
5561
)
5662
from cuda.core._module import Kernel, ObjectCode
57-
from cuda.core._program import Program, ProgramOptions
63+
from cuda.core._program import CodeType, CompilerBackend, PchStatus, Program, ProgramOptions, SourceType
5864
from cuda.core._stream import (
5965
LEGACY_DEFAULT_STREAM,
6066
PER_THREAD_DEFAULT_STREAM,
@@ -70,4 +76,5 @@ def _import_versioned_module():
7076
GraphCondition,
7177
GraphDebugPrintOptions,
7278
GraphDefinition,
79+
GraphMemoryType,
7380
)

cuda_core/cuda/core/_linker.pyx

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ cdef class Linker:
7070
def __init__(self, *object_codes: ObjectCode, options: "LinkerOptions" = None):
7171
Linker_init(self, object_codes, options)
7272

73-
def link(self, target_type) -> ObjectCode:
73+
def link(self, target_type: CodeType | str) -> ObjectCode:
7474
"""Link the provided object codes into a single output of the specified target type.
7575

7676
Parameters
7777
----------
78-
target_type : str
78+
target_type : CodeType | str
7979
The type of the target output. Must be either "cubin" or "ptx".
8080

8181
Returns
@@ -88,7 +88,7 @@ cdef class Linker:
8888
Ensure that input object codes were compiled with appropriate
8989
flags for linking (e.g., relocatable device code enabled).
9090
"""
91-
return Linker_link(self, target_type)
91+
return Linker_link(self, str(target_type))
9292

9393
def get_error_log(self) -> str:
9494
"""Get the error log generated by the linker.
@@ -168,9 +168,10 @@ cdef class Linker:
168168
return as_py(self._culink_handle)
169169

170170
@property
171-
def backend(self) -> str:
172-
"""Return this Linker instance's underlying backend."""
173-
return "nvJitLink" if self._use_nvjitlink else "driver"
171+
def backend(self) -> "CompilerBackend":
172+
"""Return this Linker instance's underlying :class:`CompilerBackend`."""
173+
from ._program import CompilerBackend
174+
return CompilerBackend.NVJITLINK if self._use_nvjitlink else CompilerBackend.DRIVER
174175

175176

176177
# =============================================================================

cuda_core/cuda/core/_memory/_managed_memory_resource.pyx

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,20 @@ from cuda.core._utils.cuda_utils cimport check_or_create_options # no-cython-li
1313
from cuda.core._utils.cuda_utils import CUDAError # no-cython-lint
1414

1515
from dataclasses import dataclass
16+
try:
17+
from enum import StrEnum
18+
except ImportError:
19+
from backports.strenum import StrEnum
1620
import threading
1721
import warnings
1822

19-
__all__ = ['ManagedMemoryResource', 'ManagedMemoryResourceOptions']
23+
__all__ = ['ManagedMemoryResource', 'ManagedMemoryResourceOptions', 'ManagedMemoryLocationType']
24+
25+
26+
class ManagedMemoryLocationType(StrEnum):
27+
DEVICE = "device"
28+
HOST = "host"
29+
HOST_NUMA = "host_numa"
2030

2131

2232
@dataclass
@@ -30,7 +40,7 @@ cdef class ManagedMemoryResourceOptions:
3040
meaning depends on ``preferred_location_type``.
3141
(Default to ``None``)
3242
33-
preferred_location_type : ``"device"`` | ``"host"`` | ``"host_numa"`` | None, optional
43+
preferred_location_type : ManagedMemoryLocationType | str | None, optional
3444
Controls how ``preferred_location`` is interpreted.
3545
3646
When set to ``None`` (the default), legacy behavior is used:
@@ -54,7 +64,7 @@ cdef class ManagedMemoryResourceOptions:
5464
(Default to ``None``)
5565
"""
5666
preferred_location: int | None = None
57-
preferred_location_type: str | None = None
67+
preferred_location_type: ManagedMemoryLocationType | None = None
5868

5969

6070
cdef class ManagedMemoryResource(_MemPool):
@@ -97,7 +107,7 @@ cdef class ManagedMemoryResource(_MemPool):
97107
return -1
98108

99109
@property
100-
def preferred_location(self) -> tuple | None:
110+
def preferred_location(self) -> tuple[MemoryLocationType, int | None] | None:
101111
"""The preferred location for managed memory allocations.
102112

103113
Returns ``None`` if no preferred location is set (driver decides),

cuda_core/cuda/core/_memory/_virtual_memory_resource.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

55
from __future__ import annotations
66

77
from dataclasses import dataclass, field
8-
from typing import TYPE_CHECKING, Iterable, Literal
8+
from typing import TYPE_CHECKING, Iterable
9+
10+
try:
11+
from enum import StrEnum
12+
except ImportError:
13+
from backports.strenum import StrEnum
914

1015
if TYPE_CHECKING:
1116
from cuda.core._stream import Stream
@@ -22,13 +27,44 @@
2227
)
2328
from cuda.core._utils.version import binding_version
2429

25-
__all__ = ["VirtualMemoryResource", "VirtualMemoryResourceOptions"]
30+
__all__ = [
31+
"VirtualMemoryAccessType",
32+
"VirtualMemoryAllocationType",
33+
"VirtualMemoryGranularityType",
34+
"VirtualMemoryHandleType",
35+
"VirtualMemoryLocationType",
36+
"VirtualMemoryResource",
37+
"VirtualMemoryResourceOptions",
38+
]
39+
40+
41+
class VirtualMemoryHandleType(StrEnum):
42+
POSIX_FD = "posix_fd"
43+
GENERIC = "generic"
44+
WIN32_KMT = "win32_kmt"
45+
FABRIC = "fabric"
46+
47+
48+
class VirtualMemoryLocationType(StrEnum):
49+
DEVICE = "device"
50+
HOST = "host"
51+
HOST_NUMA = "host_numa"
52+
HOST_NUMA_CURRENT = "host_numa_current"
53+
54+
55+
class VirtualMemoryGranularityType(StrEnum):
56+
MINIMUM = "minimum"
57+
RECOMMENDED = "recommended"
58+
59+
60+
class VirtualMemoryAccessType(StrEnum):
61+
READ_WRITE = "rw"
62+
READ = "r"
63+
2664

27-
VirtualMemoryHandleTypeT = Literal["posix_fd", "generic", "win32_kmt", "fabric"] | None
28-
VirtualMemoryLocationTypeT = Literal["device", "host", "host_numa", "host_numa_current"]
29-
VirtualMemoryGranularityT = Literal["minimum", "recommended"]
30-
VirtualMemoryAccessTypeT = Literal["rw", "r"] | None
31-
VirtualMemoryAllocationTypeT = Literal["pinned", "managed"]
65+
class VirtualMemoryAllocationType(StrEnum):
66+
PINNED = "pinned"
67+
MANAGED = "managed"
3268

3369

3470
@dataclass
@@ -38,69 +74,68 @@ class VirtualMemoryResourceOptions:
3874
3975
Attributes
4076
----------
41-
allocation_type: :obj:`~_memory.VirtualMemoryAllocationTypeT`
77+
allocation_type: :obj:`~_memory.VirtualMemoryAllocationType` | str
4278
Controls the type of allocation.
43-
location_type: :obj:`~_memory.VirtualMemoryLocationTypeT`
79+
location_type: :obj:`~_memory.VirtualMemoryLocationType` | str
4480
Controls the location of the allocation.
45-
handle_type: :obj:`~_memory.VirtualMemoryHandleTypeT`
81+
handle_type: :obj:`~_memory.VirtualMemoryHandleType` | str
4682
Export handle type for the physical allocation. Use
4783
``"posix_fd"`` on Linux if you plan to
4884
import/export the allocation (required for cuMemRetainAllocationHandle).
4985
Use `None` if you don't need an exportable handle.
5086
gpu_direct_rdma: bool
5187
Hint that the allocation should be GDR-capable (if supported).
52-
granularity: :obj:`~_memory.VirtualMemoryGranularityT`
88+
granularity: :obj:`~_memory.VirtualMemoryGranularityType`
5389
Controls granularity query and size rounding.
5490
addr_hint: int
5591
A (optional) virtual address hint to try to reserve at. Setting it to 0 lets the CUDA driver decide.
5692
addr_align: int
5793
Alignment for the VA reservation. If `None`, use the queried granularity.
5894
peers: Iterable[int]
5995
Extra device IDs that should be granted access in addition to ``device``.
60-
self_access: :obj:`~_memory.VirtualMemoryAccessTypeT`
96+
self_access: :obj:`~_memory.VirtualMemoryAccessType` | str
6197
Access flags for the owning device.
62-
peer_access: :obj:`~_memory.VirtualMemoryAccessTypeT`
98+
peer_access: :obj:`~_memory.VirtualMemoryAccessType` | str
6399
Access flags for peers.
64100
"""
65101

66-
# Human-friendly strings; normalized in __post_init__
67-
allocation_type: VirtualMemoryAllocationTypeT = "pinned"
68-
location_type: VirtualMemoryLocationTypeT = "device"
69-
handle_type: VirtualMemoryHandleTypeT = "posix_fd"
70-
granularity: VirtualMemoryGranularityT = "recommended"
102+
allocation_type: VirtualMemoryAllocationType = VirtualMemoryAllocationType.PINNED
103+
location_type: VirtualMemoryLocationType = VirtualMemoryLocationType.DEVICE
104+
handle_type: VirtualMemoryHandleType = VirtualMemoryHandleType.POSIX_FD
105+
granularity: VirtualMemoryGranularityType = VirtualMemoryGranularityType.RECOMMENDED
71106
gpu_direct_rdma: bool = False
72107
addr_hint: int | None = 0
73108
addr_align: int | None = None
74109
peers: Iterable[int] = field(default_factory=tuple)
75-
self_access: VirtualMemoryAccessTypeT = "rw"
76-
peer_access: VirtualMemoryAccessTypeT = "rw"
110+
self_access: VirtualMemoryAccessType = VirtualMemoryAccessType.READ_WRITE
111+
peer_access: VirtualMemoryAccessType = VirtualMemoryAccessType.READ_WRITE
77112

78113
_a = driver.CUmemAccess_flags
79114
_access_flags = {"rw": _a.CU_MEM_ACCESS_FLAGS_PROT_READWRITE, "r": _a.CU_MEM_ACCESS_FLAGS_PROT_READ, None: 0} # noqa: RUF012
80115
_h = driver.CUmemAllocationHandleType
81116
_handle_types = { # noqa: RUF012
82117
None: _h.CU_MEM_HANDLE_TYPE_NONE,
83-
"posix_fd": _h.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
84-
"win32_kmt": _h.CU_MEM_HANDLE_TYPE_WIN32_KMT,
85-
"fabric": _h.CU_MEM_HANDLE_TYPE_FABRIC,
118+
VirtualMemoryHandleType.POSIX_FD: _h.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
119+
VirtualMemoryHandleType.WIN32_KMT: _h.CU_MEM_HANDLE_TYPE_WIN32_KMT,
120+
VirtualMemoryHandleType.FABRIC: _h.CU_MEM_HANDLE_TYPE_FABRIC,
86121
}
87122
_g = driver.CUmemAllocationGranularity_flags
88123
_granularity = { # noqa: RUF012
89-
"recommended": _g.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED,
90-
"minimum": _g.CU_MEM_ALLOC_GRANULARITY_MINIMUM,
124+
VirtualMemoryGranularityType.RECOMMENDED: _g.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED,
125+
VirtualMemoryGranularityType.MINIMUM: _g.CU_MEM_ALLOC_GRANULARITY_MINIMUM,
91126
}
92127
_l = driver.CUmemLocationType
93128
_location_type = { # noqa: RUF012
94-
"device": _l.CU_MEM_LOCATION_TYPE_DEVICE,
95-
"host": _l.CU_MEM_LOCATION_TYPE_HOST,
96-
"host_numa": _l.CU_MEM_LOCATION_TYPE_HOST_NUMA,
97-
"host_numa_current": _l.CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT,
129+
VirtualMemoryLocationType.DEVICE: _l.CU_MEM_LOCATION_TYPE_DEVICE,
130+
VirtualMemoryLocationType.HOST: _l.CU_MEM_LOCATION_TYPE_HOST,
131+
VirtualMemoryLocationType.HOST_NUMA: _l.CU_MEM_LOCATION_TYPE_HOST_NUMA,
132+
VirtualMemoryLocationType.HOST_NUMA_CURRENT: _l.CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT,
98133
}
99134
_t = driver.CUmemAllocationType
100135
# CUDA 13+ exposes MANAGED in CUmemAllocationType; older 12.x does not
101-
_allocation_type = {"pinned": _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012
136+
_allocation_type = {VirtualMemoryAllocationType.PINNED: _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012
102137
if binding_version() >= (13, 0, 0):
103-
_allocation_type["managed"] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED
138+
_allocation_type[VirtualMemoryAllocationType.MANAGED] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED
104139

105140
@staticmethod
106141
def _access_to_flags(spec: str):

cuda_core/cuda/core/_module.pyx

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from cuda.core._device import Device
1212
from cuda.core._launch_config cimport LaunchConfig
1313
from cuda.core._launch_config import LaunchConfig
1414
from cuda.core._stream cimport Stream
15+
from cuda.core._program import CodeType
1516
from cuda.core._resource_handles cimport (
1617
LibraryHandle,
1718
KernelHandle,
@@ -569,7 +570,7 @@ cdef class Kernel:
569570

570571
CodeTypeT = bytes | bytearray | str
571572

572-
cdef tuple _supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library")
573+
cdef tuple _supported_code_type = tuple(CodeType.__members__.values())
573574

574575
cdef class ObjectCode:
575576
"""Represent a compiled program to be loaded onto the device.
@@ -599,7 +600,7 @@ cdef class ObjectCode:
599600
# _h_library is assigned during _lazy_load_module
600601
self._h_library = LibraryHandle() # Empty handle
601602

602-
self._code_type = code_type
603+
self._code_type = str(code_type)
603604
self._module = module
604605
self._sym_map = {} if symbol_mapping is None else symbol_mapping
605606
self._name = name if name else ""
@@ -629,7 +630,7 @@ cdef class ObjectCode:
629630
should be mapped to the mangled names before trying to retrieve
630631
them (default to no mappings).
631632
"""
632-
return ObjectCode._init(module, "cubin", name=name, symbol_mapping=symbol_mapping)
633+
return ObjectCode._init(module, CodeType.CUBIN, name=name, symbol_mapping=symbol_mapping)
633634

634635
@staticmethod
635636
def from_ptx(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
@@ -647,7 +648,7 @@ cdef class ObjectCode:
647648
should be mapped to the mangled names before trying to retrieve
648649
them (default to no mappings).
649650
"""
650-
return ObjectCode._init(module, "ptx", name=name, symbol_mapping=symbol_mapping)
651+
return ObjectCode._init(module, CodeType.PTX, name=name, symbol_mapping=symbol_mapping)
651652

652653
@staticmethod
653654
def from_ltoir(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
@@ -665,7 +666,7 @@ cdef class ObjectCode:
665666
should be mapped to the mangled names before trying to retrieve
666667
them (default to no mappings).
667668
"""
668-
return ObjectCode._init(module, "ltoir", name=name, symbol_mapping=symbol_mapping)
669+
return ObjectCode._init(module, CodeType.LTOIR, name=name, symbol_mapping=symbol_mapping)
669670

670671
@staticmethod
671672
def from_fatbin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
@@ -683,7 +684,7 @@ cdef class ObjectCode:
683684
should be mapped to the mangled names before trying to retrieve
684685
them (default to no mappings).
685686
"""
686-
return ObjectCode._init(module, "fatbin", name=name, symbol_mapping=symbol_mapping)
687+
return ObjectCode._init(module, CodeType.FATBIN, name=name, symbol_mapping=symbol_mapping)
687688

688689
@staticmethod
689690
def from_object(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
@@ -701,7 +702,7 @@ cdef class ObjectCode:
701702
should be mapped to the mangled names before trying to retrieve
702703
them (default to no mappings).
703704
"""
704-
return ObjectCode._init(module, "object", name=name, symbol_mapping=symbol_mapping)
705+
return ObjectCode._init(module, CodeType.OBJECT, name=name, symbol_mapping=symbol_mapping)
705706

706707
@staticmethod
707708
def from_library(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
@@ -719,7 +720,7 @@ cdef class ObjectCode:
719720
should be mapped to the mangled names before trying to retrieve
720721
them (default to no mappings).
721722
"""
722-
return ObjectCode._init(module, "library", name=name, symbol_mapping=symbol_mapping)
723+
return ObjectCode._init(module, CodeType.LIBRARY, name=name, symbol_mapping=symbol_mapping)
723724

724725
# TODO: do we want to unload in a finalizer? Probably not..
725726

@@ -758,7 +759,7 @@ cdef class ObjectCode:
758759

759760
"""
760761
self._lazy_load_module()
761-
supported_code_types = ("cubin", "ptx", "fatbin")
762+
supported_code_types = (CodeType.CUBIN, CodeType.PTX, CodeType.FATBIN)
762763
if self._code_type not in supported_code_types:
763764
raise RuntimeError(f'Unsupported code type "{self._code_type}" ({supported_code_types=})')
764765
try:

0 commit comments

Comments
 (0)