From a37df2731291dc2077fee1805f36c3f5a772c5d3 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 17 Mar 2022 10:53:11 -0400 Subject: [PATCH 1/6] Start reorganization of function code gen --- src/stan_math_backend/Cpp_util.ml | 28 + src/stan_math_backend/Function_gen.ml | 523 +++++++++++++++++++ src/stan_math_backend/Function_gen.mli | 10 + src/stan_math_backend/Stan_math_code_gen.ml | 432 +-------------- src/stan_math_backend/Stan_math_code_gen.mli | 10 + test/unit/Stan_math_code_gen_tests.ml | 114 +--- 6 files changed, 576 insertions(+), 541 deletions(-) create mode 100644 src/stan_math_backend/Cpp_util.ml create mode 100644 src/stan_math_backend/Function_gen.ml create mode 100644 src/stan_math_backend/Function_gen.mli create mode 100644 src/stan_math_backend/Stan_math_code_gen.mli diff --git a/src/stan_math_backend/Cpp_util.ml b/src/stan_math_backend/Cpp_util.ml new file mode 100644 index 0000000000..f45d1d03d3 --- /dev/null +++ b/src/stan_math_backend/Cpp_util.ml @@ -0,0 +1,28 @@ +open Fmt +open Statement_gen +open Middle + +let pp_unused = fmt "(void) %s; // suppress unused var warning" + +(** Print the body of exception handling for functions *) +let pp_located ppf _ = + pf ppf + {|stan::lang::rethrow_located(e, locations_array__[current_statement__]);|} + +(** [pp_located_error ppf (pp_body_block, body_block, err_msg)] surrounds [body_block] + with a C++ try-catch that will rethrow the error with the proper source location + from the [body_block] (required to be a [stmt_loc Block] variant). + @param ppf A pretty printer. + @param pp_body_block A pretty printer for the body block + @param body A C++ scoped body block surrounded by squiggly braces. + *) +let pp_located_error ppf (pp_body_block, body) = + pf ppf "@ try %a" pp_body_block body ; + string ppf " catch (const std::exception& e) " ; + pp_block ppf (pp_located, ()) + +(** [pp_located_error_b] automatically adds a Block wrapper *) +let pp_located_error_b ppf body_stmts = + pp_located_error ppf + ( pp_statement + , Stmt.Fixed.{pattern= Block body_stmts; meta= Locations.no_span_num} ) diff --git a/src/stan_math_backend/Function_gen.ml b/src/stan_math_backend/Function_gen.ml new file mode 100644 index 0000000000..451ae5652f --- /dev/null +++ b/src/stan_math_backend/Function_gen.ml @@ -0,0 +1,523 @@ +(** Code generation for user defined functions and the relevant functors *) + +open Core_kernel +open Core_kernel.Poly +open Middle +open Fmt +open Expression_gen +open Statement_gen +open Cpp_util + +type template = + | Typename of string + | Require of string * string + | Bool of string + +(* and stmt = + | Struct of {templates: template list; name: string; body: stmt list} + | FunDef of + { name: string + ; templates: template list + ; args: string list + ; return: string + ; body: stmt list option } + | Generated of string + (** Placeholder for all other C++. + We should slowly move more and more statements + away from strings *) *) + +let pp_template ppf template = + match template with + | Typename t -> pf ppf "typename %s" t + | Require (r, t) -> pf ppf "%s<%s>*" r t + | Bool s -> pf ppf "bool %s" s + +let pp_template_defaults ppf template = + match template with + | Require _ -> pf ppf "%a = nullptr" pp_template template + | _ -> pp_template ppf template + +let pp_templates ~defaults ppf templates = + match templates with + | [] -> () + | _ -> + pf ppf "template <@[%a@]>@ " + (list ~sep:comma + (if defaults then pp_template_defaults else pp_template) ) + templates + +type found_functor = + { struct_template: template option + ; arg_templates: template list + ; signature: string + ; defn: string } + +(** Detect if argument requires C++ template *) +let arg_needs_template = function + | UnsizedType.DataOnly, _, t -> UnsizedType.is_eigen_type t + | _, _, t when UnsizedType.is_int_type t -> false + | _ -> true + +(** Print template arguments for C++ functions that need templates +@param args A pack of [Program.fun_arg_decl] containing functions to detect templates. +@return A list of arguments with template parameter names added. +*) +let maybe_templated_arg_types (args : Program.fun_arg_decl) = + List.mapi args ~f:(fun i a -> + match arg_needs_template a with + | true -> Some (sprintf "T%d__" i) + | false -> None ) + +let maybe_require_templates (names : string option list) + (args : Program.fun_arg_decl) = + let require_for_arg arg = + match trd3 arg with + | UnsizedType.URowVector -> "stan::require_row_vector_t" + | UVector -> "stan::require_col_vector_t" + | UMatrix -> "stan::require_eigen_matrix_dynamic_t" + (* NB: Not unwinding array types due to the way arrays of eigens are printed *) + | _ -> "stan::require_stan_scalar_t" in + List.map2_exn names args ~f:(fun name a -> + match name with + | Some t -> Some (Require (require_for_arg a, t)) + | None -> None ) + +let return_arg_types (args : Program.fun_arg_decl) = + List.mapi args ~f:(fun i ((_, _, ut) as a) -> + if UnsizedType.is_eigen_type ut && arg_needs_template a then + Some (sprintf "stan::value_type_t" i) + else if arg_needs_template a then Some (sprintf "T%d__" i) + else None ) + +let%expect_test "arg types templated correctly" = + [(AutoDiffable, "xreal", UReal); (DataOnly, "yint", UInt)] + |> maybe_templated_arg_types |> List.filter_opt |> String.concat ~sep:"," + |> print_endline ; + [%expect {| T0__ |}] + +(** Print the code for promoting stan real types +@param ppf A pretty printer +@param args A pack of arguments to detect whether they need to use the promotion rules. +*) +let pp_promoted_scalar ppf args = + match args with + | [] -> pf ppf "double" + | _ -> + let rec promote_args_chunked ppf args = + let go ppf tl = + match tl with [] -> () | _ -> pf ppf ",@ %a" promote_args_chunked tl + in + match args with + | [] -> pf ppf "double" + | hd :: tl -> + pf ppf "@[stan::promote_args_t<@[%a%a@]>@]" (list ~sep:comma string) + hd go tl in + promote_args_chunked ppf + List.(chunks_of ~length:5 (filter_opt (return_arg_types args))) + +(** Pretty-prints a function's return-type, taking into account templated argument + promotion.*) +let pp_returntype ppf arg_types rt = + let scalar = str "@[%a@]" pp_promoted_scalar arg_types in + match rt with + | Some ut when UnsizedType.is_int_type ut -> + pf ppf "%a@ " pp_unsizedtype_custom_scalar ("int", ut) + | Some ut -> pf ppf "%a@ " pp_unsizedtype_custom_scalar (scalar, ut) + | None -> pf ppf "void@ " + +let pp_eigen_arg_to_ref ppf arg_types = + let pp_ref ppf name = + pf ppf "@[const auto& %s = stan::math::to_ref(%s);@]" name + (name ^ "_arg__") in + pf ppf "@[%a@]@ " (list ~sep:cut pp_ref) + (List.filter_map + ~f:(fun (_, name, ut) -> + if UnsizedType.is_eigen_type ut then Some name else None ) + arg_types ) + +(** Print the type of an object. + @param ppf A pretty printer + @param custom_scalar_opt A string representing a types inner scalar value. + @param name The name of the object + @param ut The unsized type of the object + *) +let pp_arg ppf (custom_scalar_opt, (_, name, ut)) = + let scalar = + match custom_scalar_opt with + | Some scalar -> scalar + | None -> stantype_prim_str ut in + (* we add the _arg suffix for any Eigen types *) + pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) + name + +let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) = + let scalar = + match custom_scalar_opt with + | Some scalar -> scalar + | None -> stantype_prim_str ut in + (* we add the _arg suffix for any Eigen types *) + let opt_arg_suffix = + if UnsizedType.is_eigen_type ut then name ^ "_arg__" else name in + pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) + opt_arg_suffix + +let typename t = Typename t + +(** Construct an object with it's needed templates for function signatures. + @param fdargs A sexp list of strings representing C++ types. + *) +let get_templates_and_args exprs fdargs = + let argtypetemplates = maybe_templated_arg_types fdargs in + let requireargtemplates = maybe_require_templates argtypetemplates fdargs in + ( List.filter_opt argtypetemplates + , List.filter_opt requireargtemplates + , if not exprs then + List.map + ~f:(fun a -> str "%a" pp_arg a) + (List.zip_exn argtypetemplates fdargs) + else + List.map + ~f:(fun a -> str "%a" pp_arg_eigen_suffix a) + (List.zip_exn argtypetemplates fdargs) ) + +let mk_extra_args templates args = + List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args) + +(** Print the C++ function definition. + @param ppf A pretty printer + Refactor this please - one idea might be to have different functions for + printing user defined distributions vs rngs vs regular functions. +*) +let pp_fun_def ppf + ( Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} + , (functors : (string, found_functor list) Hashtbl.t) + , (forward_decls : (string * template list) Hash_set.t) + , (funs_used_in_reduce_sum : String.Set.t) + , (funs_used_in_variadic_ode : String.Set.t) + , (funs_used_in_variadic_dae : String.Set.t) ) = + let extra, extra_templates = + match fdsuffix with + | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) + | FnRng -> (["base_rng__"], ["RNG"]) + | FnLpdf _ | FnPlain -> ([], []) in + let pp_body ppf (Stmt.Fixed.{pattern; _} as fdbody) = + pf ppf "@[using local_scalar_t__ =@ %a;@]@," pp_promoted_scalar fdargs ; + pf ppf "int current_statement__ = 0; @ " ; + if List.exists ~f:(fun (_, _, t) -> UnsizedType.is_eigen_type t) fdargs then + pp_eigen_arg_to_ref ppf fdargs ; + ( match fdsuffix with + | FnLpdf _ | FnTarget -> () + | FnPlain | FnRng -> + pf ppf "%s@ " "static constexpr bool propto__ = true;" ; + pf ppf "%s@ " "(void) propto__;" ) ; + pf ppf "%s@ " + "local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN());" ; + pf ppf "%a" pp_unused "DUMMY_VAR__" ; + let blocked_fdbody = + match pattern with + | SList stmts -> {fdbody with pattern= Block stmts} + | Block _ -> fdbody + | _ -> {fdbody with pattern= Block [fdbody]} in + pp_located_error ppf (pp_statement, blocked_fdbody) ; + pf ppf "@ " in + let get_templates exprs variadic = + let argtypetemplates, require_templates, args = + get_templates_and_args exprs fdargs in + let templates = + List.(map ~f:typename (argtypetemplates @ extra_templates)) + @ require_templates in + match (fdsuffix, variadic) with + | (FnLpdf _ | FnTarget), `None -> (Bool "propto__" :: templates, args) + | _ -> (templates, args) in + let pp_sig ppf (name, args, variadic) = + Format.open_vbox 2 ; + pp_returntype ppf fdargs fdrt ; + let args, variadic_args = + match variadic with + | `ReduceSum -> List.split_n args 3 + | `VariadicODE -> List.split_n args 2 + | `VariadicDAE -> List.split_n args 3 + | `None -> (args, []) in + let arg_strs = + args + @ mk_extra_args extra_templates extra + @ ["std::ostream* pstream__"] + @ variadic_args in + pf ppf "%s(@[%a@]) " name (list ~sep:comma string) arg_strs ; + Format.close_box () in + let templates, templated_args = get_templates true `None in + let signature = str "%a" pp_sig (fdname, templated_args, `None) in + (* We want to print the [* = nullptr] at most once, and preferrably on a forward decl *) + let defaults = + Option.is_none fdbody + || not (Hash_set.mem forward_decls (signature, templates)) in + pf ppf "%a%a" (pp_templates ~defaults) templates pp_sig + (fdname, templated_args, `None) ; + match fdbody with + | None -> + pf ppf ";@ " ; + Hash_set.add forward_decls (signature, templates) + | Some fdbody -> + pp_block ppf (pp_body, fdbody) ; + let register_functor (str_args, args, variadic) = + let suffix = + match variadic with + | `None -> functor_suffix + | `ReduceSum -> reduce_sum_functor_suffix + | `VariadicODE -> variadic_ode_functor_suffix + | `VariadicDAE -> variadic_dae_functor_suffix in + let functor_name = fdname ^ suffix in + let struct_template = + match (fdsuffix, variadic) with + | FnLpdf _, `ReduceSum -> Some (Bool "propto__") + | _ -> None in + let arg_templates, templated_args = get_templates false variadic in + let op_signature = + str "%a" pp_sig ("operator()", templated_args, variadic) in + let defn = + str "%a@ const@,{@. return %a;@.}@." pp_sig + ( functor_name + ^ (if struct_template <> None then "" else "") + ^ "::operator()" + , templated_args + , variadic ) + pp_call_str + ( ( match fdsuffix with + | FnLpdf _ | FnTarget -> fdname ^ "" + | _ -> fdname ) + , str_args + @ List.map ~f:(fun (_, name, _) -> name) args + @ extra @ ["pstream__"] ) in + Hashtbl.add_multi functors ~key:functor_name + ~data:{struct_template; arg_templates; signature= op_signature; defn} + in + register_functor ([], fdargs, `None) ; + if String.Set.mem funs_used_in_reduce_sum fdname then + (* Produces the reduce_sum functors that has the pstream argument + as the third and not last argument *) + match fdargs with + | (_, slice, _) :: (_, start, _) :: (_, end_, _) :: rest -> + register_functor + ([slice; start ^ " + 1"; end_ ^ " + 1"], rest, `ReduceSum) + | _ -> + Common.FatalError.fatal_error_msg + [%message + "Ill-formed reduce_sum call!" (fdargs : Program.fun_arg_decl)] + else if String.Set.mem funs_used_in_variadic_ode fdname then + (* Produces the variadic ode functors that has the pstream argument + as the third and not last argument *) + register_functor ([], fdargs, `VariadicODE) + else if String.Set.mem funs_used_in_variadic_dae fdname then + (* Produces the variadic DAE functors that has the pstream argument + as the fourth and not last argument *) + register_functor ([], fdargs, `VariadicDAE) + +let pp_standalone_fun_def namespace_fun ppf + Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} = + let extra, extra_templates = + match fdsuffix with + | Fun_kind.FnTarget -> + (["lp__"; "lp_accum__"], ["double"; "stan::math::accumulator"]) + | FnRng -> (["base_rng__"], ["boost::ecuyer1988"]) + | FnLpdf _ | FnPlain -> ([], []) in + let args = + List.map + ~f:(fun (_, name, ut) -> + str "const %a& %s" pp_unsizedtype_custom_scalar + (stantype_prim_str ut, ut) + name ) + fdargs in + let pp_sig_standalone ppf _ = + let arg_strs = + args + @ mk_extra_args extra_templates extra + @ ["std::ostream* pstream__ = nullptr"] in + pf ppf "(@[%a@]) " (list ~sep:comma string) arg_strs in + let mark_function_comment = "// [[stan::function]]" in + let return_type = match fdrt with None -> "void" | _ -> "auto" in + let return_stmt = match fdrt with None -> "" | _ -> "return " in + match fdbody with + | None -> pf ppf ";@ " + | Some _ -> + pf ppf "@,%s@,%s %s%a @,{@, %s%s::%a;@,}@," mark_function_comment + return_type fdname pp_sig_standalone "" return_stmt namespace_fun + pp_call_str + ( ( match fdsuffix with + | FnLpdf _ | FnTarget -> fdname ^ "" + | FnRng | FnPlain -> fdname ) + , List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"] + ) + +let is_fun_used_with_variadic_fn variadic_fn_test p = + let rec find_functors_expr accum Expr.Fixed.{pattern; _} = + String.Set.union accum + ( match pattern with + | FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _) + when variadic_fn_test x -> + String.Set.of_list [Utils.stdlib_distribution_name f] + | x -> Expr.Fixed.Pattern.fold find_functors_expr accum x ) in + let rec find_functors_stmt accum stmt = + Stmt.Fixed.( + Pattern.fold find_functors_expr find_functors_stmt accum stmt.pattern) + in + Program.fold find_functors_expr find_functors_stmt String.Set.empty p + +let collect_functors_functions p = + let (functors : (string, found_functor list) Hashtbl.t) = + String.Table.create () in + let forward_decls = Hash_set.Poly.create () in + let reduce_sum_fns = + is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p in + let variadic_ode_fns = + is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p in + let variadic_dae_fns = + is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p in + let pp_fun_def_with_variadic_fn_list ppf fblock = + (hovbox ~indent:2 pp_fun_def) + ppf + ( fblock + , functors + , forward_decls + , reduce_sum_fns + , variadic_ode_fns + , variadic_dae_fns ) in + ( str "@[%a@]" + (list ~sep:cut pp_fun_def_with_variadic_fn_list) + p.functions_block + , functors ) + +let pp_functions_functors ppf p = + let fns_str, functors = collect_functors_functions p in + let pp_functor_decls ppf tbl = + Hashtbl.iteri tbl ~f:(fun ~key ~data -> + pf ppf "@[%astruct %s {@,%aconst;@]@,};@." + (option + (option (fun ppf t -> + pf ppf "template <%a>@ " pp_template_defaults t ) ) ) + (Option.map ~f:(fun x -> x.struct_template) (List.hd data)) + key + (list ~sep:(any "const;@,") (fun ppf (ts, sign) -> + pf ppf "%a@[%a@]" (pp_templates ~defaults:true) ts text sign ) + ) + (List.map + ~f:(fun {arg_templates; signature; _} -> (arg_templates, signature)) + data ) ) in + let pp_functors ppf tbl = + Hashtbl.iter tbl ~f:(fun data -> + List.iter data ~f:(fun {struct_template; defn; arg_templates; _} -> + pf ppf "%a%a%s@." + (option (fun ppf t -> pf ppf "template <%a>@ " pp_template t)) + struct_template + (pp_templates ~defaults:false) + arg_templates defn ) ) in + pf ppf "%a@ %s@ %a" pp_functor_decls functors fns_str pp_functors functors + +(* Testing code *) + +let pp_fun_def_w_rs a b = + pp_fun_def a + ( b + , String.Table.create () + , Hash_set.Poly.create () + , String.Set.empty + , String.Set.empty + , String.Set.empty ) + +let%expect_test "udf" = + let with_no_loc stmt = + Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in + let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in + { fdrt= None + ; fdname= "sars" + ; fdsuffix= FnPlain + ; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)] + ; fdbody= + Stmt.Fixed.Pattern.Return + (Some + ( w + @@ FunApp + ( StanLib ("add", FnPlain, AoS) + , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) + |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc + |> Some + ; fdloc= Location_span.empty } + |> str "@[%a" pp_fun_def_w_rs + |> print_endline ; + [%expect + {| + template * = nullptr, + stan::require_row_vector_t* = nullptr> + void + sars(const T0__& x_arg__, const T1__& y_arg__, std::ostream* pstream__) { + using local_scalar_t__ = + stan::promote_args_t, + stan::value_type_t>; + int current_statement__ = 0; + const auto& x = stan::math::to_ref(x_arg__); + const auto& y = stan::math::to_ref(y_arg__); + static constexpr bool propto__ = true; + (void) propto__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + (void) DUMMY_VAR__; // suppress unused var warning + try { + return stan::math::add(x, 1); + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + + } |}] + +let%expect_test "udf-expressions" = + let with_no_loc stmt = + Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in + let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in + { fdrt= Some UMatrix + ; fdname= "sars" + ; fdsuffix= FnPlain + ; fdargs= + [ (DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector) + ; (AutoDiffable, "z", URowVector); (AutoDiffable, "w", UArray UMatrix) ] + ; fdbody= + Stmt.Fixed.Pattern.Return + (Some + ( w + @@ FunApp + ( StanLib ("add", FnPlain, AoS) + , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) + |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc + |> Some + ; fdloc= Location_span.empty } + |> str "@[%a" pp_fun_def_w_rs + |> print_endline ; + [%expect + {| + template * = nullptr, + stan::require_row_vector_t* = nullptr, + stan::require_row_vector_t* = nullptr, + stan::require_stan_scalar_t* = nullptr> + Eigen::Matrix, stan::value_type_t, + stan::value_type_t, T3__>, -1, -1> + sars(const T0__& x_arg__, const T1__& y_arg__, const T2__& z_arg__, + const std::vector>& w, + std::ostream* pstream__) { + using local_scalar_t__ = + stan::promote_args_t, + stan::value_type_t, + stan::value_type_t, T3__>; + int current_statement__ = 0; + const auto& x = stan::math::to_ref(x_arg__); + const auto& y = stan::math::to_ref(y_arg__); + const auto& z = stan::math::to_ref(z_arg__); + static constexpr bool propto__ = true; + (void) propto__; + local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); + (void) DUMMY_VAR__; // suppress unused var warning + try { + return stan::math::add(x, 1); + } catch (const std::exception& e) { + stan::lang::rethrow_located(e, locations_array__[current_statement__]); + } + + } |}] diff --git a/src/stan_math_backend/Function_gen.mli b/src/stan_math_backend/Function_gen.mli new file mode 100644 index 0000000000..41872842e6 --- /dev/null +++ b/src/stan_math_backend/Function_gen.mli @@ -0,0 +1,10 @@ +open Middle + +val pp_functions_functors : Format.formatter -> Program.Numbered.t -> unit +(** Pretty-print all user defined functions. Creates functor structs as needed + *) + +val pp_standalone_fun_def : + string -> Format.formatter -> 'a Program.fun_def -> unit +(** Creates functions outside the model namespaces which only call the ones + inside the namespaces *) diff --git a/src/stan_math_backend/Stan_math_code_gen.ml b/src/stan_math_backend/Stan_math_code_gen.ml index ede0fe0690..5cd5e16827 100644 --- a/src/stan_math_backend/Stan_math_code_gen.ml +++ b/src/stan_math_backend/Stan_math_code_gen.ml @@ -22,38 +22,8 @@ open Middle open Fmt open Expression_gen open Statement_gen - -(* TODO: move to seperate file and code gen more like this with a structured type *) -type template = - | Typename of string - | Require of string * string - | Bool of string - -let pp_template ppf template = - match template with - | Typename t -> pf ppf "typename %s" t - | Require (r, t) -> pf ppf "%s<%s>*" r t - | Bool s -> pf ppf "bool %s" s - -let pp_template_defaults ppf template = - match template with - | Require _ -> pf ppf "%a = nullptr" pp_template template - | _ -> pp_template ppf template - -let pp_templates ~defaults ppf templates = - match templates with - | [] -> () - | _ -> - pf ppf "template <@[%a@]>@ " - (list ~sep:comma - (if defaults then pp_template_defaults else pp_template) ) - templates - -type found_functor = - { struct_template: template option - ; arg_templates: template list - ; signature: string - ; defn: string } +open Function_gen +open Cpp_util let standalone_functions = ref false @@ -69,8 +39,6 @@ let stanc_args_to_print = |> List.filter ~f:sans_model_and_hpp_paths |> String.concat ~sep:" " -let pp_unused = fmt "(void) %s; // suppress unused var warning" - (** Print name of model function. @param prog_name Name of the Stan program. @param fname Name of the function. @@ -80,335 +48,6 @@ let pp_function__ ppf (prog_name, fname) = {|@[static constexpr const char* function__ = "%s_namespace::%s";@,%a@]|} prog_name fname pp_unused "function__" -(** Print the body of exception handling for functions *) -let pp_located ppf _ = - pf ppf - {|stan::lang::rethrow_located(e, locations_array__[current_statement__]);|} - -(** Detect if argument requires C++ template *) -let arg_needs_template = function - | UnsizedType.DataOnly, _, t -> UnsizedType.is_eigen_type t - | _, _, t when UnsizedType.is_int_type t -> false - | _ -> true - -(** Print template arguments for C++ functions that need templates - @param args A pack of [Program.fun_arg_decl] containing functions to detect templates. - @return A list of arguments with template parameter names added. - *) -let maybe_templated_arg_types (args : Program.fun_arg_decl) = - List.mapi args ~f:(fun i a -> - match arg_needs_template a with - | true -> Some (sprintf "T%d__" i) - | false -> None ) - -let maybe_require_templates (names : string option list) - (args : Program.fun_arg_decl) = - let require_for_arg arg = - match trd3 arg with - | UnsizedType.URowVector -> "stan::require_row_vector_t" - | UVector -> "stan::require_col_vector_t" - | UMatrix -> "stan::require_eigen_matrix_dynamic_t" - (* NB: Not unwinding array types due to the way arrays of eigens are printed *) - | _ -> "stan::require_stan_scalar_t" in - List.map2_exn names args ~f:(fun name a -> - match name with - | Some t -> Some (Require (require_for_arg a, t)) - | None -> None ) - -let return_arg_types (args : Program.fun_arg_decl) = - List.mapi args ~f:(fun i ((_, _, ut) as a) -> - if UnsizedType.is_eigen_type ut && arg_needs_template a then - Some (sprintf "stan::value_type_t" i) - else if arg_needs_template a then Some (sprintf "T%d__" i) - else None ) - -let%expect_test "arg types templated correctly" = - [(AutoDiffable, "xreal", UReal); (DataOnly, "yint", UInt)] - |> maybe_templated_arg_types |> List.filter_opt |> String.concat ~sep:"," - |> print_endline ; - [%expect {| T0__ |}] - -(** Print the code for promoting stan real types - @param ppf A pretty printer - @param args A pack of arguments to detect whether they need to use the promotion rules. - *) -let pp_promoted_scalar ppf args = - match args with - | [] -> pf ppf "double" - | _ -> - let rec promote_args_chunked ppf args = - let go ppf tl = - match tl with [] -> () | _ -> pf ppf ",@ %a" promote_args_chunked tl - in - match args with - | [] -> pf ppf "double" - | hd :: tl -> - pf ppf "@[stan::promote_args_t<@[%a%a@]>@]" (list ~sep:comma string) - hd go tl in - promote_args_chunked ppf - List.(chunks_of ~length:5 (filter_opt (return_arg_types args))) - -(** Pretty-prints a function's return-type, taking into account templated argument - promotion.*) -let pp_returntype ppf arg_types rt = - let scalar = str "@[%a@]" pp_promoted_scalar arg_types in - match rt with - | Some ut when UnsizedType.is_int_type ut -> - pf ppf "%a@ " pp_unsizedtype_custom_scalar ("int", ut) - | Some ut -> pf ppf "%a@ " pp_unsizedtype_custom_scalar (scalar, ut) - | None -> pf ppf "void@ " - -let pp_eigen_arg_to_ref ppf arg_types = - let pp_ref ppf name = - pf ppf "@[const auto& %s = stan::math::to_ref(%s);@]" name - (name ^ "_arg__") in - pf ppf "@[%a@]@ " (list ~sep:cut pp_ref) - (List.filter_map - ~f:(fun (_, name, ut) -> - if UnsizedType.is_eigen_type ut then Some name else None ) - arg_types ) - -(** [pp_located_error ppf (pp_body_block, body_block, err_msg)] surrounds [body_block] - with a C++ try-catch that will rethrow the error with the proper source location - from the [body_block] (required to be a [stmt_loc Block] variant). - @param ppf A pretty printer. - @param pp_body_block A pretty printer for the body block - @param body A C++ scoped body block surrounded by squiggly braces. - *) -let pp_located_error ppf (pp_body_block, body) = - pf ppf "@ try %a" pp_body_block body ; - string ppf " catch (const std::exception& e) " ; - pp_block ppf (pp_located, ()) - -(** Print the type of an object. - @param ppf A pretty printer - @param custom_scalar_opt A string representing a types inner scalar value. - @param name The name of the object - @param ut The unsized type of the object - *) -let pp_arg ppf (custom_scalar_opt, (_, name, ut)) = - let scalar = - match custom_scalar_opt with - | Some scalar -> scalar - | None -> stantype_prim_str ut in - (* we add the _arg suffix for any Eigen types *) - pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) - name - -let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) = - let scalar = - match custom_scalar_opt with - | Some scalar -> scalar - | None -> stantype_prim_str ut in - (* we add the _arg suffix for any Eigen types *) - let opt_arg_suffix = - if UnsizedType.is_eigen_type ut then name ^ "_arg__" else name in - pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) - opt_arg_suffix - -(** [pp_located_error_b] automatically adds a Block wrapper *) -let pp_located_error_b ppf body_stmts = - pp_located_error ppf - ( pp_statement - , Stmt.Fixed.{pattern= Block body_stmts; meta= Locations.no_span_num} ) - -let typename t = Typename t - -(** Construct an object with it's needed templates for function signatures. - @param fdargs A sexp list of strings representing C++ types. - *) -let get_templates_and_args exprs fdargs = - let argtypetemplates = maybe_templated_arg_types fdargs in - let requireargtemplates = maybe_require_templates argtypetemplates fdargs in - ( List.filter_opt argtypetemplates - , List.filter_opt requireargtemplates - , if not exprs then - List.map - ~f:(fun a -> str "%a" pp_arg a) - (List.zip_exn argtypetemplates fdargs) - else - List.map - ~f:(fun a -> str "%a" pp_arg_eigen_suffix a) - (List.zip_exn argtypetemplates fdargs) ) - -(** Print the C++ template parameter decleration before a function. - @param ppf A pretty printer. - *) -let pp_template_decorator ppf = function - | [] -> () - | templates -> - pf ppf "@[template <@[%a>@]@]@ " (list ~sep:comma string) templates - -let mk_extra_args templates args = - List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args) - -(** Print the C++ function definition. - @param ppf A pretty printer - Refactor this please - one idea might be to have different functions for - printing user defined distributions vs rngs vs regular functions. -*) -let pp_fun_def ppf - ( Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} - , (functors : (string, found_functor list) Hashtbl.t) - , (forward_decls : (string * template list) Hash_set.t) - , (funs_used_in_reduce_sum : String.Set.t) - , (funs_used_in_variadic_ode : String.Set.t) - , (funs_used_in_variadic_dae : String.Set.t) ) = - let extra, extra_templates = - match fdsuffix with - | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) - | FnRng -> (["base_rng__"], ["RNG"]) - | FnLpdf _ | FnPlain -> ([], []) in - let pp_body ppf (Stmt.Fixed.{pattern; _} as fdbody) = - pf ppf "@[using local_scalar_t__ =@ %a;@]@," pp_promoted_scalar fdargs ; - pf ppf "int current_statement__ = 0; @ " ; - if List.exists ~f:(fun (_, _, t) -> UnsizedType.is_eigen_type t) fdargs then - pp_eigen_arg_to_ref ppf fdargs ; - ( match fdsuffix with - | FnLpdf _ | FnTarget -> () - | FnPlain | FnRng -> - pf ppf "%s@ " "static constexpr bool propto__ = true;" ; - pf ppf "%s@ " "(void) propto__;" ) ; - pf ppf "%s@ " - "local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN());" ; - pf ppf "%a" pp_unused "DUMMY_VAR__" ; - let blocked_fdbody = - match pattern with - | SList stmts -> {fdbody with pattern= Block stmts} - | Block _ -> fdbody - | _ -> {fdbody with pattern= Block [fdbody]} in - pp_located_error ppf (pp_statement, blocked_fdbody) ; - pf ppf "@ " in - let get_templates exprs variadic = - let argtypetemplates, require_templates, args = - get_templates_and_args exprs fdargs in - let templates = - List.(map ~f:typename (argtypetemplates @ extra_templates)) - @ require_templates in - match (fdsuffix, variadic) with - | (FnLpdf _ | FnTarget), `None -> (Bool "propto__" :: templates, args) - | _ -> (templates, args) in - let pp_sig ppf (name, args, variadic) = - Format.open_vbox 2 ; - pp_returntype ppf fdargs fdrt ; - let args, variadic_args = - match variadic with - | `ReduceSum -> List.split_n args 3 - | `VariadicODE -> List.split_n args 2 - | `VariadicDAE -> List.split_n args 3 - | `None -> (args, []) in - let arg_strs = - args - @ mk_extra_args extra_templates extra - @ ["std::ostream* pstream__"] - @ variadic_args in - pf ppf "%s(@[%a@]) " name (list ~sep:comma string) arg_strs ; - Format.close_box () in - let templates, templated_args = get_templates true `None in - let signature = str "%a" pp_sig (fdname, templated_args, `None) in - (* We want to print the [* = nullptr] at most once, and preferrably on a forward decl *) - let defaults = - Option.is_none fdbody - || not (Hash_set.mem forward_decls (signature, templates)) in - pf ppf "%a%a" (pp_templates ~defaults) templates pp_sig - (fdname, templated_args, `None) ; - match fdbody with - | None -> - pf ppf ";@ " ; - Hash_set.add forward_decls (signature, templates) - | Some fdbody -> - pp_block ppf (pp_body, fdbody) ; - let register_functor (str_args, args, variadic) = - let suffix = - match variadic with - | `None -> functor_suffix - | `ReduceSum -> reduce_sum_functor_suffix - | `VariadicODE -> variadic_ode_functor_suffix - | `VariadicDAE -> variadic_dae_functor_suffix in - let functor_name = fdname ^ suffix in - let struct_template = - match (fdsuffix, variadic) with - | FnLpdf _, `ReduceSum -> Some (Bool "propto__") - | _ -> None in - let arg_templates, templated_args = get_templates false variadic in - let op_signature = - str "%a" pp_sig ("operator()", templated_args, variadic) in - let defn = - str "%a@ const@,{@. return %a;@.}@." pp_sig - ( functor_name - ^ (if struct_template <> None then "" else "") - ^ "::operator()" - , templated_args - , variadic ) - pp_call_str - ( ( match fdsuffix with - | FnLpdf _ | FnTarget -> fdname ^ "" - | _ -> fdname ) - , str_args - @ List.map ~f:(fun (_, name, _) -> name) args - @ extra @ ["pstream__"] ) in - Hashtbl.add_multi functors ~key:functor_name - ~data:{struct_template; arg_templates; signature= op_signature; defn} - in - register_functor ([], fdargs, `None) ; - if String.Set.mem funs_used_in_reduce_sum fdname then - (* Produces the reduce_sum functors that has the pstream argument - as the third and not last argument *) - match fdargs with - | (_, slice, _) :: (_, start, _) :: (_, end_, _) :: rest -> - register_functor - ([slice; start ^ " + 1"; end_ ^ " + 1"], rest, `ReduceSum) - | _ -> - Common.FatalError.fatal_error_msg - [%message - "Ill-formed reduce_sum call!" (fdargs : Program.fun_arg_decl)] - else if String.Set.mem funs_used_in_variadic_ode fdname then - (* Produces the variadic ode functors that has the pstream argument - as the third and not last argument *) - register_functor ([], fdargs, `VariadicODE) - else if String.Set.mem funs_used_in_variadic_dae fdname then - (* Produces the variadic DAE functors that has the pstream argument - as the fourth and not last argument *) - register_functor ([], fdargs, `VariadicDAE) - -(** Creates functions outside the model namespaces which only call the ones - inside the namespaces *) -let pp_standalone_fun_def namespace_fun ppf - Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} = - let extra, extra_templates = - match fdsuffix with - | Fun_kind.FnTarget -> - (["lp__"; "lp_accum__"], ["double"; "stan::math::accumulator"]) - | FnRng -> (["base_rng__"], ["boost::ecuyer1988"]) - | FnLpdf _ | FnPlain -> ([], []) in - let args = - List.map - ~f:(fun (_, name, ut) -> - str "const %a& %s" pp_unsizedtype_custom_scalar - (stantype_prim_str ut, ut) - name ) - fdargs in - let pp_sig_standalone ppf _ = - let arg_strs = - args - @ mk_extra_args extra_templates extra - @ ["std::ostream* pstream__ = nullptr"] in - pf ppf "(@[%a@]) " (list ~sep:comma string) arg_strs in - let mark_function_comment = "// [[stan::function]]" in - let return_type = match fdrt with None -> "void" | _ -> "auto" in - let return_stmt = match fdrt with None -> "" | _ -> "return " in - match fdbody with - | None -> pf ppf ";@ " - | Some _ -> - pf ppf "@,%s@,%s %s%a @,{@, %s%s::%a;@,}@," mark_function_comment - return_type fdname pp_sig_standalone "" return_stmt namespace_fun - pp_call_str - ( ( match fdsuffix with - | FnLpdf _ | FnTarget -> fdname ^ "" - | FnRng | FnPlain -> fdname ) - , List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"] - ) - let version = "// Code generated by %%NAME%% %%VERSION%%" let includes = "#include " @@ -1004,75 +643,12 @@ let pp_register_map_rect_functors ppf p = (list ~sep:cut pp_register_functor) (List.sort ~compare (Hashtbl.to_alist map_rect_calls)) -let is_fun_used_with_variadic_fn variadic_fn_test p = - let rec find_functors_expr accum Expr.Fixed.{pattern; _} = - String.Set.union accum - ( match pattern with - | FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _) - when variadic_fn_test x -> - String.Set.of_list [Utils.stdlib_distribution_name f] - | x -> Expr.Fixed.Pattern.fold find_functors_expr accum x ) in - let rec find_functors_stmt accum stmt = - Stmt.Fixed.( - Pattern.fold find_functors_expr find_functors_stmt accum stmt.pattern) - in - Program.fold find_functors_expr find_functors_stmt String.Set.empty p - -let collect_functors_functions p = - let (functors : (string, found_functor list) Hashtbl.t) = - String.Table.create () in - let forward_decls = Hash_set.Poly.create () in - let reduce_sum_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p in - let variadic_ode_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p in - let variadic_dae_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p in - let pp_fun_def_with_variadic_fn_list ppf fblock = - (hovbox ~indent:2 pp_fun_def) - ppf - ( fblock - , functors - , forward_decls - , reduce_sum_fns - , variadic_ode_fns - , variadic_dae_fns ) in - ( str "@[%a@]" - (list ~sep:cut pp_fun_def_with_variadic_fn_list) - p.functions_block - , functors ) - -(** Print the full C++ for the stan program. *) let pp_prog ppf (p : Program.Typed.t) = Hashtbl.clear Expression_gen.map_rect_calls ; (* First, do some transformations on the MIR itself before we begin printing it.*) let p, s = Locations.prepare_prog p in - let fns_str, functors = collect_functors_functions p in - let pp_functor_decls ppf tbl = - Hashtbl.iteri tbl ~f:(fun ~key ~data -> - pf ppf "@[%astruct %s {@,%aconst;@]@,};@." - (option - (option (fun ppf t -> - pf ppf "template <%a>@ " pp_template_defaults t ) ) ) - (Option.map ~f:(fun x -> x.struct_template) (List.hd data)) - key - (list ~sep:(any "const;@,") (fun ppf (ts, sign) -> - pf ppf "%a@[%a@]" (pp_templates ~defaults:true) ts text sign ) - ) - (List.map - ~f:(fun {arg_templates; signature; _} -> (arg_templates, signature)) - data ) ) in - let pp_functors ppf tbl = - Hashtbl.iter tbl ~f:(fun data -> - List.iter data ~f:(fun {struct_template; defn; arg_templates; _} -> - pf ppf "%a%a%s@." - (option (fun ppf t -> pf ppf "template <%a>@ " pp_template t)) - struct_template - (pp_templates ~defaults:false) - arg_templates defn ) ) in - pf ppf "@[@ %s@ %s@ namespace %s {@ %s@ %a@ %a@ %s@ %a@ %a@ }@ @]" version - includes (namespace p) usings Locations.pp_globals s pp_functor_decls - functors fns_str pp_functors functors + pf ppf "@[@ %s@ %s@ namespace %s {@ %s@ %a@ %a@ %a@ }@ @]" version includes + (namespace p) usings Locations.pp_globals s pp_functions_functors p (if !standalone_functions then fun _ _ -> () else pp_model) p ; if !standalone_functions then diff --git a/src/stan_math_backend/Stan_math_code_gen.mli b/src/stan_math_backend/Stan_math_code_gen.mli new file mode 100644 index 0000000000..03cb323745 --- /dev/null +++ b/src/stan_math_backend/Stan_math_code_gen.mli @@ -0,0 +1,10 @@ +val model_prefix : string +(** A string put in front of model names in the namespace. + Currently "model_" + *) + +val standalone_functions : bool ref +(** Flag to generate just function code, used in RStan *) + +val pp_prog : Format.formatter -> Middle.Program.Typed.t -> unit +(** Print the full C++ for the stan program. *) diff --git a/test/unit/Stan_math_code_gen_tests.ml b/test/unit/Stan_math_code_gen_tests.ml index 4227b3fc52..e89208e4cf 100644 --- a/test/unit/Stan_math_code_gen_tests.ml +++ b/test/unit/Stan_math_code_gen_tests.ml @@ -1,113 +1 @@ -open Middle -open Stan_math_backend -open Core_kernel -open Fmt -open Stan_math_code_gen - -let pp_fun_def_w_rs a b = - pp_fun_def a - ( b - , String.Table.create () - , Hash_set.Poly.create () - , String.Set.empty - , String.Set.empty - , String.Set.empty ) - -let%expect_test "udf" = - let with_no_loc stmt = - Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in - let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - { fdrt= None - ; fdname= "sars" - ; fdsuffix= FnPlain - ; fdargs= [(DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector)] - ; fdbody= - Stmt.Fixed.Pattern.Return - (Some - ( w - @@ FunApp - ( StanLib ("add", FnPlain, AoS) - , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) - |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc - |> Some - ; fdloc= Location_span.empty } - |> str "@[%a" pp_fun_def_w_rs - |> print_endline ; - [%expect - {| - template * = nullptr, - stan::require_row_vector_t* = nullptr> - void - sars(const T0__& x_arg__, const T1__& y_arg__, std::ostream* pstream__) { - using local_scalar_t__ = - stan::promote_args_t, - stan::value_type_t>; - int current_statement__ = 0; - const auto& x = stan::math::to_ref(x_arg__); - const auto& y = stan::math::to_ref(y_arg__); - static constexpr bool propto__ = true; - (void) propto__; - local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); - (void) DUMMY_VAR__; // suppress unused var warning - try { - return stan::math::add(x, 1); - } catch (const std::exception& e) { - stan::lang::rethrow_located(e, locations_array__[current_statement__]); - } - - } |}] - -let%expect_test "udf-expressions" = - let with_no_loc stmt = - Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in - let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - { fdrt= Some UMatrix - ; fdname= "sars" - ; fdsuffix= FnPlain - ; fdargs= - [ (DataOnly, "x", UMatrix); (AutoDiffable, "y", URowVector) - ; (AutoDiffable, "z", URowVector); (AutoDiffable, "w", UArray UMatrix) ] - ; fdbody= - Stmt.Fixed.Pattern.Return - (Some - ( w - @@ FunApp - ( StanLib ("add", FnPlain, AoS) - , [w @@ Var "x"; w @@ Lit (Int, "1")] ) ) ) - |> with_no_loc |> List.return |> Stmt.Fixed.Pattern.Block |> with_no_loc - |> Some - ; fdloc= Location_span.empty } - |> str "@[%a" pp_fun_def_w_rs - |> print_endline ; - [%expect - {| - template * = nullptr, - stan::require_row_vector_t* = nullptr, - stan::require_row_vector_t* = nullptr, - stan::require_stan_scalar_t* = nullptr> - Eigen::Matrix, stan::value_type_t, - stan::value_type_t, T3__>, -1, -1> - sars(const T0__& x_arg__, const T1__& y_arg__, const T2__& z_arg__, - const std::vector>& w, - std::ostream* pstream__) { - using local_scalar_t__ = - stan::promote_args_t, - stan::value_type_t, - stan::value_type_t, T3__>; - int current_statement__ = 0; - const auto& x = stan::math::to_ref(x_arg__); - const auto& y = stan::math::to_ref(y_arg__); - const auto& z = stan::math::to_ref(z_arg__); - static constexpr bool propto__ = true; - (void) propto__; - local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN()); - (void) DUMMY_VAR__; // suppress unused var warning - try { - return stan::math::add(x, 1); - } catch (const std::exception& e) { - stan::lang::rethrow_located(e, locations_array__[current_statement__]); - } - - } |}] +(* TODO proper unit tests or delete file *) From 5339ebbf9e492c948aabc3f6e064a72d167db9c5 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 17 Mar 2022 12:09:31 -0400 Subject: [PATCH 2/6] Delete unused code --- src/stan_math_backend/Expression_gen.ml | 2 -- src/stan_math_backend/Statement_gen.ml | 34 ------------------------- 2 files changed, 36 deletions(-) diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index 4fffb3a893..714083e26e 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -10,8 +10,6 @@ let stan_namespace_qualify f = if String.is_suffix ~suffix:"functor__" f || String.contains f ':' then f else "stan::math::" ^ f -let is_stan_math f = ends_with "__" f || starts_with "stan::math::" f - (* retun true if the type of the expression is integer, real, or complex (e.g. not a container) *) let is_scalar e = diff --git a/src/stan_math_backend/Statement_gen.ml b/src/stan_math_backend/Statement_gen.ml index 212c31c6ff..b5aca31278 100644 --- a/src/stan_math_backend/Statement_gen.ml +++ b/src/stan_math_backend/Statement_gen.ml @@ -15,31 +15,10 @@ let pp_profile ppf (pp_body, name, body) = name in pf ppf "{@;<1 2>@[%a@;@;%a@]@,}" profile name pp_body body -let rec contains_eigen (ut : UnsizedType.t) : bool = - match ut with - | UnsizedType.UArray t -> contains_eigen t - | UMatrix | URowVector | UVector -> true - | UInt | UReal | UComplex | UMathLibraryFunction | UFun _ -> false - -(*Fill only needs to happen for containers - * Note: This should probably be moved into its own function as data - * does not need to be filled as we are promised user input data has the correct - * dimensions. Transformed data must be filled as incorrect slices could lead - * to elements of objects in transform data not being set by the user. -*) -let pp_filler ppf (decl_id, st, nan_type, needs_filled) = - match (needs_filled, contains_eigen (SizedType.to_unsized st)) with - | true, true -> - pf ppf "@[stan::math::initialize_fill(%s, %s);@]@," decl_id - nan_type - | _ -> () - (*Pretty print a sized type*) let pp_st ppf (st, adtype) = pf ppf "%a" pp_unsizedtype_local (adtype, SizedType.to_unsized st) -let pp_ut ppf (ut, adtype) = pf ppf "%a" pp_unsizedtype_local (adtype, ut) - (*Get a string representing for the NaN type of the given type *) let nan_type (st, adtype) = match (adtype, st) with @@ -201,11 +180,6 @@ let pp_for_loop ppf (loopvar, lower, upper, pp_body, body) = loopvar pp_expr upper loopvar ; pf ppf " %a@]" pp_body body -let rec integer_el_type = function - | SizedType.SInt -> true - | SArray (st, _) -> integer_el_type st - | _ -> false - (** Print the private members of the model class Accounting for types that can be moved to OpenCL. @@ -228,14 +202,6 @@ let pp_data_decl ppf (vident, ut) = | _ -> pf ppf "%a %s;" pp_type (DataOnly, ut) vident ) | (true, _), _ -> pf ppf "%a %s;" pp_type (DataOnly, ut) vident -(** Create string representations for [vars__.emplace_back] *) -let pp_emplace_var ppf var = - match Expr.Typed.type_of var with - | UnsizedType.UComplex -> - pf ppf "@[vars__.emplace_back(%a.real());@]@," pp_expr var ; - pf ppf "@[vars__.emplace_back(%a.imag());@]" pp_expr var - | _ -> pf ppf "@[vars__.emplace_back(@,%a);@]" pp_expr var - (** Create strings representing maps of Eigen types*) let pp_map_decl ppf (vident, ut) = let scalar = local_scalar ut DataOnly in From b48a7d2d892abfc8c6c02c735c87a2aac0766fcd Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 17 Mar 2022 12:23:13 -0400 Subject: [PATCH 3/6] Move common functions into statement gen --- src/stan_math_backend/Cpp_util.ml | 28 --------------------- src/stan_math_backend/Function_gen.ml | 1 - src/stan_math_backend/Stan_math_code_gen.ml | 1 - src/stan_math_backend/Statement_gen.ml | 24 ++++++++++++++++++ 4 files changed, 24 insertions(+), 30 deletions(-) delete mode 100644 src/stan_math_backend/Cpp_util.ml diff --git a/src/stan_math_backend/Cpp_util.ml b/src/stan_math_backend/Cpp_util.ml deleted file mode 100644 index f45d1d03d3..0000000000 --- a/src/stan_math_backend/Cpp_util.ml +++ /dev/null @@ -1,28 +0,0 @@ -open Fmt -open Statement_gen -open Middle - -let pp_unused = fmt "(void) %s; // suppress unused var warning" - -(** Print the body of exception handling for functions *) -let pp_located ppf _ = - pf ppf - {|stan::lang::rethrow_located(e, locations_array__[current_statement__]);|} - -(** [pp_located_error ppf (pp_body_block, body_block, err_msg)] surrounds [body_block] - with a C++ try-catch that will rethrow the error with the proper source location - from the [body_block] (required to be a [stmt_loc Block] variant). - @param ppf A pretty printer. - @param pp_body_block A pretty printer for the body block - @param body A C++ scoped body block surrounded by squiggly braces. - *) -let pp_located_error ppf (pp_body_block, body) = - pf ppf "@ try %a" pp_body_block body ; - string ppf " catch (const std::exception& e) " ; - pp_block ppf (pp_located, ()) - -(** [pp_located_error_b] automatically adds a Block wrapper *) -let pp_located_error_b ppf body_stmts = - pp_located_error ppf - ( pp_statement - , Stmt.Fixed.{pattern= Block body_stmts; meta= Locations.no_span_num} ) diff --git a/src/stan_math_backend/Function_gen.ml b/src/stan_math_backend/Function_gen.ml index 451ae5652f..89f9bb52fb 100644 --- a/src/stan_math_backend/Function_gen.ml +++ b/src/stan_math_backend/Function_gen.ml @@ -6,7 +6,6 @@ open Middle open Fmt open Expression_gen open Statement_gen -open Cpp_util type template = | Typename of string diff --git a/src/stan_math_backend/Stan_math_code_gen.ml b/src/stan_math_backend/Stan_math_code_gen.ml index 5cd5e16827..2aabd43321 100644 --- a/src/stan_math_backend/Stan_math_code_gen.ml +++ b/src/stan_math_backend/Stan_math_code_gen.ml @@ -23,7 +23,6 @@ open Fmt open Expression_gen open Statement_gen open Function_gen -open Cpp_util let standalone_functions = ref false diff --git a/src/stan_math_backend/Statement_gen.ml b/src/stan_math_backend/Statement_gen.ml index b5aca31278..e7f5b36ffd 100644 --- a/src/stan_math_backend/Statement_gen.ml +++ b/src/stan_math_backend/Statement_gen.ml @@ -6,6 +6,12 @@ open Expression_gen let pp_call_str ppf (name, args) = pp_call ppf (name, string, args) let pp_block ppf (pp_body, body) = pf ppf "{@;<1 2>@[%a@]@,}" pp_body body +let pp_unused = fmt "(void) %s; // suppress unused var warning" + +(** Print the body of exception handling for functions *) +let pp_located ppf _ = + pf ppf + {|stan::lang::rethrow_located(e, locations_array__[current_statement__]);|} let pp_profile ppf (pp_body, name, body) = let profile ppf name = @@ -399,3 +405,21 @@ and pp_block_s ppf body = match body.pattern with | Block ls -> pp_block ppf (list ~sep:cut pp_statement, ls) | _ -> pp_block ppf (pp_statement, body) + +(** [pp_located_error ppf (pp_body_block, body_block, err_msg)] surrounds [body_block] + with a C++ try-catch that will rethrow the error with the proper source location + from the [body_block] (required to be a [stmt_loc Block] variant). + @param ppf A pretty printer. + @param pp_body_block A pretty printer for the body block + @param body A C++ scoped body block surrounded by squiggly braces. + *) +let pp_located_error ppf (pp_body_block, body) = + pf ppf "@ try %a" pp_body_block body ; + string ppf " catch (const std::exception& e) " ; + pp_block ppf (pp_located, ()) + +(** [pp_located_error_b] automatically adds a Block wrapper *) +let pp_located_error_b ppf body_stmts = + pp_located_error ppf + ( pp_statement + , Stmt.Fixed.{pattern= Block body_stmts; meta= Locations.no_span_num} ) From c1114b56bd997fb4f369e33e4bf6019469a8fd2f Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 22 Mar 2022 14:55:07 -0400 Subject: [PATCH 4/6] make names a little nicer, pull out some functions from the larger pretty printers --- src/stan_math_backend/Function_gen.ml | 349 ++++++++++++++------------ 1 file changed, 189 insertions(+), 160 deletions(-) diff --git a/src/stan_math_backend/Function_gen.ml b/src/stan_math_backend/Function_gen.ml index 89f9bb52fb..652942a9b8 100644 --- a/src/stan_math_backend/Function_gen.ml +++ b/src/stan_math_backend/Function_gen.ml @@ -7,112 +7,110 @@ open Fmt open Expression_gen open Statement_gen -type template = +(** + * Typename: The name of a template typename + * Require: One of Stan's C++ template require giving a condition and the template names needing to satisfy that condition. + * Bool: A boolean template type + *) +type template_parameter = | Typename of string | Require of string * string | Bool of string -(* and stmt = - | Struct of {templates: template list; name: string; body: stmt list} - | FunDef of - { name: string - ; templates: template list - ; args: string list - ; return: string - ; body: stmt list option } - | Generated of string - (** Placeholder for all other C++. - We should slowly move more and more statements - away from strings *) *) - -let pp_template ppf template = - match template with - | Typename t -> pf ppf "typename %s" t - | Require (r, t) -> pf ppf "%s<%s>*" r t - | Bool s -> pf ppf "bool %s" s - -let pp_template_defaults ppf template = - match template with - | Require _ -> pf ppf "%a = nullptr" pp_template template - | _ -> pp_template ppf template - -let pp_templates ~defaults ppf templates = - match templates with +let pp_template_parameter ppf template_parameter = + match template_parameter with + | Typename param_name -> pf ppf "typename %s" param_name + | Require (requirement, param_name) -> pf ppf "%s<%s>*" requirement param_name + | Bool param_name -> pf ppf "bool %s" param_name + +let pp_template_parameter_defaults ppf template_parameter = + match template_parameter with + | Require _ -> pf ppf "%a = nullptr" pp_template_parameter template_parameter + | _ -> pp_template_parameter ppf template_parameter + +(** + * Pretty print a full C++ `template ` + *) +let pp_template ~defaults ppf template_parameters = + match template_parameters with | [] -> () | _ -> pf ppf "template <@[%a@]>@ " (list ~sep:comma - (if defaults then pp_template_defaults else pp_template) ) - templates + ( if defaults then pp_template_parameter_defaults + else pp_template_parameter ) ) + template_parameters type found_functor = - { struct_template: template option - ; arg_templates: template list + { struct_template: template_parameter option + ; arg_templates: template_parameter list ; signature: string ; defn: string } (** Detect if argument requires C++ template *) -let arg_needs_template = function - | UnsizedType.DataOnly, _, t -> UnsizedType.is_eigen_type t +let is_data_matrix_or_not_int_type = function + | UnsizedType.DataOnly, _, ut -> UnsizedType.is_eigen_type ut | _, _, t when UnsizedType.is_int_type t -> false | _ -> true (** Print template arguments for C++ functions that need templates -@param args A pack of [Program.fun_arg_decl] containing functions to detect templates. -@return A list of arguments with template parameter names added. -*) -let maybe_templated_arg_types (args : Program.fun_arg_decl) = - List.mapi args ~f:(fun i a -> - match arg_needs_template a with + * @param args A pack of [Program.fun_arg_decl] containing functions to detect templates. + * @return A list of arguments with template parameter names added. + *) +let template_parameter_names (args : Program.fun_arg_decl) = + List.mapi args ~f:(fun i arg -> + match is_data_matrix_or_not_int_type arg with | true -> Some (sprintf "T%d__" i) | false -> None ) -let maybe_require_templates (names : string option list) +let requires (_, _, ut) = + match ut with + | UnsizedType.URowVector -> "stan::require_row_vector_t" + | UVector -> "stan::require_col_vector_t" + | UMatrix -> "stan::require_eigen_matrix_dynamic_t" + (* NB: Not unwinding array types due to the way arrays of eigens are printed *) + | _ -> "stan::require_stan_scalar_t" + +let optional_require_templates (name_ops : string option list) (args : Program.fun_arg_decl) = - let require_for_arg arg = - match trd3 arg with - | UnsizedType.URowVector -> "stan::require_row_vector_t" - | UVector -> "stan::require_col_vector_t" - | UMatrix -> "stan::require_eigen_matrix_dynamic_t" - (* NB: Not unwinding array types due to the way arrays of eigens are printed *) - | _ -> "stan::require_stan_scalar_t" in - List.map2_exn names args ~f:(fun name a -> - match name with - | Some t -> Some (Require (require_for_arg a, t)) + List.map2_exn name_ops args ~f:(fun name_op fun_arg -> + match name_op with + | Some param_name -> Some (Require (requires fun_arg, param_name)) | None -> None ) -let return_arg_types (args : Program.fun_arg_decl) = - List.mapi args ~f:(fun i ((_, _, ut) as a) -> - if UnsizedType.is_eigen_type ut && arg_needs_template a then +let return_optional_arg_types (args : Program.fun_arg_decl) = + List.mapi args ~f:(fun i ((_, _, ut) as arg) -> + if UnsizedType.is_eigen_type ut && is_data_matrix_or_not_int_type arg then Some (sprintf "stan::value_type_t" i) - else if arg_needs_template a then Some (sprintf "T%d__" i) + else if is_data_matrix_or_not_int_type arg then Some (sprintf "T%d__" i) else None ) let%expect_test "arg types templated correctly" = [(AutoDiffable, "xreal", UReal); (DataOnly, "yint", UInt)] - |> maybe_templated_arg_types |> List.filter_opt |> String.concat ~sep:"," + |> template_parameter_names |> List.filter_opt |> String.concat ~sep:"," |> print_endline ; [%expect {| T0__ |}] (** Print the code for promoting stan real types -@param ppf A pretty printer -@param args A pack of arguments to detect whether they need to use the promotion rules. -*) + * @param ppf A pretty printer$ + * @param args A pack of arguments to detect whether they need to use the promotion rules. + *) let pp_promoted_scalar ppf args = match args with | [] -> pf ppf "double" | _ -> let rec promote_args_chunked ppf args = - let go ppf tl = - match tl with [] -> () | _ -> pf ppf ",@ %a" promote_args_chunked tl - in + let chunk_till_empty ppf list_tail = + match list_tail with + | [] -> () + | _ -> pf ppf ",@ %a" promote_args_chunked list_tail in match args with | [] -> pf ppf "double" - | hd :: tl -> + | hd :: list_tail -> pf ppf "@[stan::promote_args_t<@[%a%a@]>@]" (list ~sep:comma string) - hd go tl in + hd chunk_till_empty list_tail in promote_args_chunked ppf - List.(chunks_of ~length:5 (filter_opt (return_arg_types args))) + List.(chunks_of ~length:5 (filter_opt (return_optional_arg_types args))) (** Pretty-prints a function's return-type, taking into account templated argument promotion.*) @@ -135,10 +133,10 @@ let pp_eigen_arg_to_ref ppf arg_types = arg_types ) (** Print the type of an object. - @param ppf A pretty printer - @param custom_scalar_opt A string representing a types inner scalar value. - @param name The name of the object - @param ut The unsized type of the object + * @param ppf A pretty printer + * @param custom_scalar_opt A string representing a types inner scalar value. + * @param name The name of the object + * @param ut The unsized type of the object *) let pp_arg ppf (custom_scalar_opt, (_, name, ut)) = let scalar = @@ -160,126 +158,146 @@ let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) = pf ppf "const %a& %s" pp_unsizedtype_custom_scalar_eigen_exprs (scalar, ut) opt_arg_suffix -let typename t = Typename t +let typename parameter_name = Typename parameter_name (** Construct an object with it's needed templates for function signatures. - @param fdargs A sexp list of strings representing C++ types. - *) -let get_templates_and_args exprs fdargs = - let argtypetemplates = maybe_templated_arg_types fdargs in - let requireargtemplates = maybe_require_templates argtypetemplates fdargs in - ( List.filter_opt argtypetemplates - , List.filter_opt requireargtemplates - , if not exprs then + * @param is_possibly_eigen_expr if true, argument can possibly be an unevaluated eigen expression. + * @param fdargs A sexp list of strings representing C++ types. + *) +let templates_and_args (is_possibly_eigen_expr : bool) + (fdargs : Program.fun_arg_decl) : + string list * template_parameter list * string list = + let arg_type_templates = template_parameter_names fdargs in + let require_arg_templates = + optional_require_templates arg_type_templates fdargs in + ( List.filter_opt arg_type_templates + , List.filter_opt require_arg_templates + , if not is_possibly_eigen_expr then List.map ~f:(fun a -> str "%a" pp_arg a) - (List.zip_exn argtypetemplates fdargs) + (List.zip_exn arg_type_templates fdargs) else List.map ~f:(fun a -> str "%a" pp_arg_eigen_suffix a) - (List.zip_exn argtypetemplates fdargs) ) + (List.zip_exn arg_type_templates fdargs) ) let mk_extra_args templates args = List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args) +(** + * Prints boilerplate at start of function. Body of function wrapped in a `try` block. + *) +let pp_fun_body fdargs fdsuffix ppf (Stmt.Fixed.{pattern; _} as fdbody) = + pf ppf "@[using local_scalar_t__ =@ %a;@]@," pp_promoted_scalar fdargs ; + pf ppf "int current_statement__ = 0; @ " ; + if List.exists ~f:(fun (_, _, ut) -> UnsizedType.is_eigen_type ut) fdargs then + pp_eigen_arg_to_ref ppf fdargs ; + ( match fdsuffix with + | Fun_kind.FnLpdf _ | FnTarget -> () + | FnPlain | FnRng -> + pf ppf "%s@ " "static constexpr bool propto__ = true;" ; + pf ppf "%s@ " "(void) propto__;" ) ; + pf ppf "%s@ " + "local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN());" ; + pf ppf "%a" pp_unused "DUMMY_VAR__" ; + let blocked_fdbody = + match pattern with + | SList stmts -> {fdbody with pattern= Block stmts} + | Block _ -> fdbody + | _ -> {fdbody with pattern= Block [fdbody]} in + pp_located_error ppf (pp_statement, blocked_fdbody) ; + pf ppf "@ " + +(** + * Functor to generate a pretty printer for the function signature. + *) +let gen_pp_sig fdargs fdrt extra_templates extra ppf (name, args, variadic) = + Format.open_vbox 2 ; + pp_returntype ppf fdargs fdrt ; + let args, variadic_args = + match variadic with + | `ReduceSum -> List.split_n args 3 + | `VariadicODE -> List.split_n args 2 + | `VariadicDAE -> List.split_n args 3 + | `None -> (args, []) in + let arg_strs = + args + @ mk_extra_args extra_templates extra + @ ["std::ostream* pstream__"] + @ variadic_args in + pf ppf "%s(@[%a@]) " name (list ~sep:comma string) arg_strs ; + Format.close_box () + (** Print the C++ function definition. - @param ppf A pretty printer - Refactor this please - one idea might be to have different functions for - printing user defined distributions vs rngs vs regular functions. -*) + * @param ppf A pretty printer + * Refactor this please - one idea might be to have different functions for + * printing user defined distributions vs rngs vs regular functions. + *) let pp_fun_def ppf ( Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} , (functors : (string, found_functor list) Hashtbl.t) - , (forward_decls : (string * template list) Hash_set.t) + , (forward_decls : (string * template_parameter list) Hash_set.t) , (funs_used_in_reduce_sum : String.Set.t) , (funs_used_in_variadic_ode : String.Set.t) , (funs_used_in_variadic_dae : String.Set.t) ) = - let extra, extra_templates = + let extra, template_extra_params = match fdsuffix with | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) | FnRng -> (["base_rng__"], ["RNG"]) | FnLpdf _ | FnPlain -> ([], []) in - let pp_body ppf (Stmt.Fixed.{pattern; _} as fdbody) = - pf ppf "@[using local_scalar_t__ =@ %a;@]@," pp_promoted_scalar fdargs ; - pf ppf "int current_statement__ = 0; @ " ; - if List.exists ~f:(fun (_, _, t) -> UnsizedType.is_eigen_type t) fdargs then - pp_eigen_arg_to_ref ppf fdargs ; - ( match fdsuffix with - | FnLpdf _ | FnTarget -> () - | FnPlain | FnRng -> - pf ppf "%s@ " "static constexpr bool propto__ = true;" ; - pf ppf "%s@ " "(void) propto__;" ) ; - pf ppf "%s@ " - "local_scalar_t__ DUMMY_VAR__(std::numeric_limits::quiet_NaN());" ; - pf ppf "%a" pp_unused "DUMMY_VAR__" ; - let blocked_fdbody = - match pattern with - | SList stmts -> {fdbody with pattern= Block stmts} - | Block _ -> fdbody - | _ -> {fdbody with pattern= Block [fdbody]} in - pp_located_error ppf (pp_statement, blocked_fdbody) ; - pf ppf "@ " in - let get_templates exprs variadic = - let argtypetemplates, require_templates, args = - get_templates_and_args exprs fdargs in - let templates = - List.(map ~f:typename (argtypetemplates @ extra_templates)) - @ require_templates in - match (fdsuffix, variadic) with - | (FnLpdf _ | FnTarget), `None -> (Bool "propto__" :: templates, args) - | _ -> (templates, args) in - let pp_sig ppf (name, args, variadic) = - Format.open_vbox 2 ; - pp_returntype ppf fdargs fdrt ; - let args, variadic_args = - match variadic with - | `ReduceSum -> List.split_n args 3 - | `VariadicODE -> List.split_n args 2 - | `VariadicDAE -> List.split_n args 3 - | `None -> (args, []) in - let arg_strs = - args - @ mk_extra_args extra_templates extra - @ ["std::ostream* pstream__"] - @ variadic_args in - pf ppf "%s(@[%a@]) " name (list ~sep:comma string) arg_strs ; - Format.close_box () in - let templates, templated_args = get_templates true `None in - let signature = str "%a" pp_sig (fdname, templated_args, `None) in + let template_parameter_and_arg_names is_possibly_eigen_expr variadic_fun_type + = + let template_param_names, template_require_checks, args = + templates_and_args is_possibly_eigen_expr fdargs in + let template_params = + List.(map ~f:typename (template_param_names @ template_extra_params)) + @ template_require_checks in + match (fdsuffix, variadic_fun_type) with + | (FnLpdf _ | FnTarget), `None -> (Bool "propto__" :: template_params, args) + | _ -> (template_params, args) in + let template_params, templated_args = + template_parameter_and_arg_names true `None in + let pp_fun_sig = gen_pp_sig fdargs fdrt template_extra_params extra in + let signature = str "%a" pp_fun_sig (fdname, templated_args, `None) in (* We want to print the [* = nullptr] at most once, and preferrably on a forward decl *) - let defaults = + let template_parameter_default_values = Option.is_none fdbody - || not (Hash_set.mem forward_decls (signature, templates)) in - pf ppf "%a%a" (pp_templates ~defaults) templates pp_sig + || not (Hash_set.mem forward_decls (signature, template_params)) in + pf ppf "%a%a" + (pp_template ~defaults:template_parameter_default_values) + template_params pp_fun_sig (fdname, templated_args, `None) ; match fdbody with | None -> pf ppf ";@ " ; - Hash_set.add forward_decls (signature, templates) + (* Side Effect: *) + Hash_set.add forward_decls (signature, template_params) | Some fdbody -> - pp_block ppf (pp_body, fdbody) ; - let register_functor (str_args, args, variadic) = + pp_block ppf (pp_fun_body fdargs fdsuffix, fdbody) ; + let register_functor (str_args, args, variadic_fun_type) = let suffix = - match variadic with + match variadic_fun_type with | `None -> functor_suffix | `ReduceSum -> reduce_sum_functor_suffix | `VariadicODE -> variadic_ode_functor_suffix | `VariadicDAE -> variadic_dae_functor_suffix in let functor_name = fdname ^ suffix in let struct_template = - match (fdsuffix, variadic) with + match (fdsuffix, variadic_fun_type) with | FnLpdf _, `ReduceSum -> Some (Bool "propto__") | _ -> None in - let arg_templates, templated_args = get_templates false variadic in + let arg_templates, templated_args = + template_parameter_and_arg_names false variadic_fun_type in let op_signature = - str "%a" pp_sig ("operator()", templated_args, variadic) in - let defn = - str "%a@ const@,{@. return %a;@.}@." pp_sig + str "%a" pp_fun_sig ("operator()", templated_args, variadic_fun_type) + in + let operator_paren_sig = + str "%a@ const@,{@. return %a;@.}@." pp_fun_sig ( functor_name ^ (if struct_template <> None then "" else "") ^ "::operator()" , templated_args - , variadic ) + , variadic_fun_type ) pp_call_str ( ( match fdsuffix with | FnLpdf _ | FnTarget -> fdname ^ "" @@ -287,9 +305,13 @@ let pp_fun_def ppf , str_args @ List.map ~f:(fun (_, name, _) -> name) args @ extra @ ["pstream__"] ) in + (* Side Effect: *) Hashtbl.add_multi functors ~key:functor_name - ~data:{struct_template; arg_templates; signature= op_signature; defn} - in + ~data: + { struct_template + ; arg_templates + ; signature= op_signature + ; defn= operator_paren_sig } in register_functor ([], fdargs, `None) ; if String.Set.mem funs_used_in_reduce_sum fdname then (* Produces the reduce_sum functors that has the pstream argument @@ -347,7 +369,8 @@ let pp_standalone_fun_def namespace_fun ppf , List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"] ) -let is_fun_used_with_variadic_fn variadic_fn_test p = +let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool) + (p : ('a, 'b) Program.t) = let rec find_functors_expr accum Expr.Fixed.{pattern; _} = String.Set.union accum ( match pattern with @@ -361,7 +384,7 @@ let is_fun_used_with_variadic_fn variadic_fn_test p = in Program.fold find_functors_expr find_functors_stmt String.Set.empty p -let collect_functors_functions p = +let collect_functors_functions (p : ('a, 'b) Program.t) = let (functors : (string, found_functor list) Hashtbl.t) = String.Table.create () in let forward_decls = Hash_set.Poly.create () in @@ -385,29 +408,35 @@ let collect_functors_functions p = p.functions_block , functors ) -let pp_functions_functors ppf p = +let pp_functions_functors ppf (p : ('a, 'b) Program.t) = let fns_str, functors = collect_functors_functions p in let pp_functor_decls ppf tbl = Hashtbl.iteri tbl ~f:(fun ~key ~data -> pf ppf "@[%astruct %s {@,%aconst;@]@,};@." (option - (option (fun ppf t -> - pf ppf "template <%a>@ " pp_template_defaults t ) ) ) - (Option.map ~f:(fun x -> x.struct_template) (List.hd data)) + (option (fun ppf template_param -> + pf ppf "template <%a>@ " pp_template_parameter_defaults + template_param ) ) ) + (Option.map + ~f:(fun functor_t -> functor_t.struct_template) + (List.hd data) ) key - (list ~sep:(any "const;@,") (fun ppf (ts, sign) -> - pf ppf "%a@[%a@]" (pp_templates ~defaults:true) ts text sign ) - ) + (list ~sep:(any "const;@,") (fun ppf (template_parameters, sign) -> + pf ppf "%a@[%a@]" + (pp_template ~defaults:true) + template_parameters text sign ) ) (List.map ~f:(fun {arg_templates; signature; _} -> (arg_templates, signature)) data ) ) in - let pp_functors ppf tbl = - Hashtbl.iter tbl ~f:(fun data -> + let pp_functors ppf functor_tbl = + Hashtbl.iter functor_tbl ~f:(fun data -> List.iter data ~f:(fun {struct_template; defn; arg_templates; _} -> pf ppf "%a%a%s@." - (option (fun ppf t -> pf ppf "template <%a>@ " pp_template t)) + (option (fun ppf template_param -> + pf ppf "template <%a>@ " pp_template_parameter template_param ) + ) struct_template - (pp_templates ~defaults:false) + (pp_template ~defaults:false) arg_templates defn ) ) in pf ppf "%a@ %s@ %a" pp_functor_decls functors fns_str pp_functors functors From d968240de35c59ad7c125ca2e73c4079f81e1244 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 22 Mar 2022 15:01:34 -0400 Subject: [PATCH 5/6] remove * from doc comments --- src/stan_math_backend/Function_gen.ml | 38 +++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/stan_math_backend/Function_gen.ml b/src/stan_math_backend/Function_gen.ml index 652942a9b8..2caac966a5 100644 --- a/src/stan_math_backend/Function_gen.ml +++ b/src/stan_math_backend/Function_gen.ml @@ -8,9 +8,9 @@ open Expression_gen open Statement_gen (** - * Typename: The name of a template typename - * Require: One of Stan's C++ template require giving a condition and the template names needing to satisfy that condition. - * Bool: A boolean template type + Typename: The name of a template typename + Require: One of Stan's C++ template require giving a condition and the template names needing to satisfy that condition. + Bool: A boolean template type *) type template_parameter = | Typename of string @@ -29,7 +29,7 @@ let pp_template_parameter_defaults ppf template_parameter = | _ -> pp_template_parameter ppf template_parameter (** - * Pretty print a full C++ `template ` + Pretty print a full C++ `template ` *) let pp_template ~defaults ppf template_parameters = match template_parameters with @@ -54,8 +54,8 @@ let is_data_matrix_or_not_int_type = function | _ -> true (** Print template arguments for C++ functions that need templates - * @param args A pack of [Program.fun_arg_decl] containing functions to detect templates. - * @return A list of arguments with template parameter names added. + @param args A pack of [Program.fun_arg_decl] containing functions to detect templates. + @return A list of arguments with template parameter names added. *) let template_parameter_names (args : Program.fun_arg_decl) = List.mapi args ~f:(fun i arg -> @@ -92,8 +92,8 @@ let%expect_test "arg types templated correctly" = [%expect {| T0__ |}] (** Print the code for promoting stan real types - * @param ppf A pretty printer$ - * @param args A pack of arguments to detect whether they need to use the promotion rules. + @param ppf A pretty printer$ + @param args A pack of arguments to detect whether they need to use the promotion rules. *) let pp_promoted_scalar ppf args = match args with @@ -133,10 +133,10 @@ let pp_eigen_arg_to_ref ppf arg_types = arg_types ) (** Print the type of an object. - * @param ppf A pretty printer - * @param custom_scalar_opt A string representing a types inner scalar value. - * @param name The name of the object - * @param ut The unsized type of the object + @param ppf A pretty printer + @param custom_scalar_opt A string representing a types inner scalar value. + @param name The name of the object + @param ut The unsized type of the object *) let pp_arg ppf (custom_scalar_opt, (_, name, ut)) = let scalar = @@ -161,8 +161,8 @@ let pp_arg_eigen_suffix ppf (custom_scalar_opt, (_, name, ut)) = let typename parameter_name = Typename parameter_name (** Construct an object with it's needed templates for function signatures. - * @param is_possibly_eigen_expr if true, argument can possibly be an unevaluated eigen expression. - * @param fdargs A sexp list of strings representing C++ types. + @param is_possibly_eigen_expr if true, argument can possibly be an unevaluated eigen expression. + @param fdargs A sexp list of strings representing C++ types. *) let templates_and_args (is_possibly_eigen_expr : bool) (fdargs : Program.fun_arg_decl) : @@ -185,7 +185,7 @@ let mk_extra_args templates args = List.map ~f:(fun (t, v) -> t ^ "& " ^ v) (List.zip_exn templates args) (** - * Prints boilerplate at start of function. Body of function wrapped in a `try` block. + Prints boilerplate at start of function. Body of function wrapped in a `try` block. *) let pp_fun_body fdargs fdsuffix ppf (Stmt.Fixed.{pattern; _} as fdbody) = pf ppf "@[using local_scalar_t__ =@ %a;@]@," pp_promoted_scalar fdargs ; @@ -209,7 +209,7 @@ let pp_fun_body fdargs fdsuffix ppf (Stmt.Fixed.{pattern; _} as fdbody) = pf ppf "@ " (** - * Functor to generate a pretty printer for the function signature. + Functor to generate a pretty printer for the function signature. *) let gen_pp_sig fdargs fdrt extra_templates extra ppf (name, args, variadic) = Format.open_vbox 2 ; @@ -229,9 +229,9 @@ let gen_pp_sig fdargs fdrt extra_templates extra ppf (name, args, variadic) = Format.close_box () (** Print the C++ function definition. - * @param ppf A pretty printer - * Refactor this please - one idea might be to have different functions for - * printing user defined distributions vs rngs vs regular functions. + @param ppf A pretty printer + Refactor this please - one idea might be to have different functions for + printing user defined distributions vs rngs vs regular functions. *) let pp_fun_def ppf ( Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} From 5eaca934dfeb875d96093d8ea2118e41cba56569 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 22 Mar 2022 15:03:36 -0400 Subject: [PATCH 6/6] change program type hinting to numbered --- src/stan_math_backend/Function_gen.ml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stan_math_backend/Function_gen.ml b/src/stan_math_backend/Function_gen.ml index 2caac966a5..4fe4204e43 100644 --- a/src/stan_math_backend/Function_gen.ml +++ b/src/stan_math_backend/Function_gen.ml @@ -370,7 +370,7 @@ let pp_standalone_fun_def namespace_fun ppf ) let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool) - (p : ('a, 'b) Program.t) = + (p : Program.Numbered.t) = let rec find_functors_expr accum Expr.Fixed.{pattern; _} = String.Set.union accum ( match pattern with @@ -384,7 +384,7 @@ let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool) in Program.fold find_functors_expr find_functors_stmt String.Set.empty p -let collect_functors_functions (p : ('a, 'b) Program.t) = +let collect_functors_functions (p : Program.Numbered.t) = let (functors : (string, found_functor list) Hashtbl.t) = String.Table.create () in let forward_decls = Hash_set.Poly.create () in @@ -408,7 +408,7 @@ let collect_functors_functions (p : ('a, 'b) Program.t) = p.functions_block , functors ) -let pp_functions_functors ppf (p : ('a, 'b) Program.t) = +let pp_functions_functors ppf (p : Program.Numbered.t) = let fns_str, functors = collect_functors_functions p in let pp_functor_decls ppf tbl = Hashtbl.iteri tbl ~f:(fun ~key ~data ->