|
| 1 | +#!/usr/bin/python |
| 2 | + |
| 3 | +import itertools |
| 4 | +import json |
| 5 | +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser |
| 6 | +import os |
| 7 | +import Queue |
| 8 | +import subprocess |
| 9 | +import re |
| 10 | +import sys |
| 11 | +import tempfile |
| 12 | +import threading |
| 13 | + |
| 14 | +from sig_utils import make, handle_function_list, get_signatures |
| 15 | +from signature_parser import SignatureParser |
| 16 | +from code_generator import CodeGenerator |
| 17 | + |
| 18 | +HERE = os.path.dirname(os.path.realpath(__file__)) |
| 19 | +TEST_FOLDER = os.path.abspath(os.path.join(HERE, "..", "test")) |
| 20 | +sys.path.append(TEST_FOLDER) |
| 21 | +WORKING_FOLDER = "test/varmat-compatibility" |
| 22 | + |
| 23 | +TEST_TEMPLATE = """ |
| 24 | +static void {test_name}() {{ |
| 25 | +{code} |
| 26 | +}} |
| 27 | +""" |
| 28 | + |
| 29 | +def run_command(command): |
| 30 | + """ |
| 31 | + Runs given command and waits until it finishes executing. |
| 32 | + :param command: command to execute |
| 33 | + """ |
| 34 | + proc = subprocess.Popen(command, stdout = subprocess.PIPE, stderr = subprocess.PIPE) |
| 35 | + stdout, stderr = proc.communicate() |
| 36 | + |
| 37 | + if proc.poll() == 0: |
| 38 | + return (True, stdout, stderr) |
| 39 | + else: |
| 40 | + return (False, stdout, stderr) |
| 41 | + |
| 42 | +def build_signature(prefix, cpp_code, debug): |
| 43 | + """ |
| 44 | + Try to build the given cpp code |
| 45 | +
|
| 46 | + Return true if the code was successfully built |
| 47 | +
|
| 48 | + :param prefix: Prefix to give file names so easier to debug |
| 49 | + :param cpp_code: Code to build |
| 50 | + :param debug: If true, don't delete temporary files |
| 51 | + """ |
| 52 | + f = tempfile.NamedTemporaryFile("w", dir = WORKING_FOLDER, prefix = prefix + "_", suffix = "_test.cpp", delete = False) |
| 53 | + f.write("#include <test/expressions/expression_test_helpers.hpp>\n\n") |
| 54 | + f.write(cpp_code) |
| 55 | + f.close() |
| 56 | + |
| 57 | + cpp_path = os.path.join(WORKING_FOLDER, os.path.basename(f.name)) |
| 58 | + |
| 59 | + object_path = cpp_path.replace(".cpp", ".o") |
| 60 | + dependency_path = cpp_path.replace(".cpp", ".d") |
| 61 | + stdout_path = cpp_path.replace(".cpp", ".stdout") |
| 62 | + stderr_path = cpp_path.replace(".cpp", ".stderr") |
| 63 | + |
| 64 | + successful, stdout, stderr = run_command([make, object_path]) |
| 65 | + |
| 66 | + if successful or not debug: |
| 67 | + try: |
| 68 | + os.remove(cpp_path) |
| 69 | + except OSError: |
| 70 | + pass |
| 71 | + |
| 72 | + try: |
| 73 | + os.remove(dependency_path) |
| 74 | + except OSError: |
| 75 | + pass |
| 76 | + |
| 77 | + try: |
| 78 | + os.remove(object_path) |
| 79 | + except OSError: |
| 80 | + pass |
| 81 | + else: |
| 82 | + if debug: |
| 83 | + with open(stdout_path, "w") as stdout_f: |
| 84 | + stdout_f.write(stdout.decode("utf-8")) |
| 85 | + |
| 86 | + with open(stderr_path, "w") as stderr_f: |
| 87 | + stderr_f.write(stderr.decode("utf-8")) |
| 88 | + |
| 89 | + return successful |
| 90 | + |
| 91 | +def main(functions_or_sigs, results_file, cores, debug): |
| 92 | + """ |
| 93 | + Attempt to build all the signatures in functions_or_sigs, or all the signatures |
| 94 | + associated with all the functions in functions_or_sigs, or if functions_or_sigs |
| 95 | + is empty every signature the stanc3 compiler exposes. |
| 96 | +
|
| 97 | + Results are written to a results json file. Individual signatures are classified |
| 98 | + as either compatible, incompatible, or irrelevant. |
| 99 | +
|
| 100 | + Compatible signatures can be compiled with varmat types in every argument that |
| 101 | + could possibly be a varmat (the matrix-like ones). |
| 102 | +
|
| 103 | + Incompatible signatures cannot all be built, and for irrelevant signatures it does |
| 104 | + not make sense to try to build them (there are no matrix arguments, or the function |
| 105 | + does not support reverse mode autodiff, etc). |
| 106 | +
|
| 107 | + Compilation is done in parallel using the number of specified cores. |
| 108 | +
|
| 109 | + :param functions_or_sigs: List of function names and/or signatures to benchmark |
| 110 | + :param results_file: File to use as a results cache |
| 111 | + :param cores: Number of cores to use for compiling |
| 112 | + :param debug: If true, don't delete temporary files |
| 113 | + """ |
| 114 | + all_signatures = get_signatures() |
| 115 | + functions, signatures = handle_function_list(functions_or_sigs) |
| 116 | + |
| 117 | + requested_functions = set(functions) |
| 118 | + |
| 119 | + compatible_signatures = set() |
| 120 | + incompatible_signatures = set() |
| 121 | + irrelevant_signatures = set() |
| 122 | + |
| 123 | + # Read the arguments and figure out the exact list of signatures to test |
| 124 | + signatures_to_check = set() |
| 125 | + for signature in all_signatures: |
| 126 | + sp = SignatureParser(signature) |
| 127 | + |
| 128 | + if len(requested_functions) > 0 and sp.function_name not in requested_functions: |
| 129 | + continue |
| 130 | + |
| 131 | + signatures_to_check.add(signature) |
| 132 | + |
| 133 | + work_queue = Queue.Queue() |
| 134 | + |
| 135 | + # For each signature, generate cpp code to test |
| 136 | + for signature in signatures_to_check: |
| 137 | + sp = SignatureParser(signature) |
| 138 | + |
| 139 | + if sp.is_high_order(): |
| 140 | + work_queue.put((n, signature, None)) |
| 141 | + continue |
| 142 | + |
| 143 | + cpp_code = "" |
| 144 | + any_overload_uses_varmat = False |
| 145 | + |
| 146 | + for m, overloads in enumerate(itertools.product(("Prim", "Rev", "RevVarmat"), repeat = sp.number_arguments())): |
| 147 | + cg = CodeGenerator() |
| 148 | + |
| 149 | + arg_list_base = cg.build_arguments(sp, overloads, size = 1) |
| 150 | + |
| 151 | + arg_list = [] |
| 152 | + for overload, arg in zip(overloads, arg_list_base): |
| 153 | + if arg.is_reverse_mode() and arg.is_varmat_compatible() and overload.endswith("Varmat"): |
| 154 | + any_overload_uses_varmat = True |
| 155 | + arg = cg.to_var_value(arg) |
| 156 | + |
| 157 | + arg_list.append(arg) |
| 158 | + |
| 159 | + cg.function_call_assign("stan::math::" + sp.function_name, *arg_list) |
| 160 | + |
| 161 | + cpp_code += TEST_TEMPLATE.format( |
| 162 | + test_name = sp.function_name + repr(m), |
| 163 | + code=cg.cpp(), |
| 164 | + ) |
| 165 | + |
| 166 | + if any_overload_uses_varmat: |
| 167 | + work_queue.put((work_queue.qsize(), signature, cpp_code)) |
| 168 | + else: |
| 169 | + print("{0} ... Irrelevant".format(signature.strip())) |
| 170 | + irrelevant_signatures.add(signature) |
| 171 | + |
| 172 | + output_lock = threading.Lock() |
| 173 | + |
| 174 | + if not os.path.exists(WORKING_FOLDER): |
| 175 | + os.mkdir(WORKING_FOLDER) |
| 176 | + |
| 177 | + work_queue_original_length = work_queue.qsize() |
| 178 | + |
| 179 | + # Test if each cpp file builds and update the output file |
| 180 | + # This part is done in parallel |
| 181 | + def worker(): |
| 182 | + while True: |
| 183 | + try: |
| 184 | + n, signature, cpp_code = work_queue.get(False) |
| 185 | + except Queue.Empty: |
| 186 | + return # If queue is empty, worker quits |
| 187 | + |
| 188 | + # Use signature as filename prefix to make it easier to find |
| 189 | + prefix = re.sub('[^0-9a-zA-Z]+', '_', signature.strip()) |
| 190 | + |
| 191 | + # Test the signature |
| 192 | + successful = build_signature(prefix, cpp_code, debug) |
| 193 | + |
| 194 | + # Acquire a lock to do I/O |
| 195 | + with output_lock: |
| 196 | + if successful: |
| 197 | + result_string = "Success" |
| 198 | + compatible_signatures.add(signature) |
| 199 | + else: |
| 200 | + result_string = "Fail" |
| 201 | + incompatible_signatures.add(signature) |
| 202 | + |
| 203 | + print("Results of test {0} / {1}, {2} ... ".format(n, work_queue_original_length, signature.strip()) + result_string) |
| 204 | + |
| 205 | + work_queue.task_done() |
| 206 | + |
| 207 | + for i in range(cores): |
| 208 | + threading.Thread(target = worker).start() |
| 209 | + |
| 210 | + work_queue.join() |
| 211 | + |
| 212 | + with open(results_file, "w") as f: |
| 213 | + json.dump({ "compatible_signatures" : list(compatible_signatures), |
| 214 | + "incompatible_signatures" : list(incompatible_signatures), |
| 215 | + "irrelevant_signatures" : list(irrelevant_signatures) |
| 216 | + }, f, indent = 4, sort_keys = True) |
| 217 | + |
| 218 | + |
| 219 | +class FullErrorMsgParser(ArgumentParser): |
| 220 | + """ |
| 221 | + Modified ArgumentParser that prints full error message on any error. |
| 222 | + """ |
| 223 | + |
| 224 | + def error(self, message): |
| 225 | + sys.stderr.write("error: %s\n" % message) |
| 226 | + self.print_help() |
| 227 | + sys.exit(2) |
| 228 | + |
| 229 | + |
| 230 | +def processCLIArgs(): |
| 231 | + """ |
| 232 | + Define and process the command line interface to the benchmark.py script. |
| 233 | + """ |
| 234 | + parser = FullErrorMsgParser( |
| 235 | + description="Generate and run_command benchmarks.", |
| 236 | + formatter_class=ArgumentDefaultsHelpFormatter, |
| 237 | + ) |
| 238 | + parser.add_argument( |
| 239 | + "--functions", |
| 240 | + nargs="+", |
| 241 | + type=str, |
| 242 | + default=[], |
| 243 | + help="Signatures and/or function names to benchmark.", |
| 244 | + ) |
| 245 | + parser.add_argument( |
| 246 | + "-j", |
| 247 | + type=int, |
| 248 | + default=1, |
| 249 | + help="Number of parallel cores to use.", |
| 250 | + ) |
| 251 | + parser.add_argument( |
| 252 | + "--debug", |
| 253 | + action="store_true", |
| 254 | + help="Keep cpp, stdout, and stderr for incompatible functions.", |
| 255 | + ) |
| 256 | + parser.add_argument( |
| 257 | + "results_file", |
| 258 | + type=str, |
| 259 | + default=None, |
| 260 | + help="File to save results in.", |
| 261 | + ) |
| 262 | + args = parser.parse_args() |
| 263 | + |
| 264 | + main(functions_or_sigs=args.functions, results_file = args.results_file, cores = args.j, debug = args.debug) |
| 265 | + |
| 266 | +if __name__ == "__main__": |
| 267 | + processCLIArgs() |
0 commit comments