Skip to content

Commit 618e8a8

Browse files
committed
Clean up test
1 parent 7a7d74e commit 618e8a8

1 file changed

Lines changed: 54 additions & 38 deletions

File tree

cuda_core/tests/test_enum_coverage.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,20 @@
2323
from backports.strenum import StrEnum
2424

2525
# Each entry is:
26-
# (mapping_dict, cuda_binding_enum, binding_unmapped, str_enum_unmapped)
26+
# (cuda_binding_enum, str_enum, mapping_dict, binding_unmapped, str_enum_unmapped)
2727
#
28+
# cuda_binding_enum: the cuda.bindings NVML enum class
29+
# str_enum: the cuda_core StrEnum wrapper class, or None if the mapping does
30+
# not use a StrEnum (e.g. maps to plain str or tuple)
31+
# mapping_dict: the dict that maps between the two enum types
2832
# binding_unmapped: cuda_binding_enum member names intentionally absent from the mapping
2933
# (sentinels, deprecated aliases, etc.)
3034
# str_enum_unmapped: StrEnum member names intentionally absent from the mapping
3135
# (fallback sentinels returned by the wrapper via .get(value, default))
3236
#
3337
# The first test checks that every member of cuda_binding_enum whose name is NOT
3438
# in binding_unmapped appears as either a key or a value of the mapping dict,
35-
# and conversely that every StrEnum member not in str_enum_unmapped also
39+
# and conversely that every str_enum member not in str_enum_unmapped also
3640
# appears.
3741
_CASES = []
3842

@@ -45,135 +49,154 @@
4549
_CASES.extend(
4650
[
4751
(
48-
_device._ADDRESSING_MODE_MAPPING,
4952
nvml.DeviceAddressingModeType,
53+
_device.AddressingMode,
54+
_device._ADDRESSING_MODE_MAPPING,
5055
# NONE means "no special addressing mode is active"; not a valid target
5156
{"DEVICE_ADDRESSING_MODE_NONE"},
5257
set(),
5358
),
5459
(
55-
_device._BRAND_TYPE_MAPPING,
5660
nvml.BrandType,
61+
None, # maps to plain str, not a StrEnum
62+
_device._BRAND_TYPE_MAPPING,
5763
# COUNT is a sentinel, not a real brand
5864
{"BRAND_COUNT"},
59-
set(), # maps to str, not StrEnum; reverse check is a no-op
65+
set(),
6066
),
6167
(
62-
_device._GPU_P2P_STATUS_MAPPING,
6368
nvml.GpuP2PStatus,
69+
_device.GpuP2PStatus,
70+
_device._GPU_P2P_STATUS_MAPPING,
6471
# Both the typo'd (SUPPORED) and corrected (SUPPORTED) spellings
6572
# share the same integer value; the mapping covers both via aliases
6673
set(),
6774
set(),
6875
),
6976
(
70-
_device._CLOCKS_EVENT_REASONS_MAPPING,
7177
nvml.ClocksEventReasons,
78+
_device.ClocksEventReasons,
79+
_device._CLOCKS_EVENT_REASONS_MAPPING,
7280
set(),
7381
set(),
7482
),
7583
(
76-
_device._EVENT_TYPE_MAPPING,
7784
nvml.EventType,
85+
_device.EventType,
86+
_device._EVENT_TYPE_MAPPING,
7887
set(),
7988
set(),
8089
),
8190
(
82-
_device._FAN_CONTROL_POLICY_MAPPING,
8391
nvml.FanControlPolicy,
92+
_device.FanControlPolicy,
93+
_device._FAN_CONTROL_POLICY_MAPPING,
8494
set(),
8595
set(),
8696
),
8797
(
88-
_device._COOLER_CONTROL_MAPPING,
8998
nvml.CoolerControl,
99+
_device.CoolerControl,
100+
_device._COOLER_CONTROL_MAPPING,
90101
# NONE means no signal; COUNT is a sentinel
91102
{"THERMAL_COOLER_SIGNAL_NONE", "THERMAL_COOLER_SIGNAL_COUNT"},
92103
set(),
93104
),
94105
(
95-
_device._COOLER_TARGET_MAPPING,
96106
nvml.CoolerTarget,
107+
_device.CoolerTarget,
108+
_device._COOLER_TARGET_MAPPING,
97109
# GPU_RELATED is a composite bitmask (GPU | MEMORY | POWER_SUPPLY);
98110
# the wrapper expands it into individual targets instead of mapping
99111
# it as a single entry
100112
{"THERMAL_GPU_RELATED"},
101113
set(),
102114
),
103115
(
104-
_device._THERMAL_CONTROLLER_MAPPING,
105116
nvml.ThermalController,
117+
_device.ThermalController,
118+
_device._THERMAL_CONTROLLER_MAPPING,
106119
# NONE and UNKNOWN are both handled by the .get() fallback that
107120
# returns ThermalController.UNKNOWN when the value is not in the mapping
108121
{"NONE", "UNKNOWN"},
109122
# UNKNOWN is the default returned by .get() for unrecognised controllers
110123
{"UNKNOWN"},
111124
),
112125
(
113-
_device._THERMAL_TARGET_MAPPING,
114126
nvml.ThermalTarget,
127+
_device.ThermalTarget,
128+
_device._THERMAL_TARGET_MAPPING,
115129
# UNKNOWN is a fallback sentinel; handled by .get()
116130
{"UNKNOWN"},
117131
set(),
118132
),
119133
(
120-
_device._NVLINK_VERSION_MAPPING,
121134
nvml.NvlinkVersion,
135+
None, # maps to tuple, not a StrEnum
136+
_device._NVLINK_VERSION_MAPPING,
122137
# VERSION_INVALID is a sentinel for "no NvLink present"
123138
{"VERSION_INVALID"},
124-
set(), # maps to tuple, not StrEnum; reverse check is a no-op
139+
set(),
125140
),
126141
(
127-
_system_events._SYSTEM_EVENT_TYPE_MAPPING,
128142
nvml.SystemEventType,
143+
_system_events.SystemEventType,
144+
_system_events._SYSTEM_EVENT_TYPE_MAPPING,
129145
set(),
130146
set(),
131147
),
132148
(
133-
_device._AFFINITY_SCOPE_MAPPING,
134149
nvml.AffinityScope,
150+
_device.AffinityScope,
151+
_device._AFFINITY_SCOPE_MAPPING,
135152
set(),
136153
set(),
137154
),
138155
(
139-
_device._GPU_P2P_CAPS_INDEX_MAPPING,
140156
nvml.GpuP2PCapsIndex,
157+
_device.GpuP2PCapsIndex,
158+
_device._GPU_P2P_CAPS_INDEX_MAPPING,
141159
# UNKNOWN is returned by the driver when an index is unrecognised;
142160
# it is not a capability the caller selects
143161
{"P2P_CAPS_INDEX_UNKNOWN"},
144162
# UNKNOWN is a driver-side fallback, not a caller-selectable index
145163
{"UNKNOWN"},
146164
),
147165
(
148-
_device._GPU_TOPOLOGY_LEVEL_MAPPING,
149166
nvml.GpuTopologyLevel,
167+
_device.GpuTopologyLevel,
168+
_device._GPU_TOPOLOGY_LEVEL_MAPPING,
150169
set(),
151170
set(),
152171
),
153172
(
154-
_device._CLOCK_ID_MAPPING,
155173
nvml.ClockId,
174+
_device.ClockId,
175+
_device._CLOCK_ID_MAPPING,
156176
# APP_CLOCK_TARGET and APP_CLOCK_DEFAULT are deprecated; COUNT is a sentinel
157177
{"APP_CLOCK_TARGET", "APP_CLOCK_DEFAULT", "COUNT"},
158178
set(),
159179
),
160180
(
161-
_device._CLOCK_TYPE_MAPPING,
162181
nvml.ClockType,
182+
_device.ClockType,
183+
_device._CLOCK_TYPE_MAPPING,
163184
# COUNT is a sentinel
164185
{"CLOCK_COUNT"},
165186
set(),
166187
),
167188
(
168-
_device._INFOROM_OBJECT_MAPPING,
169189
nvml.InforomObject,
190+
_device.InforomObject,
191+
_device._INFOROM_OBJECT_MAPPING,
170192
# COUNT is a sentinel
171193
{"INFOROM_COUNT"},
172194
set(),
173195
),
174196
(
175-
_device._TEMPERATURE_THRESHOLD_MAPPING,
176197
nvml.TemperatureThresholds,
198+
_device.TemperatureThresholds,
199+
_device._TEMPERATURE_THRESHOLD_MAPPING,
177200
# COUNT is a sentinel
178201
{"TEMPERATURE_THRESHOLD_COUNT"},
179202
set(),
@@ -189,11 +212,11 @@
189212

190213

191214
@pytest.mark.parametrize(
192-
"mapping, binding, binding_unmapped, str_enum_unmapped",
215+
"binding, str_enum, mapping, binding_unmapped, str_enum_unmapped",
193216
_CASES,
194-
ids=[x[1].__name__ for x in _CASES],
217+
ids=[x[0].__name__ for x in _CASES],
195218
)
196-
def test_wrapper_covers_all_binding_members(mapping, binding, binding_unmapped, str_enum_unmapped):
219+
def test_wrapper_covers_all_binding_members(binding, str_enum, mapping, binding_unmapped, str_enum_unmapped):
197220
"""Every cuda_binding enum member must appear in the wrapper mapping (or be allow-listed).
198221
199222
Also checks the reverse: every StrEnum wrapper member must appear in the
@@ -207,13 +230,11 @@ def test_wrapper_covers_all_binding_members(mapping, binding, binding_unmapped,
207230
assert not missing, f"{binding.__name__} has members not covered by the wrapper mapping: {missing}"
208231

209232
# Reverse check: every StrEnum member must also appear in the mapping.
210-
str_enum_instances = [m for m in (*mapping.keys(), *mapping.values()) if isinstance(m, StrEnum)]
211-
if str_enum_instances:
212-
str_enum_cls = type(str_enum_instances[0])
213-
required_str = set(str_enum_cls.__members__) - str_enum_unmapped
214-
covered_str = {m.name for m in str_enum_instances}
233+
if str_enum is not None:
234+
required_str = set(str_enum.__members__) - str_enum_unmapped
235+
covered_str = {m.name for m in (*mapping.keys(), *mapping.values()) if isinstance(m, str_enum)}
215236
missing_str = required_str - covered_str
216-
assert not missing_str, f"{str_enum_cls.__name__} has members not covered by the wrapper mapping: {missing_str}"
237+
assert not missing_str, f"{str_enum.__name__} has members not covered by the wrapper mapping: {missing_str}"
217238

218239

219240
def test_all_str_enums_in_cases():
@@ -245,12 +266,7 @@ def discover_str_enums() -> set[type]:
245266
found.add(obj)
246267
return found
247268

248-
covered = {
249-
type(obj)
250-
for mapping, _, _, _ in _CASES
251-
for obj in (*mapping.keys(), *mapping.values())
252-
if isinstance(obj, StrEnum)
253-
}
269+
covered = {x[1] for x in _CASES if x[1] is not None}
254270
uncovered = discover_str_enums() - covered - _UNBOUND_STR_ENUMS
255271
assert not uncovered, (
256272
f"StrEnum subclasses in cuda.core not covered by _CASES: "

0 commit comments

Comments
 (0)