@@ -234,8 +234,7 @@ let gen_pp_sig fdargs fdrt extra_templates extra ppf (name, args, variadic) =
234234 let args, variadic_args =
235235 match variadic with
236236 | `ReduceSum -> List. split_n args 3
237- | `VariadicODE -> List. split_n args 2
238- | `VariadicDAE -> List. split_n args 3
237+ | `VariadicHOF x -> List. split_n args x
239238 | `None -> (args, [] ) in
240239 let arg_strs =
241240 args
@@ -255,8 +254,7 @@ let pp_fun_def ppf
255254 , (functors : (string, found_functor list) Hashtbl.t )
256255 , (forward_decls : (string * template_parameter list) Hash_set.t )
257256 , (funs_used_in_reduce_sum : String.Set.t )
258- , (funs_used_in_variadic_ode : String.Set.t )
259- , (funs_used_in_variadic_dae : String.Set.t ) ) =
257+ , (variadic_fns : int list String.Map.t ) ) =
260258 let extra, template_extra_params =
261259 match fdsuffix with
262260 | Fun_kind. FnTarget -> ([" lp__" ; " lp_accum__" ], [" T_lp__" ; " T_lp_accum__" ])
@@ -296,8 +294,7 @@ let pp_fun_def ppf
296294 match variadic_fun_type with
297295 | `None -> functor_suffix
298296 | `ReduceSum -> reduce_sum_functor_suffix
299- | `VariadicODE -> variadic_ode_functor_suffix
300- | `VariadicDAE -> variadic_dae_functor_suffix in
297+ | `VariadicHOF x -> variadic_functor_suffix x in
301298 let functor_name = fdname ^ suffix in
302299 let struct_template =
303300 match (fdsuffix, variadic_fun_type) with
@@ -341,14 +338,12 @@ let pp_fun_def ppf
341338 Common.FatalError. fatal_error_msg
342339 [% message
343340 " Ill-formed reduce_sum call!" (fdargs : Program.fun_arg_decl )]
344- else if String.Set. mem funs_used_in_variadic_ode fdname then
345- (* Produces the variadic ode functors that has the pstream argument
346- as the third and not last argument *)
347- register_functor ([] , fdargs, `VariadicODE )
348- else if String.Set. mem funs_used_in_variadic_dae fdname then
349- (* Produces the variadic DAE functors that has the pstream argument
350- as the fourth and not last argument *)
351- register_functor ([] , fdargs, `VariadicDAE )
341+ else if String.Map. mem variadic_fns fdname then
342+ (* Produces the variadic functors that has the pstream argument
343+ as not the last argument. For DAEs this is the 4th, for ODEs the 3rd *)
344+ List. iter
345+ (List. stable_dedup @@ String.Map. find_exn variadic_fns fdname)
346+ ~f: (fun i -> register_functor ([] , fdargs, `VariadicHOF i))
352347
353348let pp_standalone_fun_def namespace_fun ppf
354349 Program. {fdname; fdsuffix; fdargs; fdbody; fdrt; _} =
@@ -386,13 +381,12 @@ let pp_standalone_fun_def namespace_fun ppf
386381 , List. map ~f: (fun (_ , name , _ ) -> name) fdargs @ extra @ [" pstream__" ]
387382 )
388383
389- let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool )
390- (p : Program.Numbered.t ) =
384+ let is_fun_used_with_reduce_sum (p : Program.Numbered.t ) =
391385 let rec find_functors_expr accum Expr.Fixed. {pattern; _} =
392386 String.Set. union accum
393387 ( match pattern with
394388 | FunApp (StanLib (x, FnPlain , _), {pattern= Var f; _} :: _)
395- when variadic_fn_test x ->
389+ when Stan_math_signatures. is_reduce_sum_fn x ->
396390 String.Set. of_list [Utils. stdlib_distribution_name f]
397391 | x -> Expr.Fixed.Pattern. fold find_functors_expr accum x ) in
398392 let rec find_functors_stmt accum stmt =
@@ -401,25 +395,35 @@ let is_fun_used_with_variadic_fn (variadic_fn_test : string -> bool)
401395 in
402396 Program. fold find_functors_expr find_functors_stmt String.Set. empty p
403397
398+ let get_variadic_requirements (p : Program.Numbered.t ) =
399+ let rec find_functors_expr accum Expr.Fixed. {pattern; _} =
400+ match pattern with
401+ | FunApp (StanLib (x , FnPlain, _ ), {pattern = Var f ; _} :: _ ) -> (
402+ match
403+ Hashtbl. find Stan_math_signatures. stan_math_variadic_signatures x
404+ with
405+ | Some {required_fn_args; _} ->
406+ Map. add_multi accum
407+ ~key: (Utils. stdlib_distribution_name f)
408+ ~data: (List. length required_fn_args)
409+ | _ -> Expr.Fixed.Pattern. fold find_functors_expr accum pattern )
410+ | _ -> Expr.Fixed.Pattern. fold find_functors_expr accum pattern in
411+ let rec find_functors_stmt accum stmt =
412+ Stmt.Fixed. (
413+ Pattern. fold find_functors_expr find_functors_stmt accum stmt.pattern)
414+ in
415+ Program. fold find_functors_expr find_functors_stmt String.Map. empty p
416+
404417let collect_functors_functions (p : Program.Numbered.t ) =
405418 let (functors : (string, found_functor list) Hashtbl. t ) =
406419 String.Table. create () in
407420 let forward_decls = Hash_set.Poly. create () in
408- let reduce_sum_fns =
409- is_fun_used_with_variadic_fn Stan_math_signatures. is_reduce_sum_fn p in
410- let variadic_ode_fns =
411- is_fun_used_with_variadic_fn Stan_math_signatures. is_variadic_ode_fn p in
412- let variadic_dae_fns =
413- is_fun_used_with_variadic_fn Stan_math_signatures. is_variadic_dae_fn p in
421+ let reduce_sum_fns = is_fun_used_with_reduce_sum p in
422+ let variadic_fns = get_variadic_requirements p in
414423 let pp_fun_def_with_variadic_fn_list ppf fblock =
415424 (hovbox ~indent: 2 pp_fun_def)
416425 ppf
417- ( fblock
418- , functors
419- , forward_decls
420- , reduce_sum_fns
421- , variadic_ode_fns
422- , variadic_dae_fns ) in
426+ (fblock, functors, forward_decls, reduce_sum_fns, variadic_fns) in
423427 ( str " @[<v>%a@]"
424428 (list ~sep: cut pp_fun_def_with_variadic_fn_list)
425429 p.functions_block
@@ -465,8 +469,7 @@ let pp_fun_def_w_rs a b =
465469 , String.Table. create ()
466470 , Hash_set.Poly. create ()
467471 , String.Set. empty
468- , String.Set. empty
469- , String.Set. empty )
472+ , String.Map. empty )
470473
471474let % expect_test " udf" =
472475 let with_no_loc stmt =
0 commit comments