Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ let fold_stmts ~take_expr ~take_stmt ~(init : 'c) (stmts : Stmt.Located.t List.t

let rec num_expr_value (v : Expr.Typed.t) : (float * string) option =
match v with
(* internal type promotions should be ignored *)
| {pattern= Fixed.Pattern.Promotion (e, _, _); _} -> num_expr_value e
| {pattern= Fixed.Pattern.Lit (Real, str); _}
|{pattern= Fixed.Pattern.Lit (Int, str); _} ->
Some (float_of_string str, str)
Expand Down
4 changes: 3 additions & 1 deletion src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,9 @@ let constant_propagation_transfer ?(preserve_stability = false)
We could do the same for matrix and array expressions if we wanted. *)
| Assignment ((s, t, []), e) -> (
match Partial_evaluator.try_eval_expr (subst_expr m e) with
| {pattern= Lit (_, _); _} as e'
| { pattern=
Promotion ({pattern= Lit (_, _); _}, _, _) | Lit (_, _)
; _ } as e'
when not (preserve_stability && UnsizedType.is_autodiffable t)
->
Map.set m ~key:s ~data:e'
Expand Down
3 changes: 2 additions & 1 deletion src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
pattern=
( match e.pattern with
| Var _ | Lit (_, _) -> e.pattern
| Promotion (expr, ut, ad) -> Promotion (eval_expr expr, ut, ad)
| Promotion (expr, ut, ad) ->
Promotion (eval_expr ~preserve_stability expr, ut, ad)
| FunApp (kind, l) -> (
let l = List.map ~f:(eval_expr ~preserve_stability) l in
match kind with
Expand Down
4 changes: 4 additions & 0 deletions src/frontend/Debug_data_generation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ let unwrap_num_exn m e =
let m = Map.Poly.map m ~f:Ast_to_Mir.trans_expr in
let e = Analysis_and_optimization.Mir_utils.subst_expr m e in
let e = Analysis_and_optimization.Partial_evaluator.eval_expr e in
let rec strip_promotions (e : Middle.Expr.Typed.t) =
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
in
let e = strip_promotions e in
match e.pattern with
| Lit (_, s) -> Float.of_string s
| _ ->
Expand Down
96 changes: 96 additions & 0 deletions src/frontend/Promotion.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
open Core_kernel
open Core_kernel.Poly
module UnsizedType = Middle.UnsizedType

(** Type to represent promotions in the typechecker.
This can be used to return information about how to promote
expressions for use in [Ast.Promotion] *)
type t =
| NoPromotion
| IntToReal
| ToVar (* used in arrays, not functions *)
| ToComplexVar (* used in arrays, not functions *)
Comment thread
rok-cesnovar marked this conversation as resolved.
| IntToComplex
| RealToComplex

let promote_inner (exp : Ast.typed_expression) prom =
let emeta = exp.emeta in
match prom with
| ToVar ->
Ast.
{ expr= Ast.Promotion (exp, UReal, AutoDiffable)
; emeta=
{ emeta with
type_= UnsizedType.promote_array emeta.type_ UReal
; ad_level= AutoDiffable } }
| ToComplexVar ->
Ast.
{ expr= Ast.Promotion (exp, UComplex, AutoDiffable)
; emeta=
{ emeta with
type_= UnsizedType.promote_array emeta.type_ UComplex
; ad_level= AutoDiffable } }
| IntToReal when UnsizedType.is_int_type emeta.type_ ->
Ast.
{ expr= Ast.Promotion (exp, UReal, emeta.ad_level)
; emeta= {emeta with type_= UnsizedType.promote_array emeta.type_ UReal}
}
| (IntToComplex | RealToComplex)
when not (UnsizedType.is_complex_type emeta.type_) ->
(* these two promotions are separated for cost, but are actually the same promotion *)
{ expr= Promotion (exp, UComplex, emeta.ad_level)
; emeta= {emeta with type_= UnsizedType.promote_array emeta.type_ UComplex}
}
| _ -> exp

let rec promote (exp : Ast.typed_expression) prom =
(* promote arrays and rowvector literals at the lowest level to avoid unnecessary copies *)
let open Ast in
match exp.expr with
| ArrayExpr es ->
let pes = List.map ~f:(fun e -> promote e prom) es in
let fst = List.hd_exn pes in
let type_, ad_level = (fst.emeta.type_, fst.emeta.ad_level) in
{ expr= ArrayExpr pes
; emeta=
{ exp.emeta with
type_= UnsizedType.promote_array exp.emeta.type_ type_
; ad_level } }
| RowVectorExpr (_ :: _ as es) ->
let pes = List.map ~f:(fun e -> promote e prom) es in
let fst = List.hd_exn pes in
let ad_level = fst.emeta.ad_level in
{expr= RowVectorExpr pes; emeta= {exp.emeta with ad_level}}
| _ -> promote_inner exp prom

let promote_list es promotions = List.map2_exn es promotions ~f:promote

(** Get the promotion needed to make the second type into the first.
Types NEED to have previously been checked to be promotable
*)
let rec get_type_promotion_exn (ad, ty) (ad2, ty2) =
match (ty, ty2) with
| UnsizedType.(UReal, (UReal | UInt) | UVector, UVector | UMatrix, UMatrix)
when ad <> ad2 ->
ToVar
| UComplex, (UReal | UInt | UComplex) when ad <> ad2 -> ToComplexVar
| UReal, UInt -> IntToReal
| UComplex, UInt -> IntToComplex
| UComplex, UReal -> RealToComplex
| UArray nt1, UArray nt2 -> get_type_promotion_exn (ad, nt1) (ad2, nt2)
| t1, t2 when t1 = t2 -> NoPromotion
| _, _ ->
Common.FatalError.fatal_error_msg
[%message
"Tried to get promotion of mismatched types!"
(ty : UnsizedType.t)
(ty2 : UnsizedType.t)]

(** Calculate the "cost"/number of promotions performed.
Used to disambiguate function signatures
*)
let promotion_cost p =
match p with
| NoPromotion | ToVar | ToComplexVar -> 0
| RealToComplex | IntToReal -> 1
| IntToComplex -> 2
44 changes: 9 additions & 35 deletions src/frontend/SignatureMismatch.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ type signature_error =
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
* function_mismatch

type promotions =
| None
| IntToRealPromotion
| IntToComplexPromotion
| RealToComplexPromotion

type ('unique, 'error) generic_match_result =
| UniqueMatch of 'unique
| AmbiguousMatch of
Expand All @@ -95,7 +89,7 @@ type ('unique, 'error) generic_match_result =
type match_result =
( UnsizedType.returntype
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
* promotions list
* Promotion.t list
, signature_error list * bool )
generic_match_result

Expand Down Expand Up @@ -133,10 +127,10 @@ let rec compare_errors e1 e2 =
let rec check_same_type depth t1 t2 =
let wrap_func = Result.map_error ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
match (t1, t2) with
| t1, t2 when t1 = t2 -> Ok None
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok IntToRealPromotion
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok IntToComplexPromotion
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok RealToComplexPromotion
| t1, t2 when t1 = t2 -> Ok Promotion.NoPromotion
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok IntToReal
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok IntToComplex
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok RealToComplex
(* Arrays: Try to recursively promote, but make sure the error is for these types,
not the recursive call *)
| UArray nt1, UArray nt2 ->
Expand All @@ -153,12 +147,12 @@ let rec check_same_type depth t1 t2 =
Error (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
| UFun (l1, _, _, _), UFun (l2, _, _, _) -> (
match check_compatible_arguments (depth + 1) l2 l1 with
| Ok _ -> Ok None
| Ok _ -> Ok NoPromotion
| Error e -> Error (InputMismatch e) |> wrap_func )
| t1, t2 -> Error (TypeMismatch (t1, t2, None))

and check_compatible_arguments depth typs args2 :
(promotions list, function_mismatch) result =
(Promotion.t list, function_mismatch) result =
match List.zip typs args2 with
| List.Or_unequal_lengths.Unequal_lengths ->
Error (ArgNumMismatch (List.length typs, List.length args2))
Expand All @@ -173,6 +167,7 @@ and check_compatible_arguments depth typs args2 :
else Error (ArgError (i + 1, DataOnlyError)) )
|> Result.all

let check_of_same_type_mod_conv = check_same_type 0
let check_compatible_arguments_mod_conv = check_compatible_arguments 0
let max_n_errors = 5

Expand All @@ -184,30 +179,9 @@ let extract_function_types f =
Some (return, args, (fun x -> UserDefined x), mem)
| _ -> None

let promote es promotions =
List.map2_exn es promotions ~f:(fun (exp : Ast.typed_expression) prom ->
let open UnsizedType in
let emeta = exp.emeta in
match prom with
| IntToRealPromotion when is_int_type emeta.type_ ->
Ast.
{ expr= Ast.Promotion (exp, UReal, emeta.ad_level)
; emeta= {emeta with type_= promote_array emeta.type_ UReal} }
| (IntToComplexPromotion | RealToComplexPromotion)
when not (is_complex_type emeta.type_) ->
{ expr= Promotion (exp, UComplex, emeta.ad_level)
; emeta= {emeta with type_= promote_array emeta.type_ UComplex} }
| _ -> exp )

let promotion_cost p =
match p with
| None -> 0
| RealToComplexPromotion | IntToRealPromotion -> 1
| IntToComplexPromotion -> 2

let unique_minimum_promotion promotion_options =
let size (_, p) =
List.fold ~init:0 ~f:(fun acc p -> acc + promotion_cost p) p in
List.fold ~init:0 ~f:(fun acc p -> acc + Promotion.promotion_cost p) p in
let sizes = List.map ~f:size promotion_options in
let min_promotion = List.min_elt ~compare:Int.compare sizes in
let sizes_and_promotons = List.zip_exn sizes promotion_options in
Expand Down
24 changes: 7 additions & 17 deletions src/frontend/SignatureMismatch.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ type signature_error =
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
* function_mismatch

(** Indicate a promotion by the resulting type *)
type promotions = private
| None
| IntToRealPromotion
| IntToComplexPromotion
| RealToComplexPromotion

type ('unique, 'error) generic_match_result =
| UniqueMatch of 'unique
| AmbiguousMatch of
Expand All @@ -37,23 +30,20 @@ type ('unique, 'error) generic_match_result =
type match_result =
( UnsizedType.returntype
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
* promotions list
* Promotion.t list
, signature_error list * bool )
generic_match_result

val check_of_same_type_mod_conv :
UnsizedType.t -> UnsizedType.t -> (Promotion.t, type_mismatch) result

val check_compatible_arguments_mod_conv :
(UnsizedType.autodifftype * UnsizedType.t) list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> (promotions list, function_mismatch) result

val promote :
Ast.typed_expression list -> promotions list -> Ast.typed_expression list
(** Given a list of expressions (arguments) and a list of [promotions],
return a list of expressions which include the
[Promotion] expression as appropiate *)
-> (Promotion.t list, function_mismatch) result

val unique_minimum_promotion :
('a * promotions list) list -> ('a * promotions list, 'a list option) result
('a * Promotion.t list) list -> ('a * Promotion.t list, 'a list option) result

val matching_function :
Environment.t
Expand All @@ -71,7 +61,7 @@ val check_variadic_args :
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> UnsizedType.t
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> ( UnsizedType.t * promotions list
-> ( UnsizedType.t * Promotion.t list
, (UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch )
result
(** Check variadic function arguments.
Expand Down
Loading