|
| 1 | +import collections |
| 2 | +import numbers |
| 3 | +import os |
| 4 | +import statement_types |
| 5 | +from sig_utils import parse_array, non_differentiable_args, special_arg_values |
| 6 | + |
| 7 | +class CodeGenerator: |
| 8 | + """ |
| 9 | + This class generates C++ to test Stan functions |
| 10 | + """ |
| 11 | + def __init__(self): |
| 12 | + self.name_counter = 0 |
| 13 | + self.code_list = [] |
| 14 | + |
| 15 | + def _add_statement(self, statement): |
| 16 | + """ |
| 17 | + Add a statement to the code generator |
| 18 | + |
| 19 | + :param statement: An object of type statement_types.CppStatement |
| 20 | + """ |
| 21 | + if not isinstance(statement, statement_types.CppStatement): |
| 22 | + raise TypeError("Argument to FunctionGenerator._add_statement must be an instance of an object that inherits from CppStatement") |
| 23 | + |
| 24 | + self.code_list.append(statement) |
| 25 | + return statement |
| 26 | + |
| 27 | + def _get_next_name_suffix(self): |
| 28 | + """Get the next available """ |
| 29 | + self.name_counter += 1 |
| 30 | + return repr(self.name_counter - 1) |
| 31 | + |
| 32 | + def cpp(self): |
| 33 | + """Generate and return the c++ code corresponding to the list of statements in the code generator""" |
| 34 | + return os.linesep.join(statement.cpp() for statement in self.code_list) |
| 35 | + |
| 36 | + def build_arguments(self, signature_parser, arg_overloads, size): |
| 37 | + """ |
| 38 | + Generate argument variables for each of the arguments in the given signature_parser |
| 39 | + with the given overloads in arg_overloads and with the given size |
| 40 | +
|
| 41 | + :param signature_parser: An instance of SignatureParser |
| 42 | + :param arg_overloads: A list of argument overloads (Prim/Fwd/Rev/etc.) as strings |
| 43 | + :param size: Size of matrix-like arguments. This is not used for array arguments (which will effectively all be size 1) |
| 44 | + """ |
| 45 | + arg_list = [] |
| 46 | + for n, (overload, stan_arg) in enumerate(zip(arg_overloads, signature_parser.stan_args)): |
| 47 | + suffix = self._get_next_name_suffix() |
| 48 | + |
| 49 | + number_nested_arrays, inner_type = parse_array(stan_arg) |
| 50 | + |
| 51 | + # Check if argument is differentiable |
| 52 | + if inner_type == "int" or n in non_differentiable_args.get(signature_parser.function_name, []): |
| 53 | + overload = "Prim" |
| 54 | + |
| 55 | + # By default the variable value is None and a default will be substituted |
| 56 | + value = None |
| 57 | + |
| 58 | + # Check for special arguments (constrained variables or types) |
| 59 | + try: |
| 60 | + special_arg = special_arg_values[signature_parser.function_name][n] |
| 61 | + if isinstance(special_arg, str): |
| 62 | + inner_type = special_arg |
| 63 | + elif special_arg is not None: |
| 64 | + value = special_arg |
| 65 | + except KeyError: |
| 66 | + pass |
| 67 | + |
| 68 | + # The first case here is used for the array initializers in sig_utils.special_arg_values |
| 69 | + # Everything else uses the second case |
| 70 | + if number_nested_arrays > 0 and isinstance(value, collections.Iterable): |
| 71 | + arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = value) |
| 72 | + else: |
| 73 | + if inner_type == "int": |
| 74 | + arg = statement_types.IntVariable("int" + suffix, value) |
| 75 | + elif inner_type == "real": |
| 76 | + arg = statement_types.RealVariable(overload, "real" + suffix, value) |
| 77 | + elif inner_type in ("vector", "row_vector", "matrix"): |
| 78 | + arg = statement_types.MatrixVariable(overload, "matrix" + suffix, inner_type, size, value) |
| 79 | + elif inner_type == "rng": |
| 80 | + arg = statement_types.RngVariable("rng" + suffix) |
| 81 | + elif inner_type == "ostream_ptr": |
| 82 | + arg = statement_types.OStreamVariable("ostream" + suffix) |
| 83 | + elif inner_type == "scalar_return_type": |
| 84 | + arg = statement_types.ReturnTypeTVariable("ret_type" + suffix, *arg_list) |
| 85 | + elif inner_type == "simplex": |
| 86 | + arg = statement_types.SimplexVariable(overload, "simplex" + suffix, size, value) |
| 87 | + elif inner_type == "positive_definite_matrix": |
| 88 | + arg = statement_types.PositiveDefiniteMatrixVariable(overload, "positive_definite_matrix" + suffix, size, value) |
| 89 | + elif inner_type == "(vector, vector, data array[] real, data array[] int) => vector": |
| 90 | + arg = statement_types.AlgebraSolverFunctorVariable("functor" + suffix) |
| 91 | + elif inner_type == "(real, vector, ostream_ptr, vector) => vector": |
| 92 | + arg = statement_types.OdeFunctorVariable("functor" + suffix) |
| 93 | + else: |
| 94 | + raise Exception("Inner type " + inner_type + " not supported") |
| 95 | + |
| 96 | + if number_nested_arrays > 0: |
| 97 | + self._add_statement(arg) |
| 98 | + arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = arg) |
| 99 | + |
| 100 | + arg_list.append(self._add_statement(arg)) |
| 101 | + |
| 102 | + if signature_parser.is_rng(): |
| 103 | + arg_list.append(self._add_statement(statement_types.RngVariable("rng" + self._get_next_name_suffix()))) |
| 104 | + |
| 105 | + return arg_list |
| 106 | + |
| 107 | + def add(self, arg1, arg2): |
| 108 | + """ |
| 109 | + Generate code for arg1 + arg2 |
| 110 | + |
| 111 | + :param arg1: First argument |
| 112 | + :param arg1: Second argument |
| 113 | + """ |
| 114 | + return self._add_statement(statement_types.FunctionCall("stan::math::add", "sum_of_sums" + self._get_next_name_suffix(), arg1, arg2)) |
| 115 | + |
| 116 | + def convert_to_expression(self, arg, size = None): |
| 117 | + """ |
| 118 | + Generate code to convert arg to an expression type of given size. If size is None, use the argument size |
| 119 | + |
| 120 | + :param arg: Argument to convert to expression |
| 121 | + """ |
| 122 | + return self._add_statement(statement_types.ExpressionVariable(arg.name + "_expr" + self._get_next_name_suffix(), arg, size)) |
| 123 | + |
| 124 | + def expect_adj_eq(self, arg1, arg2): |
| 125 | + """ |
| 126 | + Generate code that checks that the adjoints of arg1 and arg2 are equal |
| 127 | + |
| 128 | + :param arg1: First argument |
| 129 | + :param arg2: Second argument |
| 130 | + """ |
| 131 | + return self._add_statement(statement_types.FunctionCall("stan::test::expect_adj_eq", None, arg1, arg2)) |
| 132 | + |
| 133 | + def expect_eq(self, arg1, arg2): |
| 134 | + """ |
| 135 | + Generate code that checks that values of arg1 and arg2 are equal |
| 136 | + |
| 137 | + :param arg1: First argument |
| 138 | + :param arg2: Second argument |
| 139 | + """ |
| 140 | + return self._add_statement(statement_types.FunctionCall("EXPECT_STAN_EQ", None, arg1, arg2)) |
| 141 | + |
| 142 | + def expect_leq_one(self, arg): |
| 143 | + """ |
| 144 | + Generate code to check that arg is less than or equal to one |
| 145 | + |
| 146 | + :param arg: Argument to check |
| 147 | + """ |
| 148 | + one = self._add_statement(statement_types.IntVariable("int" + self._get_next_name_suffix(), 1)) |
| 149 | + return self._add_statement(statement_types.FunctionCall("EXPECT_LE", None, arg, one)) |
| 150 | + |
| 151 | + def function_call_assign(self, cpp_function_name, *args): |
| 152 | + """ |
| 153 | + Generate code to call the c++ function given by cpp_function_name with given args and assign the result to another variable |
| 154 | + |
| 155 | + :param cpp_function_name: c++ function name to call |
| 156 | + :param args: list of arguments to pass to function |
| 157 | + """ |
| 158 | + return self._add_statement(statement_types.FunctionCall(cpp_function_name, "result" + self._get_next_name_suffix(), *args)) |
| 159 | + |
| 160 | + def grad(self, arg): |
| 161 | + """ |
| 162 | + Generate code to call stan::test::grad(arg) (equivalent of arg.grad()) |
| 163 | + |
| 164 | + :param arg: Argument to call grad on |
| 165 | + """ |
| 166 | + return self._add_statement(statement_types.FunctionCall("stan::test::grad", None, arg)) |
| 167 | + |
| 168 | + def recover_memory(self): |
| 169 | + """Generate code to call stan::math::recover_memory()""" |
| 170 | + return self._add_statement(statement_types.FunctionCall("stan::math::recover_memory", None)) |
| 171 | + |
| 172 | + def recursive_sum(self, arg): |
| 173 | + """ |
| 174 | + Generate code that repeatedly sums arg until all that is left is a scalar |
| 175 | + |
| 176 | + :param arg: Argument to sum |
| 177 | + """ |
| 178 | + return self._add_statement(statement_types.FunctionCall("stan::test::recursive_sum", "summed_result" + self._get_next_name_suffix(), arg)) |
| 179 | + |
| 180 | + def to_var_value(self, arg): |
| 181 | + """ |
| 182 | + Generate code to convert arg to a varmat |
| 183 | + |
| 184 | + :param arg: Argument to convert to varmat |
| 185 | + """ |
| 186 | + return self._add_statement(statement_types.FunctionCall("stan::math::to_var_value", arg.name + "_varmat" + self._get_next_name_suffix(), arg)) |
0 commit comments