Skip to content

Commit 7d221fe

Browse files
authored
Merge pull request #1262 from stan-dev/refactor/variadic-codegen-cleanups
Refactor/variadic codegen cleanups
2 parents 51c0851 + 5ea57bf commit 7d221fe

6 files changed

Lines changed: 1749 additions & 1801 deletions

File tree

src/middle/Stan_math_signatures.ml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2627,13 +2627,6 @@ let variadic_ode_nonadjoint_fns =
26272627

26282628
let ode_tolerances_suffix = "_tol"
26292629
let is_reduce_sum_fn f = Set.mem reduce_sum_functions f
2630-
2631-
let is_variadic_ode_fn f =
2632-
Set.mem variadic_ode_nonadjoint_fns f || f = variadic_ode_adjoint_fn
2633-
2634-
let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"]
2635-
let dae_tolerances_suffix = "_tol"
2636-
let is_variadic_dae_fn f = Set.mem variadic_dae_fns f
26372630
let variadic_dae_fun_return_type = UnsizedType.UVector
26382631
let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector
26392632

src/middle/Stan_math_signatures.mli

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,3 @@ val make_assignmentoperator_stan_math_signatures : Operator.t -> signature list
7373
(* reduce_sum helpers *)
7474
val is_reduce_sum_fn : string -> bool
7575
val reduce_sum_slice_types : UnsizedType.t list
76-
77-
(** These are only used in code-gen, typing is done via [stan_math_variadic_signatures] *)
78-
79-
(* variadic ODE helpers *)
80-
val is_variadic_ode_fn : string -> bool
81-
val ode_tolerances_suffix : string
82-
val variadic_ode_adjoint_fn : string
83-
84-
(* variadic DAE helpers *)
85-
val is_variadic_dae_fn : string -> bool
86-
val dae_tolerances_suffix : string

src/stan_math_backend/Expression_gen.ml

Lines changed: 15 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,15 @@ let fn_renames =
135135
let map_rect_calls = Int.Table.create ()
136136
let functor_suffix = "_functor__"
137137
let reduce_sum_functor_suffix = "_rsfunctor__"
138-
let variadic_ode_functor_suffix = "_odefunctor__"
139-
let variadic_dae_functor_suffix = "_daefunctor__"
138+
let variadic_functor_suffix x = sprintf "_variadic%d_functor__" x
140139

141140
let functor_suffix_select hof =
142-
match hof with
143-
| x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix
144-
| x when Stan_math_signatures.is_variadic_ode_fn x ->
145-
variadic_ode_functor_suffix
146-
| x when Stan_math_signatures.is_variadic_dae_fn x ->
147-
variadic_dae_functor_suffix
148-
| _ -> functor_suffix
141+
match Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures hof with
142+
| Some {required_fn_args; _} ->
143+
variadic_functor_suffix (List.length required_fn_args)
144+
| None when Stan_math_signatures.is_reduce_sum_fn hof ->
145+
reduce_sum_functor_suffix
146+
| None -> functor_suffix
149147

150148
let constraint_to_string = function
151149
| Transformation.Ordered -> Some "ordered"
@@ -350,51 +348,14 @@ and gen_functionals fname suffix es mem_pattern =
350348
^ reduce_sum_functor_suffix in
351349
( Fmt.str "%s<%s%s>" fname normalized_dist_functor propto_template
352350
, grainsize :: container :: msgs :: tl )
353-
| x, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl
354-
when Stan_math_signatures.is_variadic_ode_fn x
355-
&& String.is_suffix fname
356-
~suffix:Stan_math_signatures.ode_tolerances_suffix
357-
&& not (Stan_math_signatures.variadic_ode_adjoint_fn = x) ->
358-
( fname
359-
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: msgs
360-
:: tl )
361-
| x, f :: y0 :: t0 :: ts :: tl
362-
when Stan_math_signatures.is_variadic_ode_fn x
363-
&& not (Stan_math_signatures.variadic_ode_adjoint_fn = x) ->
364-
(fname, f :: y0 :: t0 :: ts :: msgs :: tl)
365-
| ( x
366-
, f
367-
:: y0
368-
:: t0
369-
:: ts
370-
:: rel_tol
371-
:: abs_tol
372-
:: rel_tol_b
373-
:: abs_tol_b
374-
:: rel_tol_q
375-
:: abs_tol_q
376-
:: max_num_steps
377-
:: num_checkpoints
378-
:: interpolation_polynomial
379-
:: solver_f :: solver_b :: tl )
380-
when Stan_math_signatures.variadic_ode_adjoint_fn = x ->
381-
( fname
382-
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b
383-
:: abs_tol_b :: rel_tol_q :: abs_tol_q :: max_num_steps
384-
:: num_checkpoints :: interpolation_polynomial :: solver_f
385-
:: solver_b :: msgs :: tl )
386-
| ( x
387-
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl
388-
)
389-
when Stan_math_signatures.is_variadic_dae_fn x
390-
&& String.is_suffix fname
391-
~suffix:Stan_math_signatures.dae_tolerances_suffix ->
392-
( fname
393-
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps
394-
:: msgs :: tl )
395-
| x, f :: yy0 :: yp0 :: t0 :: ts :: tl
396-
when Stan_math_signatures.is_variadic_dae_fn x ->
397-
(fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl)
351+
| _, _
352+
when Stan_math_signatures.is_stan_math_variadic_function_name fname ->
353+
let Stan_math_signatures.{control_args; _} =
354+
Hashtbl.find_exn
355+
Stan_math_signatures.stan_math_variadic_signatures fname in
356+
let hd, tl =
357+
List.split_n converted_es (List.length control_args + 1) in
358+
(fname, hd @ (msgs :: tl))
398359
| ( "map_rect"
399360
, {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _}
400361
:: tl ) ->

src/stan_math_backend/Function_gen.ml

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ let gen_pp_sig fdargs fdrt extra_templates extra ppf (name, args, variadic) =
234234
let args, variadic_args =
235235
match variadic with
236236
| `ReduceSum -> List.split_n args 3
237-
| `VariadicODE -> List.split_n args 2
238-
| `VariadicDAE -> List.split_n args 3
237+
| `VariadicHOF x -> List.split_n args x
239238
| `None -> (args, []) in
240239
let arg_strs =
241240
args
@@ -255,8 +254,7 @@ let pp_fun_def ppf
255254
, (functors : (string, found_functor list) Hashtbl.t)
256255
, (forward_decls : (string * template_parameter list) Hash_set.t)
257256
, (funs_used_in_reduce_sum : String.Set.t)
258-
, (funs_used_in_variadic_ode : String.Set.t)
259-
, (funs_used_in_variadic_dae : String.Set.t) ) =
257+
, (variadic_fns : int list String.Map.t) ) =
260258
let extra, template_extra_params =
261259
match fdsuffix with
262260
| Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"])
@@ -296,8 +294,7 @@ let pp_fun_def ppf
296294
match variadic_fun_type with
297295
| `None -> functor_suffix
298296
| `ReduceSum -> reduce_sum_functor_suffix
299-
| `VariadicODE -> variadic_ode_functor_suffix
300-
| `VariadicDAE -> variadic_dae_functor_suffix in
297+
| `VariadicHOF x -> variadic_functor_suffix x in
301298
let functor_name = fdname ^ suffix in
302299
let struct_template =
303300
match (fdsuffix, variadic_fun_type) with
@@ -341,14 +338,12 @@ let pp_fun_def ppf
341338
Common.FatalError.fatal_error_msg
342339
[%message
343340
"Ill-formed reduce_sum call!" (fdargs : Program.fun_arg_decl)]
344-
else if String.Set.mem funs_used_in_variadic_ode fdname then
345-
(* Produces the variadic ode functors that has the pstream argument
346-
as the third and not last argument *)
347-
register_functor ([], fdargs, `VariadicODE)
348-
else if String.Set.mem funs_used_in_variadic_dae fdname then
349-
(* Produces the variadic DAE functors that has the pstream argument
350-
as the fourth and not last argument *)
351-
register_functor ([], fdargs, `VariadicDAE)
341+
else if String.Map.mem variadic_fns fdname then
342+
(* Produces the variadic functors that has the pstream argument
343+
as not the last argument. For DAEs this is the 4th, for ODEs the 3rd *)
344+
List.iter
345+
(List.stable_dedup @@ String.Map.find_exn variadic_fns fdname)
346+
~f:(fun i -> register_functor ([], fdargs, `VariadicHOF i))
352347

353348
let pp_standalone_fun_def namespace_fun ppf
354349
Program.{fdname; fdsuffix; fdargs; fdbody; fdrt; _} =
@@ -386,13 +381,12 @@ let pp_standalone_fun_def namespace_fun ppf
386381
, List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
387382
)
388383

389-
let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool)
390-
(p : Program.Numbered.t) =
384+
let is_fun_used_with_reduce_sum (p : Program.Numbered.t) =
391385
let rec find_functors_expr accum Expr.Fixed.{pattern; _} =
392386
String.Set.union accum
393387
( match pattern with
394388
| FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _)
395-
when variadic_fn_test x ->
389+
when Stan_math_signatures.is_reduce_sum_fn x ->
396390
String.Set.of_list [Utils.stdlib_distribution_name f]
397391
| x -> Expr.Fixed.Pattern.fold find_functors_expr accum x ) in
398392
let rec find_functors_stmt accum stmt =
@@ -401,25 +395,35 @@ let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool)
401395
in
402396
Program.fold find_functors_expr find_functors_stmt String.Set.empty p
403397

398+
let get_variadic_requirements (p : Program.Numbered.t) =
399+
let rec find_functors_expr accum Expr.Fixed.{pattern; _} =
400+
match pattern with
401+
| FunApp (StanLib (x, FnPlain, _), {pattern= Var f; _} :: _) -> (
402+
match
403+
Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures x
404+
with
405+
| Some {required_fn_args; _} ->
406+
Map.add_multi accum
407+
~key:(Utils.stdlib_distribution_name f)
408+
~data:(List.length required_fn_args)
409+
| _ -> Expr.Fixed.Pattern.fold find_functors_expr accum pattern )
410+
| _ -> Expr.Fixed.Pattern.fold find_functors_expr accum pattern in
411+
let rec find_functors_stmt accum stmt =
412+
Stmt.Fixed.(
413+
Pattern.fold find_functors_expr find_functors_stmt accum stmt.pattern)
414+
in
415+
Program.fold find_functors_expr find_functors_stmt String.Map.empty p
416+
404417
let collect_functors_functions (p : Program.Numbered.t) =
405418
let (functors : (string, found_functor list) Hashtbl.t) =
406419
String.Table.create () in
407420
let forward_decls = Hash_set.Poly.create () in
408-
let reduce_sum_fns =
409-
is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p in
410-
let variadic_ode_fns =
411-
is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p in
412-
let variadic_dae_fns =
413-
is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p in
421+
let reduce_sum_fns = is_fun_used_with_reduce_sum p in
422+
let variadic_fns = get_variadic_requirements p in
414423
let pp_fun_def_with_variadic_fn_list ppf fblock =
415424
(hovbox ~indent:2 pp_fun_def)
416425
ppf
417-
( fblock
418-
, functors
419-
, forward_decls
420-
, reduce_sum_fns
421-
, variadic_ode_fns
422-
, variadic_dae_fns ) in
426+
(fblock, functors, forward_decls, reduce_sum_fns, variadic_fns) in
423427
( str "@[<v>%a@]"
424428
(list ~sep:cut pp_fun_def_with_variadic_fn_list)
425429
p.functions_block
@@ -465,8 +469,7 @@ let pp_fun_def_w_rs a b =
465469
, String.Table.create ()
466470
, Hash_set.Poly.create ()
467471
, String.Set.empty
468-
, String.Set.empty
469-
, String.Set.empty )
472+
, String.Map.empty )
470473

471474
let%expect_test "udf" =
472475
let with_no_loc stmt =

0 commit comments

Comments
 (0)