2323 from backports .strenum import StrEnum
2424
2525# Each entry is:
26- # (mapping_dict, cuda_binding_enum , binding_unmapped, str_enum_unmapped)
26+ # (cuda_binding_enum, str_enum, mapping_dict , binding_unmapped, str_enum_unmapped)
2727#
28+ # cuda_binding_enum: the cuda.bindings NVML enum class
29+ # str_enum: the cuda_core StrEnum wrapper class, or None if the mapping does
30+ # not use a StrEnum (e.g. maps to plain str or tuple)
31+ # mapping_dict: the dict that maps between the two enum types
2832# binding_unmapped: cuda_binding_enum member names intentionally absent from the mapping
2933# (sentinels, deprecated aliases, etc.)
3034# str_enum_unmapped: StrEnum member names intentionally absent from the mapping
3135# (fallback sentinels returned by the wrapper via .get(value, default))
3236#
3337# The first test checks that every member of cuda_binding_enum whose name is NOT
3438# in binding_unmapped appears as either a key or a value of the mapping dict,
35- # and conversely that every StrEnum member not in str_enum_unmapped also
39+ # and conversely that every str_enum member not in str_enum_unmapped also
3640# appears.
3741_CASES = []
3842
4549 _CASES .extend (
4650 [
4751 (
48- _device ._ADDRESSING_MODE_MAPPING ,
4952 nvml .DeviceAddressingModeType ,
53+ _device .AddressingMode ,
54+ _device ._ADDRESSING_MODE_MAPPING ,
5055 # NONE means "no special addressing mode is active"; not a valid target
5156 {"DEVICE_ADDRESSING_MODE_NONE" },
5257 set (),
5358 ),
5459 (
55- _device ._BRAND_TYPE_MAPPING ,
5660 nvml .BrandType ,
61+ None , # maps to plain str, not a StrEnum
62+ _device ._BRAND_TYPE_MAPPING ,
5763 # COUNT is a sentinel, not a real brand
5864 {"BRAND_COUNT" },
59- set (), # maps to str, not StrEnum; reverse check is a no-op
65+ set (),
6066 ),
6167 (
62- _device ._GPU_P2P_STATUS_MAPPING ,
6368 nvml .GpuP2PStatus ,
69+ _device .GpuP2PStatus ,
70+ _device ._GPU_P2P_STATUS_MAPPING ,
6471 # Both the typo'd (SUPPORED) and corrected (SUPPORTED) spellings
6572 # share the same integer value; the mapping covers both via aliases
6673 set (),
6774 set (),
6875 ),
6976 (
70- _device ._CLOCKS_EVENT_REASONS_MAPPING ,
7177 nvml .ClocksEventReasons ,
78+ _device .ClocksEventReasons ,
79+ _device ._CLOCKS_EVENT_REASONS_MAPPING ,
7280 set (),
7381 set (),
7482 ),
7583 (
76- _device ._EVENT_TYPE_MAPPING ,
7784 nvml .EventType ,
85+ _device .EventType ,
86+ _device ._EVENT_TYPE_MAPPING ,
7887 set (),
7988 set (),
8089 ),
8190 (
82- _device ._FAN_CONTROL_POLICY_MAPPING ,
8391 nvml .FanControlPolicy ,
92+ _device .FanControlPolicy ,
93+ _device ._FAN_CONTROL_POLICY_MAPPING ,
8494 set (),
8595 set (),
8696 ),
8797 (
88- _device ._COOLER_CONTROL_MAPPING ,
8998 nvml .CoolerControl ,
99+ _device .CoolerControl ,
100+ _device ._COOLER_CONTROL_MAPPING ,
90101 # NONE means no signal; COUNT is a sentinel
91102 {"THERMAL_COOLER_SIGNAL_NONE" , "THERMAL_COOLER_SIGNAL_COUNT" },
92103 set (),
93104 ),
94105 (
95- _device ._COOLER_TARGET_MAPPING ,
96106 nvml .CoolerTarget ,
107+ _device .CoolerTarget ,
108+ _device ._COOLER_TARGET_MAPPING ,
97109 # GPU_RELATED is a composite bitmask (GPU | MEMORY | POWER_SUPPLY);
98110 # the wrapper expands it into individual targets instead of mapping
99111 # it as a single entry
100112 {"THERMAL_GPU_RELATED" },
101113 set (),
102114 ),
103115 (
104- _device ._THERMAL_CONTROLLER_MAPPING ,
105116 nvml .ThermalController ,
117+ _device .ThermalController ,
118+ _device ._THERMAL_CONTROLLER_MAPPING ,
106119 # NONE and UNKNOWN are both handled by the .get() fallback that
107120 # returns ThermalController.UNKNOWN when the value is not in the mapping
108121 {"NONE" , "UNKNOWN" },
109122 # UNKNOWN is the default returned by .get() for unrecognised controllers
110123 {"UNKNOWN" },
111124 ),
112125 (
113- _device ._THERMAL_TARGET_MAPPING ,
114126 nvml .ThermalTarget ,
127+ _device .ThermalTarget ,
128+ _device ._THERMAL_TARGET_MAPPING ,
115129 # UNKNOWN is a fallback sentinel; handled by .get()
116130 {"UNKNOWN" },
117131 set (),
118132 ),
119133 (
120- _device ._NVLINK_VERSION_MAPPING ,
121134 nvml .NvlinkVersion ,
135+ None , # maps to tuple, not a StrEnum
136+ _device ._NVLINK_VERSION_MAPPING ,
122137 # VERSION_INVALID is a sentinel for "no NvLink present"
123138 {"VERSION_INVALID" },
124- set (), # maps to tuple, not StrEnum; reverse check is a no-op
139+ set (),
125140 ),
126141 (
127- _system_events ._SYSTEM_EVENT_TYPE_MAPPING ,
128142 nvml .SystemEventType ,
143+ _system_events .SystemEventType ,
144+ _system_events ._SYSTEM_EVENT_TYPE_MAPPING ,
129145 set (),
130146 set (),
131147 ),
132148 (
133- _device ._AFFINITY_SCOPE_MAPPING ,
134149 nvml .AffinityScope ,
150+ _device .AffinityScope ,
151+ _device ._AFFINITY_SCOPE_MAPPING ,
135152 set (),
136153 set (),
137154 ),
138155 (
139- _device ._GPU_P2P_CAPS_INDEX_MAPPING ,
140156 nvml .GpuP2PCapsIndex ,
157+ _device .GpuP2PCapsIndex ,
158+ _device ._GPU_P2P_CAPS_INDEX_MAPPING ,
141159 # UNKNOWN is returned by the driver when an index is unrecognised;
142160 # it is not a capability the caller selects
143161 {"P2P_CAPS_INDEX_UNKNOWN" },
144162 # UNKNOWN is a driver-side fallback, not a caller-selectable index
145163 {"UNKNOWN" },
146164 ),
147165 (
148- _device ._GPU_TOPOLOGY_LEVEL_MAPPING ,
149166 nvml .GpuTopologyLevel ,
167+ _device .GpuTopologyLevel ,
168+ _device ._GPU_TOPOLOGY_LEVEL_MAPPING ,
150169 set (),
151170 set (),
152171 ),
153172 (
154- _device ._CLOCK_ID_MAPPING ,
155173 nvml .ClockId ,
174+ _device .ClockId ,
175+ _device ._CLOCK_ID_MAPPING ,
156176 # APP_CLOCK_TARGET and APP_CLOCK_DEFAULT are deprecated; COUNT is a sentinel
157177 {"APP_CLOCK_TARGET" , "APP_CLOCK_DEFAULT" , "COUNT" },
158178 set (),
159179 ),
160180 (
161- _device ._CLOCK_TYPE_MAPPING ,
162181 nvml .ClockType ,
182+ _device .ClockType ,
183+ _device ._CLOCK_TYPE_MAPPING ,
163184 # COUNT is a sentinel
164185 {"CLOCK_COUNT" },
165186 set (),
166187 ),
167188 (
168- _device ._INFOROM_OBJECT_MAPPING ,
169189 nvml .InforomObject ,
190+ _device .InforomObject ,
191+ _device ._INFOROM_OBJECT_MAPPING ,
170192 # COUNT is a sentinel
171193 {"INFOROM_COUNT" },
172194 set (),
173195 ),
174196 (
175- _device ._TEMPERATURE_THRESHOLD_MAPPING ,
176197 nvml .TemperatureThresholds ,
198+ _device .TemperatureThresholds ,
199+ _device ._TEMPERATURE_THRESHOLD_MAPPING ,
177200 # COUNT is a sentinel
178201 {"TEMPERATURE_THRESHOLD_COUNT" },
179202 set (),
189212
190213
191214@pytest .mark .parametrize (
192- "mapping, binding , binding_unmapped, str_enum_unmapped" ,
215+ "binding, str_enum, mapping , binding_unmapped, str_enum_unmapped" ,
193216 _CASES ,
194- ids = [x [1 ].__name__ for x in _CASES ],
217+ ids = [x [0 ].__name__ for x in _CASES ],
195218)
196- def test_wrapper_covers_all_binding_members (mapping , binding , binding_unmapped , str_enum_unmapped ):
219+ def test_wrapper_covers_all_binding_members (binding , str_enum , mapping , binding_unmapped , str_enum_unmapped ):
197220 """Every cuda_binding enum member must appear in the wrapper mapping (or be allow-listed).
198221
199222 Also checks the reverse: every StrEnum wrapper member must appear in the
@@ -207,13 +230,11 @@ def test_wrapper_covers_all_binding_members(mapping, binding, binding_unmapped,
207230 assert not missing , f"{ binding .__name__ } has members not covered by the wrapper mapping: { missing } "
208231
209232 # Reverse check: every StrEnum member must also appear in the mapping.
210- str_enum_instances = [m for m in (* mapping .keys (), * mapping .values ()) if isinstance (m , StrEnum )]
211- if str_enum_instances :
212- str_enum_cls = type (str_enum_instances [0 ])
213- required_str = set (str_enum_cls .__members__ ) - str_enum_unmapped
214- covered_str = {m .name for m in str_enum_instances }
233+ if str_enum is not None :
234+ required_str = set (str_enum .__members__ ) - str_enum_unmapped
235+ covered_str = {m .name for m in (* mapping .keys (), * mapping .values ()) if isinstance (m , str_enum )}
215236 missing_str = required_str - covered_str
216- assert not missing_str , f"{ str_enum_cls .__name__ } has members not covered by the wrapper mapping: { missing_str } "
237+ assert not missing_str , f"{ str_enum .__name__ } has members not covered by the wrapper mapping: { missing_str } "
217238
218239
219240def test_all_str_enums_in_cases ():
@@ -245,12 +266,7 @@ def discover_str_enums() -> set[type]:
245266 found .add (obj )
246267 return found
247268
248- covered = {
249- type (obj )
250- for mapping , _ , _ , _ in _CASES
251- for obj in (* mapping .keys (), * mapping .values ())
252- if isinstance (obj , StrEnum )
253- }
269+ covered = {x [1 ] for x in _CASES if x [1 ] is not None }
254270 uncovered = discover_str_enums () - covered - _UNBOUND_STR_ENUMS
255271 assert not uncovered , (
256272 f"StrEnum subclasses in cuda.core not covered by _CASES: "
0 commit comments