1- # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# SPDX-License-Identifier: Apache-2.0
44
55from __future__ import annotations
66
77from dataclasses import dataclass , field
8- from typing import TYPE_CHECKING , Iterable , Literal
8+ from typing import TYPE_CHECKING , Iterable
9+
10+ try :
11+ from enum import StrEnum
12+ except ImportError :
13+ from backports .strenum import StrEnum
914
1015if TYPE_CHECKING :
1116 from cuda .core ._stream import Stream
2227)
2328from cuda .core ._utils .version import binding_version
2429
25- __all__ = ["VirtualMemoryResource" , "VirtualMemoryResourceOptions" ]
30+ __all__ = [
31+ "VirtualMemoryAccessType" ,
32+ "VirtualMemoryAllocationType" ,
33+ "VirtualMemoryGranularityType" ,
34+ "VirtualMemoryHandleType" ,
35+ "VirtualMemoryLocationType" ,
36+ "VirtualMemoryResource" ,
37+ "VirtualMemoryResourceOptions" ,
38+ ]
39+
40+
41+ class VirtualMemoryHandleType (StrEnum ):
42+ POSIX_FD = "posix_fd"
43+ GENERIC = "generic"
44+ WIN32_KMT = "win32_kmt"
45+ FABRIC = "fabric"
46+
47+
48+ class VirtualMemoryLocationType (StrEnum ):
49+ DEVICE = "device"
50+ HOST = "host"
51+ HOST_NUMA = "host_numa"
52+ HOST_NUMA_CURRENT = "host_numa_current"
53+
54+
55+ class VirtualMemoryGranularityType (StrEnum ):
56+ MINIMUM = "minimum"
57+ RECOMMENDED = "recommended"
58+
59+
60+ class VirtualMemoryAccessType (StrEnum ):
61+ READ_WRITE = "rw"
62+ READ = "r"
63+
2664
27- VirtualMemoryHandleTypeT = Literal ["posix_fd" , "generic" , "win32_kmt" , "fabric" ] | None
28- VirtualMemoryLocationTypeT = Literal ["device" , "host" , "host_numa" , "host_numa_current" ]
29- VirtualMemoryGranularityT = Literal ["minimum" , "recommended" ]
30- VirtualMemoryAccessTypeT = Literal ["rw" , "r" ] | None
31- VirtualMemoryAllocationTypeT = Literal ["pinned" , "managed" ]
65+ class VirtualMemoryAllocationType (StrEnum ):
66+ PINNED = "pinned"
67+ MANAGED = "managed"
3268
3369
3470@dataclass
@@ -38,69 +74,68 @@ class VirtualMemoryResourceOptions:
3874
3975 Attributes
4076 ----------
41- allocation_type: :obj:`~_memory.VirtualMemoryAllocationTypeT`
77+ allocation_type: :obj:`~_memory.VirtualMemoryAllocationType` | str
4278 Controls the type of allocation.
43- location_type: :obj:`~_memory.VirtualMemoryLocationTypeT`
79+ location_type: :obj:`~_memory.VirtualMemoryLocationType` | str
4480 Controls the location of the allocation.
45- handle_type: :obj:`~_memory.VirtualMemoryHandleTypeT`
81+ handle_type: :obj:`~_memory.VirtualMemoryHandleType` | str
4682 Export handle type for the physical allocation. Use
4783 ``"posix_fd"`` on Linux if you plan to
4884 import/export the allocation (required for cuMemRetainAllocationHandle).
4985 Use `None` if you don't need an exportable handle.
5086 gpu_direct_rdma: bool
5187 Hint that the allocation should be GDR-capable (if supported).
52- granularity: :obj:`~_memory.VirtualMemoryGranularityT `
88+ granularity: :obj:`~_memory.VirtualMemoryGranularityType `
5389 Controls granularity query and size rounding.
5490 addr_hint: int
5591 A (optional) virtual address hint to try to reserve at. Setting it to 0 lets the CUDA driver decide.
5692 addr_align: int
5793 Alignment for the VA reservation. If `None`, use the queried granularity.
5894 peers: Iterable[int]
5995 Extra device IDs that should be granted access in addition to ``device``.
60- self_access: :obj:`~_memory.VirtualMemoryAccessTypeT`
96+ self_access: :obj:`~_memory.VirtualMemoryAccessType` | str
6197 Access flags for the owning device.
62- peer_access: :obj:`~_memory.VirtualMemoryAccessTypeT`
98+ peer_access: :obj:`~_memory.VirtualMemoryAccessType` | str
6399 Access flags for peers.
64100 """
65101
66- # Human-friendly strings; normalized in __post_init__
67- allocation_type : VirtualMemoryAllocationTypeT = "pinned"
68- location_type : VirtualMemoryLocationTypeT = "device"
69- handle_type : VirtualMemoryHandleTypeT = "posix_fd"
70- granularity : VirtualMemoryGranularityT = "recommended"
102+ allocation_type : VirtualMemoryAllocationType = VirtualMemoryAllocationType .PINNED
103+ location_type : VirtualMemoryLocationType = VirtualMemoryLocationType .DEVICE
104+ handle_type : VirtualMemoryHandleType = VirtualMemoryHandleType .POSIX_FD
105+ granularity : VirtualMemoryGranularityType = VirtualMemoryGranularityType .RECOMMENDED
71106 gpu_direct_rdma : bool = False
72107 addr_hint : int | None = 0
73108 addr_align : int | None = None
74109 peers : Iterable [int ] = field (default_factory = tuple )
75- self_access : VirtualMemoryAccessTypeT = "rw"
76- peer_access : VirtualMemoryAccessTypeT = "rw"
110+ self_access : VirtualMemoryAccessType = VirtualMemoryAccessType . READ_WRITE
111+ peer_access : VirtualMemoryAccessType = VirtualMemoryAccessType . READ_WRITE
77112
78113 _a = driver .CUmemAccess_flags
79114 _access_flags = {"rw" : _a .CU_MEM_ACCESS_FLAGS_PROT_READWRITE , "r" : _a .CU_MEM_ACCESS_FLAGS_PROT_READ , None : 0 } # noqa: RUF012
80115 _h = driver .CUmemAllocationHandleType
81116 _handle_types = { # noqa: RUF012
82117 None : _h .CU_MEM_HANDLE_TYPE_NONE ,
83- "posix_fd" : _h .CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR ,
84- "win32_kmt" : _h .CU_MEM_HANDLE_TYPE_WIN32_KMT ,
85- "fabric" : _h .CU_MEM_HANDLE_TYPE_FABRIC ,
118+ VirtualMemoryHandleType . POSIX_FD : _h .CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR ,
119+ VirtualMemoryHandleType . WIN32_KMT : _h .CU_MEM_HANDLE_TYPE_WIN32_KMT ,
120+ VirtualMemoryHandleType . FABRIC : _h .CU_MEM_HANDLE_TYPE_FABRIC ,
86121 }
87122 _g = driver .CUmemAllocationGranularity_flags
88123 _granularity = { # noqa: RUF012
89- "recommended" : _g .CU_MEM_ALLOC_GRANULARITY_RECOMMENDED ,
90- "minimum" : _g .CU_MEM_ALLOC_GRANULARITY_MINIMUM ,
124+ VirtualMemoryGranularityType . RECOMMENDED : _g .CU_MEM_ALLOC_GRANULARITY_RECOMMENDED ,
125+ VirtualMemoryGranularityType . MINIMUM : _g .CU_MEM_ALLOC_GRANULARITY_MINIMUM ,
91126 }
92127 _l = driver .CUmemLocationType
93128 _location_type = { # noqa: RUF012
94- "device" : _l .CU_MEM_LOCATION_TYPE_DEVICE ,
95- "host" : _l .CU_MEM_LOCATION_TYPE_HOST ,
96- "host_numa" : _l .CU_MEM_LOCATION_TYPE_HOST_NUMA ,
97- "host_numa_current" : _l .CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT ,
129+ VirtualMemoryLocationType . DEVICE : _l .CU_MEM_LOCATION_TYPE_DEVICE ,
130+ VirtualMemoryLocationType . HOST : _l .CU_MEM_LOCATION_TYPE_HOST ,
131+ VirtualMemoryLocationType . HOST_NUMA : _l .CU_MEM_LOCATION_TYPE_HOST_NUMA ,
132+ VirtualMemoryLocationType . HOST_NUMA_CURRENT : _l .CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT ,
98133 }
99134 _t = driver .CUmemAllocationType
100135 # CUDA 13+ exposes MANAGED in CUmemAllocationType; older 12.x does not
101- _allocation_type = {"pinned" : _t .CU_MEM_ALLOCATION_TYPE_PINNED } # noqa: RUF012
136+ _allocation_type = {VirtualMemoryAllocationType . PINNED : _t .CU_MEM_ALLOCATION_TYPE_PINNED } # noqa: RUF012
102137 if binding_version () >= (13 , 0 , 0 ):
103- _allocation_type ["managed" ] = _t .CU_MEM_ALLOCATION_TYPE_MANAGED
138+ _allocation_type [VirtualMemoryAllocationType . MANAGED ] = _t .CU_MEM_ALLOCATION_TYPE_MANAGED
104139
105140 @staticmethod
106141 def _access_to_flags (spec : str ):
0 commit comments