Skip to content

Commit 2bfbfc6

Browse files
committed
Cleanup per review
1 parent ebb54ee commit 2bfbfc6

10 files changed

Lines changed: 208 additions & 188 deletions

File tree

src/frontend/Semantic_error.ml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,17 @@ module TypeError = struct
137137
( name
138138
, arg_tys
139139
, ([((UnsizedType.ReturnType return_type, args), error)], false) )
140-
| AmbiguousFunctionPromotion (s, arg_tys, signatures) ->
140+
| AmbiguousFunctionPromotion (name, arg_tys, signatures) ->
141141
let pp_sig ppf (rt, args) =
142142
Fmt.pf ppf "@[<hov>(@[<hov>%a@]) => %a@]"
143143
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
144144
args UnsizedType.pp_returntype rt in
145145
Fmt.pf ppf
146-
"No unique minimum promotion found for function \"%s\".@ Overloaded \
146+
"No unique minimum promotion found for function '%s'.@ Overloaded \
147147
functions must not have multiple equally valid promotion paths.@ %a \
148148
function has several:@ @[<v>%a@]@ Consider defining a new signature \
149149
for the exact types needed or@ re-thinking existing definitions."
150-
s
150+
name
151151
(Fmt.option
152152
~none:(fun ppf () -> Fmt.pf ppf "This")
153153
(fun ppf tys ->
@@ -362,7 +362,7 @@ module StatementError = struct
362362
| TransformedParamsInt
363363
| FuncOverloadRtOnly of
364364
string * UnsizedType.returntype * UnsizedType.returntype
365-
| FuncDeclRedefined of string * UnsizedType.t
365+
| FuncDeclRedefined of string * UnsizedType.t * bool
366366
| FunDeclExists of string
367367
| FunDeclNoDefn
368368
| FunDeclNeedsBlock
@@ -436,9 +436,11 @@ module StatementError = struct
436436
"Function '%s' cannot be overloaded by return type only. Previously \
437437
used return type %a"
438438
name UnsizedType.pp_returntype rt'
439-
| FuncDeclRedefined (name, ut) ->
440-
Fmt.pf ppf "Function '%s' has already been declared to for signature %a"
441-
name UnsizedType.pp ut
439+
| FuncDeclRedefined (name, ut, stan_math) ->
440+
Fmt.pf ppf "Function '%s' %s signature %a" name
441+
( if stan_math then "is already declared in the Stan Math library with"
442+
else "has already been declared to for" )
443+
UnsizedType.pp ut
442444
| FunDeclExists name ->
443445
Fmt.pf ppf
444446
"Function '%s' has already been declared. A definition is expected."
@@ -701,8 +703,8 @@ let transformed_params_int loc =
701703
let fn_overload_rt_only loc name rt1 rt2 =
702704
StatementError (loc, StatementError.FuncOverloadRtOnly (name, rt1, rt2))
703705

704-
let fn_decl_redefined loc name ut =
705-
StatementError (loc, StatementError.FuncDeclRedefined (name, ut))
706+
let fn_decl_redefined loc name ~stan_math ut =
707+
StatementError (loc, StatementError.FuncDeclRedefined (name, ut, stan_math))
706708

707709
let fn_decl_exists loc name =
708710
StatementError (loc, StatementError.FunDeclExists name)

src/frontend/Semantic_error.mli

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ val fn_overload_rt_only :
142142
-> UnsizedType.returntype
143143
-> t
144144

145-
val fn_decl_redefined : Location_span.t -> string -> UnsizedType.t -> t
145+
val fn_decl_redefined :
146+
Location_span.t -> string -> stan_math:bool -> UnsizedType.t -> t
147+
146148
val fn_decl_exists : Location_span.t -> string -> t
147149
val fn_decl_without_def : Location_span.t -> t
148150
val fn_decl_needs_block : Location_span.t -> t

src/frontend/SignatureMismatch.ml

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,19 @@ let promotion_cost p =
205205
| RealToComplexPromotion | IntToRealPromotion -> 1
206206
| IntToComplexPromotion -> 2
207207

208-
let unique_minimum_promotion ps =
208+
let unique_minimum_promotion promotion_options =
209209
let size (_, p) =
210210
List.fold ~init:0 ~f:(fun acc p -> acc + promotion_cost p) p in
211-
let sizes = List.map ~f:size ps in
212-
let min_promotions = List.min_elt ~compare:Int.compare sizes in
213-
let ps = List.zip_exn sizes ps in
214-
match min_promotions with
215-
| Some n -> (
211+
let sizes = List.map ~f:size promotion_options in
212+
let min_promotion = List.min_elt ~compare:Int.compare sizes in
213+
let sizes_and_promotons = List.zip_exn sizes promotion_options in
214+
match min_promotion with
215+
| Some min_depth -> (
216216
match
217-
List.filter_map ~f:(fun (x, y) -> if x = n then Some y else None) ps
217+
List.filter_map
218+
~f:(fun (depth, promotion) ->
219+
if depth = min_depth then Some promotion else None )
220+
sizes_and_promotons
218221
with
219222
| [ans] -> Ok ans
220223
| _ :: _ as lst -> Error (Some (List.map ~f:fst lst))
@@ -227,8 +230,8 @@ let matching_function env name args =
227230
let function_types =
228231
Environment.find env name
229232
|> List.filter_map ~f:extract_function_types
230-
|> List.sort ~compare:(fun (x, _, _, _) (y, _, _, _) ->
231-
UnsizedType.compare_returntype x y ) in
233+
|> List.sort ~compare:(fun (ret1, _, _, _) (ret2, _, _, _) ->
234+
UnsizedType.compare_returntype ret1 ret2 ) in
232235
let matches, errors =
233236
List.partition_map function_types
234237
~f:(fun (rt, tys, funkind_constructor, _) ->
@@ -238,7 +241,9 @@ let matching_function env name args =
238241
match unique_minimum_promotion matches with
239242
| Ok (((rt, _), funkind_constructor), p) ->
240243
UniqueMatch (rt, funkind_constructor, p)
241-
| Error (Some e) -> AmbiguousMatch (List.map ~f:fst e)
244+
| Error (Some e) ->
245+
AmbiguousMatch (List.map ~f:fst e)
246+
(* return the return types and argument types of ambiguous matches *)
242247
| Error None ->
243248
let errors =
244249
List.sort errors ~compare:(fun (_, e1) (_, e2) -> compare_errors e1 e2)

src/frontend/SignatureMismatch.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ val matching_function :
6060
-> string
6161
-> (UnsizedType.autodifftype * UnsizedType.t) list
6262
-> match_result
63+
(** Searches for a function of the given name which can
64+
support the required argument types.
65+
Requires a unique minimum option under type promotion
66+
*)
6367

6468
val check_variadic_args :
6569
bool

src/frontend/Typechecker.ml

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -434,35 +434,34 @@ let check_normal_fn ~is_cond_dist loc tenv id es =
434434
(Env.nearest_ident tenv id.name) )
435435
|> error
436436
| _ (* a function *) -> (
437-
let
438-
(* NB: At present, [SignatureMismatch.matching_function] cannot handle overloaded function types.
439-
This is not needed until UDFs can be higher-order, as it is special cased for
440-
variadic functions
441-
*)
442-
open
443-
SignatureMismatch in
444-
match matching_function tenv id.name (get_arg_types es) with
445-
| UniqueMatch (Void, _, _) ->
446-
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
447-
|> error
448-
| UniqueMatch (ReturnType ut, fnk, promotions) ->
449-
mk_typed_expression
450-
~expr:
451-
(mk_fun_app ~is_cond_dist
452-
( fnk (Fun_kind.suffix_from_name id.name)
453-
, id
454-
, SignatureMismatch.promote es promotions ) )
455-
~ad_level:(expr_ad_lub es) ~type_:ut ~loc
456-
| AmbiguousMatch sigs ->
457-
Semantic_error.ambiguous_function_promotion loc id.name
458-
(Some (List.map ~f:type_of_expr_typed es))
459-
sigs
460-
|> error
461-
| SignatureErrors (l, b) ->
462-
es
463-
|> List.map ~f:(fun e -> e.emeta.type_)
464-
|> Semantic_error.illtyped_fn_app loc id.name (l, b)
465-
|> error )
437+
(* NB: At present, [SignatureMismatch.matching_function] cannot handle overloaded function types.
438+
This is not needed until UDFs can be higher-order, as it is special cased for
439+
variadic functions
440+
*)
441+
match
442+
SignatureMismatch.matching_function tenv id.name (get_arg_types es)
443+
with
444+
| UniqueMatch (Void, _, _) ->
445+
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
446+
|> error
447+
| UniqueMatch (ReturnType ut, fnk, promotions) ->
448+
mk_typed_expression
449+
~expr:
450+
(mk_fun_app ~is_cond_dist
451+
( fnk (Fun_kind.suffix_from_name id.name)
452+
, id
453+
, SignatureMismatch.promote es promotions ) )
454+
~ad_level:(expr_ad_lub es) ~type_:ut ~loc
455+
| AmbiguousMatch sigs ->
456+
Semantic_error.ambiguous_function_promotion loc id.name
457+
(Some (List.map ~f:type_of_expr_typed es))
458+
sigs
459+
|> error
460+
| SignatureErrors (l, b) ->
461+
es
462+
|> List.map ~f:(fun e -> e.emeta.type_)
463+
|> Semantic_error.illtyped_fn_app loc id.name (l, b)
464+
|> error )
466465

467466
(** Given a constraint function [matches], find any signature which exists
468467
Returns the first [Ok] if any exist, or else [Error]
@@ -1444,8 +1443,9 @@ and verify_unique_signature tenv loc id arg_tys rt =
14441443
| [] -> ()
14451444
| {type_= UFun (_, rt', _, _); _} :: _ when rt <> rt' ->
14461445
Semantic_error.fn_overload_rt_only loc id.name rt rt' |> error
1447-
| _ ->
1446+
| {kind; _} :: _ ->
14481447
Semantic_error.fn_decl_redefined loc id.name
1448+
~stan_math:(kind = `StanMath)
14491449
(UnsizedType.UFun (arg_tys, rt, Fun_kind.suffix_from_name id.name, AoS))
14501450
|> error
14511451

0 commit comments

Comments
 (0)