Skip to content
13 changes: 7 additions & 6 deletions cuda_core/cuda/core/_linker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ from cuda.core._utils.cuda_utils import (
driver,
is_sequence,
)
from cuda.core.typing import CompilerBackend

ctypedef const char* const_char_ptr
ctypedef void* void_ptr
Expand Down Expand Up @@ -70,12 +71,12 @@ cdef class Linker:
def __init__(self, *object_codes: ObjectCode, options: "LinkerOptions" = None):
Linker_init(self, object_codes, options)

def link(self, target_type) -> ObjectCode:
def link(self, target_type: ObjectCodeFormat | str) -> ObjectCode:
"""Link the provided object codes into a single output of the specified target type.

Parameters
----------
target_type : str
target_type : ObjectCodeFormat | str
The type of the target output. Must be either "cubin" or "ptx".

Returns
Expand All @@ -88,7 +89,7 @@ cdef class Linker:
Ensure that input object codes were compiled with appropriate
flags for linking (e.g., relocatable device code enabled).
"""
return Linker_link(self, target_type)
return Linker_link(self, str(target_type))

def get_error_log(self) -> str:
"""Get the error log generated by the linker.
Expand Down Expand Up @@ -168,9 +169,9 @@ cdef class Linker:
return as_py(self._culink_handle)

@property
def backend(self) -> str:
"""Return this Linker instance's underlying backend."""
return "nvJitLink" if self._use_nvjitlink else "driver"
def backend(self) -> CompilerBackend:
"""Return this Linker instance's underlying :class:`CompilerBackend`."""
return CompilerBackend.NVJITLINK if self._use_nvjitlink else CompilerBackend.DRIVER


# =============================================================================
Expand Down
12 changes: 7 additions & 5 deletions cuda_core/cuda/core/_memory/_managed_memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ from dataclasses import dataclass
import threading
import warnings

from cuda.core.typing import ManagedMemoryLocationType

__all__ = ['ManagedMemoryResource', 'ManagedMemoryResourceOptions']


Expand All @@ -30,7 +32,7 @@ cdef class ManagedMemoryResourceOptions:
meaning depends on ``preferred_location_type``.
(Default to ``None``)

preferred_location_type : ``"device"`` | ``"host"`` | ``"host_numa"`` | None, optional
preferred_location_type : ManagedMemoryLocationType | str | None, optional
Controls how ``preferred_location`` is interpreted.

When set to ``None`` (the default), legacy behavior is used:
Expand All @@ -54,7 +56,7 @@ cdef class ManagedMemoryResourceOptions:
(Default to ``None``)
"""
preferred_location: int | None = None
preferred_location_type: str | None = None
preferred_location_type: ManagedMemoryLocationType | str | None = None


cdef class ManagedMemoryResource(_MemPool):
Expand Down Expand Up @@ -97,7 +99,7 @@ cdef class ManagedMemoryResource(_MemPool):
return -1

@property
def preferred_location(self) -> tuple | None:
def preferred_location(self) -> tuple[ManagedMemoryLocationType, int | None] | None:
"""The preferred location for managed memory allocations.

Returns ``None`` if no preferred location is set (driver decides),
Expand All @@ -108,8 +110,8 @@ cdef class ManagedMemoryResource(_MemPool):
if self._pref_loc_type is None:
return None
if self._pref_loc_type == "host":
return ("host", None)
return (self._pref_loc_type, self._pref_loc_id)
return (ManagedMemoryLocationType.HOST, None)
return (ManagedMemoryLocationType(self._pref_loc_type), self._pref_loc_id)

@property
def is_device_accessible(self) -> bool:
Expand Down
70 changes: 37 additions & 33 deletions cuda_core/cuda/core/_memory/_virtual_memory_resource.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Iterable, Literal
from typing import TYPE_CHECKING, Iterable

if TYPE_CHECKING:
from cuda.core._stream import Stream
Expand All @@ -21,15 +21,16 @@
_check_driver_error as raise_if_driver_error,
)
from cuda.core._utils.version import binding_version
from cuda.core.typing import (
VirtualMemoryAccessType,
VirtualMemoryAllocationType,
VirtualMemoryGranularityType,
VirtualMemoryHandleType,
VirtualMemoryLocationType,
)

__all__ = ["VirtualMemoryResource", "VirtualMemoryResourceOptions"]

VirtualMemoryHandleTypeT = Literal["posix_fd", "generic", "win32_kmt", "fabric"] | None
VirtualMemoryLocationTypeT = Literal["device", "host", "host_numa", "host_numa_current"]
VirtualMemoryGranularityT = Literal["minimum", "recommended"]
VirtualMemoryAccessTypeT = Literal["rw", "r"] | None
VirtualMemoryAllocationTypeT = Literal["pinned", "managed"]


@dataclass
class VirtualMemoryResourceOptions:
Expand All @@ -38,69 +39,72 @@ class VirtualMemoryResourceOptions:

Attributes
----------
allocation_type: :obj:`~_memory.VirtualMemoryAllocationTypeT`
allocation_type: :obj:`~_memory.VirtualMemoryAllocationType` | str
Controls the type of allocation.
location_type: :obj:`~_memory.VirtualMemoryLocationTypeT`
location_type: :obj:`~_memory.VirtualMemoryLocationType` | str
Controls the location of the allocation.
handle_type: :obj:`~_memory.VirtualMemoryHandleTypeT`
handle_type: :obj:`~_memory.VirtualMemoryHandleType` | str
Export handle type for the physical allocation. Use
``"posix_fd"`` on Linux if you plan to
import/export the allocation (required for cuMemRetainAllocationHandle).
Use `None` if you don't need an exportable handle.
gpu_direct_rdma: bool
Hint that the allocation should be GDR-capable (if supported).
granularity: :obj:`~_memory.VirtualMemoryGranularityT`
granularity: :obj:`~_memory.VirtualMemoryGranularityType` | str
Controls granularity query and size rounding.
addr_hint: int
A (optional) virtual address hint to try to reserve at. Setting it to 0 lets the CUDA driver decide.
addr_align: int
Alignment for the VA reservation. If `None`, use the queried granularity.
peers: Iterable[int]
Extra device IDs that should be granted access in addition to ``device``.
self_access: :obj:`~_memory.VirtualMemoryAccessTypeT`
self_access: :obj:`~_memory.VirtualMemoryAccessType` | None | str
Access flags for the owning device.
peer_access: :obj:`~_memory.VirtualMemoryAccessTypeT`
peer_access: :obj:`~_memory.VirtualMemoryAccessType` | None | str
Access flags for peers.
"""

# Human-friendly strings; normalized in __post_init__
allocation_type: VirtualMemoryAllocationTypeT = "pinned"
location_type: VirtualMemoryLocationTypeT = "device"
handle_type: VirtualMemoryHandleTypeT = "posix_fd"
granularity: VirtualMemoryGranularityT = "recommended"
allocation_type: VirtualMemoryAllocationType = VirtualMemoryAllocationType.PINNED
location_type: VirtualMemoryLocationType = VirtualMemoryLocationType.DEVICE
handle_type: VirtualMemoryHandleType = VirtualMemoryHandleType.POSIX_FD
granularity: VirtualMemoryGranularityType = VirtualMemoryGranularityType.RECOMMENDED
gpu_direct_rdma: bool = False
addr_hint: int | None = 0
addr_align: int | None = None
peers: Iterable[int] = field(default_factory=tuple)
self_access: VirtualMemoryAccessTypeT = "rw"
peer_access: VirtualMemoryAccessTypeT = "rw"
self_access: VirtualMemoryAccessType = VirtualMemoryAccessType.READ_WRITE
peer_access: VirtualMemoryAccessType = VirtualMemoryAccessType.READ_WRITE

_a = driver.CUmemAccess_flags
_access_flags = {"rw": _a.CU_MEM_ACCESS_FLAGS_PROT_READWRITE, "r": _a.CU_MEM_ACCESS_FLAGS_PROT_READ, None: 0} # noqa: RUF012
_access_flags = { # noqa: RUF012
VirtualMemoryAccessType.READ_WRITE: _a.CU_MEM_ACCESS_FLAGS_PROT_READWRITE,
VirtualMemoryAccessType.READ: _a.CU_MEM_ACCESS_FLAGS_PROT_READ,
None: 0,
}
_h = driver.CUmemAllocationHandleType
_handle_types = { # noqa: RUF012
None: _h.CU_MEM_HANDLE_TYPE_NONE,
"posix_fd": _h.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
"win32_kmt": _h.CU_MEM_HANDLE_TYPE_WIN32_KMT,
"fabric": _h.CU_MEM_HANDLE_TYPE_FABRIC,
VirtualMemoryHandleType.POSIX_FD: _h.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
VirtualMemoryHandleType.WIN32_KMT: _h.CU_MEM_HANDLE_TYPE_WIN32_KMT,
VirtualMemoryHandleType.FABRIC: _h.CU_MEM_HANDLE_TYPE_FABRIC,
}
_g = driver.CUmemAllocationGranularity_flags
_granularity = { # noqa: RUF012
"recommended": _g.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED,
"minimum": _g.CU_MEM_ALLOC_GRANULARITY_MINIMUM,
VirtualMemoryGranularityType.RECOMMENDED: _g.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED,
VirtualMemoryGranularityType.MINIMUM: _g.CU_MEM_ALLOC_GRANULARITY_MINIMUM,
}
_l = driver.CUmemLocationType
_location_type = { # noqa: RUF012
"device": _l.CU_MEM_LOCATION_TYPE_DEVICE,
"host": _l.CU_MEM_LOCATION_TYPE_HOST,
"host_numa": _l.CU_MEM_LOCATION_TYPE_HOST_NUMA,
"host_numa_current": _l.CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT,
VirtualMemoryLocationType.DEVICE: _l.CU_MEM_LOCATION_TYPE_DEVICE,
VirtualMemoryLocationType.HOST: _l.CU_MEM_LOCATION_TYPE_HOST,
VirtualMemoryLocationType.HOST_NUMA: _l.CU_MEM_LOCATION_TYPE_HOST_NUMA,
VirtualMemoryLocationType.HOST_NUMA_CURRENT: _l.CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT,
}
_t = driver.CUmemAllocationType
# CUDA 13+ exposes MANAGED in CUmemAllocationType; older 12.x does not
_allocation_type = {"pinned": _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012
_allocation_type = {VirtualMemoryAllocationType.PINNED: _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012
if binding_version() >= (13, 0, 0):
_allocation_type["managed"] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED
_allocation_type[VirtualMemoryAllocationType.MANAGED] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED

@staticmethod
def _access_to_flags(spec: str):
Expand Down
19 changes: 10 additions & 9 deletions cuda_core/cuda/core/_module.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from cuda.core._device import Device
from cuda.core._launch_config cimport LaunchConfig
from cuda.core._launch_config import LaunchConfig
from cuda.core._stream cimport Stream
from cuda.core._program import ObjectCodeFormat
from cuda.core._resource_handles cimport (
LibraryHandle,
KernelHandle,
Expand Down Expand Up @@ -569,7 +570,7 @@ cdef class Kernel:

CodeTypeT = bytes | bytearray | str

cdef tuple _supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library")
cdef tuple _supported_code_type = tuple(ObjectCodeFormat.__members__.values())

cdef class ObjectCode:
"""Represent a compiled program to be loaded onto the device.
Expand Down Expand Up @@ -599,7 +600,7 @@ cdef class ObjectCode:
# _h_library is assigned during _lazy_load_module
self._h_library = LibraryHandle() # Empty handle

self._code_type = code_type
self._code_type = str(code_type)
self._module = module
self._sym_map = {} if symbol_mapping is None else symbol_mapping
self._name = name if name else ""
Expand Down Expand Up @@ -629,7 +630,7 @@ cdef class ObjectCode:
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "cubin", name=name, symbol_mapping=symbol_mapping)
return ObjectCode._init(module, ObjectCodeFormat.CUBIN, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_ptx(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
Expand All @@ -647,7 +648,7 @@ cdef class ObjectCode:
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "ptx", name=name, symbol_mapping=symbol_mapping)
return ObjectCode._init(module, ObjectCodeFormat.PTX, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_ltoir(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
Expand All @@ -665,7 +666,7 @@ cdef class ObjectCode:
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "ltoir", name=name, symbol_mapping=symbol_mapping)
return ObjectCode._init(module, ObjectCodeFormat.LTOIR, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_fatbin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
Expand All @@ -683,7 +684,7 @@ cdef class ObjectCode:
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "fatbin", name=name, symbol_mapping=symbol_mapping)
return ObjectCode._init(module, ObjectCodeFormat.FATBIN, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_object(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
Expand All @@ -701,7 +702,7 @@ cdef class ObjectCode:
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "object", name=name, symbol_mapping=symbol_mapping)
return ObjectCode._init(module, ObjectCodeFormat.OBJECT, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_library(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
Expand All @@ -719,7 +720,7 @@ cdef class ObjectCode:
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "library", name=name, symbol_mapping=symbol_mapping)
return ObjectCode._init(module, ObjectCodeFormat.LIBRARY, name=name, symbol_mapping=symbol_mapping)

# TODO: do we want to unload in a finalizer? Probably not..

Expand Down Expand Up @@ -758,7 +759,7 @@ cdef class ObjectCode:

"""
self._lazy_load_module()
supported_code_types = ("cubin", "ptx", "fatbin")
supported_code_types = (ObjectCodeFormat.CUBIN, ObjectCodeFormat.PTX, ObjectCodeFormat.FATBIN)
if self._code_type not in supported_code_types:
raise RuntimeError(f'Unsupported code type "{self._code_type}" ({supported_code_types=})')
try:
Expand Down
Loading
Loading