Skip to content

Commit 0a99e5a

Browse files
vzhurba01rwgk
authored andcommitted
Apply review
1 parent c7c7892 commit 0a99e5a

1 file changed

Lines changed: 81 additions & 72 deletions

File tree

cuda_bindings/setup.py

Lines changed: 81 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -76,49 +76,30 @@
7676
# - cuda_device_runtime_api.h
7777
}
7878

79-
# Assert that all headers exist
80-
header_dict = {}
81-
missing_headers = []
82-
include_path_list = [os.path.join(path, "include") for path in CUDA_HOME]
8379

84-
for library, header_list in required_headers.items():
85-
header_paths = []
86-
for header in header_list:
87-
path_candidate = [os.path.join(path, header) for path in include_path_list]
88-
for path in path_candidate:
89-
if os.path.exists(path):
90-
header_paths += [path]
91-
break
92-
if not os.path.exists(path):
93-
missing_headers += [header]
94-
95-
# Update dictionary with validated paths to headers
96-
header_dict[library] = header_paths
97-
98-
if missing_headers:
99-
error_message = "Couldn't find required headers: "
100-
error_message += ", ".join([header for header in missing_headers])
101-
raise RuntimeError(f'{error_message}\nIs CUDA_HOME setup correctly? (CUDA_HOME="{CUDA_HOME}")')
102-
103-
replace = {
104-
" __device_builtin__ ": " ",
105-
"CUDARTAPI ": " ",
106-
"typedef __device_builtin__ enum cudaError cudaError_t;": "typedef cudaError cudaError_t;",
107-
"typedef __device_builtin__ enum cudaOutputMode cudaOutputMode_t;": "typedef cudaOutputMode cudaOutputMode_t;",
108-
"typedef enum cudaError cudaError_t;": "typedef cudaError cudaError_t;",
109-
"typedef enum cudaOutputMode cudaOutputMode_t;": "typedef cudaOutputMode cudaOutputMode_t;",
110-
"typedef enum cudaDataType_t cudaDataType_t;": "",
111-
"typedef enum libraryPropertyType_t libraryPropertyType_t;": "",
112-
" enum ": " ",
113-
", enum ": ", ",
114-
"\\(enum ": "(",
115-
}
80+
def fetch_header_paths(required_headers, include_path_list):
81+
header_dict = {}
82+
missing_headers = []
83+
for library, header_list in required_headers.items():
84+
header_paths = []
85+
for header in header_list:
86+
path_candidate = [os.path.join(path, header) for path in include_path_list]
87+
for path in path_candidate:
88+
if os.path.exists(path):
89+
header_paths += [path]
90+
break
91+
else:
92+
missing_headers += [header]
93+
94+
# Update dictionary with validated paths to headers
95+
header_dict[library] = header_paths
11696

117-
found_types = []
118-
found_functions = []
119-
found_values = []
120-
found_struct = []
121-
struct_list = {}
97+
if missing_headers:
98+
error_message = "Couldn't find required headers: "
99+
error_message += ", ".join([header for header in missing_headers])
100+
raise RuntimeError(f'{error_message}\nIs CUDA_HOME setup correctly? (CUDA_HOME="{CUDA_HOME}")')
101+
102+
return header_dict
122103

123104

124105
class Struct:
@@ -149,38 +130,66 @@ def __repr__(self):
149130
return f"{self._name}: {self._member_names} with types {self._member_types}"
150131

151132

152-
print(f'Parsing headers in "{include_path_list}" (Caching = {PARSER_CACHING})')
153-
for library, header_paths in header_dict.items():
154-
print(f"Parsing {library} headers")
155-
parser = CParser(
156-
header_paths, cache="./cache_{}".format(library.split(".")[0]) if PARSER_CACHING else None, replace=replace
157-
)
133+
def parse_headers(header_dict):
134+
found_types = []
135+
found_functions = []
136+
found_values = []
137+
found_struct = []
138+
struct_list = {}
139+
140+
replace = {
141+
" __device_builtin__ ": " ",
142+
"CUDARTAPI ": " ",
143+
"typedef __device_builtin__ enum cudaError cudaError_t;": "typedef cudaError cudaError_t;",
144+
"typedef __device_builtin__ enum cudaOutputMode cudaOutputMode_t;": "typedef cudaOutputMode cudaOutputMode_t;",
145+
"typedef enum cudaError cudaError_t;": "typedef cudaError cudaError_t;",
146+
"typedef enum cudaOutputMode cudaOutputMode_t;": "typedef cudaOutputMode cudaOutputMode_t;",
147+
"typedef enum cudaDataType_t cudaDataType_t;": "",
148+
"typedef enum libraryPropertyType_t libraryPropertyType_t;": "",
149+
" enum ": " ",
150+
", enum ": ", ",
151+
"\\(enum ": "(",
152+
}
153+
154+
print(f'Parsing headers in "{include_path_list}" (Caching = {PARSER_CACHING})')
155+
for library, header_paths in header_dict.items():
156+
print(f"Parsing {library} headers")
157+
parser = CParser(
158+
header_paths, cache="./cache_{}".format(library.split(".")[0]) if PARSER_CACHING else None, replace=replace
159+
)
160+
161+
if library == "driver":
162+
CUDA_VERSION = parser.defs["macros"].get("CUDA_VERSION", "Unknown")
163+
print(f"Found CUDA_VERSION: {CUDA_VERSION}")
158164

159-
if library == "driver":
160-
CUDA_VERSION = parser.defs["macros"].get("CUDA_VERSION", "Unknown")
161-
print(f"Found CUDA_VERSION: {CUDA_VERSION}")
162-
163-
# Combine types with others since they sometimes get tangled
164-
found_types += {key for key in parser.defs["types"]}
165-
found_types += {key for key in parser.defs["structs"]}
166-
found_types += {key for key in parser.defs["unions"]}
167-
found_types += {key for key in parser.defs["enums"]}
168-
found_functions += {key for key in parser.defs["functions"]}
169-
found_values += {key for key in parser.defs["values"]}
170-
171-
for key, value in parser.defs["structs"].items():
172-
struct_list[key] = Struct(key, value["members"])
173-
for key, value in parser.defs["unions"].items():
174-
struct_list[key] = Struct(key, value["members"])
175-
176-
for key, value in struct_list.items():
177-
if key.startswith("anon_union") or key.startswith("anon_struct"):
178-
continue
179-
180-
found_struct += [key]
181-
discovered = value.discoverMembers(struct_list, key)
182-
if discovered:
183-
found_struct += discovered
165+
# Combine types with others since they sometimes get tangled
166+
found_types += {key for key in parser.defs["types"]}
167+
found_types += {key for key in parser.defs["structs"]}
168+
found_types += {key for key in parser.defs["unions"]}
169+
found_types += {key for key in parser.defs["enums"]}
170+
found_functions += {key for key in parser.defs["functions"]}
171+
found_values += {key for key in parser.defs["values"]}
172+
173+
for key, value in parser.defs["structs"].items():
174+
struct_list[key] = Struct(key, value["members"])
175+
for key, value in parser.defs["unions"].items():
176+
struct_list[key] = Struct(key, value["members"])
177+
178+
for key, value in struct_list.items():
179+
if key.startswith("anon_union") or key.startswith("anon_struct"):
180+
continue
181+
182+
found_struct += [key]
183+
discovered = value.discoverMembers(struct_list, key)
184+
if discovered:
185+
found_struct += discovered
186+
187+
return found_types, found_functions, found_values, found_struct, struct_list
188+
189+
190+
include_path_list = [os.path.join(path, "include") for path in CUDA_HOME]
191+
header_dict = fetch_header_paths(required_headers, include_path_list)
192+
found_types, found_functions, found_values, found_struct, struct_list = parse_headers(header_dict)
184193

185194
# ----------------------------------------------------------------------
186195
# Generate

0 commit comments

Comments
 (0)