Skip to content

Commit 9ca2f40

Browse files
switch the order fo comparisons to use sequence instead of list / tuple
1 parent 446b5af commit 9ca2f40

2 files changed

Lines changed: 32 additions & 30 deletions

File tree

cuda_core/cuda/core/experimental/_program.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
_handle_boolean_option,
1414
check_or_create_options,
1515
handle_return,
16-
is_list_or_tuple,
17-
is_nested_list_or_tuple,
16+
is_nested_sequence,
17+
is_sequence,
1818
)
1919

2020

@@ -247,11 +247,11 @@ def __post_init__(self):
247247
self._formatted_options.append(f"--dopt={'on' if self.device_code_optimize else 'off'}")
248248
if self.ptxas_options is not None:
249249
self._formatted_options.append("--ptxas-options")
250-
if is_list_or_tuple(self.ptxas_options):
250+
if isinstance(self.ptxas_options, str):
251+
self._formatted_options.append(self.ptxas_options)
252+
elif is_sequence(self.ptxas_options):
251253
for option in self.ptxas_options:
252254
self._formatted_options.append(option)
253-
else:
254-
self._formatted_options.append(self.ptxas_options)
255255
if self.max_register_count is not None:
256256
self._formatted_options.append(f"--maxrregcount={self.max_register_count}")
257257
if self.ftz is not None:
@@ -271,7 +271,9 @@ def __post_init__(self):
271271
if self.gen_opt_lto is not None and self.gen_opt_lto:
272272
self._formatted_options.append("--gen-opt-lto")
273273
if self.define_macro is not None:
274-
if is_nested_list_or_tuple(self.define_macro):
274+
if isinstance(self.define_macro, str):
275+
self._formatted_options.append(f"--define-macro={self.define_macro}")
276+
if is_nested_sequence(self.define_macro):
275277
for macro in self.define_macro:
276278
if isinstance(macro, tuple):
277279
assert len(macro) == 2
@@ -281,27 +283,26 @@ def __post_init__(self):
281283
elif isinstance(self.define_macro, tuple):
282284
assert len(self.define_macro) == 2
283285
self._formatted_options.append("--define-macro=MY_MACRO=999")
284-
else:
285-
self._formatted_options.append(f"--define-macro={self.define_macro}")
286286

287287
if self.undefine_macro is not None:
288-
if is_list_or_tuple(self.undefine_macro):
288+
if isinstance(self.undefine_macro, str):
289+
self._formatted_options.append(f"--undefine-macro={self.undefine_macro}")
290+
elif is_sequence(self.undefine_macro):
289291
for macro in self.undefine_macro:
290292
self._formatted_options.append(f"--undefine-macro={macro}")
291-
else:
292-
self._formatted_options.append(f"--undefine-macro={self.undefine_macro}")
293293
if self.include_path is not None:
294-
if is_list_or_tuple(self.include_path):
294+
if isinstance(self.include_path, str):
295+
self._formatted_options.append(f"--include-path={self.include_path}")
296+
elif is_sequence(self.include_path):
295297
for path in self.include_path:
296298
self._formatted_options.append(f"--include-path={path}")
297-
else:
298-
self._formatted_options.append(f"--include-path={self.include_path}")
299299
if self.pre_include is not None:
300-
if is_list_or_tuple(self.pre_include):
300+
if isinstance(self.pre_include, str):
301+
self._formatted_options.append(f"--pre-include={self.pre_include}")
302+
elif is_sequence(self.pre_include):
301303
for header in self.pre_include:
302304
self._formatted_options.append(f"--pre-include={header}")
303-
else:
304-
self._formatted_options.append(f"--pre-include={self.pre_include}")
305+
305306
if self.no_source_include is not None and self.no_source_include:
306307
self._formatted_options.append("--no-source-include")
307308
if self.std is not None:
@@ -327,23 +328,23 @@ def __post_init__(self):
327328
if self.no_display_error_number is not None and self.no_display_error_number:
328329
self._formatted_options.append("--no-display-error-number")
329330
if self.diag_error is not None:
330-
if is_list_or_tuple(self.diag_error):
331+
if isinstance(self.diag_error, int):
332+
self._formatted_options.append(f"--diag-error={self.diag_error}")
333+
elif is_sequence(self.diag_error):
331334
for error in self.diag_error:
332335
self._formatted_options.append(f"--diag-error={error}")
333-
else:
334-
self._formatted_options.append(f"--diag-error={self.diag_error}")
335336
if self.diag_suppress is not None:
336-
if is_list_or_tuple(self.diag_suppress):
337+
if isinstance(self.diag_suppress, int):
338+
self._formatted_options.append(f"--diag-suppress={self.diag_suppress}")
339+
elif is_sequence(self.diag_suppress):
337340
for suppress in self.diag_suppress:
338341
self._formatted_options.append(f"--diag-suppress={suppress}")
339-
else:
340-
self._formatted_options.append(f"--diag-suppress={self.diag_suppress}")
341342
if self.diag_warn is not None:
342-
if is_list_or_tuple(self.diag_warn):
343+
if isinstance(self.diag_warn, int):
344+
self._formatted_options.append(f"--diag-warn={self.diag_warn}")
345+
elif is_sequence(self.diag_warn):
343346
for warn in self.diag_warn:
344347
self._formatted_options.append(f"--diag-warn={warn}")
345-
else:
346-
self._formatted_options.append(f"--diag-warn={self.diag_warn}")
347348
if self.brief_diagnostics is not None:
348349
self._formatted_options.append(f"--brief-diagnostics={_handle_boolean_option(self.brief_diagnostics)}")
349350
if self.time is not None:

cuda_core/cuda/core/experimental/_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import functools
66
from collections import namedtuple
7+
from collections.abc import Sequence
78
from typing import Callable, Dict
89

910
from cuda import cuda, cudart, nvrtc
@@ -143,15 +144,15 @@ def get_device_from_ctx(ctx_handle) -> int:
143144
return device_id
144145

145146

146-
def is_list_or_tuple(obj):
147+
def is_sequence(obj):
147148
"""
148149
Check if the given object is a sequence (list or tuple).
149150
"""
150-
return isinstance(obj, (list, tuple))
151+
return isinstance(obj, Sequence)
151152

152153

153-
def is_nested_list_or_tuple(obj):
154+
def is_nested_sequence(obj):
154155
"""
155156
Check if the given object is a nested sequence (list or tuple with atleast one list or tuple element).
156157
"""
157-
return is_list_or_tuple(obj) and any(is_list_or_tuple(elem) for elem in obj)
158+
return is_sequence(obj) and any(is_sequence(elem) for elem in obj)

0 commit comments

Comments
 (0)