Skip to content

Commit 65b9c6a

Browse files
authored
Merge pull request #1046 from WardBrian/lpdf_typing
User-defined densities allowed over nested and complex values.
2 parents 88d5f33 + 43b8cb7 commit 65b9c6a

6 files changed

Lines changed: 33 additions & 20 deletions

File tree

src/frontend/Deprecation_analysis.ml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,8 @@ let update_suffix name type_ =
8282
let open String in
8383
if is_suffix ~suffix:"_cdf_log" name then drop_suffix name 8 ^ "_lcdf"
8484
else if is_suffix ~suffix:"_ccdf_log" name then drop_suffix name 9 ^ "_lccdf"
85-
else if Middle.UnsizedType.is_real_type type_ then
86-
drop_suffix name 4 ^ "_lpdf"
87-
else drop_suffix name 4 ^ "_lpmf"
85+
else if Middle.UnsizedType.is_int_type type_ then drop_suffix name 4 ^ "_lpmf"
86+
else drop_suffix name 4 ^ "_lpdf"
8887

8988
let find_udf_log_suffix = function
9089
| { stmt=

src/frontend/Typechecker.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ let context block =
5959
let calculate_autodifftype cf origin ut =
6060
match origin with
6161
| Env.(Param | TParam | Model | Functions)
62-
when not (UnsizedType.contains_int ut || cf.current_block = GQuant) ->
62+
when not (UnsizedType.is_int_type ut || cf.current_block = GQuant) ->
6363
UnsizedType.AutoDiffable
6464
| _ -> DataOnly
6565

@@ -1137,7 +1137,7 @@ and verify_transformed_param_ty loc cf is_global unsized_ty =
11371137
if
11381138
is_global
11391139
&& (cf.current_block = Param || cf.current_block = TParam)
1140-
&& UnsizedType.contains_int unsized_ty
1140+
&& UnsizedType.is_int_type unsized_ty
11411141
then Semantic_error.transformed_params_int loc |> error
11421142

11431143
and check_sizedtype cf tenv sizedty =
@@ -1268,7 +1268,7 @@ and verify_pdf_fundef_first_arg_ty loc id arg_tys =
12681268
if String.is_suffix id.name ~suffix:"_lpdf" then
12691269
let rt = List.hd arg_tys |> Option.map ~f:snd in
12701270
match rt with
1271-
| Some rt when UnsizedType.is_real_type rt -> ()
1271+
| Some rt when not (UnsizedType.is_int_type rt) -> ()
12721272
| _ -> Semantic_error.prob_density_non_real_variate loc rt |> error
12731273

12741274
and verify_pmf_fundef_first_arg_ty loc id arg_tys =

src/middle/UnsizedType.ml

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,9 @@ let rec common_type = function
130130
| _, _ -> None
131131

132132
(* -- Helpers -- *)
133-
let is_real_type = function
134-
| UReal | UVector | URowVector | UMatrix
135-
|UArray UReal
136-
|UArray UVector
137-
|UArray URowVector
138-
|UArray UMatrix ->
139-
true
133+
let rec is_real_type = function
134+
| UReal | UVector | URowVector | UMatrix -> true
135+
| UArray x -> is_real_type x
140136
| _ -> false
141137

142138
let rec is_autodiffable = function
@@ -145,17 +141,15 @@ let rec is_autodiffable = function
145141
| _ -> false
146142

147143
let is_scalar_type = function UReal | UInt -> true | _ -> false
148-
let is_int_type = function UInt | UArray UInt -> true | _ -> false
144+
145+
let rec is_int_type ut =
146+
match ut with UInt -> true | UArray ut -> is_int_type ut | _ -> false
149147

150148
let is_eigen_type ut =
151149
match ut with UVector | URowVector | UMatrix -> true | _ -> false
152150

153151
let is_fun_type = function UFun _ | UMathLibraryFunction -> true | _ -> false
154152

155-
(** Detect if type contains an integer *)
156-
let rec contains_int ut =
157-
match ut with UInt -> true | UArray ut -> contains_int ut | _ -> false
158-
159153
let rec is_indexing_matrix = function
160154
| UArray t, _ :: idcs -> is_indexing_matrix (t, idcs)
161155
| UMatrix, [] -> false

src/stan_math_backend/Stan_math_code_gen.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ let pp_located ppf _ =
5656
(** Detect if argument requires C++ template *)
5757
let arg_needs_template = function
5858
| UnsizedType.DataOnly, _, t -> UnsizedType.is_eigen_type t
59-
| _, _, t when UnsizedType.contains_int t -> false
59+
| _, _, t when UnsizedType.is_int_type t -> false
6060
| _ -> true
6161

6262
(** Print template arguments for C++ functions that need templates
@@ -107,7 +107,7 @@ let pp_promoted_scalar ppf args =
107107
let pp_returntype ppf arg_types rt =
108108
let scalar = str "%a" pp_promoted_scalar arg_types in
109109
match rt with
110-
| Some ut when UnsizedType.contains_int ut ->
110+
| Some ut when UnsizedType.is_int_type ut ->
111111
pf ppf "%a@," pp_unsizedtype_custom_scalar ("int", ut)
112112
| Some ut -> pf ppf "%a@," pp_unsizedtype_custom_scalar (scalar, ut)
113113
| None -> pf ppf "void@,"

test/integration/good/fun-defs-lpdf.stan

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,20 @@ functions {
22
real bar_baz_lpdf(real a, real b) {
33
return a / b;
44
}
5+
real foo_bar_lpdf(array[,,,] real x){
6+
return 1.0;
7+
}
8+
real baz_foo_lpdf(complex z, real a){
9+
return get_imag(z) * a;
10+
}
511
}
612
parameters {
713
real y;
14+
complex z;
15+
array[1,1,1,1] real arr;
816
}
917
model {
1018
y ~ bar_baz(3.2);
19+
z ~ baz_foo(1.5);
20+
arr ~ foo_bar();
1121
}

test/integration/good/pretty.expected

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,12 +3097,22 @@ functions {
30973097
real bar_baz_lpdf(real a, real b) {
30983098
return a / b;
30993099
}
3100+
real foo_bar_lpdf(array[,,,] real x) {
3101+
return 1.0;
3102+
}
3103+
real baz_foo_lpdf(complex z, real a) {
3104+
return get_imag(z) * a;
3105+
}
31003106
}
31013107
parameters {
31023108
real y;
3109+
complex z;
3110+
array[1, 1, 1, 1] real arr;
31033111
}
31043112
model {
31053113
y ~ bar_baz(3.2);
3114+
z ~ baz_foo(1.5);
3115+
arr ~ foo_bar();
31063116
}
31073117

31083118
$ ../../../../install/default/bin/stanc --auto-format fun-return-typ1.stan

0 commit comments

Comments
 (0)