Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ module TypeError = struct
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
* SignatureMismatch.function_mismatch
| IllTypedVariadicODE of
| IllTypedVariadicDE of
string
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
* SignatureMismatch.function_mismatch
* UnsizedType.t
| ReturningFnExpectedNonReturningFound of string
| ReturningFnExpectedNonFnFound of string
| ReturningFnExpectedUndeclaredIdentFound of string * string option
Expand Down Expand Up @@ -125,15 +126,11 @@ module TypeError = struct
| IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) ->
SignatureMismatch.pp_signature_mismatch ppf
(name, arg_tys, ([((ReturnType UReal, expected_args), error)], false))
| IllTypedVariadicODE (name, arg_tys, args, error) ->
| IllTypedVariadicDE (name, arg_tys, args, error, return_type) ->
SignatureMismatch.pp_signature_mismatch ppf
( name
, arg_tys
, ( [ ( ( UnsizedType.ReturnType
Stan_math_signatures.variadic_ode_fun_return_type
, args )
, error ) ]
, false ) )
, ([((UnsizedType.ReturnType return_type, args), error)], false) )
| NotIndexable (ut, nidcs) ->
Fmt.pf ppf
"Too many indexes, expression dimensions=%d, indexes found=%d."
Expand Down Expand Up @@ -518,7 +515,24 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error =
)

let illtyped_variadic_ode loc name arg_tys args error =
TypeError (loc, TypeError.IllTypedVariadicODE (name, arg_tys, args, error))
TypeError
( loc
, TypeError.IllTypedVariadicDE
( name
, arg_tys
, args
, error
, Stan_math_signatures.variadic_ode_fun_return_type ) )

let illtyped_variadic_dae loc name arg_tys args error =
TypeError
( loc
, TypeError.IllTypedVariadicDE
( name
, arg_tys
, args
, error
, Stan_math_signatures.variadic_dae_fun_return_type ) )

let returning_fn_expected_nonfn_found loc name =
TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name)
Expand Down
8 changes: 8 additions & 0 deletions src/frontend/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ val illtyped_variadic_ode :
-> SignatureMismatch.function_mismatch
-> t

val illtyped_variadic_dae :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> SignatureMismatch.function_mismatch
-> t

val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t
val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t

Expand Down
29 changes: 29 additions & 0 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ let verify_name_fresh_udf loc tenv name =
(* variadic functions are currently not in math sigs *)
|| Stan_math_signatures.is_reduce_sum_fn name
|| Stan_math_signatures.is_variadic_ode_fn name
|| Stan_math_signatures.is_variadic_dae_fn name
then Semantic_error.ident_is_stanmath_name loc name |> error
else if Utils.is_unnormalized_distribution name then
Semantic_error.udf_is_unnormalized_fn loc name |> error
Expand Down Expand Up @@ -525,11 +526,39 @@ let check_variadic_ode ~is_cond_dist loc id es =
expected_args err
|> error

let check_variadic_dae ~is_cond_dist loc id es =
let optional_tol_mandatory_args =
if Stan_math_signatures.is_variadic_dae_tol_fn id.name then
Stan_math_signatures.variadic_dae_tol_arg_types
else [] in
let mandatory_arg_types =
Stan_math_signatures.variadic_dae_mandatory_arg_types
@ optional_tol_mandatory_args in
match
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_dae_mandatory_fun_args
Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types es)
with
| Ok promotions ->
mk_typed_expression
~expr:
(mk_fun_app ~is_cond_dist
(StanLib FnPlain, id, SignatureMismatch.promote es promotions) )
~ad_level:(expr_ad_lub es)
~type_:Stan_math_signatures.variadic_dae_return_type ~loc
| Error (expected_args, err) ->
Semantic_error.illtyped_variadic_dae loc id.name
(List.map ~f:type_of_expr_typed es)
expected_args err
|> error

let check_fn ~is_cond_dist loc tenv id es =
if Stan_math_signatures.is_reduce_sum_fn id.name then
check_reduce_sum ~is_cond_dist loc id es
else if Stan_math_signatures.is_variadic_ode_fn id.name then
check_variadic_ode ~is_cond_dist loc id es
else if Stan_math_signatures.is_variadic_dae_fn id.name then
check_variadic_dae ~is_cond_dist loc id es
else check_fn ~is_cond_dist loc tenv id es

let rec check_funapp loc cf tenv ~is_cond_dist id tes =
Expand Down
26 changes: 26 additions & 0 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ let variadic_ode_mandatory_fun_args =
let variadic_ode_fun_return_type = UnsizedType.UVector
let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector

let variadic_dae_tol_arg_types =
[ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal)
; (DataOnly, UInt) ]

let variadic_dae_mandatory_arg_types =
[ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *)
(UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *)
(AutoDiffable, UReal); (AutoDiffable, UArray UReal) ]

let variadic_dae_mandatory_fun_args =
[ (UnsizedType.AutoDiffable, UnsizedType.UReal)
; (UnsizedType.AutoDiffable, UnsizedType.UVector)
; (UnsizedType.AutoDiffable, UnsizedType.UVector) ]

let variadic_dae_fun_return_type = UnsizedType.UVector
let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector

let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
let is_glm = String.is_suffix ~suffix:"_glm" name in
let sfxes = function
Expand Down Expand Up @@ -204,6 +221,13 @@ let is_variadic_ode_nonadjoint_tol_fn f =
is_variadic_ode_nonadjoint_fn f
&& String.is_suffix f ~suffix:ode_tolerances_suffix

let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"]
let dae_tolerances_suffix = "_tol"
let is_variadic_dae_fn f = Set.mem variadic_dae_fns f

let is_variadic_dae_tol_fn f =
is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix

let distributions =
[ ( full_lpmf
, "beta_binomial"
Expand Down Expand Up @@ -352,6 +376,7 @@ let stan_math_returntype (name : string) (args : fun_arg list) =
match name with
| x when is_reduce_sum_fn x -> Some (UnsizedType.ReturnType UReal)
| x when is_variadic_ode_fn x -> Some (UnsizedType.ReturnType (UArray UVector))
| x when is_variadic_dae_fn x -> Some (UnsizedType.ReturnType (UArray UVector))
| _ ->
(* Return the least return type in case there are multiple options
(due to implicit UInt-UReal conversion), where UInt<UReal
Expand Down Expand Up @@ -517,6 +542,7 @@ let query_stan_math_mem_pattern_support (name : string) (args : fun_arg list) =
match name with
| x when is_reduce_sum_fn x -> false
| x when is_variadic_ode_fn x -> false
| x when is_variadic_dae_fn x -> false
| _ -> (
(* let printer intro s = Set.Poly.iter ~f:(printf intro) s in*)
match List.length filteredmatches = 0 with
Expand Down
15 changes: 15 additions & 0 deletions src/stan_math_backend/Expression_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,15 @@ let map_rect_calls = Int.Table.create ()
let functor_suffix = "_functor__"
let reduce_sum_functor_suffix = "_rsfunctor__"
let variadic_ode_functor_suffix = "_odefunctor__"
let variadic_dae_functor_suffix = "_daefunctor__"

let functor_suffix_select hof =
match hof with
| x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix
| x when Stan_math_signatures.is_variadic_ode_fn x ->
variadic_ode_functor_suffix
| x when Stan_math_signatures.is_variadic_dae_fn x ->
variadic_dae_functor_suffix
| _ -> functor_suffix

let constraint_to_string = function
Expand Down Expand Up @@ -399,6 +402,18 @@ and gen_fun_app suffix ppf fname es mem_pattern
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b :: abs_tol_b
:: rel_tol_q :: abs_tol_q :: max_num_steps :: num_checkpoints
:: interpolation_polynomial :: solver_f :: solver_b :: msgs :: tl )
| ( true
, x
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl )
when Stan_math_signatures.is_variadic_dae_fn x
&& String.is_suffix fname
~suffix:Stan_math_signatures.dae_tolerances_suffix ->
( fname
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps
:: msgs :: tl )
| true, x, f :: yy0 :: yp0 :: t0 :: ts :: tl
when Stan_math_signatures.is_variadic_dae_fn x ->
(fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl)
| ( true
, "map_rect"
, {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _}
Expand Down
12 changes: 10 additions & 2 deletions src/stan_math_backend/Stan_math_code_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ let mk_extra_args templates args =
printing user defined distributions vs rngs vs regular functions.
*)
let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
funs_used_in_reduce_sum funs_used_in_variadic_ode =
funs_used_in_reduce_sum funs_used_in_variadic_ode funs_used_in_variadic_dae
=
let extra, extra_templates =
match fdsuffix with
| Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"])
Expand Down Expand Up @@ -239,6 +240,7 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
match variadic with
| `ReduceSum -> List.split_n args 3
| `VariadicODE -> List.split_n args 2
| `VariadicDAE -> List.split_n args 3
| `None -> (args, []) in
let arg_strs =
args
Expand All @@ -256,7 +258,8 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
match variadic with
| `None -> functor_suffix
| `ReduceSum -> reduce_sum_functor_suffix
| `VariadicODE -> variadic_ode_functor_suffix in
| `VariadicODE -> variadic_ode_functor_suffix
| `VariadicDAE -> variadic_dae_functor_suffix in
let pp_template_propto ppf () =
match (fdsuffix, variadic) with
| FnLpdf _, `ReduceSum -> pf ppf "template <bool propto__>@ "
Expand Down Expand Up @@ -287,6 +290,10 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
(* Produces the variadic ode functors that has the pstream argument
as the third and not last argument *)
pp_functor ppf ([], fdargs, `VariadicODE)
else if String.Set.mem funs_used_in_variadic_dae fdname then
(* Produces the variadic DAE functors that has the pstream argument
as the fourth and not last argument *)
pp_functor ppf ([], fdargs, `VariadicDAE)

(** Creates functions outside the model namespaces which only call the ones
inside the namespaces *)
Expand Down Expand Up @@ -943,6 +950,7 @@ let pp_prog ppf (p : Program.Typed.t) =
pp_fun_def ppf fblock
(is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p)
(is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p)
(is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p)
in
let reduce_sum_struct_decls =
String.Set.map
Expand Down
33 changes: 33 additions & 0 deletions test/integration/bad/variadic_dae/bad_abs_tol.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
functions {
vector chem_dae(real t, vector yy, vector yp,
array[] real p) {
vector[3] res;
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
return res;
}
}
data {
vector[3] yy0;
vector[3] yp0;
real t0;
array[1] real x;
array[4] vector[3] y;
}
transformed data {
array[4] real ts;
array[2] real a;
}
parameters {
array[3] real theta;
real<lower=0> sigma;
}
transformed parameters {
array[4] vector[3] y_hat;
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, a, 100, theta);
}
model {
for (t in 1 : 4)
y[t] ~ normal(y_hat[t], sigma);
}
33 changes: 33 additions & 0 deletions test/integration/bad/variadic_dae/bad_initial_derivative.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
functions {
vector chem_dae(real t, vector yy, vector yp,
array[] real p) {
vector[3] res;
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
return res;
}
}
data {
vector[3] yy0;
real t0;
array[1] real x;
array[4] vector[3] y;
}
transformed data {
array[4] real ts;
}
parameters {
real yp0;
array[3] real theta;
real<lower=0> sigma;
}
transformed parameters {
array[4] vector[3] y_hat;
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta);
}
model {
for (t in 1 : 4)
y[t] ~ normal(y_hat[t], sigma);
}

33 changes: 33 additions & 0 deletions test/integration/bad/variadic_dae/bad_initial_state.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
functions {
vector chem_dae(real t, vector yy, vector yp,
array[] real p) {
vector[3] res;
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
return res;
}
}
data {
real yy0;
vector[3] yp0;
real t0;
array[1] real x;
array[4] vector[3] y;
}
transformed data {
array[4] real ts;
}
parameters {
array[3] real theta;
real<lower=0> sigma;
}
transformed parameters {
array[4] vector[3] y_hat;
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta);
}
model {
for (t in 1 : 4)
y[t] ~ normal(y_hat[t], sigma);
}

32 changes: 32 additions & 0 deletions test/integration/bad/variadic_dae/bad_initial_time.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
functions {
vector chem_dae(real t, vector yy, vector yp,
array[] real p) {
vector[3] res;
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
return res;
}
}
data {
vector[3] yy0;
vector[3] yp0;
vector[2] t0;
array[1] real x;
array[4] vector[3] y;
}
transformed data {
array[4] real ts;
}
parameters {
array[3] real theta;
real<lower=0> sigma;
}
transformed parameters {
array[4] vector[3] y_hat;
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta);
}
model {
for (t in 1 : 4)
y[t] ~ normal(y_hat[t], sigma);
}
Loading