Skip to content

Commit 577692a

Browse files
authored
Merge pull request #1027 from stan-dev/function_overloading
Function overloading
2 parents 5f128a5 + 67a2ddf commit 577692a

64 files changed

Lines changed: 14789 additions & 11594 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/analysis_and_optimization/Optimize.ml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,10 @@ let create_function_inline_map adt l =
423423
definitions), because that would replace the function call with a Skip.
424424
*)
425425
let f (accum, visited) Program.{fdname; fdargs; fdbody; fdrt; _} =
426-
if Set.mem visited fdname then (accum, visited)
426+
(* If we see a function more than once,
427+
remove it to prevent inlining of overloaded functions
428+
*)
429+
if Set.mem visited fdname then (Map.remove accum fdname, visited)
427430
else
428431
let accum' =
429432
match fdbody with
@@ -1121,20 +1124,20 @@ let optimize_ad_levels (mir : Program.Typed.t) =
11211124
(**
11221125
* Deduces whether types can be Structures of Arrays (SoA/fast) or
11231126
* Arrays of Structs (AoS/slow). See the docs in
1124-
* Mem_pattern.query_demote_stmt/exprs* functions for
1127+
* Mem_pattern.query_demote_stmt/exprs* functions for
11251128
* details on the rules surrounding when demotion from
11261129
* SoA -> AoS needs to happen.
11271130
*
11281131
* This first does a simple iter over
11291132
* the log_prob portion of the MIR, finding the names of all matrices
11301133
* (and arrays of matrices) where either the Stan math function
1131-
* does not support SoA or the object is single cell accesed within a
1134+
* does not support SoA or the object is single cell accesed within a
11321135
* For or While loop. These are the initial variables
11331136
* given to the monotone framework. Then log_prob has all matrix like objects
1134-
* and the functions that use them to SoA. After that the
1137+
* and the functions that use them to SoA. After that the
11351138
* Monotone framework is used to deduce assignment paths of AoS <-> SoA
1136-
* and vice versa which need to be demoted to AoS as well as updating
1137-
* functions and objects after these assignment passes that then
1139+
* and vice versa which need to be demoted to AoS as well as updating
1140+
* functions and objects after these assignment passes that then
11381141
* also need to be AoS.
11391142
*
11401143
* @param mir: The program's whole MIR.
@@ -1151,7 +1154,7 @@ let optimize_soa (mir : Program.Typed.t) =
11511154
~f:(Mem_pattern.query_initial_demotable_stmt false)
11521155
mir.log_prob in
11531156
(*
1154-
let print_set s =
1157+
let print_set s =
11551158
Set.Poly.iter ~f:print_endline s in
11561159
let () = print_set initial_variables in
11571160
*)

src/frontend/Semantic_error.ml

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ module TypeError = struct
3232
* (UnsizedType.autodifftype * UnsizedType.t) list
3333
* SignatureMismatch.function_mismatch
3434
* UnsizedType.t
35+
| AmbiguousFunctionPromotion of
36+
string
37+
* UnsizedType.t list option
38+
* ( UnsizedType.returntype
39+
* (UnsizedType.autodifftype * UnsizedType.t) list )
40+
list
3541
| ReturningFnExpectedNonReturningFound of string
3642
| ReturningFnExpectedNonFnFound of string
3743
| ReturningFnExpectedUndeclaredIdentFound of string * string option
@@ -131,6 +137,26 @@ module TypeError = struct
131137
( name
132138
, arg_tys
133139
, ([((UnsizedType.ReturnType return_type, args), error)], false) )
140+
| AmbiguousFunctionPromotion (name, arg_tys, signatures) ->
141+
let pp_sig ppf (rt, args) =
142+
Fmt.pf ppf "@[<hov>(@[<hov>%a@]) => %a@]"
143+
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
144+
args UnsizedType.pp_returntype rt in
145+
Fmt.pf ppf
146+
"No unique minimum promotion found for function '%s'.@ Overloaded \
147+
functions must not have multiple equally valid promotion paths.@ %a \
148+
function has several:@ @[<v>%a@]@ Consider defining a new signature \
149+
for the exact types needed or@ re-thinking existing definitions."
150+
name
151+
(Fmt.option
152+
~none:(fun ppf () -> Fmt.pf ppf "This")
153+
(fun ppf tys ->
154+
Fmt.pf ppf "For args @[(%a)@], this"
155+
(Fmt.list ~sep:Fmt.comma UnsizedType.pp)
156+
tys ) )
157+
arg_tys
158+
(Fmt.list ~sep:Fmt.cut pp_sig)
159+
signatures
134160
| NotIndexable (ut, nidcs) ->
135161
Fmt.pf ppf
136162
"Too many indexes, expression dimensions=%d, indexes found=%d."
@@ -238,7 +264,9 @@ module IdentifierError = struct
238264

239265
let pp ppf = function
240266
| IsStanMathName name ->
241-
Fmt.pf ppf "Identifier '%s' clashes with Stan Math library function."
267+
Fmt.pf ppf
268+
"Identifier '%s' clashes with a non-overloadable Stan Math library \
269+
function."
242270
name
243271
| InUse name -> Fmt.pf ppf "Identifier '%s' is already in use." name
244272
| IsModelName name ->
@@ -332,7 +360,9 @@ module StatementError = struct
332360
| NonIntBounds
333361
| ComplexTransform
334362
| TransformedParamsInt
335-
| MismatchFunDefDecl of string * UnsizedType.t option
363+
| FuncOverloadRtOnly of
364+
string * UnsizedType.returntype * UnsizedType.returntype
365+
| FuncDeclRedefined of string * UnsizedType.t * bool
336366
| FunDeclExists of string
337367
| FunDeclNoDefn
338368
| FunDeclNeedsBlock
@@ -401,14 +431,16 @@ module StatementError = struct
401431
Fmt.pf ppf "Complex types do not support transformations."
402432
| TransformedParamsInt ->
403433
Fmt.pf ppf "(Transformed) Parameters cannot be integers."
404-
| MismatchFunDefDecl (name, Some ut) ->
405-
Fmt.pf ppf "Function '%s' has already been declared to have type %a"
406-
name UnsizedType.pp ut
407-
| MismatchFunDefDecl (name, None) ->
408-
Fmt.pf ppf
409-
"Function '%s' has already been declared but type cannot be \
410-
determined."
411-
name
434+
| FuncOverloadRtOnly (name, _, rt') ->
435+
Fmt.pf ppf
436+
"Function '%s' cannot be overloaded by return type only. Previously \
437+
used return type %a"
438+
name UnsizedType.pp_returntype rt'
439+
| FuncDeclRedefined (name, ut, stan_math) ->
440+
Fmt.pf ppf "Function '%s' %s signature %a" name
441+
( if stan_math then "is already declared in the Stan Math library with"
442+
else "has already been declared to for" )
443+
UnsizedType.pp ut
412444
| FunDeclExists name ->
413445
Fmt.pf ppf
414446
"Function '%s' has already been declared. A definition is expected."
@@ -534,6 +566,10 @@ let illtyped_variadic_dae loc name arg_tys args error =
534566
, error
535567
, Stan_math_signatures.variadic_dae_fun_return_type ) )
536568

569+
let ambiguous_function_promotion loc name arg_tys signatures =
570+
TypeError
571+
(loc, TypeError.AmbiguousFunctionPromotion (name, arg_tys, signatures))
572+
537573
let returning_fn_expected_nonfn_found loc name =
538574
TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name)
539575

@@ -664,8 +700,11 @@ let complex_transform loc = StatementError (loc, StatementError.ComplexTransform
664700
let transformed_params_int loc =
665701
StatementError (loc, StatementError.TransformedParamsInt)
666702

667-
let mismatched_fn_def_decl loc name ut_opt =
668-
StatementError (loc, StatementError.MismatchFunDefDecl (name, ut_opt))
703+
let fn_overload_rt_only loc name rt1 rt2 =
704+
StatementError (loc, StatementError.FuncOverloadRtOnly (name, rt1, rt2))
705+
706+
let fn_decl_redefined loc name ~stan_math ut =
707+
StatementError (loc, StatementError.FuncDeclRedefined (name, ut, stan_math))
669708

670709
let fn_decl_exists loc name =
671710
StatementError (loc, StatementError.FunDeclExists name)

src/frontend/Semantic_error.mli

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ val illtyped_variadic_ode :
6666
-> SignatureMismatch.function_mismatch
6767
-> t
6868

69+
val ambiguous_function_promotion :
70+
Location_span.t
71+
-> string
72+
-> UnsizedType.t list option
73+
-> (UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
74+
list
75+
-> t
76+
6977
val illtyped_variadic_dae :
7078
Location_span.t
7179
-> string
@@ -127,8 +135,15 @@ val non_int_bounds : Location_span.t -> t
127135
val complex_transform : Location_span.t -> t
128136
val transformed_params_int : Location_span.t -> t
129137

130-
val mismatched_fn_def_decl :
131-
Location_span.t -> string -> UnsizedType.t option -> t
138+
val fn_overload_rt_only :
139+
Location_span.t
140+
-> string
141+
-> UnsizedType.returntype
142+
-> UnsizedType.returntype
143+
-> t
144+
145+
val fn_decl_redefined :
146+
Location_span.t -> string -> stan_math:bool -> UnsizedType.t -> t
132147

133148
val fn_decl_exists : Location_span.t -> string -> t
134149
val fn_decl_without_def : Location_span.t -> t

src/frontend/SignatureMismatch.ml

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,26 @@ type signature_error =
7979
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
8080
* function_mismatch
8181

82+
type promotions =
83+
| None
84+
| IntToRealPromotion
85+
| IntToComplexPromotion
86+
| RealToComplexPromotion
87+
88+
type ('unique, 'error) generic_match_result =
89+
| UniqueMatch of 'unique
90+
| AmbiguousMatch of
91+
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
92+
list
93+
| SignatureErrors of 'error
94+
95+
type match_result =
96+
( UnsizedType.returntype
97+
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
98+
* promotions list
99+
, signature_error list * bool )
100+
generic_match_result
101+
82102
let rec compare_types t1 t2 =
83103
match (t1, t2) with
84104
| UnsizedType.(UArray t1, UArray t2) -> compare_types t1 t2
@@ -110,14 +130,13 @@ let rec compare_errors e1 e2 =
110130
| SuffixMismatch _, _ | _, InputMismatch _ -> -1
111131
| InputMismatch _, _ | _, SuffixMismatch _ -> 1 ) )
112132

113-
type promotions = None | RealPromotion | ComplexPromotion
114-
115133
let rec check_same_type depth t1 t2 =
116134
let wrap_func = Result.map_error ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
117135
match (t1, t2) with
118136
| t1, t2 when t1 = t2 -> Ok None
119-
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok RealPromotion
120-
| UnsizedType.(UComplex, (UInt | UReal)) when depth < 1 -> Ok ComplexPromotion
137+
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok IntToRealPromotion
138+
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok IntToComplexPromotion
139+
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok RealToComplexPromotion
121140
(* Arrays: Try to recursively promote, but make sure the error is for these types,
122141
not the recursive call *)
123142
| UArray nt1, UArray nt2 ->
@@ -170,34 +189,67 @@ let promote es promotions =
170189
let open UnsizedType in
171190
let emeta = exp.emeta in
172191
match prom with
173-
| RealPromotion when is_int_type emeta.type_ ->
192+
| IntToRealPromotion when is_int_type emeta.type_ ->
174193
Ast.
175194
{ expr= Ast.Promotion (exp, UReal, emeta.ad_level)
176195
; emeta= {emeta with type_= promote_array emeta.type_ UReal} }
177-
| ComplexPromotion when not (is_complex_type emeta.type_) ->
196+
| (IntToComplexPromotion | RealToComplexPromotion)
197+
when not (is_complex_type emeta.type_) ->
178198
{ expr= Promotion (exp, UComplex, emeta.ad_level)
179199
; emeta= {emeta with type_= promote_array emeta.type_ UComplex} }
180200
| _ -> exp )
181201

182-
let returntype env name args =
202+
let promotion_cost p =
203+
match p with
204+
| None -> 0
205+
| RealToComplexPromotion | IntToRealPromotion -> 1
206+
| IntToComplexPromotion -> 2
207+
208+
let unique_minimum_promotion promotion_options =
209+
let size (_, p) =
210+
List.fold ~init:0 ~f:(fun acc p -> acc + promotion_cost p) p in
211+
let sizes = List.map ~f:size promotion_options in
212+
let min_promotion = List.min_elt ~compare:Int.compare sizes in
213+
let sizes_and_promotons = List.zip_exn sizes promotion_options in
214+
match min_promotion with
215+
| Some min_depth -> (
216+
match
217+
List.filter_map
218+
~f:(fun (depth, promotion) ->
219+
if depth = min_depth then Some promotion else None )
220+
sizes_and_promotons
221+
with
222+
| [ans] -> Ok ans
223+
| _ :: _ as lst -> Error (Some (List.map ~f:fst lst))
224+
| [] -> Error None )
225+
| None -> Error None
226+
227+
let matching_function env name args =
183228
(* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
184229
let name = Utils.stdlib_distribution_name name in
185-
Environment.find env name
186-
|> List.filter_map ~f:extract_function_types
187-
|> List.sort ~compare:(fun (x, _, _, _) (y, _, _, _) ->
188-
UnsizedType.compare_returntype x y )
189-
(* Check the least return type first in case there are multiple options (due to implicit UInt-UReal conversion), where UInt<UReal *)
190-
|> List.fold_until ~init:[]
191-
~f:(fun errors (rt, tys, funkind_constructor, _) ->
192-
match check_compatible_arguments 0 tys args with
193-
| Ok p -> Stop (Ok (rt, funkind_constructor, p))
194-
| Error e -> Continue (((rt, tys), e) :: errors) )
195-
~finish:(fun errors ->
196-
let errors =
197-
List.sort errors ~compare:(fun (_, e1) (_, e2) ->
198-
compare_errors e1 e2 ) in
199-
let errors, omitted = List.split_n errors max_n_errors in
200-
Error (errors, not (List.is_empty omitted)) )
230+
let function_types =
231+
Environment.find env name
232+
|> List.filter_map ~f:extract_function_types
233+
|> List.sort ~compare:(fun (ret1, _, _, _) (ret2, _, _, _) ->
234+
UnsizedType.compare_returntype ret1 ret2 ) in
235+
let matches, errors =
236+
List.partition_map function_types
237+
~f:(fun (rt, tys, funkind_constructor, _) ->
238+
match check_compatible_arguments 0 tys args with
239+
| Ok p -> Either.First (((rt, tys), funkind_constructor), p)
240+
| Error e -> Second ((rt, tys), e) ) in
241+
match unique_minimum_promotion matches with
242+
| Ok (((rt, _), funkind_constructor), p) ->
243+
UniqueMatch (rt, funkind_constructor, p)
244+
| Error (Some e) ->
245+
AmbiguousMatch (List.map ~f:fst e)
246+
(* return the return types and argument types of ambiguous matches *)
247+
| Error None ->
248+
let errors =
249+
List.sort errors ~compare:(fun (_, e1) (_, e2) -> compare_errors e1 e2)
250+
in
251+
let errors, omitted = List.split_n errors max_n_errors in
252+
SignatureErrors (errors, not (List.is_empty omitted))
201253

202254
let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
203255
fun_return args =
@@ -231,6 +283,7 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
231283
((UnsizedType.AutoDiffable, func_type) :: mandatory_arg_tys)
232284
@ variadic_arg_tys in
233285
check_compatible_arguments 0 expected_args args
286+
|> Result.map ~f:(fun x -> (func_type, x))
234287
|> Result.map_error ~f:(fun x -> (expected_args, x)) )
235288
else wrap_func_error (SuffixMismatch (FnPlain, suffix))
236289
| (_, x) :: _ -> TypeMismatch (minimal_func_type, x, None) |> wrap_err

0 commit comments

Comments
 (0)