Skip to content

Commit 40803f8

Browse files
authored
Merge pull request #2434 from stan-dev/feature/varmat-signatures3
varmat compatibility tests
2 parents f390d82 + 8d81e62 commit 40803f8

File tree

6 files changed

+532
-0
lines changed

6 files changed

+532
-0
lines changed

Jenkinsfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ pipeline {
328328
sh "python ./test/code_generator_test.py"
329329
sh "python ./test/signature_parser_test.py"
330330
sh "python ./test/statement_types_test.py"
331+
sh "python ./test/varmat_compatibility_summary_test.py"
332+
sh "python ./test/varmat_compatibility_test.py"
331333
withEnv(['PATH+TBB=./lib/tbb']) {
332334
sh "python ./test/expressions/test_expression_testing_framework.py"
333335
}

test/sig_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def reference_vector_argument(arg):
311311
overload_scalar = {
312312
"Prim": "double",
313313
"Rev": "stan::math::var",
314+
"RevVarmat": "stan::math::var",
314315
"Fwd": "stan::math::fvar<double>",
315316
"Mix": "stan::math::fvar<stan::math::var>",
316317
}

test/varmat_compatibility.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#!/usr/bin/python
2+
3+
import itertools
4+
import json
5+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
6+
import os
7+
import Queue
8+
import subprocess
9+
import re
10+
import sys
11+
import tempfile
12+
import threading
13+
14+
from sig_utils import make, handle_function_list, get_signatures
15+
from signature_parser import SignatureParser
16+
from code_generator import CodeGenerator
17+
18+
HERE = os.path.dirname(os.path.realpath(__file__))
19+
TEST_FOLDER = os.path.abspath(os.path.join(HERE, "..", "test"))
20+
sys.path.append(TEST_FOLDER)
21+
WORKING_FOLDER = "test/varmat-compatibility"
22+
23+
TEST_TEMPLATE = """
24+
static void {test_name}() {{
25+
{code}
26+
}}
27+
"""
28+
29+
def run_command(command):
30+
"""
31+
Runs given command and waits until it finishes executing.
32+
:param command: command to execute
33+
"""
34+
proc = subprocess.Popen(command, stdout = subprocess.PIPE, stderr = subprocess.PIPE)
35+
stdout, stderr = proc.communicate()
36+
37+
if proc.poll() == 0:
38+
return (True, stdout, stderr)
39+
else:
40+
return (False, stdout, stderr)
41+
42+
def build_signature(prefix, cpp_code, debug):
43+
"""
44+
Try to build the given cpp code
45+
46+
Return true if the code was successfully built
47+
48+
:param prefix: Prefix to give file names so easier to debug
49+
:param cpp_code: Code to build
50+
:param debug: If true, don't delete temporary files
51+
"""
52+
f = tempfile.NamedTemporaryFile("w", dir = WORKING_FOLDER, prefix = prefix + "_", suffix = "_test.cpp", delete = False)
53+
f.write("#include <test/expressions/expression_test_helpers.hpp>\n\n")
54+
f.write(cpp_code)
55+
f.close()
56+
57+
cpp_path = os.path.join(WORKING_FOLDER, os.path.basename(f.name))
58+
59+
object_path = cpp_path.replace(".cpp", ".o")
60+
dependency_path = cpp_path.replace(".cpp", ".d")
61+
stdout_path = cpp_path.replace(".cpp", ".stdout")
62+
stderr_path = cpp_path.replace(".cpp", ".stderr")
63+
64+
successful, stdout, stderr = run_command([make, object_path])
65+
66+
if successful or not debug:
67+
try:
68+
os.remove(cpp_path)
69+
except OSError:
70+
pass
71+
72+
try:
73+
os.remove(dependency_path)
74+
except OSError:
75+
pass
76+
77+
try:
78+
os.remove(object_path)
79+
except OSError:
80+
pass
81+
else:
82+
if debug:
83+
with open(stdout_path, "w") as stdout_f:
84+
stdout_f.write(stdout.decode("utf-8"))
85+
86+
with open(stderr_path, "w") as stderr_f:
87+
stderr_f.write(stderr.decode("utf-8"))
88+
89+
return successful
90+
91+
def main(functions_or_sigs, results_file, cores, debug):
92+
"""
93+
Attempt to build all the signatures in functions_or_sigs, or all the signatures
94+
associated with all the functions in functions_or_sigs, or if functions_or_sigs
95+
is empty every signature the stanc3 compiler exposes.
96+
97+
Results are written to a results json file. Individual signatures are classified
98+
as either compatible, incompatible, or irrelevant.
99+
100+
Compatible signatures can be compiled with varmat types in every argument that
101+
could possibly be a varmat (the matrix-like ones).
102+
103+
Incompatible signatures cannot all be built, and for irrelevant signatures it does
104+
not make sense to try to build them (there are no matrix arguments, or the function
105+
does not support reverse mode autodiff, etc).
106+
107+
Compilation is done in parallel using the number of specified cores.
108+
109+
:param functions_or_sigs: List of function names and/or signatures to benchmark
110+
:param results_file: File to use as a results cache
111+
:param cores: Number of cores to use for compiling
112+
:param debug: If true, don't delete temporary files
113+
"""
114+
all_signatures = get_signatures()
115+
functions, signatures = handle_function_list(functions_or_sigs)
116+
117+
requested_functions = set(functions)
118+
119+
compatible_signatures = set()
120+
incompatible_signatures = set()
121+
irrelevant_signatures = set()
122+
123+
# Read the arguments and figure out the exact list of signatures to test
124+
signatures_to_check = set()
125+
for signature in all_signatures:
126+
sp = SignatureParser(signature)
127+
128+
if len(requested_functions) > 0 and sp.function_name not in requested_functions:
129+
continue
130+
131+
signatures_to_check.add(signature)
132+
133+
work_queue = Queue.Queue()
134+
135+
# For each signature, generate cpp code to test
136+
for signature in signatures_to_check:
137+
sp = SignatureParser(signature)
138+
139+
if sp.is_high_order():
140+
work_queue.put((n, signature, None))
141+
continue
142+
143+
cpp_code = ""
144+
any_overload_uses_varmat = False
145+
146+
for m, overloads in enumerate(itertools.product(("Prim", "Rev", "RevVarmat"), repeat = sp.number_arguments())):
147+
cg = CodeGenerator()
148+
149+
arg_list_base = cg.build_arguments(sp, overloads, size = 1)
150+
151+
arg_list = []
152+
for overload, arg in zip(overloads, arg_list_base):
153+
if arg.is_reverse_mode() and arg.is_varmat_compatible() and overload.endswith("Varmat"):
154+
any_overload_uses_varmat = True
155+
arg = cg.to_var_value(arg)
156+
157+
arg_list.append(arg)
158+
159+
cg.function_call_assign("stan::math::" + sp.function_name, *arg_list)
160+
161+
cpp_code += TEST_TEMPLATE.format(
162+
test_name = sp.function_name + repr(m),
163+
code=cg.cpp(),
164+
)
165+
166+
if any_overload_uses_varmat:
167+
work_queue.put((work_queue.qsize(), signature, cpp_code))
168+
else:
169+
print("{0} ... Irrelevant".format(signature.strip()))
170+
irrelevant_signatures.add(signature)
171+
172+
output_lock = threading.Lock()
173+
174+
if not os.path.exists(WORKING_FOLDER):
175+
os.mkdir(WORKING_FOLDER)
176+
177+
work_queue_original_length = work_queue.qsize()
178+
179+
# Test if each cpp file builds and update the output file
180+
# This part is done in parallel
181+
def worker():
182+
while True:
183+
try:
184+
n, signature, cpp_code = work_queue.get(False)
185+
except Queue.Empty:
186+
return # If queue is empty, worker quits
187+
188+
# Use signature as filename prefix to make it easier to find
189+
prefix = re.sub('[^0-9a-zA-Z]+', '_', signature.strip())
190+
191+
# Test the signature
192+
successful = build_signature(prefix, cpp_code, debug)
193+
194+
# Acquire a lock to do I/O
195+
with output_lock:
196+
if successful:
197+
result_string = "Success"
198+
compatible_signatures.add(signature)
199+
else:
200+
result_string = "Fail"
201+
incompatible_signatures.add(signature)
202+
203+
print("Results of test {0} / {1}, {2} ... ".format(n, work_queue_original_length, signature.strip()) + result_string)
204+
205+
work_queue.task_done()
206+
207+
for i in range(cores):
208+
threading.Thread(target = worker).start()
209+
210+
work_queue.join()
211+
212+
with open(results_file, "w") as f:
213+
json.dump({ "compatible_signatures" : list(compatible_signatures),
214+
"incompatible_signatures" : list(incompatible_signatures),
215+
"irrelevant_signatures" : list(irrelevant_signatures)
216+
}, f, indent = 4, sort_keys = True)
217+
218+
219+
class FullErrorMsgParser(ArgumentParser):
220+
"""
221+
Modified ArgumentParser that prints full error message on any error.
222+
"""
223+
224+
def error(self, message):
225+
sys.stderr.write("error: %s\n" % message)
226+
self.print_help()
227+
sys.exit(2)
228+
229+
230+
def processCLIArgs():
231+
"""
232+
Define and process the command line interface to the benchmark.py script.
233+
"""
234+
parser = FullErrorMsgParser(
235+
description="Generate and run_command benchmarks.",
236+
formatter_class=ArgumentDefaultsHelpFormatter,
237+
)
238+
parser.add_argument(
239+
"--functions",
240+
nargs="+",
241+
type=str,
242+
default=[],
243+
help="Signatures and/or function names to benchmark.",
244+
)
245+
parser.add_argument(
246+
"-j",
247+
type=int,
248+
default=1,
249+
help="Number of parallel cores to use.",
250+
)
251+
parser.add_argument(
252+
"--debug",
253+
action="store_true",
254+
help="Keep cpp, stdout, and stderr for incompatible functions.",
255+
)
256+
parser.add_argument(
257+
"results_file",
258+
type=str,
259+
default=None,
260+
help="File to save results in.",
261+
)
262+
args = parser.parse_args()
263+
264+
main(functions_or_sigs=args.functions, results_file = args.results_file, cores = args.j, debug = args.debug)
265+
266+
if __name__ == "__main__":
267+
processCLIArgs()

0 commit comments

Comments
 (0)