diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index 4fffb3a893..714083e26e 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -10,8 +10,6 @@ let stan_namespace_qualify f = if String.is_suffix ~suffix:"functor__" f || String.contains f ':' then f else "stan::math::" ^ f -let is_stan_math f = ends_with "__" f || starts_with "stan::math::" f - (* retun true if the type of the expression is integer, real, or complex (e.g. not a container) *) let is_scalar e = diff --git a/src/stan_math_backend/Function_gen.ml b/src/stan_math_backend/Function_gen.ml new file mode 100644 index 0000000000..4fe4204e43 --- /dev/null +++ b/src/stan_math_backend/Function_gen.ml @@ -0,0 +1,551 @@ +(** Code generation for user defined functions and the relevant functors *) + +open Core_kernel +open Core_kernel.Poly +open Middle +open Fmt +open Expression_gen +open Statement_gen + +(** + Typename: The name of a template typename + Require: One of Stan's C++ template require giving a condition and the template names needing to satisfy that condition. + Bool: A boolean template type + *) +type template_parameter = + | Typename of string + | Require of string * string + | Bool of string + +let pp_template_parameter ppf template_parameter = + match template_parameter with + | Typename param_name -> pf ppf "typename %s" param_name + | Require (requirement, param_name) -> pf ppf "%s<%s>*" requirement param_name + | Bool param_name -> pf ppf "bool %s" param_name + +let pp_template_parameter_defaults ppf template_parameter = + match template_parameter with + | Require _ -> pf ppf "%a = nullptr" pp_template_parameter template_parameter + | _ -> pp_template_parameter ppf template_parameter + +(** + Pretty print a full C++ `template ` + *) +let pp_template ~defaults ppf template_parameters = + match template_parameters with + | [] -> () + | _ -> + pf ppf "template <@[%a@]>@ " + (list ~sep:comma + ( if defaults then pp_template_parameter_defaults + else pp_template_parameter ) ) + template_parameters + +type found_functor = + { struct_template: template_parameter option + ; arg_templates: template_parameter list + ; signature: string + ; defn: string } + +(** Detect if argument requires C++ template *) +let is_data_matrix_or_not_int_type = function + | UnsizedType.DataOnly, _, ut -> UnsizedType.is_eigen_type ut + | _, _, t when UnsizedType.is_int_type t -> false + | _ -> true + +(** Print template arguments for C++ functions that need templates + @param args A pack of [Program.fun_arg_decl] containing functions to detect templates. + @return A list of arguments with template parameter names added. + *) +let template_parameter_names (args : Program.fun_arg_decl) = + List.mapi args ~f:(fun i arg -> + match is_data_matrix_or_not_int_type arg with + | true -> Some (sprintf "T%d__" i) + | false -> None ) + +let requires (_, _, ut) = + match ut with + | UnsizedType.URowVector -> "stan::require_row_vector_t" + | UVector -> "stan::require_col_vector_t" + | UMatrix -> "stan::require_eigen_matrix_dynamic_t" + (* NB: Not unwinding array types due to the way arrays of eigens are printed *) + | _ -> "stan::require_stan_scalar_t" + +let optional_require_templates (name_ops : string option list) + (args : Program.fun_arg_decl) = + List.map2_exn name_ops args ~f:(fun name_op fun_arg -> + match name_op with + | Some param_name -> Some (Require (requires fun_arg, param_name)) + | None -> None ) + +let return_optional_arg_types (args : Program.fun_arg_decl) = + List.mapi args ~f:(fun i ((_, _, ut) as arg) -> + if UnsizedType.is_eigen_type ut && is_data_matrix_or_not_int_type arg then + Some (sprintf "stan::value_type_t" i) + else if is_data_matrix_or_not_int_type arg then Some (sprintf "T%d__" i) + else None ) + +let%expect_test "arg types templated correctly" = + [(AutoDiffable, "xreal", UReal); (DataOnly, "yint", UInt)] + |> template_parameter_names |> List.filter_opt |> String.concat ~sep:"," + |> print_endline ; + [%expect {| T0__ |}] + +(** Print the code for promoting stan real types + @param ppf A pretty printer$ + @param args A pack of arguments to detect whether they need to use the promotion rules. + *) +let pp_promoted_scalar ppf args = + match args with + | [] -> pf ppf "double" + | _ -> + let rec promote_args_chunked ppf args = + let chunk_till_empty ppf list_tail = + match list_tail with + | [] -> () + | _ -> pf ppf ",@ %a" promote_args_chunked list_tail in + match args with + | [] -> pf ppf "double" + | hd :: list_tail -> + pf ppf "@[stan::promote_args_t<@[%a%a@]>@]" (list ~sep:comma string) + hd chunk_till_empty list_tail in + promote_args_chunked ppf + List.(chunks_of ~length:5 (filter_opt (return_optional_arg_types args))) + +(** Pretty-prints a function's return-type, taking into account templated argument + promotion.*) +let pp_returntype ppf arg_types rt = + let scalar = str "@[%a@]" pp_promoted_scalar arg_types in + match rt with + | Some ut when UnsizedType.is_int_type ut -> + pf ppf "%a@ " pp_unsizedtype_custom_scalar ("int", ut) + | Some ut -> pf ppf "%a@ " pp_unsizedtype_custom_scalar (scalar, ut) + | None -> pf ppf "void@ " + +let pp_eigen_arg_to_ref ppf arg_types = + let pp_ref ppf name = + pf ppf "@[const auto& %s = stan::math::to_ref(%s);@]" name + (name ^ "_arg__") in + pf ppf "@[%a@]@ " (list ~sep:cut pp_ref) + (List.filter_map + ~f:(fun (_, name, ut) -> + if UnsizedType.is_eigen_type ut then Some name else None ) + arg_types ) + +(** Print the type of an object. + @param ppf A pretty printer + @param custom_scalar_opt A string representing a types inner scalar value. + @param name The name of the object + @param ut The unsized type of the object + *) +let pp_arg ppf (custom_scalar_opt, (_, name, ut)) = + let scalar = + match custom_scalar_opt with + | Some scalar -> scalar + | None -> stantype_prim_str ut in + (* we add the _arg suffix for any Eigen types *) + pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) + name + +let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) = + let scalar = + match custom_scalar_opt with + | Some scalar -> scalar + | None -> stantype_prim_str ut in + (* we add the _arg suffix for any Eigen types *) + let opt_arg_suffix = + if UnsizedType.is_eigen_type ut then name ^ "_arg__" else name in + pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) + opt_arg_suffix + +let typename parameter_name = Typename parameter_name + +(** Construct an object with it's needed templates for function signatures. + @param is_possibly_eigen_expr if true, argument can possibly be an unevaluated eigen expression. + @param fdargs A sexp list of strings representing C++ types. + *) +let templates_and_args (is_possibly_eigen_expr : bool) + (fdargs : Program.fun_arg_decl) : + string list * template_parameter list * string list = + let arg_type_templates = template_parameter_names fdargs in + let require_arg_templates = + optional_require_templates arg_type_templates fdargs in + ( List.filter_opt arg_type_templates + , List.filter_opt require_arg_templates + , if not is_possibly_eigen_expr then + List.map + ~f:(fun a -> str "%a" pp_arg a) + (List.zip_exn arg_type_templates fdargs) + else + List.map + ~f:(fun a -> str "%a" pp_arg_eigen_suffix a) + (List.zip_exn arg_type_templates fdargs) ) + +let mk_extra_args templates args = + List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args) + +(** + Prints boilerplate at start of function. Body of function wrapped in a `try` block. + *) +let pp_fun_body fdargs fdsuffix ppf (Stmt.Fixed.{pattern; _} as fdbody) = + pf ppf "@[using local_scalar_t__ =@ %a;@]@," pp_promoted_scalar fdargs ; + pf ppf "int current_statement__ = 0; @ " ; + if List.exists ~f:(fun (_, _, ut) -> UnsizedType.is_eigen_type ut) fdargs then + pp_eigen_arg_to_ref ppf fdargs ; + ( match fdsuffix with + | Fun_kind.FnLpdf _ | FnTarget -> () + | FnPlain | FnRng -> + pf ppf "%s@ " "static constexpr bool propto__ = true;" ; + pf ppf "%s@ " "(void) propto__;" ) ; + pf ppf "%s@ " + "local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN());" ; + pf ppf "%a" pp_unused "DUMMY_VAR__" ; + let blocked_fdbody = + match pattern with + | SList stmts -> {fdbody with pattern= Block stmts} + | Block _ -> fdbody + | _ -> {fdbody with pattern= Block [fdbody]} in + pp_located_error ppf (pp_statement, blocked_fdbody) ; + pf ppf "@ " + +(** + Functor to generate a pretty printer for the function signature. + *) +let gen_pp_sig fdargs fdrt extra_templates extra ppf (name, args, variadic) = + Format.open_vbox 2 ; + pp_returntype ppf fdargs fdrt ; + let args, variadic_args = + match variadic with + | `ReduceSum -> List.split_n args 3 + | `VariadicODE -> List.split_n args 2 + | `VariadicDAE -> List.split_n args 3 + | `None -> (args, []) in + let arg_strs = + args + @ mk_extra_args extra_templates extra + @ ["std::ostream* pstream__"] + @ variadic_args in + pf ppf "%s(@[%a@]) " name (list ~sep:comma string) arg_strs ; + Format.close_box () + +(** Print the C++ function definition. + @param ppf A pretty printer + Refactor this please - one idea might be to have different functions for + printing user defined distributions vs rngs vs regular functions. + *) +let pp_fun_def ppf + ( Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} + , (functors : (string, found_functor list) Hashtbl.t) + , (forward_decls : (string * template_parameter list) Hash_set.t) + , (funs_used_in_reduce_sum : String.Set.t) + , (funs_used_in_variadic_ode : String.Set.t) + , (funs_used_in_variadic_dae : String.Set.t) ) = + let extra, template_extra_params = + match fdsuffix with + | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) + | FnRng -> (["base_rng__"], ["RNG"]) + | FnLpdf _ | FnPlain -> ([], []) in + let template_parameter_and_arg_names is_possibly_eigen_expr variadic_fun_type + = + let template_param_names, template_require_checks, args = + templates_and_args is_possibly_eigen_expr fdargs in + let template_params = + List.(map ~f:typename (template_param_names @ template_extra_params)) + @ template_require_checks in + match (fdsuffix, variadic_fun_type) with + | (FnLpdf _ | FnTarget), `None -> (Bool "propto__" :: template_params, args) + | _ -> (template_params, args) in + let template_params, templated_args = + template_parameter_and_arg_names true `None in + let pp_fun_sig = gen_pp_sig fdargs fdrt template_extra_params extra in + let signature = str "%a" pp_fun_sig (fdname, templated_args, `None) in + (* We want to print the [* = nullptr] at most once, and preferrably on a forward decl *) + let template_parameter_default_values = + Option.is_none fdbody + || not (Hash_set.mem forward_decls (signature, template_params)) in + pf ppf "%a%a" + (pp_template ~defaults:template_parameter_default_values) + template_params pp_fun_sig + (fdname, templated_args, `None) ; + match fdbody with + | None -> + pf ppf ";@ " ; + (* Side Effect: *) + Hash_set.add forward_decls (signature, template_params) + | Some fdbody -> + pp_block ppf (pp_fun_body fdargs fdsuffix, fdbody) ; + let register_functor (str_args, args, variadic_fun_type) = + let suffix = + match variadic_fun_type with + | `None -> functor_suffix + | `ReduceSum -> reduce_sum_functor_suffix + | `VariadicODE -> variadic_ode_functor_suffix + | `VariadicDAE -> variadic_dae_functor_suffix in + let functor_name = fdname ^ suffix in + let struct_template = + match (fdsuffix, variadic_fun_type) with + | FnLpdf _, `ReduceSum -> Some (Bool "propto__") + | _ -> None in + let arg_templates, templated_args = + template_parameter_and_arg_names false variadic_fun_type in + let op_signature = + str "%a" pp_fun_sig ("operator()", templated_args, variadic_fun_type) + in + let operator_paren_sig = + str "%a@ const@,{@. return %a;@.}@." pp_fun_sig + ( functor_name + ^ (if struct_template <> None then "" else "") + ^ "::operator()" + , templated_args + , variadic_fun_type ) + pp_call_str + ( ( match fdsuffix with + | FnLpdf _ | FnTarget -> fdname ^ "" + | _ -> fdname ) + , str_args + @ List.map ~f:(fun (_, name, _) -> name) args + @ extra @ ["pstream__"] ) in + (* Side Effect: *) + Hashtbl.add_multi functors ~key:functor_name + ~data: + { struct_template + ; arg_templates + ; signature= op_signature + ; defn= operator_paren_sig } in + register_functor ([], fdargs, `None) ; + if String.Set.mem funs_used_in_reduce_sum fdname then + (* Produces the reduce_sum functors that has the pstream argument + as the third and not last argument *) + match fdargs with + | (_, slice, _) :: (_, start, _) :: (_, end_, _) :: rest -> + register_functor + ([slice; start ^ " + 1"; end_ ^ " + 1"], rest, `ReduceSum) + | _ -> + Common.FatalError.fatal_error_msg + [%message + "Ill-formed reduce_sum call!" (fdargs : Program.fun_arg_decl)] + else if String.Set.mem funs_used_in_variadic_ode fdname then + (* Produces the variadic ode functors that has the pstream argument + as the third and not last argument *) + register_functor ([], fdargs, `VariadicODE) + else if String.Set.mem funs_used_in_variadic_dae fdname then + (* Produces the variadic DAE functors that has the pstream argument + as the fourth and not last argument *) + register_functor ([], fdargs, `VariadicDAE) + +let pp_standalone_fun_def namespace_fun ppf + Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} = + let extra, extra_templates = + match fdsuffix with + | Fun_kind.FnTarget -> + (["lp__"; "lp_accum__"], ["double"; "stan::math::accumulator"]) + | FnRng -> (["base_rng__"], ["boost::ecuyer1988"]) + | FnLpdf _ | FnPlain -> ([], []) in + let args = + List.map + ~f:(fun (_, name, ut) -> + str "const %a& %s" pp_unsizedtype_custom_scalar + (stantype_prim_str ut, ut) + name ) + fdargs in + let pp_sig_standalone ppf _ = + let arg_strs = + args + @ mk_extra_args extra_templates extra + @ ["std::ostream* pstream__ = nullptr"] in + pf ppf "(@[%a@]) " (list ~sep:comma string) arg_strs in + let mark_function_comment = "// [[stan::function]]" in + let return_type = match fdrt with None -> "void" | _ -> "auto" in + let return_stmt = match fdrt with None -> "" | _ -> "return " in + match fdbody with + | None -> pf ppf ";@ " + | Some _ -> + pf ppf "@,%s@,%s %s%a @,{@, %s%s::%a;@,}@," mark_function_comment + return_type fdname pp_sig_standalone "" return_stmt namespace_fun + pp_call_str + ( ( match fdsuffix with + | FnLpdf _ | FnTarget -> fdname ^ "" + | FnRng | FnPlain -> fdname ) + , List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"] + ) + +let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool) + (p : Program.Numbered.t) = + let rec find_functors_expr accum Expr.Fixed.{pattern; _} = + String.Set.union accum + ( match pattern with + | FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _) + when variadic_fn_test x -> + String.Set.of_list [Utils.stdlib_distribution_name f] + | x -> Expr.Fixed.Pattern.fold find_functors_expr accum x ) in + let rec find_functors_stmt accum stmt = + Stmt.Fixed.( + Pattern.fold find_functors_expr find_functors_stmt accum stmt.pattern) + in + Program.fold find_functors_expr find_functors_stmt String.Set.empty p + +let collect_functors_functions (p : Program.Numbered.t) = + let (functors : (string, found_functor list) Hashtbl.t) = + String.Table.create () in + let forward_decls = Hash_set.Poly.create () in + let reduce_sum_fns = + is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p in + let variadic_ode_fns = + is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p in + let variadic_dae_fns = + is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p in + let pp_fun_def_with_variadic_fn_list ppf fblock = + (hovbox ~indent:2 pp_fun_def) + ppf + ( fblock + , functors + , forward_decls + , reduce_sum_fns + , variadic_ode_fns + , variadic_dae_fns ) in + ( str "@[%a@]" + (list ~sep:cut pp_fun_def_with_variadic_fn_list) + p.functions_block + , functors ) + +let pp_functions_functors ppf (p : Program.Numbered.t) = + let fns_str, functors = collect_functors_functions p in + let pp_functor_decls ppf tbl = + Hashtbl.iteri tbl ~f:(fun ~key ~data -> + pf ppf "@[%astruct %s {@,%aconst;@]@,};@." + (option + (option (fun ppf template_param -> + pf ppf "template <%a>@ " pp_template_parameter_defaults + template_param ) ) ) + (Option.map + ~f:(fun functor_t -> functor_t.struct_template) + (List.hd data) ) + key + (list ~sep:(any "const;@,") (fun ppf (template_parameters, sign) -> + pf ppf "%a@[%a@]" + (pp_template ~defaults:true) + template_parameters text sign ) ) + (List.map + ~f:(fun {arg_templates; signature; _} -> (arg_templates, signature)) + data ) ) in + let pp_functors ppf functor_tbl = + Hashtbl.iter functor_tbl ~f:(fun data -> + List.iter data ~f:(fun {struct_template; defn; arg_templates; _} -> + pf ppf "%a%a%s@." + (option (fun ppf template_param -> + pf ppf "template <%a>@ " pp_template_parameter template_param ) + ) + struct_template + (pp_template ~defaults:false) + arg_templates defn ) ) in + pf ppf "%a@ %s@ %a" pp_functor_decls functors fns_str pp_functors functors + +(* Testing code *) + +let pp_fun_def_w_rs a b = + pp_fun_def a + ( b + , String.Table.create () + , Hash_set.Poly.create () + , String.Set.empty + , String.Set.empty + , String.Set.empty ) + +let%expect_test "udf" = + let with_no_loc stmt = + Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in + let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in + { fdrt= None + ; fdname= "sars" + ; fdsuffix= FnPlain + ; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)] + ; fdbody= + Stmt.Fixed.Pattern.Return + (Some + ( w + @@ FunApp + ( StanLib ("add", FnPlain, AoS) + , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) + |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc + |> Some + ; fdloc= Location_span.empty } + |> str "@[%a" pp_fun_def_w_rs + |> print_endline ; + [%expect + {| + template * = nullptr, + stan::require_row_vector_t* = nullptr> + void + sars(const T0__& x_arg__, const T1__& y_arg__, std::ostream* pstream__) { + using local_scalar_t__ = + stan::promote_args_t, + stan::value_type_t>; + int current_statement__ = 0; + const auto& x = stan::math::to_ref(x_arg__); + const auto& y = stan::math::to_ref(y_arg__); + static constexpr bool propto__ = true; + (void) propto__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + (void) DUMMY_VAR__; // suppress unused var warning + try { + return stan::math::add(x, 1); + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + + } |}] + +let%expect_test "udf-expressions" = + let with_no_loc stmt = + Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in + let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in + { fdrt= Some UMatrix + ; fdname= "sars" + ; fdsuffix= FnPlain + ; fdargs= + [ (DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector) + ; (AutoDiffable, "z", URowVector); (AutoDiffable, "w", UArray UMatrix) ] + ; fdbody= + Stmt.Fixed.Pattern.Return + (Some + ( w + @@ FunApp + ( StanLib ("add", FnPlain, AoS) + , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) + |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc + |> Some + ; fdloc= Location_span.empty } + |> str "@[%a" pp_fun_def_w_rs + |> print_endline ; + [%expect + {| + template * = nullptr, + stan::require_row_vector_t* = nullptr, + stan::require_row_vector_t* = nullptr, + stan::require_stan_scalar_t* = nullptr> + Eigen::Matrix, stan::value_type_t, + stan::value_type_t, T3__>, -1, -1> + sars(const T0__& x_arg__, const T1__& y_arg__, const T2__& z_arg__, + const std::vector>& w, + std::ostream* pstream__) { + using local_scalar_t__ = + stan::promote_args_t, + stan::value_type_t, + stan::value_type_t, T3__>; + int current_statement__ = 0; + const auto& x = stan::math::to_ref(x_arg__); + const auto& y = stan::math::to_ref(y_arg__); + const auto& z = stan::math::to_ref(z_arg__); + static constexpr bool propto__ = true; + (void) propto__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + (void) DUMMY_VAR__; // suppress unused var warning + try { + return stan::math::add(x, 1); + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + + } |}] diff --git a/src/stan_math_backend/Function_gen.mli b/src/stan_math_backend/Function_gen.mli new file mode 100644 index 0000000000..41872842e6 --- /dev/null +++ b/src/stan_math_backend/Function_gen.mli @@ -0,0 +1,10 @@ +open Middle + +val pp_functions_functors : Format.formatter -> Program.Numbered.t -> unit +(** Pretty-print all user defined functions. Creates functor structs as needed + *) + +val pp_standalone_fun_def : + string -> Format.formatter -> 'a Program.fun_def -> unit +(** Creates functions outside the model namespaces which only call the ones + inside the namespaces *) diff --git a/src/stan_math_backend/Stan_math_code_gen.ml b/src/stan_math_backend/Stan_math_code_gen.ml index ede0fe0690..2aabd43321 100644 --- a/src/stan_math_backend/Stan_math_code_gen.ml +++ b/src/stan_math_backend/Stan_math_code_gen.ml @@ -22,38 +22,7 @@ open Middle open Fmt open Expression_gen open Statement_gen - -(* TODO: move to seperate file and code gen more like this with a structured type *) -type template = - | Typename of string - | Require of string * string - | Bool of string - -let pp_template ppf template = - match template with - | Typename t -> pf ppf "typename %s" t - | Require (r, t) -> pf ppf "%s<%s>*" r t - | Bool s -> pf ppf "bool %s" s - -let pp_template_defaults ppf template = - match template with - | Require _ -> pf ppf "%a = nullptr" pp_template template - | _ -> pp_template ppf template - -let pp_templates ~defaults ppf templates = - match templates with - | [] -> () - | _ -> - pf ppf "template <@[%a@]>@ " - (list ~sep:comma - (if defaults then pp_template_defaults else pp_template) ) - templates - -type found_functor = - { struct_template: template option - ; arg_templates: template list - ; signature: string - ; defn: string } +open Function_gen let standalone_functions = ref false @@ -69,8 +38,6 @@ let stanc_args_to_print = |> List.filter ~f:sans_model_and_hpp_paths |> String.concat ~sep:" " -let pp_unused = fmt "(void) %s; // suppress unused var warning" - (** Print name of model function. @param prog_name Name of the Stan program. @param fname Name of the function. @@ -80,335 +47,6 @@ let pp_function__ ppf (prog_name, fname) = {|@[static constexpr const char* function__ = "%s_namespace::%s";@,%a@]|} prog_name fname pp_unused "function__" -(** Print the body of exception handling for functions *) -let pp_located ppf _ = - pf ppf - {|stan::lang::rethrow_located(e, locations_array__[current_statement__]);|} - -(** Detect if argument requires C++ template *) -let arg_needs_template = function - | UnsizedType.DataOnly, _, t -> UnsizedType.is_eigen_type t - | _, _, t when UnsizedType.is_int_type t -> false - | _ -> true - -(** Print template arguments for C++ functions that need templates - @param args A pack of [Program.fun_arg_decl] containing functions to detect templates. - @return A list of arguments with template parameter names added. - *) -let maybe_templated_arg_types (args : Program.fun_arg_decl) = - List.mapi args ~f:(fun i a -> - match arg_needs_template a with - | true -> Some (sprintf "T%d__" i) - | false -> None ) - -let maybe_require_templates (names : string option list) - (args : Program.fun_arg_decl) = - let require_for_arg arg = - match trd3 arg with - | UnsizedType.URowVector -> "stan::require_row_vector_t" - | UVector -> "stan::require_col_vector_t" - | UMatrix -> "stan::require_eigen_matrix_dynamic_t" - (* NB: Not unwinding array types due to the way arrays of eigens are printed *) - | _ -> "stan::require_stan_scalar_t" in - List.map2_exn names args ~f:(fun name a -> - match name with - | Some t -> Some (Require (require_for_arg a, t)) - | None -> None ) - -let return_arg_types (args : Program.fun_arg_decl) = - List.mapi args ~f:(fun i ((_, _, ut) as a) -> - if UnsizedType.is_eigen_type ut && arg_needs_template a then - Some (sprintf "stan::value_type_t" i) - else if arg_needs_template a then Some (sprintf "T%d__" i) - else None ) - -let%expect_test "arg types templated correctly" = - [(AutoDiffable, "xreal", UReal); (DataOnly, "yint", UInt)] - |> maybe_templated_arg_types |> List.filter_opt |> String.concat ~sep:"," - |> print_endline ; - [%expect {| T0__ |}] - -(** Print the code for promoting stan real types - @param ppf A pretty printer - @param args A pack of arguments to detect whether they need to use the promotion rules. - *) -let pp_promoted_scalar ppf args = - match args with - | [] -> pf ppf "double" - | _ -> - let rec promote_args_chunked ppf args = - let go ppf tl = - match tl with [] -> () | _ -> pf ppf ",@ %a" promote_args_chunked tl - in - match args with - | [] -> pf ppf "double" - | hd :: tl -> - pf ppf "@[stan::promote_args_t<@[%a%a@]>@]" (list ~sep:comma string) - hd go tl in - promote_args_chunked ppf - List.(chunks_of ~length:5 (filter_opt (return_arg_types args))) - -(** Pretty-prints a function's return-type, taking into account templated argument - promotion.*) -let pp_returntype ppf arg_types rt = - let scalar = str "@[%a@]" pp_promoted_scalar arg_types in - match rt with - | Some ut when UnsizedType.is_int_type ut -> - pf ppf "%a@ " pp_unsizedtype_custom_scalar ("int", ut) - | Some ut -> pf ppf "%a@ " pp_unsizedtype_custom_scalar (scalar, ut) - | None -> pf ppf "void@ " - -let pp_eigen_arg_to_ref ppf arg_types = - let pp_ref ppf name = - pf ppf "@[const auto& %s = stan::math::to_ref(%s);@]" name - (name ^ "_arg__") in - pf ppf "@[%a@]@ " (list ~sep:cut pp_ref) - (List.filter_map - ~f:(fun (_, name, ut) -> - if UnsizedType.is_eigen_type ut then Some name else None ) - arg_types ) - -(** [pp_located_error ppf (pp_body_block, body_block, err_msg)] surrounds [body_block] - with a C++ try-catch that will rethrow the error with the proper source location - from the [body_block] (required to be a [stmt_loc Block] variant). - @param ppf A pretty printer. - @param pp_body_block A pretty printer for the body block - @param body A C++ scoped body block surrounded by squiggly braces. - *) -let pp_located_error ppf (pp_body_block, body) = - pf ppf "@ try %a" pp_body_block body ; - string ppf " catch (const std::exception& e) " ; - pp_block ppf (pp_located, ()) - -(** Print the type of an object. - @param ppf A pretty printer - @param custom_scalar_opt A string representing a types inner scalar value. - @param name The name of the object - @param ut The unsized type of the object - *) -let pp_arg ppf (custom_scalar_opt, (_, name, ut)) = - let scalar = - match custom_scalar_opt with - | Some scalar -> scalar - | None -> stantype_prim_str ut in - (* we add the _arg suffix for any Eigen types *) - pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) - name - -let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) = - let scalar = - match custom_scalar_opt with - | Some scalar -> scalar - | None -> stantype_prim_str ut in - (* we add the _arg suffix for any Eigen types *) - let opt_arg_suffix = - if UnsizedType.is_eigen_type ut then name ^ "_arg__" else name in - pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) - opt_arg_suffix - -(** [pp_located_error_b] automatically adds a Block wrapper *) -let pp_located_error_b ppf body_stmts = - pp_located_error ppf - ( pp_statement - , Stmt.Fixed.{pattern= Block body_stmts; meta= Locations.no_span_num} ) - -let typename t = Typename t - -(** Construct an object with it's needed templates for function signatures. - @param fdargs A sexp list of strings representing C++ types. - *) -let get_templates_and_args exprs fdargs = - let argtypetemplates = maybe_templated_arg_types fdargs in - let requireargtemplates = maybe_require_templates argtypetemplates fdargs in - ( List.filter_opt argtypetemplates - , List.filter_opt requireargtemplates - , if not exprs then - List.map - ~f:(fun a -> str "%a" pp_arg a) - (List.zip_exn argtypetemplates fdargs) - else - List.map - ~f:(fun a -> str "%a" pp_arg_eigen_suffix a) - (List.zip_exn argtypetemplates fdargs) ) - -(** Print the C++ template parameter decleration before a function. - @param ppf A pretty printer. - *) -let pp_template_decorator ppf = function - | [] -> () - | templates -> - pf ppf "@[template <@[%a>@]@]@ " (list ~sep:comma string) templates - -let mk_extra_args templates args = - List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args) - -(** Print the C++ function definition. - @param ppf A pretty printer - Refactor this please - one idea might be to have different functions for - printing user defined distributions vs rngs vs regular functions. -*) -let pp_fun_def ppf - ( Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} - , (functors : (string, found_functor list) Hashtbl.t) - , (forward_decls : (string * template list) Hash_set.t) - , (funs_used_in_reduce_sum : String.Set.t) - , (funs_used_in_variadic_ode : String.Set.t) - , (funs_used_in_variadic_dae : String.Set.t) ) = - let extra, extra_templates = - match fdsuffix with - | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) - | FnRng -> (["base_rng__"], ["RNG"]) - | FnLpdf _ | FnPlain -> ([], []) in - let pp_body ppf (Stmt.Fixed.{pattern; _} as fdbody) = - pf ppf "@[using local_scalar_t__ =@ %a;@]@," pp_promoted_scalar fdargs ; - pf ppf "int current_statement__ = 0; @ " ; - if List.exists ~f:(fun (_, _, t) -> UnsizedType.is_eigen_type t) fdargs then - pp_eigen_arg_to_ref ppf fdargs ; - ( match fdsuffix with - | FnLpdf _ | FnTarget -> () - | FnPlain | FnRng -> - pf ppf "%s@ " "static constexpr bool propto__ = true;" ; - pf ppf "%s@ " "(void) propto__;" ) ; - pf ppf "%s@ " - "local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN());" ; - pf ppf "%a" pp_unused "DUMMY_VAR__" ; - let blocked_fdbody = - match pattern with - | SList stmts -> {fdbody with pattern= Block stmts} - | Block _ -> fdbody - | _ -> {fdbody with pattern= Block [fdbody]} in - pp_located_error ppf (pp_statement, blocked_fdbody) ; - pf ppf "@ " in - let get_templates exprs variadic = - let argtypetemplates, require_templates, args = - get_templates_and_args exprs fdargs in - let templates = - List.(map ~f:typename (argtypetemplates @ extra_templates)) - @ require_templates in - match (fdsuffix, variadic) with - | (FnLpdf _ | FnTarget), `None -> (Bool "propto__" :: templates, args) - | _ -> (templates, args) in - let pp_sig ppf (name, args, variadic) = - Format.open_vbox 2 ; - pp_returntype ppf fdargs fdrt ; - let args, variadic_args = - match variadic with - | `ReduceSum -> List.split_n args 3 - | `VariadicODE -> List.split_n args 2 - | `VariadicDAE -> List.split_n args 3 - | `None -> (args, []) in - let arg_strs = - args - @ mk_extra_args extra_templates extra - @ ["std::ostream* pstream__"] - @ variadic_args in - pf ppf "%s(@[%a@]) " name (list ~sep:comma string) arg_strs ; - Format.close_box () in - let templates, templated_args = get_templates true `None in - let signature = str "%a" pp_sig (fdname, templated_args, `None) in - (* We want to print the [* = nullptr] at most once, and preferrably on a forward decl *) - let defaults = - Option.is_none fdbody - || not (Hash_set.mem forward_decls (signature, templates)) in - pf ppf "%a%a" (pp_templates ~defaults) templates pp_sig - (fdname, templated_args, `None) ; - match fdbody with - | None -> - pf ppf ";@ " ; - Hash_set.add forward_decls (signature, templates) - | Some fdbody -> - pp_block ppf (pp_body, fdbody) ; - let register_functor (str_args, args, variadic) = - let suffix = - match variadic with - | `None -> functor_suffix - | `ReduceSum -> reduce_sum_functor_suffix - | `VariadicODE -> variadic_ode_functor_suffix - | `VariadicDAE -> variadic_dae_functor_suffix in - let functor_name = fdname ^ suffix in - let struct_template = - match (fdsuffix, variadic) with - | FnLpdf _, `ReduceSum -> Some (Bool "propto__") - | _ -> None in - let arg_templates, templated_args = get_templates false variadic in - let op_signature = - str "%a" pp_sig ("operator()", templated_args, variadic) in - let defn = - str "%a@ const@,{@. return %a;@.}@." pp_sig - ( functor_name - ^ (if struct_template <> None then "" else "") - ^ "::operator()" - , templated_args - , variadic ) - pp_call_str - ( ( match fdsuffix with - | FnLpdf _ | FnTarget -> fdname ^ "" - | _ -> fdname ) - , str_args - @ List.map ~f:(fun (_, name, _) -> name) args - @ extra @ ["pstream__"] ) in - Hashtbl.add_multi functors ~key:functor_name - ~data:{struct_template; arg_templates; signature= op_signature; defn} - in - register_functor ([], fdargs, `None) ; - if String.Set.mem funs_used_in_reduce_sum fdname then - (* Produces the reduce_sum functors that has the pstream argument - as the third and not last argument *) - match fdargs with - | (_, slice, _) :: (_, start, _) :: (_, end_, _) :: rest -> - register_functor - ([slice; start ^ " + 1"; end_ ^ " + 1"], rest, `ReduceSum) - | _ -> - Common.FatalError.fatal_error_msg - [%message - "Ill-formed reduce_sum call!" (fdargs : Program.fun_arg_decl)] - else if String.Set.mem funs_used_in_variadic_ode fdname then - (* Produces the variadic ode functors that has the pstream argument - as the third and not last argument *) - register_functor ([], fdargs, `VariadicODE) - else if String.Set.mem funs_used_in_variadic_dae fdname then - (* Produces the variadic DAE functors that has the pstream argument - as the fourth and not last argument *) - register_functor ([], fdargs, `VariadicDAE) - -(** Creates functions outside the model namespaces which only call the ones - inside the namespaces *) -let pp_standalone_fun_def namespace_fun ppf - Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} = - let extra, extra_templates = - match fdsuffix with - | Fun_kind.FnTarget -> - (["lp__"; "lp_accum__"], ["double"; "stan::math::accumulator"]) - | FnRng -> (["base_rng__"], ["boost::ecuyer1988"]) - | FnLpdf _ | FnPlain -> ([], []) in - let args = - List.map - ~f:(fun (_, name, ut) -> - str "const %a& %s" pp_unsizedtype_custom_scalar - (stantype_prim_str ut, ut) - name ) - fdargs in - let pp_sig_standalone ppf _ = - let arg_strs = - args - @ mk_extra_args extra_templates extra - @ ["std::ostream* pstream__ = nullptr"] in - pf ppf "(@[%a@]) " (list ~sep:comma string) arg_strs in - let mark_function_comment = "// [[stan::function]]" in - let return_type = match fdrt with None -> "void" | _ -> "auto" in - let return_stmt = match fdrt with None -> "" | _ -> "return " in - match fdbody with - | None -> pf ppf ";@ " - | Some _ -> - pf ppf "@,%s@,%s %s%a @,{@, %s%s::%a;@,}@," mark_function_comment - return_type fdname pp_sig_standalone "" return_stmt namespace_fun - pp_call_str - ( ( match fdsuffix with - | FnLpdf _ | FnTarget -> fdname ^ "" - | FnRng | FnPlain -> fdname ) - , List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"] - ) - let version = "// Code generated by %%NAME%% %%VERSION%%" let includes = "#include " @@ -1004,75 +642,12 @@ let pp_register_map_rect_functors ppf p = (list ~sep:cut pp_register_functor) (List.sort ~compare (Hashtbl.to_alist map_rect_calls)) -let is_fun_used_with_variadic_fn variadic_fn_test p = - let rec find_functors_expr accum Expr.Fixed.{pattern; _} = - String.Set.union accum - ( match pattern with - | FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _) - when variadic_fn_test x -> - String.Set.of_list [Utils.stdlib_distribution_name f] - | x -> Expr.Fixed.Pattern.fold find_functors_expr accum x ) in - let rec find_functors_stmt accum stmt = - Stmt.Fixed.( - Pattern.fold find_functors_expr find_functors_stmt accum stmt.pattern) - in - Program.fold find_functors_expr find_functors_stmt String.Set.empty p - -let collect_functors_functions p = - let (functors : (string, found_functor list) Hashtbl.t) = - String.Table.create () in - let forward_decls = Hash_set.Poly.create () in - let reduce_sum_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p in - let variadic_ode_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p in - let variadic_dae_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p in - let pp_fun_def_with_variadic_fn_list ppf fblock = - (hovbox ~indent:2 pp_fun_def) - ppf - ( fblock - , functors - , forward_decls - , reduce_sum_fns - , variadic_ode_fns - , variadic_dae_fns ) in - ( str "@[%a@]" - (list ~sep:cut pp_fun_def_with_variadic_fn_list) - p.functions_block - , functors ) - -(** Print the full C++ for the stan program. *) let pp_prog ppf (p : Program.Typed.t) = Hashtbl.clear Expression_gen.map_rect_calls ; (* First, do some transformations on the MIR itself before we begin printing it.*) let p, s = Locations.prepare_prog p in - let fns_str, functors = collect_functors_functions p in - let pp_functor_decls ppf tbl = - Hashtbl.iteri tbl ~f:(fun ~key ~data -> - pf ppf "@[%astruct %s {@,%aconst;@]@,};@." - (option - (option (fun ppf t -> - pf ppf "template <%a>@ " pp_template_defaults t ) ) ) - (Option.map ~f:(fun x -> x.struct_template) (List.hd data)) - key - (list ~sep:(any "const;@,") (fun ppf (ts, sign) -> - pf ppf "%a@[%a@]" (pp_templates ~defaults:true) ts text sign ) - ) - (List.map - ~f:(fun {arg_templates; signature; _} -> (arg_templates, signature)) - data ) ) in - let pp_functors ppf tbl = - Hashtbl.iter tbl ~f:(fun data -> - List.iter data ~f:(fun {struct_template; defn; arg_templates; _} -> - pf ppf "%a%a%s@." - (option (fun ppf t -> pf ppf "template <%a>@ " pp_template t)) - struct_template - (pp_templates ~defaults:false) - arg_templates defn ) ) in - pf ppf "@[@ %s@ %s@ namespace %s {@ %s@ %a@ %a@ %s@ %a@ %a@ }@ @]" version - includes (namespace p) usings Locations.pp_globals s pp_functor_decls - functors fns_str pp_functors functors + pf ppf "@[@ %s@ %s@ namespace %s {@ %s@ %a@ %a@ %a@ }@ @]" version includes + (namespace p) usings Locations.pp_globals s pp_functions_functors p (if !standalone_functions then fun _ _ -> () else pp_model) p ; if !standalone_functions then diff --git a/src/stan_math_backend/Stan_math_code_gen.mli b/src/stan_math_backend/Stan_math_code_gen.mli new file mode 100644 index 0000000000..03cb323745 --- /dev/null +++ b/src/stan_math_backend/Stan_math_code_gen.mli @@ -0,0 +1,10 @@ +val model_prefix : string +(** A string put in front of model names in the namespace. + Currently "model_" + *) + +val standalone_functions : bool ref +(** Flag to generate just function code, used in RStan *) + +val pp_prog : Format.formatter -> Middle.Program.Typed.t -> unit +(** Print the full C++ for the stan program. *) diff --git a/src/stan_math_backend/Statement_gen.ml b/src/stan_math_backend/Statement_gen.ml index 212c31c6ff..e7f5b36ffd 100644 --- a/src/stan_math_backend/Statement_gen.ml +++ b/src/stan_math_backend/Statement_gen.ml @@ -6,6 +6,12 @@ open Expression_gen let pp_call_str ppf (name, args) = pp_call ppf (name, string, args) let pp_block ppf (pp_body, body) = pf ppf "{@;<1 2>@[%a@]@,}" pp_body body +let pp_unused = fmt "(void) %s; // suppress unused var warning" + +(** Print the body of exception handling for functions *) +let pp_located ppf _ = + pf ppf + {|stan::lang::rethrow_located(e, locations_array__[current_statement__]);|} let pp_profile ppf (pp_body, name, body) = let profile ppf name = @@ -15,31 +21,10 @@ let pp_profile ppf (pp_body, name, body) = name in pf ppf "{@;<1 2>@[%a@;@;%a@]@,}" profile name pp_body body -let rec contains_eigen (ut : UnsizedType.t) : bool = - match ut with - | UnsizedType.UArray t -> contains_eigen t - | UMatrix | URowVector | UVector -> true - | UInt | UReal | UComplex | UMathLibraryFunction | UFun _ -> false - -(*Fill only needs to happen for containers - * Note: This should probably be moved into its own function as data - * does not need to be filled as we are promised user input data has the correct - * dimensions. Transformed data must be filled as incorrect slices could lead - * to elements of objects in transform data not being set by the user. -*) -let pp_filler ppf (decl_id, st, nan_type, needs_filled) = - match (needs_filled, contains_eigen (SizedType.to_unsized st)) with - | true, true -> - pf ppf "@[stan::math::initialize_fill(%s, %s);@]@," decl_id - nan_type - | _ -> () - (*Pretty print a sized type*) let pp_st ppf (st, adtype) = pf ppf "%a" pp_unsizedtype_local (adtype, SizedType.to_unsized st) -let pp_ut ppf (ut, adtype) = pf ppf "%a" pp_unsizedtype_local (adtype, ut) - (*Get a string representing for the NaN type of the given type *) let nan_type (st, adtype) = match (adtype, st) with @@ -201,11 +186,6 @@ let pp_for_loop ppf (loopvar, lower, upper, pp_body, body) = loopvar pp_expr upper loopvar ; pf ppf " %a@]" pp_body body -let rec integer_el_type = function - | SizedType.SInt -> true - | SArray (st, _) -> integer_el_type st - | _ -> false - (** Print the private members of the model class Accounting for types that can be moved to OpenCL. @@ -228,14 +208,6 @@ let pp_data_decl ppf (vident, ut) = | _ -> pf ppf "%a %s;" pp_type (DataOnly, ut) vident ) | (true, _), _ -> pf ppf "%a %s;" pp_type (DataOnly, ut) vident -(** Create string representations for [vars__.emplace_back] *) -let pp_emplace_var ppf var = - match Expr.Typed.type_of var with - | UnsizedType.UComplex -> - pf ppf "@[vars__.emplace_back(%a.real());@]@," pp_expr var ; - pf ppf "@[vars__.emplace_back(%a.imag());@]" pp_expr var - | _ -> pf ppf "@[vars__.emplace_back(@,%a);@]" pp_expr var - (** Create strings representing maps of Eigen types*) let pp_map_decl ppf (vident, ut) = let scalar = local_scalar ut DataOnly in @@ -433,3 +405,21 @@ and pp_block_s ppf body = match body.pattern with | Block ls -> pp_block ppf (list ~sep:cut pp_statement, ls) | _ -> pp_block ppf (pp_statement, body) + +(** [pp_located_error ppf (pp_body_block, body_block, err_msg)] surrounds [body_block] + with a C++ try-catch that will rethrow the error with the proper source location + from the [body_block] (required to be a [stmt_loc Block] variant). + @param ppf A pretty printer. + @param pp_body_block A pretty printer for the body block + @param body A C++ scoped body block surrounded by squiggly braces. + *) +let pp_located_error ppf (pp_body_block, body) = + pf ppf "@ try %a" pp_body_block body ; + string ppf " catch (const std::exception& e) " ; + pp_block ppf (pp_located, ()) + +(** [pp_located_error_b] automatically adds a Block wrapper *) +let pp_located_error_b ppf body_stmts = + pp_located_error ppf + ( pp_statement + , Stmt.Fixed.{pattern= Block body_stmts; meta= Locations.no_span_num} ) diff --git a/test/unit/Stan_math_code_gen_tests.ml b/test/unit/Stan_math_code_gen_tests.ml index 4227b3fc52..e89208e4cf 100644 --- a/test/unit/Stan_math_code_gen_tests.ml +++ b/test/unit/Stan_math_code_gen_tests.ml @@ -1,113 +1 @@ -open Middle -open Stan_math_backend -open Core_kernel -open Fmt -open Stan_math_code_gen - -let pp_fun_def_w_rs a b = - pp_fun_def a - ( b - , String.Table.create () - , Hash_set.Poly.create () - , String.Set.empty - , String.Set.empty - , String.Set.empty ) - -let%expect_test "udf" = - let with_no_loc stmt = - Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in - let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - { fdrt= None - ; fdname= "sars" - ; fdsuffix= FnPlain - ; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)] - ; fdbody= - Stmt.Fixed.Pattern.Return - (Some - ( w - @@ FunApp - ( StanLib ("add", FnPlain, AoS) - , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) - |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc - |> Some - ; fdloc= Location_span.empty } - |> str "@[%a" pp_fun_def_w_rs - |> print_endline ; - [%expect - {| - template * = nullptr, - stan::require_row_vector_t* = nullptr> - void - sars(const T0__& x_arg__, const T1__& y_arg__, std::ostream* pstream__) { - using local_scalar_t__ = - stan::promote_args_t, - stan::value_type_t>; - int current_statement__ = 0; - const auto& x = stan::math::to_ref(x_arg__); - const auto& y = stan::math::to_ref(y_arg__); - static constexpr bool propto__ = true; - (void) propto__; - local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); - (void) DUMMY_VAR__; // suppress unused var warning - try { - return stan::math::add(x, 1); - } catch (const std::exception& e) { - stan::lang::rethrow_located(e, locations_array__[current_statement__]); - } - - } |}] - -let%expect_test "udf-expressions" = - let with_no_loc stmt = - Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in - let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - { fdrt= Some UMatrix - ; fdname= "sars" - ; fdsuffix= FnPlain - ; fdargs= - [ (DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector) - ; (AutoDiffable, "z", URowVector); (AutoDiffable, "w", UArray UMatrix) ] - ; fdbody= - Stmt.Fixed.Pattern.Return - (Some - ( w - @@ FunApp - ( StanLib ("add", FnPlain, AoS) - , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) - |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc - |> Some - ; fdloc= Location_span.empty } - |> str "@[%a" pp_fun_def_w_rs - |> print_endline ; - [%expect - {| - template * = nullptr, - stan::require_row_vector_t* = nullptr, - stan::require_row_vector_t* = nullptr, - stan::require_stan_scalar_t* = nullptr> - Eigen::Matrix, stan::value_type_t, - stan::value_type_t, T3__>, -1, -1> - sars(const T0__& x_arg__, const T1__& y_arg__, const T2__& z_arg__, - const std::vector>& w, - std::ostream* pstream__) { - using local_scalar_t__ = - stan::promote_args_t, - stan::value_type_t, - stan::value_type_t, T3__>; - int current_statement__ = 0; - const auto& x = stan::math::to_ref(x_arg__); - const auto& y = stan::math::to_ref(y_arg__); - const auto& z = stan::math::to_ref(z_arg__); - static constexpr bool propto__ = true; - (void) propto__; - local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); - (void) DUMMY_VAR__; // suppress unused var warning - try { - return stan::math::add(x, 1); - } catch (const std::exception& e) { - stan::lang::rethrow_located(e, locations_array__[current_statement__]); - } - - } |}] +(* TODO proper unit tests or delete file *)