@@ -79,6 +79,26 @@ type signature_error =
7979 (UnsizedType .returntype * (UnsizedType .autodifftype * UnsizedType .t ) list )
8080 * function_mismatch
8181
82+ type promotions =
83+ | None
84+ | IntToRealPromotion
85+ | IntToComplexPromotion
86+ | RealToComplexPromotion
87+
88+ type ('unique, 'error) generic_match_result =
89+ | UniqueMatch of 'unique
90+ | AmbiguousMatch of
91+ (UnsizedType .returntype * (UnsizedType .autodifftype * UnsizedType .t ) list )
92+ list
93+ | SignatureErrors of 'error
94+
95+ type match_result =
96+ ( UnsizedType .returntype
97+ * (bool Middle.Fun_kind .suffix -> Ast .fun_kind )
98+ * promotions list
99+ , signature_error list * bool )
100+ generic_match_result
101+
82102let rec compare_types t1 t2 =
83103 match (t1, t2) with
84104 | UnsizedType. (UArray t1 , UArray t2 ) -> compare_types t1 t2
@@ -110,14 +130,13 @@ let rec compare_errors e1 e2 =
110130 | SuffixMismatch _ , _ | _ , InputMismatch _ -> - 1
111131 | InputMismatch _ , _ | _ , SuffixMismatch _ -> 1 ) )
112132
113- type promotions = None | RealPromotion | ComplexPromotion
114-
115133let rec check_same_type depth t1 t2 =
116134 let wrap_func = Result. map_error ~f: (fun e -> TypeMismatch (t1, t2, Some e)) in
117135 match (t1, t2) with
118136 | t1 , t2 when t1 = t2 -> Ok None
119- | UnsizedType. (UReal, UInt) when depth < 1 -> Ok RealPromotion
120- | UnsizedType. (UComplex , (UInt | UReal )) when depth < 1 -> Ok ComplexPromotion
137+ | UnsizedType. (UReal, UInt) when depth < 1 -> Ok IntToRealPromotion
138+ | UnsizedType. (UComplex, UInt) when depth < 1 -> Ok IntToComplexPromotion
139+ | UnsizedType. (UComplex, UReal) when depth < 1 -> Ok RealToComplexPromotion
121140 (* Arrays: Try to recursively promote, but make sure the error is for these types,
122141 not the recursive call *)
123142 | UArray nt1 , UArray nt2 ->
@@ -170,34 +189,67 @@ let promote es promotions =
170189 let open UnsizedType in
171190 let emeta = exp.emeta in
172191 match prom with
173- | RealPromotion when is_int_type emeta.type_ ->
192+ | IntToRealPromotion when is_int_type emeta.type_ ->
174193 Ast.
175194 { expr= Ast. Promotion (exp, UReal , emeta.ad_level)
176195 ; emeta= {emeta with type_= promote_array emeta.type_ UReal } }
177- | ComplexPromotion when not (is_complex_type emeta.type_) ->
196+ | (IntToComplexPromotion | RealToComplexPromotion )
197+ when not (is_complex_type emeta.type_) ->
178198 { expr= Promotion (exp, UComplex , emeta.ad_level)
179199 ; emeta= {emeta with type_= promote_array emeta.type_ UComplex } }
180200 | _ -> exp )
181201
182- let returntype env name args =
202+ let promotion_cost p =
203+ match p with
204+ | None -> 0
205+ | RealToComplexPromotion | IntToRealPromotion -> 1
206+ | IntToComplexPromotion -> 2
207+
208+ let unique_minimum_promotion promotion_options =
209+ let size (_ , p ) =
210+ List. fold ~init: 0 ~f: (fun acc p -> acc + promotion_cost p) p in
211+ let sizes = List. map ~f: size promotion_options in
212+ let min_promotion = List. min_elt ~compare: Int. compare sizes in
213+ let sizes_and_promotons = List. zip_exn sizes promotion_options in
214+ match min_promotion with
215+ | Some min_depth -> (
216+ match
217+ List. filter_map
218+ ~f: (fun (depth , promotion ) ->
219+ if depth = min_depth then Some promotion else None )
220+ sizes_and_promotons
221+ with
222+ | [ans] -> Ok ans
223+ | _ :: _ as lst -> Error (Some (List. map ~f: fst lst))
224+ | [] -> Error None )
225+ | None -> Error None
226+
227+ let matching_function env name args =
183228 (* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
184229 let name = Utils. stdlib_distribution_name name in
185- Environment. find env name
186- |> List. filter_map ~f: extract_function_types
187- |> List. sort ~compare: (fun (x , _ , _ , _ ) (y , _ , _ , _ ) ->
188- UnsizedType. compare_returntype x y )
189- (* Check the least return type first in case there are multiple options (due to implicit UInt-UReal conversion), where UInt<UReal *)
190- |> List. fold_until ~init: []
191- ~f: (fun errors (rt , tys , funkind_constructor , _ ) ->
192- match check_compatible_arguments 0 tys args with
193- | Ok p -> Stop (Ok (rt, funkind_constructor, p))
194- | Error e -> Continue (((rt, tys), e) :: errors) )
195- ~finish: (fun errors ->
196- let errors =
197- List. sort errors ~compare: (fun (_ , e1 ) (_ , e2 ) ->
198- compare_errors e1 e2 ) in
199- let errors, omitted = List. split_n errors max_n_errors in
200- Error (errors, not (List. is_empty omitted)) )
230+ let function_types =
231+ Environment. find env name
232+ |> List. filter_map ~f: extract_function_types
233+ |> List. sort ~compare: (fun (ret1 , _ , _ , _ ) (ret2 , _ , _ , _ ) ->
234+ UnsizedType. compare_returntype ret1 ret2 ) in
235+ let matches, errors =
236+ List. partition_map function_types
237+ ~f: (fun (rt , tys , funkind_constructor , _ ) ->
238+ match check_compatible_arguments 0 tys args with
239+ | Ok p -> Either. First (((rt, tys), funkind_constructor), p)
240+ | Error e -> Second ((rt, tys), e) ) in
241+ match unique_minimum_promotion matches with
242+ | Ok (((rt , _ ), funkind_constructor ), p ) ->
243+ UniqueMatch (rt, funkind_constructor, p)
244+ | Error (Some e ) ->
245+ AmbiguousMatch (List. map ~f: fst e)
246+ (* return the return types and argument types of ambiguous matches *)
247+ | Error None ->
248+ let errors =
249+ List. sort errors ~compare: (fun (_ , e1 ) (_ , e2 ) -> compare_errors e1 e2)
250+ in
251+ let errors, omitted = List. split_n errors max_n_errors in
252+ SignatureErrors (errors, not (List. is_empty omitted))
201253
202254let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
203255 fun_return args =
@@ -231,6 +283,7 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
231283 ((UnsizedType. AutoDiffable , func_type) :: mandatory_arg_tys)
232284 @ variadic_arg_tys in
233285 check_compatible_arguments 0 expected_args args
286+ |> Result. map ~f: (fun x -> (func_type, x))
234287 |> Result. map_error ~f: (fun x -> (expected_args, x)) )
235288 else wrap_func_error (SuffixMismatch (FnPlain , suffix))
236289 | (_ , x ) :: _ -> TypeMismatch (minimal_func_type, x, None ) |> wrap_err
0 commit comments