Skip to content

Commit 17f651a

Browse files
authored
Merge pull request #1102 from stan-dev/assignment-promotion
Move assignment logic into promotion functionality
2 parents e7ed2eb + 225669f commit 17f651a

23 files changed

Lines changed: 2528 additions & 2014 deletions

src/analysis_and_optimization/Mir_utils.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ let fold_stmts ~take_expr ~take_stmt ~(init : 'c) (stmts : Stmt.Located.t List.t
2727

2828
let rec num_expr_value (v : Expr.Typed.t) : (float * string) option =
2929
match v with
30+
(* internal type promotions should be ignored *)
31+
| {pattern= Fixed.Pattern.Promotion (e, _, _); _} -> num_expr_value e
3032
| {pattern= Fixed.Pattern.Lit (Real, str); _}
3133
|{pattern= Fixed.Pattern.Lit (Int, str); _} ->
3234
Some (float_of_string str, str)

src/analysis_and_optimization/Monotone_framework.ml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ let constant_propagation_transfer ?(preserve_stability = false)
326326
We could do the same for matrix and array expressions if we wanted. *)
327327
| Assignment ((s, t, []), e) -> (
328328
match Partial_evaluator.try_eval_expr (subst_expr m e) with
329-
| {pattern= Lit (_, _); _} as e'
329+
| { pattern=
330+
Promotion ({pattern= Lit (_, _); _}, _, _) | Lit (_, _)
331+
; _ } as e'
330332
when not (preserve_stability && UnsizedType.is_autodiffable t)
331333
->
332334
Map.set m ~key:s ~data:e'

src/analysis_and_optimization/Partial_evaluator.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
9393
pattern=
9494
( match e.pattern with
9595
| Var _ | Lit (_, _) -> e.pattern
96-
| Promotion (expr, ut, ad) -> Promotion (eval_expr expr, ut, ad)
96+
| Promotion (expr, ut, ad) ->
97+
Promotion (eval_expr ~preserve_stability expr, ut, ad)
9798
| FunApp (kind, l) -> (
9899
let l = List.map ~f:(eval_expr ~preserve_stability) l in
99100
match kind with

src/frontend/Debug_data_generation.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ let unwrap_num_exn m e =
3434
let m = Map.Poly.map m ~f:Ast_to_Mir.trans_expr in
3535
let e = Analysis_and_optimization.Mir_utils.subst_expr m e in
3636
let e = Analysis_and_optimization.Partial_evaluator.eval_expr e in
37+
let rec strip_promotions (e : Middle.Expr.Typed.t) =
38+
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
39+
in
40+
let e = strip_promotions e in
3741
match e.pattern with
3842
| Lit (_, s) -> Float.of_string s
3943
| _ ->

src/frontend/Promotion.ml

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
open Core_kernel
2+
open Core_kernel.Poly
3+
module UnsizedType = Middle.UnsizedType
4+
5+
(** Type to represent promotions in the typechecker.
6+
This can be used to return information about how to promote
7+
expressions for use in [Ast.Promotion] *)
8+
type t =
9+
| NoPromotion
10+
| IntToReal
11+
| ToVar (* used in arrays, not functions *)
12+
| ToComplexVar (* used in arrays, not functions *)
13+
| IntToComplex
14+
| RealToComplex
15+
16+
let promote_inner (exp : Ast.typed_expression) prom =
17+
let emeta = exp.emeta in
18+
match prom with
19+
| ToVar ->
20+
Ast.
21+
{ expr= Ast.Promotion (exp, UReal, AutoDiffable)
22+
; emeta=
23+
{ emeta with
24+
type_= UnsizedType.promote_array emeta.type_ UReal
25+
; ad_level= AutoDiffable } }
26+
| ToComplexVar ->
27+
Ast.
28+
{ expr= Ast.Promotion (exp, UComplex, AutoDiffable)
29+
; emeta=
30+
{ emeta with
31+
type_= UnsizedType.promote_array emeta.type_ UComplex
32+
; ad_level= AutoDiffable } }
33+
| IntToReal when UnsizedType.is_int_type emeta.type_ ->
34+
Ast.
35+
{ expr= Ast.Promotion (exp, UReal, emeta.ad_level)
36+
; emeta= {emeta with type_= UnsizedType.promote_array emeta.type_ UReal}
37+
}
38+
| (IntToComplex | RealToComplex)
39+
when not (UnsizedType.is_complex_type emeta.type_) ->
40+
(* these two promotions are separated for cost, but are actually the same promotion *)
41+
{ expr= Promotion (exp, UComplex, emeta.ad_level)
42+
; emeta= {emeta with type_= UnsizedType.promote_array emeta.type_ UComplex}
43+
}
44+
| _ -> exp
45+
46+
let rec promote (exp : Ast.typed_expression) prom =
47+
(* promote arrays and rowvector literals at the lowest level to avoid unnecessary copies *)
48+
let open Ast in
49+
match exp.expr with
50+
| ArrayExpr es ->
51+
let pes = List.map ~f:(fun e -> promote e prom) es in
52+
let fst = List.hd_exn pes in
53+
let type_, ad_level = (fst.emeta.type_, fst.emeta.ad_level) in
54+
{ expr= ArrayExpr pes
55+
; emeta=
56+
{ exp.emeta with
57+
type_= UnsizedType.promote_array exp.emeta.type_ type_
58+
; ad_level } }
59+
| RowVectorExpr (_ :: _ as es) ->
60+
let pes = List.map ~f:(fun e -> promote e prom) es in
61+
let fst = List.hd_exn pes in
62+
let ad_level = fst.emeta.ad_level in
63+
{expr= RowVectorExpr pes; emeta= {exp.emeta with ad_level}}
64+
| _ -> promote_inner exp prom
65+
66+
let promote_list es promotions = List.map2_exn es promotions ~f:promote
67+
68+
(** Get the promotion needed to make the second type into the first.
69+
Types NEED to have previously been checked to be promotable
70+
*)
71+
let rec get_type_promotion_exn (ad, ty) (ad2, ty2) =
72+
match (ty, ty2) with
73+
| UnsizedType.(UReal, (UReal | UInt) | UVector, UVector | UMatrix, UMatrix)
74+
when ad <> ad2 ->
75+
ToVar
76+
| UComplex, (UReal | UInt | UComplex) when ad <> ad2 -> ToComplexVar
77+
| UReal, UInt -> IntToReal
78+
| UComplex, UInt -> IntToComplex
79+
| UComplex, UReal -> RealToComplex
80+
| UArray nt1, UArray nt2 -> get_type_promotion_exn (ad, nt1) (ad2, nt2)
81+
| t1, t2 when t1 = t2 -> NoPromotion
82+
| _, _ ->
83+
Common.FatalError.fatal_error_msg
84+
[%message
85+
"Tried to get promotion of mismatched types!"
86+
(ty : UnsizedType.t)
87+
(ty2 : UnsizedType.t)]
88+
89+
(** Calculate the "cost"/number of promotions performed.
90+
Used to disambiguate function signatures
91+
*)
92+
let promotion_cost p =
93+
match p with
94+
| NoPromotion | ToVar | ToComplexVar -> 0
95+
| RealToComplex | IntToReal -> 1
96+
| IntToComplex -> 2

src/frontend/SignatureMismatch.ml

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,6 @@ type signature_error =
7979
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
8080
* function_mismatch
8181

82-
type promotions =
83-
| None
84-
| IntToRealPromotion
85-
| IntToComplexPromotion
86-
| RealToComplexPromotion
87-
8882
type ('unique, 'error) generic_match_result =
8983
| UniqueMatch of 'unique
9084
| AmbiguousMatch of
@@ -95,7 +89,7 @@ type ('unique, 'error) generic_match_result =
9589
type match_result =
9690
( UnsizedType.returntype
9791
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
98-
* promotions list
92+
* Promotion.t list
9993
, signature_error list * bool )
10094
generic_match_result
10195

@@ -133,10 +127,10 @@ let rec compare_errors e1 e2 =
133127
let rec check_same_type depth t1 t2 =
134128
let wrap_func = Result.map_error ~f:(fun e -> TypeMismatch (t1, t2, Some e)) in
135129
match (t1, t2) with
136-
| t1, t2 when t1 = t2 -> Ok None
137-
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok IntToRealPromotion
138-
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok IntToComplexPromotion
139-
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok RealToComplexPromotion
130+
| t1, t2 when t1 = t2 -> Ok Promotion.NoPromotion
131+
| UnsizedType.(UReal, UInt) when depth < 1 -> Ok IntToReal
132+
| UnsizedType.(UComplex, UInt) when depth < 1 -> Ok IntToComplex
133+
| UnsizedType.(UComplex, UReal) when depth < 1 -> Ok RealToComplex
140134
(* Arrays: Try to recursively promote, but make sure the error is for these types,
141135
not the recursive call *)
142136
| UArray nt1, UArray nt2 ->
@@ -153,12 +147,12 @@ let rec check_same_type depth t1 t2 =
153147
Error (ReturnTypeMismatch (rt1, rt2)) |> wrap_func
154148
| UFun (l1, _, _, _), UFun (l2, _, _, _) -> (
155149
match check_compatible_arguments (depth + 1) l2 l1 with
156-
| Ok _ -> Ok None
150+
| Ok _ -> Ok NoPromotion
157151
| Error e -> Error (InputMismatch e) |> wrap_func )
158152
| t1, t2 -> Error (TypeMismatch (t1, t2, None))
159153

160154
and check_compatible_arguments depth typs args2 :
161-
(promotions list, function_mismatch) result =
155+
(Promotion.t list, function_mismatch) result =
162156
match List.zip typs args2 with
163157
| List.Or_unequal_lengths.Unequal_lengths ->
164158
Error (ArgNumMismatch (List.length typs, List.length args2))
@@ -173,6 +167,7 @@ and check_compatible_arguments depth typs args2 :
173167
else Error (ArgError (i + 1, DataOnlyError)) )
174168
|> Result.all
175169

170+
let check_of_same_type_mod_conv = check_same_type 0
176171
let check_compatible_arguments_mod_conv = check_compatible_arguments 0
177172
let max_n_errors = 5
178173

@@ -184,30 +179,9 @@ let extract_function_types f =
184179
Some (return, args, (fun x -> UserDefined x), mem)
185180
| _ -> None
186181

187-
let promote es promotions =
188-
List.map2_exn es promotions ~f:(fun (exp : Ast.typed_expression) prom ->
189-
let open UnsizedType in
190-
let emeta = exp.emeta in
191-
match prom with
192-
| IntToRealPromotion when is_int_type emeta.type_ ->
193-
Ast.
194-
{ expr= Ast.Promotion (exp, UReal, emeta.ad_level)
195-
; emeta= {emeta with type_= promote_array emeta.type_ UReal} }
196-
| (IntToComplexPromotion | RealToComplexPromotion)
197-
when not (is_complex_type emeta.type_) ->
198-
{ expr= Promotion (exp, UComplex, emeta.ad_level)
199-
; emeta= {emeta with type_= promote_array emeta.type_ UComplex} }
200-
| _ -> exp )
201-
202-
let promotion_cost p =
203-
match p with
204-
| None -> 0
205-
| RealToComplexPromotion | IntToRealPromotion -> 1
206-
| IntToComplexPromotion -> 2
207-
208182
let unique_minimum_promotion promotion_options =
209183
let size (_, p) =
210-
List.fold ~init:0 ~f:(fun acc p -> acc + promotion_cost p) p in
184+
List.fold ~init:0 ~f:(fun acc p -> acc + Promotion.promotion_cost p) p in
211185
let sizes = List.map ~f:size promotion_options in
212186
let min_promotion = List.min_elt ~compare:Int.compare sizes in
213187
let sizes_and_promotons = List.zip_exn sizes promotion_options in

src/frontend/SignatureMismatch.mli

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@ type signature_error =
1919
(UnsizedType.returntype * (UnsizedType.autodifftype * UnsizedType.t) list)
2020
* function_mismatch
2121

22-
(** Indicate a promotion by the resulting type *)
23-
type promotions = private
24-
| None
25-
| IntToRealPromotion
26-
| IntToComplexPromotion
27-
| RealToComplexPromotion
28-
2922
type ('unique, 'error) generic_match_result =
3023
| UniqueMatch of 'unique
3124
| AmbiguousMatch of
@@ -37,23 +30,20 @@ type ('unique, 'error) generic_match_result =
3730
type match_result =
3831
( UnsizedType.returntype
3932
* (bool Middle.Fun_kind.suffix -> Ast.fun_kind)
40-
* promotions list
33+
* Promotion.t list
4134
, signature_error list * bool )
4235
generic_match_result
4336

37+
val check_of_same_type_mod_conv :
38+
UnsizedType.t -> UnsizedType.t -> (Promotion.t, type_mismatch) result
39+
4440
val check_compatible_arguments_mod_conv :
4541
(UnsizedType.autodifftype * UnsizedType.t) list
4642
-> (UnsizedType.autodifftype * UnsizedType.t) list
47-
-> (promotions list, function_mismatch) result
48-
49-
val promote :
50-
Ast.typed_expression list -> promotions list -> Ast.typed_expression list
51-
(** Given a list of expressions (arguments) and a list of [promotions],
52-
return a list of expressions which include the
53-
[Promotion] expression as appropiate *)
43+
-> (Promotion.t list, function_mismatch) result
5444

5545
val unique_minimum_promotion :
56-
('a * promotions list) list -> ('a * promotions list, 'a list option) result
46+
('a * Promotion.t list) list -> ('a * Promotion.t list, 'a list option) result
5747

5848
val matching_function :
5949
Environment.t
@@ -71,7 +61,7 @@ val check_variadic_args :
7161
-> (UnsizedType.autodifftype * UnsizedType.t) list
7262
-> UnsizedType.t
7363
-> (UnsizedType.autodifftype * UnsizedType.t) list
74-
-> ( UnsizedType.t * promotions list
64+
-> ( UnsizedType.t * Promotion.t list
7565
, (UnsizedType.autodifftype * UnsizedType.t) list * function_mismatch )
7666
result
7767
(** Check variadic function arguments.

0 commit comments

Comments
 (0)