Skip to content

Commit 6321198

Browse files
committed
migrate Debug_data_generation to MIR-only
1 parent 29dda9c commit 6321198

9 files changed

Lines changed: 145 additions & 111 deletions

File tree

src/analysis_and_optimization/Debug_data_generation.ml

Lines changed: 57 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
open Core_kernel
22
open Middle
3-
open Frontend
4-
open Ast
53

64
let rec transpose = function
75
| [] :: _ -> []
@@ -30,15 +28,16 @@ let rec vect_to_mat l m =
3028
let hd, tl = List.split_n l m in
3129
hd :: vect_to_mat tl m
3230

33-
let unwrap_num_exn m e =
34-
let e = Ast_to_Mir.trans_expr e in
35-
let m = Map.Poly.map m ~f:Ast_to_Mir.trans_expr in
31+
let eval_expr m e =
3632
let e = Mir_utils.subst_expr m e in
3733
let e = Partial_evaluator.eval_expr e in
3834
let rec strip_promotions (e : Middle.Expr.Typed.t) =
3935
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
4036
in
41-
let e = strip_promotions e in
37+
strip_promotions e
38+
39+
let unwrap_num_exn m e =
40+
let e = eval_expr m e in
4241
match e.pattern with
4342
| Lit (_, s) -> Float.of_string s
4443
| _ ->
@@ -74,92 +73,56 @@ let rec repeat n e =
7473
let rec repeat_th n f =
7574
match n with n when n <= 0 -> [] | m -> f () :: repeat_th (m - 1) f
7675

77-
let wrap_int n =
78-
{ expr= IntNumeral (Int.to_string n)
79-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UInt} }
80-
81-
let int_two = wrap_int 2
82-
83-
let wrap_real r =
84-
{ expr= RealNumeral (Float.to_string r)
85-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UReal} }
86-
87-
let wrap_row_vector l =
88-
{ expr= RowVectorExpr l
89-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= URowVector} }
90-
91-
let wrap_vector l =
92-
{ expr= PostfixOp (wrap_row_vector l, Transpose)
93-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UVector} }
94-
95-
let gen_int m t = wrap_int (gen_num_int m t)
96-
let gen_real m t = wrap_real (gen_num_real m t)
97-
9876
let gen_row_vector m n t =
99-
let extract_var e =
100-
match e with {expr= Variable x; _} -> Map.find_exn m x.name | _ -> e in
10177
let gen_bounded t e =
102-
match e with
103-
| {expr= RowVectorExpr unpacked_e; _}
104-
|{expr= ArrayExpr unpacked_e; _}
105-
|{expr= PostfixOp ({expr= RowVectorExpr unpacked_e; _}, Transpose); _} ->
106-
wrap_row_vector (List.map ~f:(fun x -> gen_real m (t x)) unpacked_e)
107-
| _ ->
78+
match Expr.Helpers.try_unpack e with
79+
| Some unpacked_e ->
80+
Expr.Helpers.row_vector
81+
(List.map ~f:(fun x -> gen_num_real m (t x)) unpacked_e)
82+
| None ->
10883
Common.FatalError.fatal_error_msg
10984
[%message
11085
"Bad bounded (upper OR lower) expr: "
111-
(e : (typed_expr_meta, fun_kind) expr_with)] in
86+
(e : Expr.Typed.Meta.t Expr.Fixed.t)] in
11287
let gen_ul_bounded e1 e2 =
11388
let create_bounds l u =
114-
wrap_row_vector
89+
Expr.Helpers.row_vector
11590
(List.map2_exn
116-
~f:(fun x y -> gen_real m (Transformation.LowerUpper (x, y)))
91+
~f:(fun x y -> gen_num_real m (Transformation.LowerUpper (x, y)))
11792
l u ) in
118-
match (e1, e2) with
119-
| ( ( {expr= RowVectorExpr unpacked_e1 | ArrayExpr unpacked_e1; _}
120-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose); _}
121-
)
122-
, ( {expr= RowVectorExpr unpacked_e2 | ArrayExpr unpacked_e2; _}
123-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose); _}
124-
) ) ->
125-
(* | {expr= ArrayExpr unpacked_e1; _}, {expr= ArrayExpr unpacked_e2; _} -> *)
93+
match Expr.Helpers.(try_unpack e1, try_unpack e2) with
94+
| Some unpacked_e1, Some unpacked_e2 ->
12695
create_bounds unpacked_e1 unpacked_e2
127-
| ( ({expr= RealNumeral _; _} | {expr= IntNumeral _; _})
128-
, ( {expr= RowVectorExpr unpacked_e2; _}
129-
| {expr= ArrayExpr unpacked_e2; _}
130-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose); _}
131-
) ) ->
96+
| None, Some unpacked_e2 ->
13297
create_bounds
13398
(List.init (List.length unpacked_e2) ~f:(fun _ -> e1))
13499
unpacked_e2
135-
| ( ( {expr= RowVectorExpr unpacked_e1; _}
136-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose); _}
137-
| {expr= ArrayExpr unpacked_e1; _} )
138-
, ({expr= RealNumeral _; _} | {expr= IntNumeral _; _}) ) ->
100+
| Some unpacked_e1, None ->
139101
create_bounds unpacked_e1
140102
(List.init (List.length unpacked_e1) ~f:(fun _ -> e2))
141103
| _ ->
142104
Common.FatalError.fatal_error_msg
143105
[%message
144106
"Bad bounded upper and lower expr: "
145-
(e1 : (typed_expr_meta, fun_kind) expr_with)
107+
(e1 : Expr.Typed.t)
146108
" and "
147-
(e2 : (typed_expr_meta, fun_kind) expr_with)] in
148-
match t with
149-
| Transformation.Lower ({emeta= {type_= UVector | URowVector; _}; _} as e) ->
150-
gen_bounded (fun x -> Transformation.Lower x) (extract_var e)
151-
| Transformation.Upper ({emeta= {type_= UVector | URowVector; _}; _} as e) ->
152-
gen_bounded (fun x -> Transformation.Upper x) (extract_var e)
109+
(e2 : Expr.Typed.t)] in
110+
match (t : Expr.Typed.t Transformation.t) with
111+
| Transformation.Lower ({meta= {type_= UVector | URowVector; _}; _} as e) ->
112+
gen_bounded (fun x -> Transformation.Lower x) (eval_expr m e)
113+
| Transformation.Upper ({meta= {type_= UVector | URowVector; _}; _} as e) ->
114+
gen_bounded (fun x -> Transformation.Upper x) (eval_expr m e)
153115
| Transformation.LowerUpper
154-
( ({emeta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
155-
, ({emeta= {type_= UVector | URowVector; _}; _} as e2) )
116+
( ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
117+
, ({meta= {type_= UVector | URowVector; _}; _} as e2) )
156118
|Transformation.LowerUpper
157-
( ({emeta= {type_= UVector | URowVector; _}; _} as e1)
158-
, ({emeta= {type_= UReal | UInt; _}; _} as e2) ) ->
159-
gen_ul_bounded (extract_var e1) (extract_var e2)
119+
( ({meta= {type_= UVector | URowVector; _}; _} as e1)
120+
, ({meta= {type_= UReal | UInt; _}; _} as e2) ) ->
121+
gen_ul_bounded (eval_expr m e1) (eval_expr m e2)
160122
| _ ->
161-
{ expr= RowVectorExpr (repeat_th n (fun _ -> gen_real m t))
162-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UMatrix} }
123+
let e =
124+
Expr.Helpers.row_vector (repeat_th n (fun _ -> gen_num_real m t)) in
125+
{e with meta= {e.meta with type_= UMatrix}}
163126

164127
let gen_vector m n t =
165128
let gen_ordered n =
@@ -173,38 +136,36 @@ let gen_vector m n t =
173136
let l = repeat_th n (fun _ -> Random.float 1.) in
174137
let sum = List.fold l ~init:0. ~f:(fun accum elt -> accum +. elt) in
175138
let l = List.map l ~f:(fun x -> x /. sum) in
176-
wrap_vector (List.map ~f:wrap_real l)
139+
Expr.Helpers.vector l
177140
| Ordered ->
178141
let l = gen_ordered n in
179142
let halfmax =
180143
Option.value_exn (List.max_elt l ~compare:compare_float) /. 2. in
181144
let l = List.map l ~f:(fun x -> (x -. halfmax) /. halfmax) in
182-
wrap_vector (List.map ~f:wrap_real l)
145+
Expr.Helpers.vector l
183146
| PositiveOrdered ->
184147
let l = gen_ordered n in
185148
let max = Option.value_exn (List.max_elt l ~compare:compare_float) in
186149
let l = List.map l ~f:(fun x -> x /. max) in
187-
wrap_vector (List.map ~f:wrap_real l)
150+
Expr.Helpers.vector l
188151
| UnitVector ->
189152
let l = repeat_th n (fun _ -> Random.float 1.) in
190153
let sum =
191154
Float.sqrt
192155
(List.fold l ~init:0. ~f:(fun accum elt -> accum +. (elt ** 2.)))
193156
in
194157
let l = List.map l ~f:(fun x -> x /. sum) in
195-
wrap_vector (List.map ~f:wrap_real l)
196-
| _ -> {int_two with expr= PostfixOp (gen_row_vector m n t, Transpose)}
158+
Expr.Helpers.vector l
159+
| _ ->
160+
let v = Expr.Helpers.unary_op Transpose (gen_row_vector m n t) in
161+
{v with meta= {v.meta with type_= UVector}}
197162

198163
let gen_cov_unwrapped n =
199164
let l = repeat_th (n * n) (fun _ -> Random.float 2.) in
200165
let l_mat = vect_to_mat l n in
201166
matprod l_mat (transpose l_mat)
202167

203-
let wrap_real_mat m =
204-
let mat_wrapped =
205-
List.map ~f:wrap_row_vector
206-
(List.map ~f:(fun x -> List.map ~f:wrap_real x) m) in
207-
{int_two with expr= RowVectorExpr mat_wrapped}
168+
let wrap_real_mat m = Expr.Helpers.matrix m
208169

209170
let gen_diag_mat l =
210171
let n = List.length l in
@@ -263,20 +224,18 @@ let gen_matrix mm m n t =
263224
| CholeskyCov -> gen_cov_cholesky m n
264225
| CholeskyCorr -> gen_corr_cholesky m
265226
| _ ->
266-
{ int_two with
267-
expr= RowVectorExpr (repeat_th m (fun () -> gen_row_vector mm n t)) }
268-
269-
(* TODO: do some proper random generation of these special matrices *)
227+
Expr.Helpers.matrix_from_rows
228+
(repeat_th m (fun () -> gen_row_vector mm n t))
270229

271-
let gen_array elt n _ = {int_two with expr= ArrayExpr (repeat_th n elt)}
230+
let gen_array elt n _ = Expr.Helpers.array_expr (repeat_th n elt)
272231

273232
let rec generate_value m st t =
274233
match st with
275-
| SizedType.SInt -> gen_int m t
276-
| SReal -> gen_real m t
234+
| SizedType.SInt -> Expr.Helpers.int (gen_num_int m t)
235+
| SReal -> Expr.Helpers.float (gen_num_real m t)
277236
| SComplex ->
278237
(* when serialzied, a complex number looks just like a 2-array of reals *)
279-
generate_value m (SArray (SReal, wrap_int 2)) t
238+
generate_value m (SArray (SReal, Expr.Helpers.int 2)) t
280239
| SVector (_, e) -> gen_vector m (unwrap_int_exn m e) t
281240
| SRowVector (_, e) -> gen_row_vector m (unwrap_int_exn m e) t
282241
| SMatrix (_, e1, e2) ->
@@ -286,28 +245,23 @@ let rec generate_value m st t =
286245
gen_array element (unwrap_int_exn m e) t
287246

288247
let rec pp_value_json ppf e =
289-
match e.expr with
290-
| PostfixOp (e, Transpose) -> pp_value_json ppf e
291-
| IntNumeral s | RealNumeral s -> Fmt.string ppf s
292-
| ArrayExpr l | RowVectorExpr l ->
248+
match e.Expr.Fixed.pattern with
249+
| Lit ((Int | Real), s) -> Fmt.string ppf s
250+
| FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray), l) ->
293251
Fmt.(pf ppf "[@[<hov 1>%a@]]" (list ~sep:comma pp_value_json) l)
252+
| FunApp (StanLib (transpose, _, _), [e])
253+
when String.equal transpose (Operator.to_string Transpose) ->
254+
pp_value_json ppf e
294255
| _ ->
295256
Common.FatalError.fatal_error_msg
296-
[%message "Could not evaluate expression " (e : typed_expression)]
257+
[%message "Could not evaluate expression " (e : Expr.Typed.t)]
297258

298259
let print_data_prog s =
299-
let data = Ast.get_stmts s.datablock in
300260
let l, _ =
301-
List.fold data ~init:([], Map.Poly.empty) ~f:(fun (l, m) decl ->
302-
match decl.stmt with
303-
| VarDecl
304-
{ decl_type= Sized sizedtype
305-
; transformation
306-
; identifier= {name; _}
307-
; _ } ->
308-
let value = generate_value m sizedtype transformation in
309-
((name, value) :: l, Map.set m ~key:name ~data:value)
310-
| _ -> (l, m) ) in
261+
List.fold s ~init:([], Map.Poly.empty)
262+
~f:(fun (l, m) (sizedtype, transformation, name) ->
263+
let value = generate_value m sizedtype transformation in
264+
((name, value) :: l, Map.set m ~key:name ~data:value) ) in
311265
let pp ppf (id, value) =
312266
Fmt.pf ppf {|@[<hov 2>"%s":@ %a@]|} id pp_value_json value in
313267
Fmt.(str "{@ @[<hov>%a@]@ }" (list ~sep:comma pp) (List.rev l))
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
val print_data_prog : Frontend.Ast.typed_program -> string
1+
open Middle
2+
3+
val print_data_prog :
4+
(Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list
5+
-> string

src/frontend/Ast_to_Mir.ml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,22 @@ let migrate_checks_to_end_of_block stmts =
679679
let checks, not_checks = List.partition_tf ~f:stmt_contains_check stmts in
680680
not_checks @ checks
681681

682+
let trans_data (p : Ast.typed_program) =
683+
let data = Ast.get_stmts p.datablock in
684+
List.filter_map data ~f:(function
685+
| { stmt=
686+
VarDecl
687+
{ decl_type= Sized sizedtype
688+
; transformation
689+
; identifier= {name; _}
690+
; _ }
691+
; _ } ->
692+
Some
693+
( SizedType.map trans_expr sizedtype
694+
, Transformation.map trans_expr transformation
695+
, name )
696+
| _ -> None )
697+
682698
let trans_prog filename (p : Ast.typed_program) : Program.Typed.t =
683699
let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = p in
684700
let map f list_op =

src/frontend/Ast_to_Mir.mli

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
(** Translate from the AST to the MIR *)
2+
open Middle
23

3-
val trans_prog : string -> Ast.typed_program -> Middle.Program.Typed.t
4-
val trans_expr : Ast.typed_expression -> Middle.Expr.Typed.t
4+
val trans_data :
5+
Ast.typed_program
6+
-> ( Expr.Typed.Meta.t Expr.Fixed.t SizedType.t
7+
* Expr.Typed.Meta.t Expr.Fixed.t Transformation.t
8+
* string )
9+
list
10+
11+
val trans_prog : string -> Ast.typed_program -> Program.Typed.t

src/middle/Expr.ml

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
open Core_kernel
2-
open Core_kernel.Poly
32
open Common
43
open Helpers
54

@@ -163,13 +162,18 @@ module Helpers = struct
163162
let int i = {Fixed.meta= Typed.Meta.empty; pattern= Lit (Int, string_of_int i)}
164163

165164
let float i =
166-
{Fixed.meta= Typed.Meta.empty; pattern= Lit (Real, string_of_float i)}
165+
{ Fixed.meta= {Typed.Meta.empty with type_= UReal}
166+
; pattern= Lit (Real, Float.to_string i) }
167167

168168
let str i = {Fixed.meta= Typed.Meta.empty; pattern= Lit (Str, i)}
169169
let variable v = {Fixed.meta= Typed.Meta.empty; pattern= Var v}
170170
let zero = int 0
171171
let one = int 1
172172

173+
let unary_op op e =
174+
{ Fixed.meta= Typed.Meta.empty
175+
; pattern= FunApp (StanLib (Operator.to_string op, FnPlain, AoS), [e]) }
176+
173177
let binop e1 op e2 =
174178
{ Fixed.meta= Typed.Meta.empty
175179
; pattern= FunApp (StanLib (Operator.to_string op, FnPlain, AoS), [e1; e2])
@@ -181,6 +185,41 @@ module Helpers = struct
181185
| head :: rest ->
182186
List.fold ~init:head ~f:(fun accum next -> binop accum op next) rest
183187

188+
let row_vector l =
189+
{ Fixed.meta= {Typed.Meta.empty with type_= URowVector}
190+
; pattern= FunApp (CompilerInternal FnMakeRowVec, List.map ~f:float l) }
191+
192+
let vector l =
193+
let v = unary_op Transpose (row_vector l) in
194+
{v with meta= {Typed.Meta.empty with type_= UVector}}
195+
196+
let matrix l =
197+
{ Fixed.meta= {Typed.Meta.empty with type_= UMatrix}
198+
; pattern= FunApp (CompilerInternal FnMakeRowVec, List.map ~f:row_vector l)
199+
}
200+
201+
let matrix_from_rows l =
202+
{ Fixed.meta= {Typed.Meta.empty with type_= UMatrix}
203+
; pattern= FunApp (CompilerInternal FnMakeRowVec, l) }
204+
205+
let array_expr l =
206+
let type_ =
207+
List.hd l |> Option.value_map ~f:Typed.type_of ~default:UnsizedType.UReal
208+
in
209+
{ Fixed.meta= {Typed.Meta.empty with type_= UArray type_}
210+
; pattern= FunApp (CompilerInternal FnMakeArray, l) }
211+
212+
let try_unpack e =
213+
(* FIXME: what about matrices? *)
214+
match e.Fixed.pattern with
215+
| FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray), l) -> Some l
216+
| FunApp
217+
( StanLib (transpose, FnPlain, _)
218+
, [{pattern= FunApp (CompilerInternal FnMakeRowVec, l); _}] )
219+
when String.equal transpose (Operator.to_string Transpose) ->
220+
Some l
221+
| _ -> None
222+
184223
let loop_bottom = one
185224

186225
let internal_funapp fn args meta =
@@ -197,7 +236,9 @@ module Helpers = struct
197236

198237
let%test "expr contains fn" =
199238
internal_funapp FnReadData [] ()
200-
|> contains_fn_kind (fun kind -> kind = CompilerInternal FnReadData)
239+
|> contains_fn_kind (function
240+
| CompilerInternal FnReadData -> true
241+
| _ -> false )
201242

202243
let rec infer_type_of_indexed ut indices =
203244
match (ut, indices) with

0 commit comments

Comments
 (0)