Skip to content

Commit 62af705

Browse files
authored
Merge pull request #1599 from stan-dev/laplace-signature-tweaks
Update laplace signatures to have hessian_block_size before covariance function
2 parents 6014db9 + 5ec9488 commit 62af705

59 files changed

Lines changed: 689 additions & 417 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/frontend/Semantic_error.ml

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ module TypeError = struct
4949
| IllTypedLaplaceMarginal of string * bool * UnsizedType.argumentlist
5050
| LaplaceCompatibilityIssue of string
5151
| IlltypedLaplaceTooMany of string * int
52+
| IlltypedLaplaceHessianBlockSize of
53+
string * (UnsizedType.autodifftype * UnsizedType.t) option
5254
| IlltypedLaplaceTolArgs of string * SignatureMismatch.function_mismatch
5355
| AmbiguousFunctionPromotion of
5456
string
@@ -80,14 +82,38 @@ module TypeError = struct
8082
| 1 -> "first element of the control parameter tuple (initial guess)"
8183
| 2 -> "second element of the control parameter tuple (tolerance)"
8284
| 3 -> "third element of the control parameter tuple (max_num_steps)"
83-
| 4 -> "fourth element of the control parameter tuple (hessian_block_size)"
84-
| 5 -> "fifth element of the control parameter tuple (solver)"
85-
| 6 ->
86-
"sixth element of the control parameter tuple (max_steps_line_search)"
87-
| 7 -> "seventh element of the control parameter tuple (allow_fallthrough)"
85+
| 4 -> "fourth element of the control parameter tuple (solver)"
86+
| 5 ->
87+
"fifth element of the control parameter tuple (max_steps_line_search)"
88+
| 6 -> "sixth element of the control parameter tuple (allow_fallthrough)"
8889
| n ->
8990
Fmt.str "%a element of the control parameter tuple" (Fmt.ordinal ()) n
9091

92+
let generic_laplace_usage info ppf (name, supplied) =
93+
let req = Stan_math_signatures.laplace_helper_param_types name in
94+
let is_helper = not @@ List.is_empty req in
95+
let pp_lik_args ppf =
96+
if is_helper then Fmt.(list ~sep:comma UnsizedType.pp_fun_arg) ppf req
97+
else Fmt.pf ppf "(vector, T_l%t) => real,@ tuple(T_l%t)" ellipsis ellipsis
98+
in
99+
let pp_laplace_tols ppf =
100+
if String.is_substring ~substring:"_tol" name then
101+
Fmt.pf ppf ", %a"
102+
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
103+
Stan_math_signatures.laplace_tolerance_argument_types in
104+
let pp_supplied_tys ppf =
105+
if List.is_empty supplied then Fmt.nop ppf ()
106+
else
107+
Fmt.pf ppf "@ However, we received the types:@ @[<hov 2>(%a)@]"
108+
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
109+
supplied in
110+
Fmt.pf ppf
111+
"@[<v>Ill-typed arguments supplied to function %a.@ The valid signature \
112+
of this function is@ @[<hov 2>%s(%t,@ data int,@ (T_k%t) => matrix,@ \
113+
tuple(T_k%t)%t)@]%t@ @[%a@]@]"
114+
quoted name name pp_lik_args ellipsis ellipsis pp_laplace_tols
115+
pp_supplied_tys info ()
116+
91117
let rec expected_types : UnsizedType.t Common.Nonempty_list.t Fmt.t =
92118
let ust = expected_style UnsizedType.pp in
93119
fun ppf l ->
@@ -193,39 +219,24 @@ module TypeError = struct
193219
details
194220
Fmt.(list ~sep:comma (expected_style UnsizedType.pp_fun_arg))
195221
expected
196-
| IllTypedLaplaceMarginal (name, early, supplied) ->
197-
let req = Stan_math_signatures.laplace_helper_param_types name in
198-
let is_helper = not @@ List.is_empty req in
199-
let info =
200-
if early then
222+
| IllTypedLaplaceMarginal (name, true, supplied) ->
223+
let info ppf () =
224+
Fmt.text ppf
201225
"We were unable to start more in-depth checking. Please ensure you \
202226
are passing enough arguments and that the first argument is a \
203-
function."
204-
else
205-
let n = if is_helper then List.length req else 2 in
206-
Fmt.str
207-
"Typechecking failed after checking the first %d arguments. \
208-
Please ensure you are passing enough arguments and that the %a \
209-
is a function."
210-
n (Fmt.ordinal ()) (n + 1) in
211-
let pp_lik_args ppf =
212-
if is_helper then Fmt.(list ~sep:comma UnsizedType.pp_fun_arg) ppf req
213-
else
214-
Fmt.pf ppf "(vector, T_l%t) => real,@ tuple(T_l%t)" ellipsis
215-
ellipsis in
216-
let pp_laplace_tols ppf =
217-
if String.is_substring ~substring:"_tol" name then
218-
Fmt.pf ppf ", %a"
219-
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
220-
Stan_math_signatures.laplace_tolerance_argument_types in
221-
Fmt.pf ppf
222-
"@[<v>Ill-typed arguments supplied to function %a.@ The valid \
223-
signature of this function is@ @[<hov 2>%s(%t,@ vector,@ (T_k%t) => \
224-
matrix,@ tuple(T_k%t)%t)@]@ However, we received the types:@ @[<hov \
225-
2>(%a)@]@ @[%a@]@]"
226-
quoted name name pp_lik_args ellipsis ellipsis pp_laplace_tols
227-
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
228-
supplied Fmt.text info
227+
function." in
228+
generic_laplace_usage info ppf (name, supplied)
229+
| IllTypedLaplaceMarginal (name, false, supplied) ->
230+
let req = Stan_math_signatures.laplace_helper_param_types name in
231+
let is_helper = not @@ List.is_empty req in
232+
let info ppf () =
233+
let n = (if is_helper then List.length req else 2) + 1 in
234+
Fmt.pf ppf
235+
"Typechecking failed after checking the first %d arguments.@ \
236+
Please ensure you are passing enough arguments and that the %a is \
237+
a function."
238+
n (Fmt.ordinal ()) (n + 1) in
239+
generic_laplace_usage info ppf (name, supplied)
229240
| LaplaceCompatibilityIssue banned_function ->
230241
Fmt.pf ppf
231242
"The function %a, called by this likelihood function,@ does not \
@@ -239,6 +250,28 @@ module TypeError = struct
239250
"Only a single tuple of control parameters is expected."
240251
else if n_args = 1 then "Did you mean to call the _tol version?"
241252
else "Did you mean to call the _tol version with a tuple of these?")
253+
| IlltypedLaplaceHessianBlockSize (name, None) ->
254+
let info ppf () =
255+
Fmt.pf ppf
256+
"@[<hov>Missing the hessian block size (data-only %a) and \
257+
remaining arguments.@]"
258+
(expected_style UnsizedType.pp)
259+
UInt in
260+
generic_laplace_usage info ppf (name, [])
261+
| IlltypedLaplaceHessianBlockSize (name, Some (DataOnly, ty)) ->
262+
Fmt.pf ppf
263+
"@[<hov>The hessian block size argument to %a must be a data-only \
264+
%a.%a@]"
265+
quoted name
266+
(expected_style UnsizedType.pp)
267+
UInt found_type ty
268+
| IlltypedLaplaceHessianBlockSize (name, Some (_, ty)) ->
269+
Fmt.pf ppf
270+
"@[<hov>The hessian block size argument to %a must be a data-only \
271+
%a.%a@ %a@]"
272+
quoted name
273+
(expected_style UnsizedType.pp)
274+
UInt found_type ty SignatureMismatch.data_only_msg ()
242275
| IlltypedLaplaceTolArgs (name, ArgNumMismatch (_, 0)) ->
243276
Fmt.pf ppf
244277
"Missing control parameter tuple at the end of the call to %a.@ \
@@ -777,6 +810,9 @@ let laplace_compatibility loc banned_function =
777810
let illtyped_laplace_extra_args loc name args =
778811
(loc, TypeError (TypeError.IlltypedLaplaceTooMany (name, args)))
779812

813+
let illtyped_laplace_hessian_block_size_arg loc name arg_ty =
814+
(loc, TypeError (TypeError.IlltypedLaplaceHessianBlockSize (name, arg_ty)))
815+
780816
let illtyped_laplace_tolerance_args loc name mismatch =
781817
(loc, TypeError (TypeError.IlltypedLaplaceTolArgs (name, mismatch)))
782818

src/frontend/Semantic_error.mli

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ val illtyped_laplace_generic :
9595
val laplace_compatibility : Location_span.t -> string -> t
9696
val illtyped_laplace_extra_args : Location_span.t -> string -> int -> t
9797

98+
val illtyped_laplace_hessian_block_size_arg :
99+
Location_span.t
100+
-> string
101+
-> (UnsizedType.autodifftype * UnsizedType.t) option
102+
-> t
103+
98104
val illtyped_laplace_tolerance_args :
99105
Location_span.t -> string -> SignatureMismatch.function_mismatch -> t
100106

src/frontend/Typechecker.ml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,23 @@ and check_laplace_fn ~is_cond_dist loc cf tenv id tes =
864864
(UnsizedType.ReturnType UReal) in
865865
([lik_fun; lik_tupl], tes)
866866
| _ -> generic_failure ~early:true () in
867+
(* check hessian block size *)
868+
let hbs_arg, rest =
869+
let loc =
870+
match List.last lik_args with
871+
| Some e -> {e.emeta.loc with begin_loc= e.emeta.loc.end_loc}
872+
| None -> loc in
873+
match rest with
874+
| hbs :: rest ->
875+
let hbs_ty = arg_type hbs in
876+
if hbs_ty <> UnsizedType.(DataOnly, UInt) then
877+
Semantic_error.illtyped_laplace_hessian_block_size_arg hbs.emeta.loc
878+
id.name (Some hbs_ty)
879+
|> error
880+
else (hbs, rest)
881+
| _ ->
882+
Semantic_error.illtyped_laplace_hessian_block_size_arg loc id.name None
883+
|> error in
867884
(* Check the remaining arguments: initial guess, covariance, and tolerances *)
868885
match rest with
869886
| {expr= Variable cov_fun; _} :: cov_tupl :: control_args ->
@@ -875,7 +892,8 @@ and check_laplace_fn ~is_cond_dist loc cf tenv id tes =
875892
probably require two more calls to
876893
[check_function_callable_with_tuple] *)
877894
verify_laplace_control_args loc id control_args;
878-
let args = lik_args @ (cov_fun_type :: cov_tupl :: control_args) in
895+
let args =
896+
lik_args @ (hbs_arg :: cov_fun_type :: cov_tupl :: control_args) in
879897
let return_type =
880898
if String.is_suffix id.name ~suffix:"_rng" then UnsizedType.UVector
881899
else UnsizedType.UReal in

src/middle/UnsizedType.ml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ let rec wind_array_type = function
6565
| typ, 0 -> typ
6666
| typ, n -> wind_array_type (UArray typ, n - 1)
6767

68+
let is_fun_type = function UFun _ | UMathLibraryFunction -> true | _ -> false
69+
6870
let rec pp ppf = function
6971
| UInt -> Fmt.string ppf "int"
7072
| UReal -> Fmt.string ppf "real"
@@ -91,7 +93,10 @@ let rec pp ppf = function
9193

9294
and pp_fun_arg ppf (ad_ty, unsized_ty) =
9395
let open Fmt in
94-
let pp_data = if' (equal_autodifftype ad_ty DataOnly) (any "data ") in
96+
let pp_data =
97+
if'
98+
(equal_autodifftype ad_ty DataOnly && not (is_fun_type unsized_ty))
99+
(any "data ") in
95100
(pp_data ++ pp) ppf unsized_ty
96101

97102
and pp_returntype ppf = function
@@ -236,8 +241,6 @@ let is_eigen_type ut =
236241
true
237242
| _ -> false
238243

239-
let is_fun_type = function UFun _ | UMathLibraryFunction -> true | _ -> false
240-
241244
(** Detect if type contains an integer *)
242245
let rec contains_int ut =
243246
match ut with

src/stan_math_signatures/Generate.ml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,17 @@ let () =
12081208
, ReturnType UReal
12091209
, [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix]
12101210
, AoS );
1211+
List.iter [UnsizedType.UInt; UVector] ~f:(fun t ->
1212+
add_unqualified
1213+
( "generate_laplace_options"
1214+
, ReturnType
1215+
(UTuple
1216+
[ UVector (* theta_0 *); UReal (* tolerance *)
1217+
; UInt (* max_num_steps *); UInt (* solver *)
1218+
; UInt (* max_steps_line_search *); UInt (* allow_fallthrough *)
1219+
])
1220+
, [t]
1221+
, AoS ));
12111222
add_unqualified
12121223
("gp_dot_prod_cov", ReturnType UMatrix, [UArray UReal; UReal], AoS);
12131224
add_unqualified

src/stan_math_signatures/Stan_math_signatures.ml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ let laplace_helper_param_types name =
197197
let laplace_tolerance_argument_types =
198198
UnsizedType.
199199
[ (AutoDiffable, UVector) (* theta_0 *); (DataOnly, UReal) (* tolerance *)
200-
; (DataOnly, UInt) (* max_num_steps *)
201-
; (DataOnly, UInt) (* hessian_block_size *); (DataOnly, UInt) (* solver *)
200+
; (DataOnly, UInt) (* max_num_steps *); (DataOnly, UInt) (* solver *)
202201
; (DataOnly, UInt) (* max_steps_line_search *)
203202
; (DataOnly, UInt) (* allow_fallthrough *) ]
204203

@@ -208,7 +207,7 @@ let is_special_function_name name =
208207
|| is_embedded_laplace_fn name
209208

210209
let disallowed_second_order =
211-
[ "algebra_solver"; "algebra_solver_newton"; "integrate_ode"
210+
[ "algebra_solver"; "algebra_solver_newton"; "integrate_1d"; "integrate_ode"
212211
; "integrate_ode_adams"; "integrate_ode_bdf"; "integrate_ode_rk45"; "map_rect"
213212
; "hmm_marginal"; "hmm_hidden_state_prob" ]
214213
|> String.Set.of_list

test/integration/bad/embedded_laplace/autodiff_incompatibility1.stan

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,6 @@ parameters {
5353
}
5454
model {
5555

56-
target += laplace_marginal(ll_function, (eta, log_ye, y),
56+
target += laplace_marginal(ll_function, (eta, log_ye, y), 1,
5757
K_function, (x, n_obs, alpha, rho));
5858
}

test/integration/bad/embedded_laplace/autodiff_incompatibility2.stan

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,5 @@ parameters {
5757

5858
generated quantities {
5959
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
60-
K_function, (x, n_obs, alpha, rho));
60+
1, K_function, (x, n_obs, alpha, rho));
6161
}

test/integration/bad/embedded_laplace/autodiff_incompatibility3.stan

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,5 @@ parameters {
4949

5050
generated quantities {
5151
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
52-
K_function, (x, n_obs, alpha, rho));
52+
1, K_function, (x, n_obs, alpha, rho));
5353
}

test/integration/bad/embedded_laplace/autodiff_incompatibility4.stan

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ functions {
1616
array[] int y) {
1717
// observed count
1818
return neg_binomial_2_lpmf(y | exp(log_ye + theta), eta) +
19-
// integrate 1d is itself allowed, actually
19+
// integrate 1d SHOULD be allowed,
2020
// see https://github.com/stan-dev/math/pull/2929
21+
// but there is a bug: https://github.com/stan-dev/math/issues/3280
2122
integrate_1d(integrand, 0, 1, y, y, y);
2223
}
2324

@@ -48,5 +49,5 @@ parameters {
4849

4950
generated quantities {
5051
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
51-
K_function, (x, n_obs, alpha, rho));
52+
1, K_function, (x, n_obs, alpha, rho));
5253
}

0 commit comments

Comments
 (0)