Skip to content

Commit ab791f7

Browse files
authored
Merge pull request #1259 from stan-dev/refactor/encapsulate-variadics
Refactor: encapsulate variadic functions in typechecking
2 parents d333bd5 + f42a2fa commit ab791f7

9 files changed

Lines changed: 208 additions & 236 deletions

File tree

docs/exposing_new_functions.mld

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,15 @@ For example, the following line defines the signature [add(real, matrix) => matr
130130
Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length
131131
list of arguments, are {b NOT} added to this list.
132132

133-
These are instead treated as special cases in the [Typechecker] module in the frontend folder. It
134-
it best to consult an existing example of how these are done before proceeding.
133+
"Nice" variadic functions are added to the hashtable [Stan_math_signatures.stan_math_variadic_signatures].
134+
This is probably sufficient for most variadic functions, e.g. all the ODE solvers and DAE solvers are done
135+
via this method.
136+
[reduce_sum] is not "nice", since it is both variadic and {e polymorphic}, requiring certain arguments to have the same
137+
(but {e not predetermined}) type. Therefore, [reduce_sum] is treated as special case in the [Typechecker]
138+
module in the frontend folder.
139+
140+
Note that higher-order functions also usually require changes to the C++ code generation to work properly.
141+
It is best to consult an existing example of how these are done before proceeding.
135142

136143
{1 Testing}
137144

src/analysis_and_optimization/Memory_patterns.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,8 @@ let query_stan_math_mem_pattern_support (name : string)
113113
(args : (UnsizedType.autodifftype * UnsizedType.t) list) =
114114
let open Stan_math_signatures in
115115
match name with
116+
| x when is_stan_math_variadic_function_name x -> false
116117
| x when is_reduce_sum_fn x -> false
117-
| x when is_variadic_ode_fn x -> false
118-
| x when is_variadic_dae_fn x -> false
119118
| _ ->
120119
let name =
121120
string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name)

src/frontend/Semantic_error.ml

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ module TypeError = struct
2626
* UnsizedType.t list
2727
* (UnsizedType.autodifftype * UnsizedType.t) list
2828
* SignatureMismatch.function_mismatch
29-
| IllTypedVariadicDE of
29+
| IllTypedVariadic of
3030
string
3131
* UnsizedType.t list
3232
* (UnsizedType.autodifftype * UnsizedType.t) list
@@ -131,7 +131,7 @@ module TypeError = struct
131131
| IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) ->
132132
SignatureMismatch.pp_signature_mismatch ppf
133133
(name, arg_tys, ([((ReturnType UReal, expected_args), error)], false))
134-
| IllTypedVariadicDE (name, arg_tys, args, error, return_type) ->
134+
| IllTypedVariadic (name, arg_tys, args, error, return_type) ->
135135
SignatureMismatch.pp_signature_mismatch ppf
136136
( name
137137
, arg_tys
@@ -550,25 +550,8 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error =
550550
, TypeError.IllTypedReduceSumGeneric (name, arg_tys, expected_args, error)
551551
)
552552

553-
let illtyped_variadic_ode loc name arg_tys args error =
554-
TypeError
555-
( loc
556-
, TypeError.IllTypedVariadicDE
557-
( name
558-
, arg_tys
559-
, args
560-
, error
561-
, Stan_math_signatures.variadic_ode_fun_return_type ) )
562-
563-
let illtyped_variadic_dae loc name arg_tys args error =
564-
TypeError
565-
( loc
566-
, TypeError.IllTypedVariadicDE
567-
( name
568-
, arg_tys
569-
, args
570-
, error
571-
, Stan_math_signatures.variadic_dae_fun_return_type ) )
553+
let illtyped_variadic loc name arg_tys args fn_rt error =
554+
TypeError (loc, TypeError.IllTypedVariadic (name, arg_tys, args, error, fn_rt))
572555

573556
let ambiguous_function_promotion loc name arg_tys signatures =
574557
TypeError

src/frontend/Semantic_error.mli

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,6 @@ val illtyped_reduce_sum_generic :
5858
-> SignatureMismatch.function_mismatch
5959
-> t
6060

61-
val illtyped_variadic_ode :
62-
Location_span.t
63-
-> string
64-
-> UnsizedType.t list
65-
-> (UnsizedType.autodifftype * UnsizedType.t) list
66-
-> SignatureMismatch.function_mismatch
67-
-> t
68-
6961
val ambiguous_function_promotion :
7062
Location_span.t
7163
-> string
@@ -74,11 +66,12 @@ val ambiguous_function_promotion :
7466
list
7567
-> t
7668

77-
val illtyped_variadic_dae :
69+
val illtyped_variadic :
7870
Location_span.t
7971
-> string
8072
-> UnsizedType.t list
8173
-> (UnsizedType.autodifftype * UnsizedType.t) list
74+
-> UnsizedType.t
8275
-> SignatureMismatch.function_mismatch
8376
-> t
8477

src/frontend/SignatureMismatch.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ let matching_function env name args =
248248
let matching_stanlib_function =
249249
matching_function Environment.stan_math_environment
250250

251-
let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
251+
let check_variadic_args ~allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
252252
fun_return args =
253253
let minimal_func_type =
254254
UnsizedType.UFun (mandatory_fun_arg_tys, ReturnType fun_return, FnPlain, AoS)

src/frontend/SignatureMismatch.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ val matching_stanlib_function :
6262
*)
6363

6464
val check_variadic_args :
65-
bool
65+
allow_lpdf:bool
6666
-> (UnsizedType.autodifftype * UnsizedType.t) list
6767
-> (UnsizedType.autodifftype * UnsizedType.t) list
6868
-> UnsizedType.t

src/frontend/Typechecker.ml

Lines changed: 44 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ let verify_name_fresh_udf loc tenv name =
122122
(* variadic functions are currently not in math sigs and aren't
123123
overloadable due to their separate typechecking *)
124124
Stan_math_signatures.is_reduce_sum_fn name
125-
|| Stan_math_signatures.is_variadic_ode_fn name
126-
|| Stan_math_signatures.is_variadic_dae_fn name
125+
|| Stan_math_signatures.is_stan_math_variadic_function_name name
127126
then Semantic_error.ident_is_stanmath_name loc name |> error
128127
else if Utils.is_unnormalized_distribution name then
129128
Semantic_error.udf_is_unnormalized_fn loc name |> error
@@ -191,14 +190,13 @@ let match_to_rt_option = function
191190
| _ -> None
192191

193192
let stan_math_return_type name arg_tys =
194-
match name with
195-
| x when Stan_math_signatures.is_reduce_sum_fn x ->
193+
match
194+
Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures name
195+
with
196+
| Some {return_type; _} -> Some (UnsizedType.ReturnType return_type)
197+
| None when Stan_math_signatures.is_reduce_sum_fn name ->
196198
Some (UnsizedType.ReturnType UReal)
197-
| x when Stan_math_signatures.is_variadic_ode_fn x ->
198-
Some (UnsizedType.ReturnType (UArray UVector))
199-
| x when Stan_math_signatures.is_variadic_dae_fn x ->
200-
Some (UnsizedType.ReturnType (UArray UVector))
201-
| _ ->
199+
| None ->
202200
SignatureMismatch.matching_stanlib_function name arg_tys
203201
|> match_to_rt_option
204202

@@ -571,30 +569,25 @@ let make_function_variable cf loc id = function
571569

572570
let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list)
573571
=
574-
if Stan_math_signatures.is_reduce_sum_fn id.name then
572+
if Stan_math_signatures.is_stan_math_variadic_function_name id.name then
573+
check_variadic ~is_cond_dist loc cf tenv id tes
574+
else if Stan_math_signatures.is_reduce_sum_fn id.name then
575575
check_reduce_sum ~is_cond_dist loc cf tenv id tes
576-
else if Stan_math_signatures.is_variadic_ode_fn id.name then
577-
check_variadic_ode ~is_cond_dist loc cf tenv id tes
578-
else if Stan_math_signatures.is_variadic_dae_fn id.name then
579-
check_variadic_dae ~is_cond_dist loc cf tenv id tes
580576
else check_normal_fn ~is_cond_dist loc tenv id tes
581577

578+
(** Reduce sum is a special case, even compared to the other
579+
variadic functions, because it is polymorphic in the type of the
580+
first argument. The first, fourth, and fifth arguments must agree,
581+
which is too complicated to be captured declaratively. *)
582582
and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
583583
let basic_mismatch () =
584584
let mandatory_args =
585585
UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in
586586
let mandatory_fun_args =
587587
UnsizedType.
588588
[(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in
589-
SignatureMismatch.check_variadic_args true mandatory_args mandatory_fun_args
590-
UReal (get_arg_types tes) in
591-
let fail () =
592-
let expected_args, err =
593-
basic_mismatch () |> Result.error |> Option.value_exn in
594-
Semantic_error.illtyped_reduce_sum_generic loc id.name
595-
(List.map ~f:type_of_expr_typed tes)
596-
expected_args err
597-
|> error in
589+
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
590+
mandatory_fun_args UReal (get_arg_types tes) in
598591
let matching remaining_es fn =
599592
match fn with
600593
| Env.
@@ -611,7 +604,7 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
611604
let arg_types =
612605
(calculate_autodifftype cf Functions ftype, ftype)
613606
:: get_arg_types remaining_es in
614-
SignatureMismatch.check_variadic_args true mandatory_args
607+
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
615608
mandatory_fun_args UReal arg_types
616609
| _ -> basic_mismatch () in
617610
match tes with
@@ -633,81 +626,25 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
633626
(List.map ~f:type_of_expr_typed tes)
634627
expected_args err
635628
|> error )
636-
| _ -> fail ()
637-
638-
and check_variadic_ode ~is_cond_dist loc cf tenv id tes =
639-
let optional_tol_mandatory_args =
640-
if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then
641-
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
642-
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then
643-
Stan_math_signatures.variadic_ode_tol_arg_types
644-
else [] in
645-
let mandatory_arg_types =
646-
Stan_math_signatures.variadic_ode_mandatory_arg_types
647-
@ optional_tol_mandatory_args in
648-
let fail () =
649-
let expected_args, err =
650-
SignatureMismatch.check_variadic_args false mandatory_arg_types
651-
Stan_math_signatures.variadic_ode_mandatory_fun_args
652-
Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types tes)
653-
|> Result.error |> Option.value_exn in
654-
Semantic_error.illtyped_variadic_ode loc id.name
655-
(List.map ~f:type_of_expr_typed tes)
656-
expected_args err
657-
|> error in
658-
let matching remaining_es Env.{type_= ftype; _} =
659-
let arg_types =
660-
(calculate_autodifftype cf Functions ftype, ftype)
661-
:: get_arg_types remaining_es in
662-
SignatureMismatch.check_variadic_args false mandatory_arg_types
663-
Stan_math_signatures.variadic_ode_mandatory_fun_args
664-
Stan_math_signatures.variadic_ode_fun_return_type arg_types in
665-
match tes with
666-
| {expr= Variable fname; _} :: remaining_es -> (
667-
match find_matching_first_order_fn tenv (matching remaining_es) fname with
668-
| SignatureMismatch.UniqueMatch (ftype, promotions) ->
669-
let tes = make_function_variable cf loc fname ftype :: remaining_es in
670-
mk_typed_expression
671-
~expr:
672-
(mk_fun_app ~is_cond_dist
673-
(StanLib FnPlain, id, Promotion.promote_list tes promotions) )
674-
~ad_level:(expr_ad_lub tes)
675-
~type_:Stan_math_signatures.variadic_ode_return_type ~loc
676-
| AmbiguousMatch ps ->
677-
Semantic_error.ambiguous_function_promotion loc fname.name None ps
678-
|> error
679-
| SignatureErrors (expected_args, err) ->
680-
Semantic_error.illtyped_variadic_ode loc id.name
681-
(List.map ~f:type_of_expr_typed tes)
682-
expected_args err
683-
|> error )
684-
| _ -> fail ()
685-
686-
and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
687-
let optional_tol_mandatory_args =
688-
if Stan_math_signatures.is_variadic_dae_tol_fn id.name then
689-
Stan_math_signatures.variadic_dae_tol_arg_types
690-
else [] in
691-
let mandatory_arg_types =
692-
Stan_math_signatures.variadic_dae_mandatory_arg_types
693-
@ optional_tol_mandatory_args in
694-
let fail () =
695-
let expected_args, err =
696-
SignatureMismatch.check_variadic_args false mandatory_arg_types
697-
Stan_math_signatures.variadic_dae_mandatory_fun_args
698-
Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types tes)
699-
|> Result.error |> Option.value_exn in
700-
Semantic_error.illtyped_variadic_dae loc id.name
701-
(List.map ~f:type_of_expr_typed tes)
702-
expected_args err
703-
|> error in
629+
| _ ->
630+
let expected_args, err =
631+
basic_mismatch () |> Result.error |> Option.value_exn in
632+
Semantic_error.illtyped_reduce_sum_generic loc id.name
633+
(List.map ~f:type_of_expr_typed tes)
634+
expected_args err
635+
|> error
636+
637+
and check_variadic ~is_cond_dist loc cf tenv id tes =
638+
let Stan_math_signatures.
639+
{control_args; required_fn_args; required_fn_rt; return_type} =
640+
Hashtbl.find_exn Stan_math_signatures.stan_math_variadic_signatures id.name
641+
in
704642
let matching remaining_es Env.{type_= ftype; _} =
705643
let arg_types =
706644
(calculate_autodifftype cf Functions ftype, ftype)
707645
:: get_arg_types remaining_es in
708-
SignatureMismatch.check_variadic_args false mandatory_arg_types
709-
Stan_math_signatures.variadic_dae_mandatory_fun_args
710-
Stan_math_signatures.variadic_dae_fun_return_type arg_types in
646+
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
647+
required_fn_args required_fn_rt arg_types in
711648
match tes with
712649
| {expr= Variable fname; _} :: remaining_es -> (
713650
match find_matching_first_order_fn tenv (matching remaining_es) fname with
@@ -717,17 +654,24 @@ and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
717654
~expr:
718655
(mk_fun_app ~is_cond_dist
719656
(StanLib FnPlain, id, Promotion.promote_list tes promotions) )
720-
~ad_level:(expr_ad_lub tes)
721-
~type_:Stan_math_signatures.variadic_dae_return_type ~loc
657+
~ad_level:(expr_ad_lub tes) ~type_:return_type ~loc
722658
| AmbiguousMatch ps ->
723659
Semantic_error.ambiguous_function_promotion loc fname.name None ps
724660
|> error
725661
| SignatureErrors (expected_args, err) ->
726-
Semantic_error.illtyped_variadic_dae loc id.name
662+
Semantic_error.illtyped_variadic loc id.name
727663
(List.map ~f:type_of_expr_typed tes)
728-
expected_args err
664+
expected_args required_fn_rt err
729665
|> error )
730-
| _ -> fail ()
666+
| _ ->
667+
let expected_args, err =
668+
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
669+
required_fn_args required_fn_rt (get_arg_types tes)
670+
|> Result.error |> Option.value_exn in
671+
Semantic_error.illtyped_variadic loc id.name
672+
(List.map ~f:type_of_expr_typed tes)
673+
expected_args required_fn_rt err
674+
|> error
731675

732676
and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) =
733677
let name_check =

0 commit comments

Comments
 (0)