From 49ec69d95f4870767b1cf33c2ddbe69998925018 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 11 Jan 2022 13:16:03 -0800 Subject: [PATCH 1/6] Add DAE signature support & (good/bad) integration tests --- src/frontend/Semantic_error.ml | 17 ++ src/frontend/Semantic_error.mli | 8 + src/frontend/Typechecker.ml | 27 +++ src/middle/Stan_math_signatures.ml | 31 ++++ src/stan_math_backend/Expression_gen.ml | 15 +- src/stan_math_backend/Stan_math_code_gen.ml | 11 +- .../bad/variadic_dae/bad_abs_tol.stan | 33 ++++ .../variadic_dae/bad_initial_derivative.stan | 33 ++++ .../bad/variadic_dae/bad_initial_state.stan | 33 ++++ .../bad/variadic_dae/bad_initial_time.stan | 32 ++++ .../bad/variadic_dae/bad_max_num_steps.stan | 33 ++++ .../bad/variadic_dae/bad_no_args.stan | 32 ++++ .../variadic_dae/bad_non_matching_args.stan | 32 ++++ .../bad/variadic_dae/bad_rel_tol.stan | 33 ++++ .../bad/variadic_dae/bad_times_tol.stan | 32 ++++ test/integration/bad/variadic_dae/dune | 1 + .../bad/variadic_dae/stanc.expected | 174 ++++++++++++++++++ test/integration/good/dae_good.stan | 52 ++++++ test/integration/good/pretty.expected | 56 ++++++ test/unit/Stan_math_code_gen_tests.ml | 4 +- 20 files changed, 683 insertions(+), 6 deletions(-) create mode 100644 test/integration/bad/variadic_dae/bad_abs_tol.stan create mode 100644 test/integration/bad/variadic_dae/bad_initial_derivative.stan create mode 100644 test/integration/bad/variadic_dae/bad_initial_state.stan create mode 100644 test/integration/bad/variadic_dae/bad_initial_time.stan create mode 100644 test/integration/bad/variadic_dae/bad_max_num_steps.stan create mode 100644 test/integration/bad/variadic_dae/bad_no_args.stan create mode 100644 test/integration/bad/variadic_dae/bad_non_matching_args.stan create mode 100644 test/integration/bad/variadic_dae/bad_rel_tol.stan create mode 100644 test/integration/bad/variadic_dae/bad_times_tol.stan create mode 100644 test/integration/bad/variadic_dae/dune create mode 100644 test/integration/bad/variadic_dae/stanc.expected create mode 100644 test/integration/good/dae_good.stan diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index b4a7b15a02..a90784ca50 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -31,6 +31,11 @@ module TypeError = struct * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list * SignatureMismatch.function_mismatch + | IllTypedVariadicDAE of + string + * UnsizedType.t list + * (UnsizedType.autodifftype * UnsizedType.t) list + * SignatureMismatch.function_mismatch | ReturningFnExpectedNonReturningFound of string | ReturningFnExpectedNonFnFound of string | ReturningFnExpectedUndeclaredIdentFound of string * string option @@ -134,6 +139,15 @@ module TypeError = struct , args ) , error ) ] , false ) ) + | IllTypedVariadicDAE (name, arg_tys, args, error) -> + SignatureMismatch.pp_signature_mismatch ppf + ( name + , arg_tys + , ( [ ( ( UnsizedType.ReturnType + Stan_math_signatures.variadic_dae_fun_return_type + , args ) + , error ) ] + , false ) ) | NotIndexable (ut, nidcs) -> Fmt.pf ppf "Too many indexes, expression dimensions=%d, indexes found=%d." @@ -520,6 +534,9 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error = let illtyped_variadic_ode loc name arg_tys args error = TypeError (loc, TypeError.IllTypedVariadicODE (name, arg_tys, args, error)) +let illtyped_variadic_dae loc name arg_tys args error = + TypeError (loc, TypeError.IllTypedVariadicDAE (name, arg_tys, args, error)) + let returning_fn_expected_nonfn_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name) diff --git a/src/frontend/Semantic_error.mli b/src/frontend/Semantic_error.mli index f1769c60aa..8d5b6cdb13 100644 --- a/src/frontend/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -66,6 +66,14 @@ val illtyped_variadic_ode : -> SignatureMismatch.function_mismatch -> t +val illtyped_variadic_dae : + Location_span.t + -> string + -> UnsizedType.t list + -> (UnsizedType.autodifftype * UnsizedType.t) list + -> SignatureMismatch.function_mismatch + -> t + val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index d5df6ee411..83d81f346c 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -124,6 +124,7 @@ let verify_name_fresh_udf loc tenv name = (* variadic functions are currently not in math sigs *) || Stan_math_signatures.is_reduce_sum_fn name || Stan_math_signatures.is_variadic_ode_fn name + || Stan_math_signatures.is_variadic_dae_fn name then Semantic_error.ident_is_stanmath_name loc name |> error else if Utils.is_unnormalized_distribution name then Semantic_error.udf_is_unnormalized_fn loc name |> error @@ -509,11 +510,37 @@ let check_variadic_ode ~is_cond_dist loc id es = expected_args err |> error +let check_variadic_dae ~is_cond_dist loc id es = + let optional_tol_mandatory_args = + if Stan_math_signatures.is_variadic_dae_tol_fn id.name then + Stan_math_signatures.variadic_dae_tol_arg_types + else [] in + let mandatory_arg_types = + Stan_math_signatures.variadic_dae_mandatory_arg_types + @ optional_tol_mandatory_args in + match + SignatureMismatch.check_variadic_args false mandatory_arg_types + Stan_math_signatures.variadic_dae_mandatory_fun_args + Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types es) + with + | None -> + mk_typed_expression + ~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es)) + ~ad_level:(expr_ad_lub es) + ~type_:Stan_math_signatures.variadic_dae_return_type ~loc + | Some (expected_args, err) -> + Semantic_error.illtyped_variadic_dae loc id.name + (List.map ~f:type_of_expr_typed es) + expected_args err + |> error + let check_fn ~is_cond_dist loc tenv id es = if Stan_math_signatures.is_reduce_sum_fn id.name then check_reduce_sum ~is_cond_dist loc id es else if Stan_math_signatures.is_variadic_ode_fn id.name then check_variadic_ode ~is_cond_dist loc id es + else if Stan_math_signatures.is_variadic_dae_fn id.name then + check_variadic_dae ~is_cond_dist loc id es else check_fn ~is_cond_dist loc tenv id es let rec check_funapp loc cf tenv ~is_cond_dist id tes = diff --git a/src/middle/Stan_math_signatures.ml b/src/middle/Stan_math_signatures.ml index 45215c6b4d..0c10020129 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/middle/Stan_math_signatures.ml @@ -140,6 +140,26 @@ let variadic_ode_mandatory_fun_args = let variadic_ode_fun_return_type = UnsizedType.UVector let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector +(* Variadic DAE *) +let variadic_dae_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_dae_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) + (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) + (AutoDiffable, UReal); (* t0 *) + (AutoDiffable, UArray UReal) ] (* ts *) + +let variadic_dae_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_dae_fun_return_type = UnsizedType.UVector +let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector +(* end of Variadic DAE *) + let mk_declarative_sig (fnkinds, name, args, mem_pattern) = let is_glm = String.is_suffix ~suffix:"_glm" name in let sfxes = function @@ -204,6 +224,15 @@ let is_variadic_ode_nonadjoint_tol_fn f = is_variadic_ode_nonadjoint_fn f && String.is_suffix f ~suffix:ode_tolerances_suffix +(* dae *) +let variadic_dae_fns = String.Set.of_list + [ "dae_tol"; "dae" ] +let dae_tolerances_suffix = "_tol" +let is_variadic_dae_fn f = Set.mem variadic_dae_fns f +let is_variadic_dae_tol_fn f = + is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix +(* end of dae *) + let distributions = [ ( full_lpmf , "beta_binomial" @@ -351,6 +380,7 @@ let stan_math_returntype (name : string) (args : fun_arg list) = match name with | x when is_reduce_sum_fn x -> Some (UnsizedType.ReturnType UReal) | x when is_variadic_ode_fn x -> Some (UnsizedType.ReturnType (UArray UVector)) + | x when is_variadic_dae_fn x -> Some (UnsizedType.ReturnType (UArray UVector)) | _ -> if List.length filteredmatches = 0 then None (* Return the least return type in case there are multiple options (due to implicit UInt-UReal conversion), where UInt false | x when is_variadic_ode_fn x -> false + | x when is_variadic_dae_fn x -> false | _ -> ( (* let printer intro s = Set.Poly.iter ~f:(printf intro) s in*) match List.length filteredmatches = 0 with diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index a185b52252..10093cee57 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -141,12 +141,13 @@ let map_rect_calls = Int.Table.create () let functor_suffix = "_functor__" let reduce_sum_functor_suffix = "_rsfunctor__" let variadic_ode_functor_suffix = "_odefunctor__" +let variadic_dae_functor_suffix = "_daefunctor__" let functor_suffix_select hof = match hof with | x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix - | x when Stan_math_signatures.is_variadic_ode_fn x -> - variadic_ode_functor_suffix + | x when Stan_math_signatures.is_variadic_ode_fn x -> variadic_ode_functor_suffix + | x when Stan_math_signatures.is_variadic_dae_fn x -> variadic_dae_functor_suffix | _ -> functor_suffix let constraint_to_string = function @@ -371,6 +372,16 @@ and gen_fun_app suffix ppf fname es mem_pattern = , f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b :: abs_tol_b :: rel_tol_q :: abs_tol_q :: max_num_steps :: num_checkpoints :: interpolation_polynomial :: solver_f :: solver_b :: msgs :: tl ) + | true, x, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl + when Stan_math_signatures.is_variadic_dae_fn x + && String.is_suffix fname + ~suffix:Stan_math_signatures.dae_tolerances_suffix -> + ( fname + , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: msgs :: tl + ) + | true, x, f :: yy0 :: yp0 :: t0 :: ts :: tl + when Stan_math_signatures.is_variadic_dae_fn x -> + (fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl) | ( true , "map_rect" , {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _} diff --git a/src/stan_math_backend/Stan_math_code_gen.ml b/src/stan_math_backend/Stan_math_code_gen.ml index 904ea769c2..7c8260ae90 100644 --- a/src/stan_math_backend/Stan_math_code_gen.ml +++ b/src/stan_math_backend/Stan_math_code_gen.ml @@ -200,7 +200,7 @@ let mk_extra_args templates args = printing user defined distributions vs rngs vs regular functions. *) let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} - funs_used_in_reduce_sum funs_used_in_variadic_ode = + funs_used_in_reduce_sum funs_used_in_variadic_ode funs_used_in_variadic_dae = let extra, extra_templates = match fdsuffix with | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) @@ -239,6 +239,7 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} match variadic with | `ReduceSum -> List.split_n args 3 | `VariadicODE -> List.split_n args 2 + | `VariadicDAE -> List.split_n args 3 | `None -> (args, []) in let arg_strs = args @@ -256,7 +257,8 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} match variadic with | `None -> functor_suffix | `ReduceSum -> reduce_sum_functor_suffix - | `VariadicODE -> variadic_ode_functor_suffix in + | `VariadicODE -> variadic_ode_functor_suffix + | `VariadicDAE -> variadic_dae_functor_suffix in let pp_template_propto ppf () = match (fdsuffix, variadic) with | FnLpdf _, `ReduceSum -> pf ppf "template @ " @@ -287,6 +289,10 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} (* Produces the variadic ode functors that has the pstream argument as the third and not last argument *) pp_functor ppf ([], fdargs, `VariadicODE) + else if String.Set.mem funs_used_in_variadic_dae fdname then + (* Produces the variadic DAE functors that has the pstream argument + as the fourth and not last argument *) + pp_functor ppf ([], fdargs, `VariadicDAE) (** Creates functions outside the model namespaces which only call the ones inside the namespaces *) @@ -943,6 +949,7 @@ let pp_prog ppf (p : Program.Typed.t) = pp_fun_def ppf fblock (is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p) (is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p) + (is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p) in let reduce_sum_struct_decls = String.Set.map diff --git a/test/integration/bad/variadic_dae/bad_abs_tol.stan b/test/integration/bad/variadic_dae/bad_abs_tol.stan new file mode 100644 index 0000000000..f4454e2849 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_abs_tol.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; + array[2] real a; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, a, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_initial_derivative.stan b/test/integration/bad/variadic_dae/bad_initial_derivative.stan new file mode 100644 index 0000000000..e886780457 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_initial_derivative.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + real yp0; + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} + diff --git a/test/integration/bad/variadic_dae/bad_initial_state.stan b/test/integration/bad/variadic_dae/bad_initial_state.stan new file mode 100644 index 0000000000..852e06c85b --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_initial_state.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + real yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} + diff --git a/test/integration/bad/variadic_dae/bad_initial_time.stan b/test/integration/bad/variadic_dae/bad_initial_time.stan new file mode 100644 index 0000000000..6af0074b6c --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_initial_time.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + vector[2] t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_max_num_steps.stan b/test/integration/bad/variadic_dae/bad_max_num_steps.stan new file mode 100644 index 0000000000..f72dd5b04f --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_max_num_steps.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; + array[2] real a; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, a, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_no_args.stan b/test/integration/bad/variadic_dae/bad_no_args.stan new file mode 100644 index 0000000000..2e5f7bf49c --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_no_args.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, array[] real yp) { + vector[3] res; + array[3] real p; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_non_matching_args.stan b/test/integration/bad/variadic_dae/bad_non_matching_args.stan new file mode 100644 index 0000000000..1d0ee61706 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_non_matching_args.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p, real x_r) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_rel_tol.stan b/test/integration/bad/variadic_dae/bad_rel_tol.stan new file mode 100644 index 0000000000..f5e53b24a5 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_rel_tol.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; + array[2] real a; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, a, 0.01, 100, theta, x); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_times_tol.stan b/test/integration/bad/variadic_dae/bad_times_tol.stan new file mode 100644 index 0000000000..d1e5ef0223 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_times_tol.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] vector[3] ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/dune b/test/integration/bad/variadic_dae/dune new file mode 100644 index 0000000000..856a7fccef --- /dev/null +++ b/test/integration/bad/variadic_dae/dune @@ -0,0 +1 @@ +(include ../dune) diff --git a/test/integration/bad/variadic_dae/stanc.expected b/test/integration/bad/variadic_dae/stanc.expected new file mode 100644 index 0000000000..b503289ab3 --- /dev/null +++ b/test/integration/bad/variadic_dae/stanc.expected @@ -0,0 +1,174 @@ + $ ../../../../../install/default/bin/stanc bad_abs_tol.stan +Semantic error in 'bad_abs_tol.stan', line 28, column 10 to column 66: + ------------------------------------------------- + 26: transformed parameters { + 27: array[4] vector[3] y_hat; + 28: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, a, 100, theta); + ^ + 29: } + 30: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, array[] real, int, + array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The 7th argument must be real but got array[] real + $ ../../../../../install/default/bin/stanc bad_initial_derivative.stan +Semantic error in 'bad_initial_derivative.stan', line 27, column 10 to column 70: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, real, real, array[] real, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The third argument must be vector but got real + $ ../../../../../install/default/bin/stanc bad_initial_state.stan +Semantic error in 'bad_initial_state.stan', line 27, column 10 to column 70: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, real, vector, real, array[] real, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The second argument must be vector but got real + $ ../../../../../install/default/bin/stanc bad_initial_time.stan +Semantic error in 'bad_initial_time.stan', line 27, column 10 to column 70: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, vector, array[] real, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The fourth argument must be real but got vector + $ ../../../../../install/default/bin/stanc bad_max_num_steps.stan +Semantic error in 'bad_max_num_steps.stan', line 28, column 10 to column 68: + ------------------------------------------------- + 26: transformed parameters { + 27: array[4] vector[3] y_hat; + 28: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, a, theta); + ^ + 29: } + 30: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, real, array[] real, + array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The 8th argument must be int but got array[] real + $ ../../../../../install/default/bin/stanc bad_no_args.stan +Semantic error in 'bad_no_args.stan', line 27, column 10 to column 63: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, real, int) +where F1 = (real, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int) => vector +where F2 = (real, vector, vector) => vector + The first argument must be + (real, vector, vector) => vector + but got + (real, vector, array[] real) => vector + These are not compatible because: + The types for the third argument are incompatible: one is + array[] real + but the other is + vector + $ ../../../../../install/default/bin/stanc bad_non_matching_args.stan +Semantic error in 'bad_non_matching_args.stan', line 27, column 10 to column 77: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, real, int, array[] real, + array[] real) +where F1 = (real, vector, vector, array[] real, real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real, real) => vector + The 10th argument must be real but got array[] real + $ ../../../../../install/default/bin/stanc bad_rel_tol.stan +Semantic error in 'bad_rel_tol.stan', line 28, column 10 to column 69: + ------------------------------------------------- + 26: transformed parameters { + 27: array[4] vector[3] y_hat; + 28: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, a, 0.01, 100, theta, x); + ^ + 29: } + 30: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, array[] real, real, int, + array[] real, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + Expected 9 arguments but found 10 arguments. + $ ../../../../../install/default/bin/stanc bad_times_tol.stan +Semantic error in 'bad_times_tol.stan', line 27, column 10 to column 69: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] vector, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The 5th argument must be array[] real but got array[] vector diff --git a/test/integration/good/dae_good.stan b/test/integration/good/dae_good.stan new file mode 100644 index 0000000000..86124e69b1 --- /dev/null +++ b/test/integration/good/dae_good.stan @@ -0,0 +1,52 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p, array[] real x) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + vector[3] yy0_var; + vector[3] yp0_var; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + { + y_hat = dae(chem_dae, yy0, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0_var, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0_var, yp0_var, t0, ts, 0.01, 0.01, 100, theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0_var, t0, ts, 0.01, 0.01, 100, theta, x); + } +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); // independent normal noise +} + diff --git a/test/integration/good/pretty.expected b/test/integration/good/pretty.expected index 75359b2f98..91e671ff6b 100644 --- a/test/integration/good/pretty.expected +++ b/test/integration/good/pretty.expected @@ -1384,6 +1384,62 @@ model { y ~ normal(0, 1); } + $ ../../../../install/default/bin/stanc --auto-format dae_good.stan +functions { + vector chem_dae(real t, vector yy, vector yp, array[] real p, + array[] real x) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + vector[3] yy0_var; + vector[3] yp0_var; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + { + y_hat = dae(chem_dae, yy0, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0_var, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0_var, yp0_var, t0, ts, 0.01, 0.01, 100, + theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0_var, t0, ts, 0.01, 0.01, 100, theta, + x); + } +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); // independent normal noise +} + $ ../../../../install/default/bin/stanc --auto-format declarations.stan data { int a0; diff --git a/test/unit/Stan_math_code_gen_tests.ml b/test/unit/Stan_math_code_gen_tests.ml index 8c3a2fa993..bcaac97388 100644 --- a/test/unit/Stan_math_code_gen_tests.ml +++ b/test/unit/Stan_math_code_gen_tests.ml @@ -8,7 +8,7 @@ let%expect_test "udf" = let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty in + let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in { fdrt= None ; fdname= "sars" ; fdsuffix= FnPlain @@ -60,7 +60,7 @@ let%expect_test "udf-expressions" = let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty in + let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in { fdrt= Some UMatrix ; fdname= "sars" ; fdsuffix= FnPlain From b356cb70714ad978ef1d48943029ae184b837fe1 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 11 Jan 2022 15:50:26 -0800 Subject: [PATCH 2/6] formating --- src/middle/Stan_math_signatures.ml | 7 +++---- src/stan_math_backend/Expression_gen.ml | 14 +++++++++----- src/stan_math_backend/Stan_math_code_gen.ml | 3 ++- test/unit/Stan_math_code_gen_tests.ml | 6 ++++-- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/middle/Stan_math_signatures.ml b/src/middle/Stan_math_signatures.ml index 0c10020129..75589df56c 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/middle/Stan_math_signatures.ml @@ -148,8 +148,7 @@ let variadic_dae_tol_arg_types = let variadic_dae_mandatory_arg_types = [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) - (AutoDiffable, UReal); (* t0 *) - (AutoDiffable, UArray UReal) ] (* ts *) + (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] let variadic_dae_mandatory_fun_args = [ (UnsizedType.AutoDiffable, UnsizedType.UReal) @@ -225,10 +224,10 @@ let is_variadic_ode_nonadjoint_tol_fn f = && String.is_suffix f ~suffix:ode_tolerances_suffix (* dae *) -let variadic_dae_fns = String.Set.of_list - [ "dae_tol"; "dae" ] +let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"] let dae_tolerances_suffix = "_tol" let is_variadic_dae_fn f = Set.mem variadic_dae_fns f + let is_variadic_dae_tol_fn f = is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix (* end of dae *) diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index 10093cee57..5e7d4f9e87 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -146,8 +146,10 @@ let variadic_dae_functor_suffix = "_daefunctor__" let functor_suffix_select hof = match hof with | x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix - | x when Stan_math_signatures.is_variadic_ode_fn x -> variadic_ode_functor_suffix - | x when Stan_math_signatures.is_variadic_dae_fn x -> variadic_dae_functor_suffix + | x when Stan_math_signatures.is_variadic_ode_fn x -> + variadic_ode_functor_suffix + | x when Stan_math_signatures.is_variadic_dae_fn x -> + variadic_dae_functor_suffix | _ -> functor_suffix let constraint_to_string = function @@ -372,13 +374,15 @@ and gen_fun_app suffix ppf fname es mem_pattern = , f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b :: abs_tol_b :: rel_tol_q :: abs_tol_q :: max_num_steps :: num_checkpoints :: interpolation_polynomial :: solver_f :: solver_b :: msgs :: tl ) - | true, x, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl + | ( true + , x + , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl ) when Stan_math_signatures.is_variadic_dae_fn x && String.is_suffix fname ~suffix:Stan_math_signatures.dae_tolerances_suffix -> ( fname - , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: msgs :: tl - ) + , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps + :: msgs :: tl ) | true, x, f :: yy0 :: yp0 :: t0 :: ts :: tl when Stan_math_signatures.is_variadic_dae_fn x -> (fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl) diff --git a/src/stan_math_backend/Stan_math_code_gen.ml b/src/stan_math_backend/Stan_math_code_gen.ml index 7c8260ae90..2f206de84e 100644 --- a/src/stan_math_backend/Stan_math_code_gen.ml +++ b/src/stan_math_backend/Stan_math_code_gen.ml @@ -200,7 +200,8 @@ let mk_extra_args templates args = printing user defined distributions vs rngs vs regular functions. *) let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} - funs_used_in_reduce_sum funs_used_in_variadic_ode funs_used_in_variadic_dae = + funs_used_in_reduce_sum funs_used_in_variadic_ode funs_used_in_variadic_dae + = let extra, extra_templates = match fdsuffix with | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) diff --git a/test/unit/Stan_math_code_gen_tests.ml b/test/unit/Stan_math_code_gen_tests.ml index bcaac97388..7fb61e6a6d 100644 --- a/test/unit/Stan_math_code_gen_tests.ml +++ b/test/unit/Stan_math_code_gen_tests.ml @@ -8,7 +8,8 @@ let%expect_test "udf" = let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in + let pp_fun_def_w_rs a b = + pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in { fdrt= None ; fdname= "sars" ; fdsuffix= FnPlain @@ -60,7 +61,8 @@ let%expect_test "udf-expressions" = let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in + let pp_fun_def_w_rs a b = + pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in { fdrt= Some UMatrix ; fdname= "sars" ; fdsuffix= FnPlain From 021a1bbc1e58c2032da3e8233dc547cbfea6b1d6 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 13 Jan 2022 11:41:32 -0800 Subject: [PATCH 3/6] simplify illtyped_variadic checker --- src/frontend/Semantic_error.ml | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index a90784ca50..87b2baa7b5 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -26,16 +26,12 @@ module TypeError = struct * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list * SignatureMismatch.function_mismatch - | IllTypedVariadicODE of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedVariadicDAE of + | IllTypedVariadicDE of string * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list * SignatureMismatch.function_mismatch + * UnsizedType.t | ReturningFnExpectedNonReturningFound of string | ReturningFnExpectedNonFnFound of string | ReturningFnExpectedUndeclaredIdentFound of string * string option @@ -130,21 +126,11 @@ module TypeError = struct | IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) -> SignatureMismatch.pp_signature_mismatch ppf (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedVariadicODE (name, arg_tys, args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - ( name - , arg_tys - , ( [ ( ( UnsizedType.ReturnType - Stan_math_signatures.variadic_ode_fun_return_type - , args ) - , error ) ] - , false ) ) - | IllTypedVariadicDAE (name, arg_tys, args, error) -> + | IllTypedVariadicDE (name, arg_tys, args, error, return_type) -> SignatureMismatch.pp_signature_mismatch ppf ( name , arg_tys - , ( [ ( ( UnsizedType.ReturnType - Stan_math_signatures.variadic_dae_fun_return_type + , ( [ ( ( UnsizedType.ReturnType return_type , args ) , error ) ] , false ) ) @@ -532,10 +518,12 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error = ) let illtyped_variadic_ode loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedVariadicODE (name, arg_tys, args, error)) + TypeError (loc, TypeError.IllTypedVariadicDE (name, arg_tys, args, + error, Stan_math_signatures.variadic_ode_fun_return_type)) let illtyped_variadic_dae loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedVariadicDAE (name, arg_tys, args, error)) + TypeError (loc, TypeError.IllTypedVariadicDE (name, arg_tys, args, + error, Stan_math_signatures.variadic_dae_fun_return_type)) let returning_fn_expected_nonfn_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name) From 2f04bd7c1bf742e3b2fc9d73634b2a37596e2103 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 13 Jan 2022 12:28:08 -0800 Subject: [PATCH 4/6] remove comments --- src/middle/Stan_math_signatures.ml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/middle/Stan_math_signatures.ml b/src/middle/Stan_math_signatures.ml index 75589df56c..e65f374585 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/middle/Stan_math_signatures.ml @@ -140,7 +140,6 @@ let variadic_ode_mandatory_fun_args = let variadic_ode_fun_return_type = UnsizedType.UVector let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector -(* Variadic DAE *) let variadic_dae_tol_arg_types = [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) ; (DataOnly, UInt) ] @@ -157,7 +156,6 @@ let variadic_dae_mandatory_fun_args = let variadic_dae_fun_return_type = UnsizedType.UVector let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector -(* end of Variadic DAE *) let mk_declarative_sig (fnkinds, name, args, mem_pattern) = let is_glm = String.is_suffix ~suffix:"_glm" name in @@ -223,14 +221,12 @@ let is_variadic_ode_nonadjoint_tol_fn f = is_variadic_ode_nonadjoint_fn f && String.is_suffix f ~suffix:ode_tolerances_suffix -(* dae *) let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"] let dae_tolerances_suffix = "_tol" let is_variadic_dae_fn f = Set.mem variadic_dae_fns f let is_variadic_dae_tol_fn f = is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix -(* end of dae *) let distributions = [ ( full_lpmf From cb5821aba5219ac3999c9ba9314e7bdc086a2c8d Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Fri, 14 Jan 2022 12:29:03 -0800 Subject: [PATCH 5/6] update to new pattern match of variadic checks --- src/frontend/Typechecker.ml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index 20bd3ff9cb..20276b12f3 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -539,12 +539,14 @@ let check_variadic_dae ~is_cond_dist loc id es = Stan_math_signatures.variadic_dae_mandatory_fun_args Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types es) with - | None -> + | Ok promotions -> mk_typed_expression - ~expr:(mk_fun_app ~is_cond_dist (StanLib FnPlain, id, es)) + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, SignatureMismatch.promote es promotions) ) ~ad_level:(expr_ad_lub es) ~type_:Stan_math_signatures.variadic_dae_return_type ~loc - | Some (expected_args, err) -> + | Error (expected_args, err) -> Semantic_error.illtyped_variadic_dae loc id.name (List.map ~f:type_of_expr_typed es) expected_args err From 92c86ca1221fa2b9afaae5d502fd509759d34296 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Fri, 14 Jan 2022 12:44:33 -0800 Subject: [PATCH 6/6] format --- src/frontend/Semantic_error.ml | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index 87b2baa7b5..eef8319b4a 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -130,10 +130,7 @@ module TypeError = struct SignatureMismatch.pp_signature_mismatch ppf ( name , arg_tys - , ( [ ( ( UnsizedType.ReturnType return_type - , args ) - , error ) ] - , false ) ) + , ([((UnsizedType.ReturnType return_type, args), error)], false) ) | NotIndexable (ut, nidcs) -> Fmt.pf ppf "Too many indexes, expression dimensions=%d, indexes found=%d." @@ -518,12 +515,24 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error = ) let illtyped_variadic_ode loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedVariadicDE (name, arg_tys, args, - error, Stan_math_signatures.variadic_ode_fun_return_type)) + TypeError + ( loc + , TypeError.IllTypedVariadicDE + ( name + , arg_tys + , args + , error + , Stan_math_signatures.variadic_ode_fun_return_type ) ) let illtyped_variadic_dae loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedVariadicDE (name, arg_tys, args, - error, Stan_math_signatures.variadic_dae_fun_return_type)) + TypeError + ( loc + , TypeError.IllTypedVariadicDE + ( name + , arg_tys + , args + , error + , Stan_math_signatures.variadic_dae_fun_return_type ) ) let returning_fn_expected_nonfn_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name)