diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index b4a7b15a02..eef8319b4a 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -26,11 +26,12 @@ module TypeError = struct * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list * SignatureMismatch.function_mismatch - | IllTypedVariadicODE of + | IllTypedVariadicDE of string * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list * SignatureMismatch.function_mismatch + * UnsizedType.t | ReturningFnExpectedNonReturningFound of string | ReturningFnExpectedNonFnFound of string | ReturningFnExpectedUndeclaredIdentFound of string * string option @@ -125,15 +126,11 @@ module TypeError = struct | IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) -> SignatureMismatch.pp_signature_mismatch ppf (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedVariadicODE (name, arg_tys, args, error) -> + | IllTypedVariadicDE (name, arg_tys, args, error, return_type) -> SignatureMismatch.pp_signature_mismatch ppf ( name , arg_tys - , ( [ ( ( UnsizedType.ReturnType - Stan_math_signatures.variadic_ode_fun_return_type - , args ) - , error ) ] - , false ) ) + , ([((UnsizedType.ReturnType return_type, args), error)], false) ) | NotIndexable (ut, nidcs) -> Fmt.pf ppf "Too many indexes, expression dimensions=%d, indexes found=%d." @@ -518,7 +515,24 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error = ) let illtyped_variadic_ode loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedVariadicODE (name, arg_tys, args, error)) + TypeError + ( loc + , TypeError.IllTypedVariadicDE + ( name + , arg_tys + , args + , error + , Stan_math_signatures.variadic_ode_fun_return_type ) ) + +let illtyped_variadic_dae loc name arg_tys args error = + TypeError + ( loc + , TypeError.IllTypedVariadicDE + ( name + , arg_tys + , args + , error + , Stan_math_signatures.variadic_dae_fun_return_type ) ) let returning_fn_expected_nonfn_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name) diff --git a/src/frontend/Semantic_error.mli b/src/frontend/Semantic_error.mli index f1769c60aa..8d5b6cdb13 100644 --- a/src/frontend/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -66,6 +66,14 @@ val illtyped_variadic_ode : -> SignatureMismatch.function_mismatch -> t +val illtyped_variadic_dae : + Location_span.t + -> string + -> UnsizedType.t list + -> (UnsizedType.autodifftype * UnsizedType.t) list + -> SignatureMismatch.function_mismatch + -> t + val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index eda71a4dd5..20276b12f3 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -124,6 +124,7 @@ let verify_name_fresh_udf loc tenv name = (* variadic functions are currently not in math sigs *) || Stan_math_signatures.is_reduce_sum_fn name || Stan_math_signatures.is_variadic_ode_fn name + || Stan_math_signatures.is_variadic_dae_fn name then Semantic_error.ident_is_stanmath_name loc name |> error else if Utils.is_unnormalized_distribution name then Semantic_error.udf_is_unnormalized_fn loc name |> error @@ -525,11 +526,39 @@ let check_variadic_ode ~is_cond_dist loc id es = expected_args err |> error +let check_variadic_dae ~is_cond_dist loc id es = + let optional_tol_mandatory_args = + if Stan_math_signatures.is_variadic_dae_tol_fn id.name then + Stan_math_signatures.variadic_dae_tol_arg_types + else [] in + let mandatory_arg_types = + Stan_math_signatures.variadic_dae_mandatory_arg_types + @ optional_tol_mandatory_args in + match + SignatureMismatch.check_variadic_args false mandatory_arg_types + Stan_math_signatures.variadic_dae_mandatory_fun_args + Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types es) + with + | Ok promotions -> + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, SignatureMismatch.promote es promotions) ) + ~ad_level:(expr_ad_lub es) + ~type_:Stan_math_signatures.variadic_dae_return_type ~loc + | Error (expected_args, err) -> + Semantic_error.illtyped_variadic_dae loc id.name + (List.map ~f:type_of_expr_typed es) + expected_args err + |> error + let check_fn ~is_cond_dist loc tenv id es = if Stan_math_signatures.is_reduce_sum_fn id.name then check_reduce_sum ~is_cond_dist loc id es else if Stan_math_signatures.is_variadic_ode_fn id.name then check_variadic_ode ~is_cond_dist loc id es + else if Stan_math_signatures.is_variadic_dae_fn id.name then + check_variadic_dae ~is_cond_dist loc id es else check_fn ~is_cond_dist loc tenv id es let rec check_funapp loc cf tenv ~is_cond_dist id tes = diff --git a/src/middle/Stan_math_signatures.ml b/src/middle/Stan_math_signatures.ml index 98f4e23b86..8a9a0bd39d 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/middle/Stan_math_signatures.ml @@ -140,6 +140,23 @@ let variadic_ode_mandatory_fun_args = let variadic_ode_fun_return_type = UnsizedType.UVector let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector +let variadic_dae_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_dae_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) + (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) + (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] + +let variadic_dae_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_dae_fun_return_type = UnsizedType.UVector +let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector + let mk_declarative_sig (fnkinds, name, args, mem_pattern) = let is_glm = String.is_suffix ~suffix:"_glm" name in let sfxes = function @@ -204,6 +221,13 @@ let is_variadic_ode_nonadjoint_tol_fn f = is_variadic_ode_nonadjoint_fn f && String.is_suffix f ~suffix:ode_tolerances_suffix +let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"] +let dae_tolerances_suffix = "_tol" +let is_variadic_dae_fn f = Set.mem variadic_dae_fns f + +let is_variadic_dae_tol_fn f = + is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix + let distributions = [ ( full_lpmf , "beta_binomial" @@ -352,6 +376,7 @@ let stan_math_returntype (name : string) (args : fun_arg list) = match name with | x when is_reduce_sum_fn x -> Some (UnsizedType.ReturnType UReal) | x when is_variadic_ode_fn x -> Some (UnsizedType.ReturnType (UArray UVector)) + | x when is_variadic_dae_fn x -> Some (UnsizedType.ReturnType (UArray UVector)) | _ -> (* Return the least return type in case there are multiple options (due to implicit UInt-UReal conversion), where UInt false | x when is_variadic_ode_fn x -> false + | x when is_variadic_dae_fn x -> false | _ -> ( (* let printer intro s = Set.Poly.iter ~f:(printf intro) s in*) match List.length filteredmatches = 0 with diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index a0c972da77..772bf449b7 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -150,12 +150,15 @@ let map_rect_calls = Int.Table.create () let functor_suffix = "_functor__" let reduce_sum_functor_suffix = "_rsfunctor__" let variadic_ode_functor_suffix = "_odefunctor__" +let variadic_dae_functor_suffix = "_daefunctor__" let functor_suffix_select hof = match hof with | x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix | x when Stan_math_signatures.is_variadic_ode_fn x -> variadic_ode_functor_suffix + | x when Stan_math_signatures.is_variadic_dae_fn x -> + variadic_dae_functor_suffix | _ -> functor_suffix let constraint_to_string = function @@ -399,6 +402,18 @@ and gen_fun_app suffix ppf fname es mem_pattern , f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b :: abs_tol_b :: rel_tol_q :: abs_tol_q :: max_num_steps :: num_checkpoints :: interpolation_polynomial :: solver_f :: solver_b :: msgs :: tl ) + | ( true + , x + , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl ) + when Stan_math_signatures.is_variadic_dae_fn x + && String.is_suffix fname + ~suffix:Stan_math_signatures.dae_tolerances_suffix -> + ( fname + , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps + :: msgs :: tl ) + | true, x, f :: yy0 :: yp0 :: t0 :: ts :: tl + when Stan_math_signatures.is_variadic_dae_fn x -> + (fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl) | ( true , "map_rect" , {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _} diff --git a/src/stan_math_backend/Stan_math_code_gen.ml b/src/stan_math_backend/Stan_math_code_gen.ml index 904ea769c2..2f206de84e 100644 --- a/src/stan_math_backend/Stan_math_code_gen.ml +++ b/src/stan_math_backend/Stan_math_code_gen.ml @@ -200,7 +200,8 @@ let mk_extra_args templates args = printing user defined distributions vs rngs vs regular functions. *) let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} - funs_used_in_reduce_sum funs_used_in_variadic_ode = + funs_used_in_reduce_sum funs_used_in_variadic_ode funs_used_in_variadic_dae + = let extra, extra_templates = match fdsuffix with | Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"]) @@ -239,6 +240,7 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} match variadic with | `ReduceSum -> List.split_n args 3 | `VariadicODE -> List.split_n args 2 + | `VariadicDAE -> List.split_n args 3 | `None -> (args, []) in let arg_strs = args @@ -256,7 +258,8 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} match variadic with | `None -> functor_suffix | `ReduceSum -> reduce_sum_functor_suffix - | `VariadicODE -> variadic_ode_functor_suffix in + | `VariadicODE -> variadic_ode_functor_suffix + | `VariadicDAE -> variadic_dae_functor_suffix in let pp_template_propto ppf () = match (fdsuffix, variadic) with | FnLpdf _, `ReduceSum -> pf ppf "template @ " @@ -287,6 +290,10 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _} (* Produces the variadic ode functors that has the pstream argument as the third and not last argument *) pp_functor ppf ([], fdargs, `VariadicODE) + else if String.Set.mem funs_used_in_variadic_dae fdname then + (* Produces the variadic DAE functors that has the pstream argument + as the fourth and not last argument *) + pp_functor ppf ([], fdargs, `VariadicDAE) (** Creates functions outside the model namespaces which only call the ones inside the namespaces *) @@ -943,6 +950,7 @@ let pp_prog ppf (p : Program.Typed.t) = pp_fun_def ppf fblock (is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p) (is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p) + (is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p) in let reduce_sum_struct_decls = String.Set.map diff --git a/test/integration/bad/variadic_dae/bad_abs_tol.stan b/test/integration/bad/variadic_dae/bad_abs_tol.stan new file mode 100644 index 0000000000..f4454e2849 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_abs_tol.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; + array[2] real a; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, a, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_initial_derivative.stan b/test/integration/bad/variadic_dae/bad_initial_derivative.stan new file mode 100644 index 0000000000..e886780457 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_initial_derivative.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + real yp0; + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} + diff --git a/test/integration/bad/variadic_dae/bad_initial_state.stan b/test/integration/bad/variadic_dae/bad_initial_state.stan new file mode 100644 index 0000000000..852e06c85b --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_initial_state.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + real yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} + diff --git a/test/integration/bad/variadic_dae/bad_initial_time.stan b/test/integration/bad/variadic_dae/bad_initial_time.stan new file mode 100644 index 0000000000..6af0074b6c --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_initial_time.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + vector[2] t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_max_num_steps.stan b/test/integration/bad/variadic_dae/bad_max_num_steps.stan new file mode 100644 index 0000000000..f72dd5b04f --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_max_num_steps.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; + array[2] real a; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, a, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_no_args.stan b/test/integration/bad/variadic_dae/bad_no_args.stan new file mode 100644 index 0000000000..2e5f7bf49c --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_no_args.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, array[] real yp) { + vector[3] res; + array[3] real p; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_non_matching_args.stan b/test/integration/bad/variadic_dae/bad_non_matching_args.stan new file mode 100644 index 0000000000..1d0ee61706 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_non_matching_args.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p, real x_r) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_rel_tol.stan b/test/integration/bad/variadic_dae/bad_rel_tol.stan new file mode 100644 index 0000000000..f5e53b24a5 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_rel_tol.stan @@ -0,0 +1,33 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; + array[2] real a; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, a, 0.01, 100, theta, x); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/bad_times_tol.stan b/test/integration/bad/variadic_dae/bad_times_tol.stan new file mode 100644 index 0000000000..d1e5ef0223 --- /dev/null +++ b/test/integration/bad/variadic_dae/bad_times_tol.stan @@ -0,0 +1,32 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] vector[3] ts; +} +parameters { + array[3] real theta; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta); +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); +} diff --git a/test/integration/bad/variadic_dae/dune b/test/integration/bad/variadic_dae/dune new file mode 100644 index 0000000000..856a7fccef --- /dev/null +++ b/test/integration/bad/variadic_dae/dune @@ -0,0 +1 @@ +(include ../dune) diff --git a/test/integration/bad/variadic_dae/stanc.expected b/test/integration/bad/variadic_dae/stanc.expected new file mode 100644 index 0000000000..b503289ab3 --- /dev/null +++ b/test/integration/bad/variadic_dae/stanc.expected @@ -0,0 +1,174 @@ + $ ../../../../../install/default/bin/stanc bad_abs_tol.stan +Semantic error in 'bad_abs_tol.stan', line 28, column 10 to column 66: + ------------------------------------------------- + 26: transformed parameters { + 27: array[4] vector[3] y_hat; + 28: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, a, 100, theta); + ^ + 29: } + 30: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, array[] real, int, + array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The 7th argument must be real but got array[] real + $ ../../../../../install/default/bin/stanc bad_initial_derivative.stan +Semantic error in 'bad_initial_derivative.stan', line 27, column 10 to column 70: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, real, real, array[] real, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The third argument must be vector but got real + $ ../../../../../install/default/bin/stanc bad_initial_state.stan +Semantic error in 'bad_initial_state.stan', line 27, column 10 to column 70: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, real, vector, real, array[] real, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The second argument must be vector but got real + $ ../../../../../install/default/bin/stanc bad_initial_time.stan +Semantic error in 'bad_initial_time.stan', line 27, column 10 to column 70: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, vector, array[] real, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The fourth argument must be real but got vector + $ ../../../../../install/default/bin/stanc bad_max_num_steps.stan +Semantic error in 'bad_max_num_steps.stan', line 28, column 10 to column 68: + ------------------------------------------------- + 26: transformed parameters { + 27: array[4] vector[3] y_hat; + 28: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, a, theta); + ^ + 29: } + 30: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, real, array[] real, + array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The 8th argument must be int but got array[] real + $ ../../../../../install/default/bin/stanc bad_no_args.stan +Semantic error in 'bad_no_args.stan', line 27, column 10 to column 63: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, real, int) +where F1 = (real, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int) => vector +where F2 = (real, vector, vector) => vector + The first argument must be + (real, vector, vector) => vector + but got + (real, vector, array[] real) => vector + These are not compatible because: + The types for the third argument are incompatible: one is + array[] real + but the other is + vector + $ ../../../../../install/default/bin/stanc bad_non_matching_args.stan +Semantic error in 'bad_non_matching_args.stan', line 27, column 10 to column 77: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, real, real, int, array[] real, + array[] real) +where F1 = (real, vector, vector, array[] real, real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real, real) => vector + The 10th argument must be real but got array[] real + $ ../../../../../install/default/bin/stanc bad_rel_tol.stan +Semantic error in 'bad_rel_tol.stan', line 28, column 10 to column 69: + ------------------------------------------------- + 26: transformed parameters { + 27: array[4] vector[3] y_hat; + 28: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, a, 0.01, 100, theta, x); + ^ + 29: } + 30: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] real, array[] real, real, int, + array[] real, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + Expected 9 arguments but found 10 arguments. + $ ../../../../../install/default/bin/stanc bad_times_tol.stan +Semantic error in 'bad_times_tol.stan', line 27, column 10 to column 69: + ------------------------------------------------- + 25: transformed parameters { + 26: array[4] vector[3] y_hat; + 27: y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta); + ^ + 28: } + 29: model { + ------------------------------------------------- + +Ill-typed arguments supplied to function 'dae_tol': +(, vector, vector, real, array[] vector, real, real, int, array[] real) +where F1 = (real, vector, vector, array[] real) => vector +Available signatures: +(, vector, vector, real, array[] real, data real, data real, data int, + array[] real) => vector + The 5th argument must be array[] real but got array[] vector diff --git a/test/integration/good/dae_good.stan b/test/integration/good/dae_good.stan new file mode 100644 index 0000000000..86124e69b1 --- /dev/null +++ b/test/integration/good/dae_good.stan @@ -0,0 +1,52 @@ +functions { + vector chem_dae(real t, vector yy, vector yp, + array[] real p, array[] real x) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + vector[3] yy0_var; + vector[3] yp0_var; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + { + y_hat = dae(chem_dae, yy0, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0_var, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0_var, yp0_var, t0, ts, 0.01, 0.01, 100, theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0_var, t0, ts, 0.01, 0.01, 100, theta, x); + } +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); // independent normal noise +} + diff --git a/test/integration/good/pretty.expected b/test/integration/good/pretty.expected index 75359b2f98..91e671ff6b 100644 --- a/test/integration/good/pretty.expected +++ b/test/integration/good/pretty.expected @@ -1384,6 +1384,62 @@ model { y ~ normal(0, 1); } + $ ../../../../install/default/bin/stanc --auto-format dae_good.stan +functions { + vector chem_dae(real t, vector yy, vector yp, array[] real p, + array[] real x) { + vector[3] res; + res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3]; + res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + + p[3] * yy[2] * yy[2]; + res[3] = yy[1] + yy[2] + yy[3] - 1.0; + return res; + } +} +data { + vector[3] yy0; + vector[3] yp0; + real t0; + array[1] real x; + array[4] vector[3] y; +} +transformed data { + array[4] real ts; +} +parameters { + array[3] real theta; + vector[3] yy0_var; + vector[3] yp0_var; + real sigma; +} +transformed parameters { + array[4] vector[3] y_hat; + { + y_hat = dae(chem_dae, yy0, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.01, 100, theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0_var, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0_var, yp0_var, t0, ts, 0.01, 0.01, 100, + theta, x); + } + { + y_hat = dae(chem_dae, yy0_var, yp0, t0, ts, theta, x); + } + { + y_hat = dae_tol(chem_dae, yy0, yp0_var, t0, ts, 0.01, 0.01, 100, theta, + x); + } +} +model { + for (t in 1 : 4) + y[t] ~ normal(y_hat[t], sigma); // independent normal noise +} + $ ../../../../install/default/bin/stanc --auto-format declarations.stan data { int a0; diff --git a/test/unit/Stan_math_code_gen_tests.ml b/test/unit/Stan_math_code_gen_tests.ml index 8c3a2fa993..7fb61e6a6d 100644 --- a/test/unit/Stan_math_code_gen_tests.ml +++ b/test/unit/Stan_math_code_gen_tests.ml @@ -8,7 +8,8 @@ let%expect_test "udf" = let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty in + let pp_fun_def_w_rs a b = + pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in { fdrt= None ; fdname= "sars" ; fdsuffix= FnPlain @@ -60,7 +61,8 @@ let%expect_test "udf-expressions" = let with_no_loc stmt = Stmt.Fixed.{pattern= stmt; meta= Locations.no_span_num} in let w e = Expr.{Fixed.pattern= e; meta= Typed.Meta.empty} in - let pp_fun_def_w_rs a b = pp_fun_def a b String.Set.empty String.Set.empty in + let pp_fun_def_w_rs a b = + pp_fun_def a b String.Set.empty String.Set.empty String.Set.empty in { fdrt= Some UMatrix ; fdname= "sars" ; fdsuffix= FnPlain