Skip to content

Commit 6c2bc5e

Browse files
authored
Merge pull request #1092 from stan-dev/add_dae_support
Add DAE signature support
2 parents 989f5ee + 92c86ca commit 6c2bc5e

20 files changed

Lines changed: 690 additions & 12 deletions

src/frontend/Semantic_error.ml

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ module TypeError = struct
2626
* UnsizedType.t list
2727
* (UnsizedType.autodifftype * UnsizedType.t) list
2828
* SignatureMismatch.function_mismatch
29-
| IllTypedVariadicODE of
29+
| IllTypedVariadicDE of
3030
string
3131
* UnsizedType.t list
3232
* (UnsizedType.autodifftype * UnsizedType.t) list
3333
* SignatureMismatch.function_mismatch
34+
* UnsizedType.t
3435
| ReturningFnExpectedNonReturningFound of string
3536
| ReturningFnExpectedNonFnFound of string
3637
| ReturningFnExpectedUndeclaredIdentFound of string * string option
@@ -125,15 +126,11 @@ module TypeError = struct
125126
| IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) ->
126127
SignatureMismatch.pp_signature_mismatch ppf
127128
(name, arg_tys, ([((ReturnType UReal, expected_args), error)], false))
128-
| IllTypedVariadicODE (name, arg_tys, args, error) ->
129+
| IllTypedVariadicDE (name, arg_tys, args, error, return_type) ->
129130
SignatureMismatch.pp_signature_mismatch ppf
130131
( name
131132
, arg_tys
132-
, ( [ ( ( UnsizedType.ReturnType
133-
Stan_math_signatures.variadic_ode_fun_return_type
134-
, args )
135-
, error ) ]
136-
, false ) )
133+
, ([((UnsizedType.ReturnType return_type, args), error)], false) )
137134
| NotIndexable (ut, nidcs) ->
138135
Fmt.pf ppf
139136
"Too many indexes, expression dimensions=%d, indexes found=%d."
@@ -518,7 +515,24 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error =
518515
)
519516

520517
let illtyped_variadic_ode loc name arg_tys args error =
521-
TypeError (loc, TypeError.IllTypedVariadicODE (name, arg_tys, args, error))
518+
TypeError
519+
( loc
520+
, TypeError.IllTypedVariadicDE
521+
( name
522+
, arg_tys
523+
, args
524+
, error
525+
, Stan_math_signatures.variadic_ode_fun_return_type ) )
526+
527+
let illtyped_variadic_dae loc name arg_tys args error =
528+
TypeError
529+
( loc
530+
, TypeError.IllTypedVariadicDE
531+
( name
532+
, arg_tys
533+
, args
534+
, error
535+
, Stan_math_signatures.variadic_dae_fun_return_type ) )
522536

523537
let returning_fn_expected_nonfn_found loc name =
524538
TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name)

src/frontend/Semantic_error.mli

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ val illtyped_variadic_ode :
6666
-> SignatureMismatch.function_mismatch
6767
-> t
6868

69+
val illtyped_variadic_dae :
70+
Location_span.t
71+
-> string
72+
-> UnsizedType.t list
73+
-> (UnsizedType.autodifftype * UnsizedType.t) list
74+
-> SignatureMismatch.function_mismatch
75+
-> t
76+
6977
val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t
7078
val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t
7179

src/frontend/Typechecker.ml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ let verify_name_fresh_udf loc tenv name =
124124
(* variadic functions are currently not in math sigs *)
125125
|| Stan_math_signatures.is_reduce_sum_fn name
126126
|| Stan_math_signatures.is_variadic_ode_fn name
127+
|| Stan_math_signatures.is_variadic_dae_fn name
127128
then Semantic_error.ident_is_stanmath_name loc name |> error
128129
else if Utils.is_unnormalized_distribution name then
129130
Semantic_error.udf_is_unnormalized_fn loc name |> error
@@ -525,11 +526,39 @@ let check_variadic_ode ~is_cond_dist loc id es =
525526
expected_args err
526527
|> error
527528

529+
let check_variadic_dae ~is_cond_dist loc id es =
530+
let optional_tol_mandatory_args =
531+
if Stan_math_signatures.is_variadic_dae_tol_fn id.name then
532+
Stan_math_signatures.variadic_dae_tol_arg_types
533+
else [] in
534+
let mandatory_arg_types =
535+
Stan_math_signatures.variadic_dae_mandatory_arg_types
536+
@ optional_tol_mandatory_args in
537+
match
538+
SignatureMismatch.check_variadic_args false mandatory_arg_types
539+
Stan_math_signatures.variadic_dae_mandatory_fun_args
540+
Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types es)
541+
with
542+
| Ok promotions ->
543+
mk_typed_expression
544+
~expr:
545+
(mk_fun_app ~is_cond_dist
546+
(StanLib FnPlain, id, SignatureMismatch.promote es promotions) )
547+
~ad_level:(expr_ad_lub es)
548+
~type_:Stan_math_signatures.variadic_dae_return_type ~loc
549+
| Error (expected_args, err) ->
550+
Semantic_error.illtyped_variadic_dae loc id.name
551+
(List.map ~f:type_of_expr_typed es)
552+
expected_args err
553+
|> error
554+
528555
let check_fn ~is_cond_dist loc tenv id es =
529556
if Stan_math_signatures.is_reduce_sum_fn id.name then
530557
check_reduce_sum ~is_cond_dist loc id es
531558
else if Stan_math_signatures.is_variadic_ode_fn id.name then
532559
check_variadic_ode ~is_cond_dist loc id es
560+
else if Stan_math_signatures.is_variadic_dae_fn id.name then
561+
check_variadic_dae ~is_cond_dist loc id es
533562
else check_fn ~is_cond_dist loc tenv id es
534563

535564
let rec check_funapp loc cf tenv ~is_cond_dist id tes =

src/middle/Stan_math_signatures.ml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,23 @@ let variadic_ode_mandatory_fun_args =
142142
let variadic_ode_fun_return_type = UnsizedType.UVector
143143
let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector
144144

145+
let variadic_dae_tol_arg_types =
146+
[ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal)
147+
; (DataOnly, UInt) ]
148+
149+
let variadic_dae_mandatory_arg_types =
150+
[ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *)
151+
(UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *)
152+
(AutoDiffable, UReal); (AutoDiffable, UArray UReal) ]
153+
154+
let variadic_dae_mandatory_fun_args =
155+
[ (UnsizedType.AutoDiffable, UnsizedType.UReal)
156+
; (UnsizedType.AutoDiffable, UnsizedType.UVector)
157+
; (UnsizedType.AutoDiffable, UnsizedType.UVector) ]
158+
159+
let variadic_dae_fun_return_type = UnsizedType.UVector
160+
let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector
161+
145162
let mk_declarative_sig (fnkinds, name, args, mem_pattern) =
146163
let is_glm = String.is_suffix ~suffix:"_glm" name in
147164
let sfxes = function
@@ -206,6 +223,13 @@ let is_variadic_ode_nonadjoint_tol_fn f =
206223
is_variadic_ode_nonadjoint_fn f
207224
&& String.is_suffix f ~suffix:ode_tolerances_suffix
208225

226+
let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"]
227+
let dae_tolerances_suffix = "_tol"
228+
let is_variadic_dae_fn f = Set.mem variadic_dae_fns f
229+
230+
let is_variadic_dae_tol_fn f =
231+
is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix
232+
209233
let distributions =
210234
[ ( full_lpmf
211235
, "beta_binomial"
@@ -355,6 +379,7 @@ let stan_math_returntype (name : string) (args : fun_arg list) =
355379
match name with
356380
| x when is_reduce_sum_fn x -> Some (UnsizedType.ReturnType UReal)
357381
| x when is_variadic_ode_fn x -> Some (UnsizedType.ReturnType (UArray UVector))
382+
| x when is_variadic_dae_fn x -> Some (UnsizedType.ReturnType (UArray UVector))
358383
| _ ->
359384
(* Return the least return type in case there are multiple options
360385
(due to implicit UInt-UReal conversion), where UInt<UReal
@@ -520,6 +545,7 @@ let query_stan_math_mem_pattern_support (name : string) (args : fun_arg list) =
520545
match name with
521546
| x when is_reduce_sum_fn x -> false
522547
| x when is_variadic_ode_fn x -> false
548+
| x when is_variadic_dae_fn x -> false
523549
| _ -> (
524550
(* let printer intro s = Set.Poly.iter ~f:(printf intro) s in*)
525551
match List.length filteredmatches = 0 with

src/stan_math_backend/Expression_gen.ml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,15 @@ let map_rect_calls = Int.Table.create ()
150150
let functor_suffix = "_functor__"
151151
let reduce_sum_functor_suffix = "_rsfunctor__"
152152
let variadic_ode_functor_suffix = "_odefunctor__"
153+
let variadic_dae_functor_suffix = "_daefunctor__"
153154

154155
let functor_suffix_select hof =
155156
match hof with
156157
| x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix
157158
| x when Stan_math_signatures.is_variadic_ode_fn x ->
158159
variadic_ode_functor_suffix
160+
| x when Stan_math_signatures.is_variadic_dae_fn x ->
161+
variadic_dae_functor_suffix
159162
| _ -> functor_suffix
160163

161164
let constraint_to_string = function
@@ -399,6 +402,18 @@ and gen_fun_app suffix ppf fname es mem_pattern
399402
, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b :: abs_tol_b
400403
:: rel_tol_q :: abs_tol_q :: max_num_steps :: num_checkpoints
401404
:: interpolation_polynomial :: solver_f :: solver_b :: msgs :: tl )
405+
| ( true
406+
, x
407+
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl )
408+
when Stan_math_signatures.is_variadic_dae_fn x
409+
&& String.is_suffix fname
410+
~suffix:Stan_math_signatures.dae_tolerances_suffix ->
411+
( fname
412+
, f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps
413+
:: msgs :: tl )
414+
| true, x, f :: yy0 :: yp0 :: t0 :: ts :: tl
415+
when Stan_math_signatures.is_variadic_dae_fn x ->
416+
(fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl)
402417
| ( true
403418
, "map_rect"
404419
, {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _}

src/stan_math_backend/Stan_math_code_gen.ml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ let mk_extra_args templates args =
200200
printing user defined distributions vs rngs vs regular functions.
201201
*)
202202
let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
203-
funs_used_in_reduce_sum funs_used_in_variadic_ode =
203+
funs_used_in_reduce_sum funs_used_in_variadic_ode funs_used_in_variadic_dae
204+
=
204205
let extra, extra_templates =
205206
match fdsuffix with
206207
| Fun_kind.FnTarget -> (["lp__"; "lp_accum__"], ["T_lp__"; "T_lp_accum__"])
@@ -239,6 +240,7 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
239240
match variadic with
240241
| `ReduceSum -> List.split_n args 3
241242
| `VariadicODE -> List.split_n args 2
243+
| `VariadicDAE -> List.split_n args 3
242244
| `None -> (args, []) in
243245
let arg_strs =
244246
args
@@ -256,7 +258,8 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
256258
match variadic with
257259
| `None -> functor_suffix
258260
| `ReduceSum -> reduce_sum_functor_suffix
259-
| `VariadicODE -> variadic_ode_functor_suffix in
261+
| `VariadicODE -> variadic_ode_functor_suffix
262+
| `VariadicDAE -> variadic_dae_functor_suffix in
260263
let pp_template_propto ppf () =
261264
match (fdsuffix, variadic) with
262265
| FnLpdf _, `ReduceSum -> pf ppf "template <bool propto__>@ "
@@ -287,6 +290,10 @@ let pp_fun_def ppf Program.{fdrt; fdname; fdsuffix; fdargs; fdbody; _}
287290
(* Produces the variadic ode functors that has the pstream argument
288291
as the third and not last argument *)
289292
pp_functor ppf ([], fdargs, `VariadicODE)
293+
else if String.Set.mem funs_used_in_variadic_dae fdname then
294+
(* Produces the variadic DAE functors that has the pstream argument
295+
as the fourth and not last argument *)
296+
pp_functor ppf ([], fdargs, `VariadicDAE)
290297

291298
(** Creates functions outside the model namespaces which only call the ones
292299
inside the namespaces *)
@@ -943,6 +950,7 @@ let pp_prog ppf (p : Program.Typed.t) =
943950
pp_fun_def ppf fblock
944951
(is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p)
945952
(is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p)
953+
(is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p)
946954
in
947955
let reduce_sum_struct_decls =
948956
String.Set.map
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
functions {
2+
vector chem_dae(real t, vector yy, vector yp,
3+
array[] real p) {
4+
vector[3] res;
5+
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
6+
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
7+
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
8+
return res;
9+
}
10+
}
11+
data {
12+
vector[3] yy0;
13+
vector[3] yp0;
14+
real t0;
15+
array[1] real x;
16+
array[4] vector[3] y;
17+
}
18+
transformed data {
19+
array[4] real ts;
20+
array[2] real a;
21+
}
22+
parameters {
23+
array[3] real theta;
24+
real<lower=0> sigma;
25+
}
26+
transformed parameters {
27+
array[4] vector[3] y_hat;
28+
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, a, 100, theta);
29+
}
30+
model {
31+
for (t in 1 : 4)
32+
y[t] ~ normal(y_hat[t], sigma);
33+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
functions {
2+
vector chem_dae(real t, vector yy, vector yp,
3+
array[] real p) {
4+
vector[3] res;
5+
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
6+
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
7+
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
8+
return res;
9+
}
10+
}
11+
data {
12+
vector[3] yy0;
13+
real t0;
14+
array[1] real x;
15+
array[4] vector[3] y;
16+
}
17+
transformed data {
18+
array[4] real ts;
19+
}
20+
parameters {
21+
real yp0;
22+
array[3] real theta;
23+
real<lower=0> sigma;
24+
}
25+
transformed parameters {
26+
array[4] vector[3] y_hat;
27+
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta);
28+
}
29+
model {
30+
for (t in 1 : 4)
31+
y[t] ~ normal(y_hat[t], sigma);
32+
}
33+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
functions {
2+
vector chem_dae(real t, vector yy, vector yp,
3+
array[] real p) {
4+
vector[3] res;
5+
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
6+
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
7+
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
8+
return res;
9+
}
10+
}
11+
data {
12+
real yy0;
13+
vector[3] yp0;
14+
real t0;
15+
array[1] real x;
16+
array[4] vector[3] y;
17+
}
18+
transformed data {
19+
array[4] real ts;
20+
}
21+
parameters {
22+
array[3] real theta;
23+
real<lower=0> sigma;
24+
}
25+
transformed parameters {
26+
array[4] vector[3] y_hat;
27+
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta);
28+
}
29+
model {
30+
for (t in 1 : 4)
31+
y[t] ~ normal(y_hat[t], sigma);
32+
}
33+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
functions {
2+
vector chem_dae(real t, vector yy, vector yp,
3+
array[] real p) {
4+
vector[3] res;
5+
res[1] = yp[1] + p[1] * yy[1] - p[2] * yy[2] * yy[3];
6+
res[2] = yp[2] - p[1] * yy[1] + p[2] * yy[2] * yy[3] + p[3] * yy[2] * yy[2];
7+
res[3] = yy[1] + yy[2] + yy[3] - 1.0;
8+
return res;
9+
}
10+
}
11+
data {
12+
vector[3] yy0;
13+
vector[3] yp0;
14+
vector[2] t0;
15+
array[1] real x;
16+
array[4] vector[3] y;
17+
}
18+
transformed data {
19+
array[4] real ts;
20+
}
21+
parameters {
22+
array[3] real theta;
23+
real<lower=0> sigma;
24+
}
25+
transformed parameters {
26+
array[4] vector[3] y_hat;
27+
y_hat = dae_tol(chem_dae, yy0, yp0, t0, ts, 0.01, 0.001, 100, theta);
28+
}
29+
model {
30+
for (t in 1 : 4)
31+
y[t] ~ normal(y_hat[t], sigma);
32+
}

0 commit comments

Comments
 (0)