Skip to content

Commit 51669a6

Browse files
authored
Merge pull request #2419 from stan-dev/feature/varmat-signatures
Update expression tests code generation
2 parents 80f8578 + 7318a0b commit 51669a6

13 files changed

Lines changed: 1082 additions & 281 deletions

Jenkinsfile

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def runTestsWin(String testPath, boolean buildLibs = true, boolean jumbo = false
3030
}
3131
}
3232

33+
3334
def deleteDirWin() {
3435
bat "attrib -r -s /s /d"
3536
deleteDir()
@@ -317,6 +318,34 @@ pipeline {
317318
echo "Distribution tests failed. Check out dist.log.zip artifact for test logs."
318319
}
319320
}
321+
}
322+
stage('Expressions test') {
323+
agent any
324+
steps {
325+
unstash 'MathSetup'
326+
script {
327+
sh "echo O=0 > make/local"
328+
sh "python ./test/code_generator_test.py"
329+
sh "python ./test/signature_parser_test.py"
330+
sh "python ./test/statement_types_test.py"
331+
withEnv(['PATH+TBB=./lib/tbb']) {
332+
sh "python ./test/expressions/test_expression_testing_framework.py"
333+
}
334+
withEnv(['PATH+TBB=./lib/tbb']) {
335+
try { sh "./runTests.py -j${env.PARALLEL} test/expressions" }
336+
finally { junit 'test/**/*.xml' }
337+
}
338+
sh "make clean-all"
339+
sh "echo STAN_THREADS=true >> make/local"
340+
withEnv(['PATH+TBB=./lib/tbb']) {
341+
try {
342+
sh "./runTests.py -j${env.PARALLEL} test/expressions --only-functions reduce_sum map_rect"
343+
}
344+
finally { junit 'test/**/*.xml' }
345+
}
346+
}
347+
}
348+
post { always { deleteDir() } }
320349
}
321350
stage('Threading tests') {
322351
agent any

runTests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ def handleExpressionTests(tests, only_functions, n_test_files):
304304
HERE = os.path.dirname(os.path.realpath(__file__))
305305
sys.path.append(os.path.join(HERE, "test"))
306306
sys.path.append(os.path.join(HERE, "test/expressions"))
307-
import generateExpressionTests
307+
import generate_expression_tests
308308

309-
generateExpressionTests.main(only_functions, n_test_files)
309+
generate_expression_tests.main(only_functions, n_test_files)
310310
for i in range(n_test_files):
311311
tests.append("test/expressions/tests%d_test.cpp" % i)
312312
elif only_functions:

stan/math/prim/fun/rank.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ inline int rank(const C& v, int s) {
2121
check_range("rank", "v", v.size(), s);
2222
--s; // adjust for indexing by one
2323
return apply_vector_unary<C>::reduce(v, [s](const auto& vec) {
24-
return (vec.array() < vec.coeff(s)).template cast<int>().sum();
24+
const auto& vec_ref = to_ref(vec);
25+
26+
return (vec_ref.array() < vec_ref.coeff(s)).template cast<int>().sum();
2527
});
2628
}
2729

test/code_generator.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import collections
2+
import numbers
3+
import os
4+
import statement_types
5+
from sig_utils import parse_array, non_differentiable_args, special_arg_values
6+
7+
class CodeGenerator:
8+
"""
9+
This class generates C++ to test Stan functions
10+
"""
11+
def __init__(self):
12+
self.name_counter = 0
13+
self.code_list = []
14+
15+
def _add_statement(self, statement):
16+
"""
17+
Add a statement to the code generator
18+
19+
:param statement: An object of type statement_types.CppStatement
20+
"""
21+
if not isinstance(statement, statement_types.CppStatement):
22+
raise TypeError("Argument to FunctionGenerator._add_statement must be an instance of an object that inherits from CppStatement")
23+
24+
self.code_list.append(statement)
25+
return statement
26+
27+
def _get_next_name_suffix(self):
28+
"""Get the next available """
29+
self.name_counter += 1
30+
return repr(self.name_counter - 1)
31+
32+
def cpp(self):
33+
"""Generate and return the c++ code corresponding to the list of statements in the code generator"""
34+
return os.linesep.join(statement.cpp() for statement in self.code_list)
35+
36+
def build_arguments(self, signature_parser, arg_overloads, size):
37+
"""
38+
Generate argument variables for each of the arguments in the given signature_parser
39+
with the given overloads in arg_overloads and with the given size
40+
41+
:param signature_parser: An instance of SignatureParser
42+
:param arg_overloads: A list of argument overloads (Prim/Fwd/Rev/etc.) as strings
43+
:param size: Size of matrix-like arguments. This is not used for array arguments (which will effectively all be size 1)
44+
"""
45+
arg_list = []
46+
for n, (overload, stan_arg) in enumerate(zip(arg_overloads, signature_parser.stan_args)):
47+
suffix = self._get_next_name_suffix()
48+
49+
number_nested_arrays, inner_type = parse_array(stan_arg)
50+
51+
# Check if argument is differentiable
52+
if inner_type == "int" or n in non_differentiable_args.get(signature_parser.function_name, []):
53+
overload = "Prim"
54+
55+
# By default the variable value is None and a default will be substituted
56+
value = None
57+
58+
# Check for special arguments (constrained variables or types)
59+
try:
60+
special_arg = special_arg_values[signature_parser.function_name][n]
61+
if isinstance(special_arg, str):
62+
inner_type = special_arg
63+
elif special_arg is not None:
64+
value = special_arg
65+
except KeyError:
66+
pass
67+
68+
# The first case here is used for the array initializers in sig_utils.special_arg_values
69+
# Everything else uses the second case
70+
if number_nested_arrays > 0 and isinstance(value, collections.Iterable):
71+
arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = value)
72+
else:
73+
if inner_type == "int":
74+
arg = statement_types.IntVariable("int" + suffix, value)
75+
elif inner_type == "real":
76+
arg = statement_types.RealVariable(overload, "real" + suffix, value)
77+
elif inner_type in ("vector", "row_vector", "matrix"):
78+
arg = statement_types.MatrixVariable(overload, "matrix" + suffix, inner_type, size, value)
79+
elif inner_type == "rng":
80+
arg = statement_types.RngVariable("rng" + suffix)
81+
elif inner_type == "ostream_ptr":
82+
arg = statement_types.OStreamVariable("ostream" + suffix)
83+
elif inner_type == "scalar_return_type":
84+
arg = statement_types.ReturnTypeTVariable("ret_type" + suffix, *arg_list)
85+
elif inner_type == "simplex":
86+
arg = statement_types.SimplexVariable(overload, "simplex" + suffix, size, value)
87+
elif inner_type == "positive_definite_matrix":
88+
arg = statement_types.PositiveDefiniteMatrixVariable(overload, "positive_definite_matrix" + suffix, size, value)
89+
elif inner_type == "(vector, vector, data array[] real, data array[] int) => vector":
90+
arg = statement_types.AlgebraSolverFunctorVariable("functor" + suffix)
91+
elif inner_type == "(real, vector, ostream_ptr, vector) => vector":
92+
arg = statement_types.OdeFunctorVariable("functor" + suffix)
93+
else:
94+
raise Exception("Inner type " + inner_type + " not supported")
95+
96+
if number_nested_arrays > 0:
97+
self._add_statement(arg)
98+
arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = arg)
99+
100+
arg_list.append(self._add_statement(arg))
101+
102+
if signature_parser.is_rng():
103+
arg_list.append(self._add_statement(statement_types.RngVariable("rng" + self._get_next_name_suffix())))
104+
105+
return arg_list
106+
107+
def add(self, arg1, arg2):
108+
"""
109+
Generate code for arg1 + arg2
110+
111+
:param arg1: First argument
112+
:param arg1: Second argument
113+
"""
114+
return self._add_statement(statement_types.FunctionCall("stan::math::add", "sum_of_sums" + self._get_next_name_suffix(), arg1, arg2))
115+
116+
def convert_to_expression(self, arg, size = None):
117+
"""
118+
Generate code to convert arg to an expression type of given size. If size is None, use the argument size
119+
120+
:param arg: Argument to convert to expression
121+
"""
122+
return self._add_statement(statement_types.ExpressionVariable(arg.name + "_expr" + self._get_next_name_suffix(), arg, size))
123+
124+
def expect_adj_eq(self, arg1, arg2):
125+
"""
126+
Generate code that checks that the adjoints of arg1 and arg2 are equal
127+
128+
:param arg1: First argument
129+
:param arg2: Second argument
130+
"""
131+
return self._add_statement(statement_types.FunctionCall("stan::test::expect_adj_eq", None, arg1, arg2))
132+
133+
def expect_eq(self, arg1, arg2):
134+
"""
135+
Generate code that checks that values of arg1 and arg2 are equal
136+
137+
:param arg1: First argument
138+
:param arg2: Second argument
139+
"""
140+
return self._add_statement(statement_types.FunctionCall("EXPECT_STAN_EQ", None, arg1, arg2))
141+
142+
def expect_leq_one(self, arg):
143+
"""
144+
Generate code to check that arg is less than or equal to one
145+
146+
:param arg: Argument to check
147+
"""
148+
one = self._add_statement(statement_types.IntVariable("int" + self._get_next_name_suffix(), 1))
149+
return self._add_statement(statement_types.FunctionCall("EXPECT_LE", None, arg, one))
150+
151+
def function_call_assign(self, cpp_function_name, *args):
152+
"""
153+
Generate code to call the c++ function given by cpp_function_name with given args and assign the result to another variable
154+
155+
:param cpp_function_name: c++ function name to call
156+
:param args: list of arguments to pass to function
157+
"""
158+
return self._add_statement(statement_types.FunctionCall(cpp_function_name, "result" + self._get_next_name_suffix(), *args))
159+
160+
def grad(self, arg):
161+
"""
162+
Generate code to call stan::test::grad(arg) (equivalent of arg.grad())
163+
164+
:param arg: Argument to call grad on
165+
"""
166+
return self._add_statement(statement_types.FunctionCall("stan::test::grad", None, arg))
167+
168+
def recover_memory(self):
169+
"""Generate code to call stan::math::recover_memory()"""
170+
return self._add_statement(statement_types.FunctionCall("stan::math::recover_memory", None))
171+
172+
def recursive_sum(self, arg):
173+
"""
174+
Generate code that repeatedly sums arg until all that is left is a scalar
175+
176+
:param arg: Argument to sum
177+
"""
178+
return self._add_statement(statement_types.FunctionCall("stan::test::recursive_sum", "summed_result" + self._get_next_name_suffix(), arg))
179+
180+
def to_var_value(self, arg):
181+
"""
182+
Generate code to convert arg to a varmat
183+
184+
:param arg: Argument to convert to varmat
185+
"""
186+
return self._add_statement(statement_types.FunctionCall("stan::math::to_var_value", arg.name + "_varmat" + self._get_next_name_suffix(), arg))

test/code_generator_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from signature_parser import SignatureParser
2+
from code_generator import CodeGenerator
3+
from statement_types import IntVariable, RealVariable, MatrixVariable
4+
import unittest
5+
6+
class CodeGeneratorTest(unittest.TestCase):
7+
def setUp(self):
8+
self.add = SignatureParser("add(real, vector) => vector")
9+
self.int_var = IntVariable("myint", 5)
10+
self.real_var1 = RealVariable("Rev", "myreal1", 0.5)
11+
self.real_var2 = RealVariable("Rev", "myreal2", 0.5)
12+
self.matrix_var = MatrixVariable("Rev", "mymatrix", "matrix", 2, 0.5)
13+
self.cg = CodeGenerator()
14+
15+
def test_prim_prim(self):
16+
self.cg.build_arguments(self.add, ["Prim", "Prim"], 1)
17+
self.assertEqual(self.cg.cpp(), """double real0 = 0.4;
18+
auto matrix1 = stan::test::make_arg<Eigen::Matrix<double, Eigen::Dynamic, 1>>(0.4, 1);""")
19+
20+
def test_prim_rev(self):
21+
self.cg.build_arguments(self.add, ["Prim", "Rev"], 1)
22+
self.assertEqual(self.cg.cpp(), """double real0 = 0.4;
23+
auto matrix1 = stan::test::make_arg<Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1>>(0.4, 1);""")
24+
25+
def test_rev_rev(self):
26+
self.cg.build_arguments(self.add, ["Rev", "Rev"], 1)
27+
self.assertEqual(self.cg.cpp(), """stan::math::var real0 = 0.4;
28+
auto matrix1 = stan::test::make_arg<Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1>>(0.4, 1);""")
29+
30+
def test_size(self):
31+
self.cg.build_arguments(self.add, ["Rev", "Rev"], 2)
32+
self.assertEqual(self.cg.cpp(), """stan::math::var real0 = 0.4;
33+
auto matrix1 = stan::test::make_arg<Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1>>(0.4, 2);""")
34+
35+
def test_add(self):
36+
self.cg.add(self.real_var1, self.real_var2)
37+
self.assertEqual(self.cg.cpp(), "auto sum_of_sums0 = stan::math::eval(stan::math::add(myreal1,myreal2));")
38+
39+
def test_convert_to_expression(self):
40+
self.cg.convert_to_expression(self.matrix_var)
41+
self.assertEqual(self.cg.cpp(), """int mymatrix_expr0_counter = 0;
42+
stan::test::counterOp<stan::math::var> mymatrix_expr0_counter_op(&mymatrix_expr0_counter);
43+
auto mymatrix_expr0 = mymatrix.block(0,0,2,2).unaryExpr(mymatrix_expr0_counter_op);""")
44+
45+
def test_expect_adj_eq(self):
46+
self.cg.expect_adj_eq(self.real_var1, self.real_var2)
47+
self.assertEqual(self.cg.cpp(), "stan::test::expect_adj_eq(myreal1,myreal2);")
48+
49+
def test_expect_eq(self):
50+
self.cg.expect_eq(self.real_var1, self.real_var2)
51+
self.assertEqual(self.cg.cpp(), "EXPECT_STAN_EQ(myreal1,myreal2);")
52+
53+
def test_expect_leq_one(self):
54+
self.cg.expect_leq_one(self.int_var)
55+
self.assertEqual(self.cg.cpp(), """int int0 = 1;
56+
EXPECT_LE(myint,int0);""")
57+
58+
def test_function_call_assign(self):
59+
self.cg.function_call_assign("stan::math::add", self.real_var1, self.real_var2)
60+
self.assertEqual(self.cg.cpp(), "auto result0 = stan::math::eval(stan::math::add(myreal1,myreal2));")
61+
62+
def test_grad(self):
63+
self.cg.grad(self.real_var1)
64+
self.assertEqual(self.cg.cpp(), "stan::test::grad(myreal1);")
65+
66+
def test_recover_memory(self):
67+
self.cg.recover_memory()
68+
self.assertEqual(self.cg.cpp(), "stan::math::recover_memory();")
69+
70+
def test_recursive_sum(self):
71+
self.cg.recursive_sum(self.real_var1)
72+
self.assertEqual(self.cg.cpp(), "auto summed_result0 = stan::math::eval(stan::test::recursive_sum(myreal1));")
73+
74+
def test_to_var_value(self):
75+
self.cg.to_var_value(self.matrix_var)
76+
self.assertEqual(self.cg.cpp(), "auto mymatrix_varmat0 = stan::math::eval(stan::math::to_var_value(mymatrix));")
77+
78+
if __name__ == '__main__':
79+
unittest.main()

test/expressions/expression_test_helpers.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,15 @@ void expect_eq(const std::vector<T>& a, const std::vector<T>& b,
135135
}
136136

137137
template <typename T, require_not_st_var<T>* = nullptr>
138-
void expect_adj_eq(const T& a, const T& b, const char* msg) {}
138+
void expect_adj_eq(const T& a, const T& b, const char* msg = "expect_ad_eq") {}
139139

140-
void expect_adj_eq(math::var a, math::var b, const char* msg) {
140+
void expect_adj_eq(math::var a, math::var b, const char* msg = "expect_ad_eq") {
141141
EXPECT_EQ(a.adj(), b.adj()) << msg;
142142
}
143143

144144
template <typename T1, typename T2, require_all_eigen_t<T1, T2>* = nullptr,
145145
require_vt_same<T1, T2>* = nullptr>
146-
void expect_adj_eq(const T1& a, const T2& b, const char* msg) {
146+
void expect_adj_eq(const T1& a, const T2& b, const char* msg = "expect_ad_eq") {
147147
EXPECT_EQ(a.rows(), b.rows()) << msg;
148148
EXPECT_EQ(a.cols(), b.cols()) << msg;
149149
const auto& a_ref = math::to_ref(a);
@@ -157,13 +157,15 @@ void expect_adj_eq(const T1& a, const T2& b, const char* msg) {
157157

158158
template <typename T>
159159
void expect_adj_eq(const std::vector<T>& a, const std::vector<T>& b,
160-
const char* msg) {
160+
const char* msg = "expect_ad_eq") {
161161
EXPECT_EQ(a.size(), b.size()) << msg;
162162
for (int i = 0; i < a.size(); i++) {
163163
expect_adj_eq(a[i], b[i], msg);
164164
}
165165
}
166166

167+
void grad(stan::math::var& a) { a.grad(); }
168+
167169
#define TO_STRING_(x) #x
168170
#define TO_STRING(x) TO_STRING_(x)
169171
#define EXPECT_STAN_EQ(a, b) \

0 commit comments

Comments
 (0)