Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2627,13 +2627,6 @@ let variadic_ode_nonadjoint_fns =

let ode_tolerances_suffix = "_tol"
let is_reduce_sum_fn f = Set.mem reduce_sum_functions f

let is_variadic_ode_fn f =
Set.mem variadic_ode_nonadjoint_fns f || f = variadic_ode_adjoint_fn

let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"]
let dae_tolerances_suffix = "_tol"
let is_variadic_dae_fn f = Set.mem variadic_dae_fns f
let variadic_dae_fun_return_type = UnsizedType.UVector
let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector

Expand Down
11 changes: 0 additions & 11 deletions src/middle/Stan_math_signatures.mli
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,3 @@ val make_assignmentoperator_stan_math_signatures : Operator.t -> signature list
(* reduce_sum helpers *)
val is_reduce_sum_fn : string -> bool
val reduce_sum_slice_types : UnsizedType.t list

(** These are only used in code-gen, typing is done via [stan_math_variadic_signatures] *)

(* variadic ODE helpers *)
val is_variadic_ode_fn : string -> bool
val ode_tolerances_suffix : string
val variadic_ode_adjoint_fn : string

(* variadic DAE helpers *)
val is_variadic_dae_fn : string -> bool
val dae_tolerances_suffix : string
69 changes: 15 additions & 54 deletions src/stan_math_backend/Expression_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,15 @@ let fn_renames =
let map_rect_calls = Int.Table.create ()
let functor_suffix = "_functor__"
let reduce_sum_functor_suffix = "_rsfunctor__"
let variadic_ode_functor_suffix = "_odefunctor__"
let variadic_dae_functor_suffix = "_daefunctor__"
let variadic_functor_suffix x = sprintf "_variadic%d_functor__" x

let functor_suffix_select hof =
match hof with
| x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix
| x when Stan_math_signatures.is_variadic_ode_fn x ->
variadic_ode_functor_suffix
| x when Stan_math_signatures.is_variadic_dae_fn x ->
variadic_dae_functor_suffix
| _ -> functor_suffix
match Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures hof with
| Some {required_fn_args; _} ->
variadic_functor_suffix (List.length required_fn_args)
| None when Stan_math_signatures.is_reduce_sum_fn hof ->
reduce_sum_functor_suffix
| None -> functor_suffix

let constraint_to_string = function
| Transformation.Ordered -> Some "ordered"
Expand Down Expand Up @@ -350,51 +348,14 @@ and gen_functionals fname suffix es mem_pattern =
^ reduce_sum_functor_suffix in
( Fmt.str "%s<%s%s>" fname normalized_dist_functor propto_template
, grainsize :: container :: msgs :: tl )
| x, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl
when Stan_math_signatures.is_variadic_ode_fn x
&& String.is_suffix fname
~suffix:Stan_math_signatures.ode_tolerances_suffix
&& not (Stan_math_signatures.variadic_ode_adjoint_fn = x) ->
( fname
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: msgs
:: tl )
| x, f :: y0 :: t0 :: ts :: tl
when Stan_math_signatures.is_variadic_ode_fn x
&& not (Stan_math_signatures.variadic_ode_adjoint_fn = x) ->
(fname, f :: y0 :: t0 :: ts :: msgs :: tl)
| ( x
, f
:: y0
:: t0
:: ts
:: rel_tol
:: abs_tol
:: rel_tol_b
:: abs_tol_b
:: rel_tol_q
:: abs_tol_q
:: max_num_steps
:: num_checkpoints
:: interpolation_polynomial
:: solver_f :: solver_b :: tl )
when Stan_math_signatures.variadic_ode_adjoint_fn = x ->
( fname
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b
:: abs_tol_b :: rel_tol_q :: abs_tol_q :: max_num_steps
:: num_checkpoints :: interpolation_polynomial :: solver_f
:: solver_b :: msgs :: tl )
| ( x
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl
)
when Stan_math_signatures.is_variadic_dae_fn x
&& String.is_suffix fname
~suffix:Stan_math_signatures.dae_tolerances_suffix ->
( fname
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps
:: msgs :: tl )
| x, f :: yy0 :: yp0 :: t0 :: ts :: tl
when Stan_math_signatures.is_variadic_dae_fn x ->
(fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl)
| _, _
when Stan_math_signatures.is_stan_math_variadic_function_name fname ->
let Stan_math_signatures.{control_args; _} =
Hashtbl.find_exn
Stan_math_signatures.stan_math_variadic_signatures fname in
let hd, tl =
List.split_n converted_es (List.length control_args + 1) in
(fname, hd @ (msgs :: tl))
| ( "map_rect"
, {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _}
:: tl ) ->
Expand Down
65 changes: 34 additions & 31 deletions src/stan_math_backend/Function_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,7 @@ let gen_pp_sig fdargs fdrt extra_templates extra ppf (name, args, variadic) =
let args, variadic_args =
match variadic with
| `ReduceSum -> List.split_n args 3
| `VariadicODE -> List.split_n args 2
| `VariadicDAE -> List.split_n args 3
| `VariadicHOF x -> List.split_n args x
| `None -> (args, []) in
let arg_strs =
args
Expand All @@ -255,8 +254,7 @@ let pp_fun_def ppf
, (functors : (string, found_functor list) Hashtbl.t)
, (forward_decls : (string * template_parameter list) Hash_set.t)
, (funs_used_in_reduce_sum : String.Set.t)
, (funs_used_in_variadic_ode : String.Set.t)
, (funs_used_in_variadic_dae : String.Set.t) ) =
, (variadic_fns : int list String.Map.t) ) =
let extra, template_extra_params =
match fdsuffix with
| Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"])
Expand Down Expand Up @@ -296,8 +294,7 @@ let pp_fun_def ppf
match variadic_fun_type with
| `None -> functor_suffix
| `ReduceSum -> reduce_sum_functor_suffix
| `VariadicODE -> variadic_ode_functor_suffix
| `VariadicDAE -> variadic_dae_functor_suffix in
| `VariadicHOF x -> variadic_functor_suffix x in
let functor_name = fdname ^ suffix in
let struct_template =
match (fdsuffix, variadic_fun_type) with
Expand Down Expand Up @@ -341,14 +338,12 @@ let pp_fun_def ppf
Common.FatalError.fatal_error_msg
[%message
"Ill-formed reduce_sum call!" (fdargs : Program.fun_arg_decl)]
else if String.Set.mem funs_used_in_variadic_ode fdname then
(* Produces the variadic ode functors that has the pstream argument
as the third and not last argument *)
register_functor ([], fdargs, `VariadicODE)
else if String.Set.mem funs_used_in_variadic_dae fdname then
(* Produces the variadic DAE functors that has the pstream argument
as the fourth and not last argument *)
register_functor ([], fdargs, `VariadicDAE)
else if String.Map.mem variadic_fns fdname then
(* Produces the variadic functors that has the pstream argument
as not the last argument. For DAEs this is the 4th, for ODEs the 3rd *)
List.iter
(List.stable_dedup @@ String.Map.find_exn variadic_fns fdname)
~f:(fun i -> register_functor ([], fdargs, `VariadicHOF i))

let pp_standalone_fun_def namespace_fun ppf
Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} =
Expand Down Expand Up @@ -386,13 +381,12 @@ let pp_standalone_fun_def namespace_fun ppf
, List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
)

let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool)
(p : Program.Numbered.t) =
let is_fun_used_with_reduce_sum (p : Program.Numbered.t) =
let rec find_functors_expr accum Expr.Fixed.{pattern; _} =
String.Set.union accum
( match pattern with
| FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _)
when variadic_fn_test x ->
when Stan_math_signatures.is_reduce_sum_fn x ->
String.Set.of_list [Utils.stdlib_distribution_name f]
| x -> Expr.Fixed.Pattern.fold find_functors_expr accum x ) in
let rec find_functors_stmt accum stmt =
Expand All @@ -401,25 +395,35 @@ let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool)
in
Program.fold find_functors_expr find_functors_stmt String.Set.empty p

let get_variadic_requirements (p : Program.Numbered.t) =
let rec find_functors_expr accum Expr.Fixed.{pattern; _} =
match pattern with
| FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _) -> (
match
Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures x
with
| Some {required_fn_args; _} ->
Map.add_multi accum
~key:(Utils.stdlib_distribution_name f)
~data:(List.length required_fn_args)
| _ -> Expr.Fixed.Pattern.fold find_functors_expr accum pattern )
| _ -> Expr.Fixed.Pattern.fold find_functors_expr accum pattern in
let rec find_functors_stmt accum stmt =
Stmt.Fixed.(
Pattern.fold find_functors_expr find_functors_stmt accum stmt.pattern)
in
Program.fold find_functors_expr find_functors_stmt String.Map.empty p

let collect_functors_functions (p : Program.Numbered.t) =
let (functors : (string, found_functor list) Hashtbl.t) =
String.Table.create () in
let forward_decls = Hash_set.Poly.create () in
let reduce_sum_fns =
is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p in
let variadic_ode_fns =
is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p in
let variadic_dae_fns =
is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p in
let reduce_sum_fns = is_fun_used_with_reduce_sum p in
let variadic_fns = get_variadic_requirements p in
let pp_fun_def_with_variadic_fn_list ppf fblock =
(hovbox ~indent:2 pp_fun_def)
ppf
( fblock
, functors
, forward_decls
, reduce_sum_fns
, variadic_ode_fns
, variadic_dae_fns ) in
(fblock, functors, forward_decls, reduce_sum_fns, variadic_fns) in
( str "@[<v>%a@]"
(list ~sep:cut pp_fun_def_with_variadic_fn_list)
p.functions_block
Expand Down Expand Up @@ -465,8 +469,7 @@ let pp_fun_def_w_rs a b =
, String.Table.create ()
, Hash_set.Poly.create ()
, String.Set.empty
, String.Set.empty
, String.Set.empty )
, String.Map.empty )

let%expect_test "udf" =
let with_no_loc stmt =
Expand Down
Loading