Skip to content

Commit afb355e

Browse files
authored
Merge pull request #1128 from nhuurre/datagen-mir
Debug_data_generation fixes
2 parents 05ac92f + cebdc5a commit afb355e

16 files changed

Lines changed: 226 additions & 196 deletions
Lines changed: 98 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
open Core_kernel
22
open Middle
3-
open Frontend
4-
open Ast
53

64
let rec transpose = function
75
| [] :: _ -> []
@@ -30,15 +28,16 @@ let rec vect_to_mat l m =
3028
let hd, tl = List.split_n l m in
3129
hd :: vect_to_mat tl m
3230

33-
let unwrap_num_exn m e =
34-
let e = Ast_to_Mir.trans_expr e in
35-
let m = Map.Poly.map m ~f:Ast_to_Mir.trans_expr in
31+
let eval_expr m e =
3632
let e = Mir_utils.subst_expr m e in
3733
let e = Partial_evaluator.eval_expr e in
3834
let rec strip_promotions (e : Middle.Expr.Typed.t) =
3935
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
4036
in
41-
let e = strip_promotions e in
37+
strip_promotions e
38+
39+
let unwrap_num_exn m e =
40+
let e = eval_expr m e in
4241
match e.pattern with
4342
| Lit (_, s) -> Float.of_string s
4443
| _ ->
@@ -74,92 +73,51 @@ let rec repeat n e =
7473
let rec repeat_th n f =
7574
match n with n when n <= 0 -> [] | m -> f () :: repeat_th (m - 1) f
7675

77-
let wrap_int n =
78-
{ expr= IntNumeral (Int.to_string n)
79-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UInt} }
80-
81-
let int_two = wrap_int 2
82-
83-
let wrap_real r =
84-
{ expr= RealNumeral (Float.to_string r)
85-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UReal} }
86-
87-
let wrap_row_vector l =
88-
{ expr= RowVectorExpr l
89-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= URowVector} }
90-
91-
let wrap_vector l =
92-
{ expr= PostfixOp (wrap_row_vector l, Transpose)
93-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UVector} }
76+
let gen_bounded m gen e =
77+
match Expr.Helpers.try_unpack (eval_expr m e) with
78+
| Some unpacked_e -> List.map ~f:gen unpacked_e
79+
| None ->
80+
Common.FatalError.fatal_error_msg
81+
[%message "Bad bounded (upper OR lower) expr: " (e : Expr.Typed.t)]
9482

95-
let gen_int m t = wrap_int (gen_num_int m t)
96-
let gen_real m t = wrap_real (gen_num_real m t)
83+
let gen_ul_bounded m gen e1 e2 =
84+
let create_bounds l u =
85+
List.map2_exn ~f:(fun x y -> gen (Transformation.LowerUpper (x, y))) l u
86+
in
87+
let e1, e2 = (eval_expr m e1, eval_expr m e2) in
88+
match Expr.Helpers.(try_unpack e1, try_unpack e2) with
89+
| Some unpacked_e1, Some unpacked_e2 -> create_bounds unpacked_e1 unpacked_e2
90+
| None, Some unpacked_e2 ->
91+
create_bounds
92+
(List.init (List.length unpacked_e2) ~f:(fun _ -> e1))
93+
unpacked_e2
94+
| Some unpacked_e1, None ->
95+
create_bounds unpacked_e1
96+
(List.init (List.length unpacked_e1) ~f:(fun _ -> e2))
97+
| _ ->
98+
Common.FatalError.fatal_error_msg
99+
[%message
100+
"Bad bounded upper and lower expr: "
101+
(e1 : Expr.Typed.t)
102+
" and "
103+
(e2 : Expr.Typed.t)]
97104

98105
let gen_row_vector m n t =
99-
let extract_var e =
100-
match e with {expr= Variable x; _} -> Map.find_exn m x.name | _ -> e in
101-
let gen_bounded t e =
102-
match e with
103-
| {expr= RowVectorExpr unpacked_e; _}
104-
|{expr= ArrayExpr unpacked_e; _}
105-
|{expr= PostfixOp ({expr= RowVectorExpr unpacked_e; _}, Transpose); _} ->
106-
wrap_row_vector (List.map ~f:(fun x -> gen_real m (t x)) unpacked_e)
107-
| _ ->
108-
Common.FatalError.fatal_error_msg
109-
[%message
110-
"Bad bounded (upper OR lower) expr: "
111-
(e : (typed_expr_meta, fun_kind) expr_with)] in
112-
let gen_ul_bounded e1 e2 =
113-
let create_bounds l u =
114-
wrap_row_vector
115-
(List.map2_exn
116-
~f:(fun x y -> gen_real m (Transformation.LowerUpper (x, y)))
117-
l u ) in
118-
match (e1, e2) with
119-
| ( ( {expr= RowVectorExpr unpacked_e1 | ArrayExpr unpacked_e1; _}
120-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose); _}
121-
)
122-
, ( {expr= RowVectorExpr unpacked_e2 | ArrayExpr unpacked_e2; _}
123-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose); _}
124-
) ) ->
125-
(* | {expr= ArrayExpr unpacked_e1; _}, {expr= ArrayExpr unpacked_e2; _} -> *)
126-
create_bounds unpacked_e1 unpacked_e2
127-
| ( ({expr= RealNumeral _; _} | {expr= IntNumeral _; _})
128-
, ( {expr= RowVectorExpr unpacked_e2; _}
129-
| {expr= ArrayExpr unpacked_e2; _}
130-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose); _}
131-
) ) ->
132-
create_bounds
133-
(List.init (List.length unpacked_e2) ~f:(fun _ -> e1))
134-
unpacked_e2
135-
| ( ( {expr= RowVectorExpr unpacked_e1; _}
136-
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose); _}
137-
| {expr= ArrayExpr unpacked_e1; _} )
138-
, ({expr= RealNumeral _; _} | {expr= IntNumeral _; _}) ) ->
139-
create_bounds unpacked_e1
140-
(List.init (List.length unpacked_e1) ~f:(fun _ -> e2))
141-
| _ ->
142-
Common.FatalError.fatal_error_msg
143-
[%message
144-
"Bad bounded upper and lower expr: "
145-
(e1 : (typed_expr_meta, fun_kind) expr_with)
146-
" and "
147-
(e2 : (typed_expr_meta, fun_kind) expr_with)] in
148-
match t with
149-
| Transformation.Lower ({emeta= {type_= UVector | URowVector; _}; _} as e) ->
150-
gen_bounded (fun x -> Transformation.Lower x) (extract_var e)
151-
| Transformation.Upper ({emeta= {type_= UVector | URowVector; _}; _} as e) ->
152-
gen_bounded (fun x -> Transformation.Upper x) (extract_var e)
106+
match (t : Expr.Typed.t Transformation.t) with
107+
| Transformation.Lower ({meta= {type_= UVector | URowVector; _}; _} as e) ->
108+
gen_bounded m (fun x -> gen_num_real m (Transformation.Lower x)) e
109+
|> Expr.Helpers.row_vector
110+
| Transformation.Upper ({meta= {type_= UVector | URowVector; _}; _} as e) ->
111+
gen_bounded m (fun x -> gen_num_real m (Transformation.Upper x)) e
112+
|> Expr.Helpers.row_vector
153113
| Transformation.LowerUpper
154-
( ({emeta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
155-
, ({emeta= {type_= UVector | URowVector; _}; _} as e2) )
114+
( ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
115+
, ({meta= {type_= UVector | URowVector; _}; _} as e2) )
156116
|Transformation.LowerUpper
157-
( ({emeta= {type_= UVector | URowVector; _}; _} as e1)
158-
, ({emeta= {type_= UReal | UInt; _}; _} as e2) ) ->
159-
gen_ul_bounded (extract_var e1) (extract_var e2)
160-
| _ ->
161-
{ expr= RowVectorExpr (repeat_th n (fun _ -> gen_real m t))
162-
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UMatrix} }
117+
( ({meta= {type_= UVector | URowVector; _}; _} as e1)
118+
, ({meta= {type_= UReal | UInt; _}; _} as e2) ) ->
119+
gen_ul_bounded m (gen_num_real m) e1 e2 |> Expr.Helpers.row_vector
120+
| _ -> Expr.Helpers.row_vector (repeat_th n (fun _ -> gen_num_real m t))
163121

164122
let gen_vector m n t =
165123
let gen_ordered n =
@@ -173,38 +131,36 @@ let gen_vector m n t =
173131
let l = repeat_th n (fun _ -> Random.float 1.) in
174132
let sum = List.fold l ~init:0. ~f:(fun accum elt -> accum +. elt) in
175133
let l = List.map l ~f:(fun x -> x /. sum) in
176-
wrap_vector (List.map ~f:wrap_real l)
134+
Expr.Helpers.vector l
177135
| Ordered ->
178136
let l = gen_ordered n in
179137
let halfmax =
180138
Option.value_exn (List.max_elt l ~compare:compare_float) /. 2. in
181139
let l = List.map l ~f:(fun x -> (x -. halfmax) /. halfmax) in
182-
wrap_vector (List.map ~f:wrap_real l)
140+
Expr.Helpers.vector l
183141
| PositiveOrdered ->
184142
let l = gen_ordered n in
185143
let max = Option.value_exn (List.max_elt l ~compare:compare_float) in
186144
let l = List.map l ~f:(fun x -> x /. max) in
187-
wrap_vector (List.map ~f:wrap_real l)
145+
Expr.Helpers.vector l
188146
| UnitVector ->
189147
let l = repeat_th n (fun _ -> Random.float 1.) in
190148
let sum =
191149
Float.sqrt
192150
(List.fold l ~init:0. ~f:(fun accum elt -> accum +. (elt ** 2.)))
193151
in
194152
let l = List.map l ~f:(fun x -> x /. sum) in
195-
wrap_vector (List.map ~f:wrap_real l)
196-
| _ -> {int_two with expr= PostfixOp (gen_row_vector m n t, Transpose)}
153+
Expr.Helpers.vector l
154+
| _ ->
155+
let v = Expr.Helpers.unary_op Transpose (gen_row_vector m n t) in
156+
{v with meta= {v.meta with type_= UVector}}
197157

198158
let gen_cov_unwrapped n =
199159
let l = repeat_th (n * n) (fun _ -> Random.float 2.) in
200160
let l_mat = vect_to_mat l n in
201161
matprod l_mat (transpose l_mat)
202162

203-
let wrap_real_mat m =
204-
let mat_wrapped =
205-
List.map ~f:wrap_row_vector
206-
(List.map ~f:(fun x -> List.map ~f:wrap_real x) m) in
207-
{int_two with expr= RowVectorExpr mat_wrapped}
163+
let wrap_real_mat m = Expr.Helpers.matrix m
208164

209165
let gen_diag_mat l =
210166
let n = List.length l in
@@ -257,57 +213,70 @@ let gen_corr_matrix n =
257213
wrap_real_mat (matprod corr_chol (transpose corr_chol))
258214

259215
let gen_matrix mm m n t =
260-
match t with
261-
| Transformation.Covariance -> gen_cov_matrix m
216+
match (t : Expr.Typed.t Transformation.t) with
217+
| Covariance -> gen_cov_matrix m
262218
| Correlation -> gen_corr_matrix m
263219
| CholeskyCov -> gen_cov_cholesky m n
264220
| CholeskyCorr -> gen_corr_cholesky m
221+
| Lower ({meta= {type_= UMatrix; _}; _} as e) ->
222+
Expr.Helpers.matrix_from_rows
223+
(gen_bounded mm (fun x -> gen_row_vector mm n (Lower x)) e)
224+
| Upper ({meta= {type_= UMatrix; _}; _} as e) ->
225+
Expr.Helpers.matrix_from_rows
226+
(gen_bounded mm (fun x -> gen_row_vector mm n (Upper x)) e)
227+
| LowerUpper (({meta= {type_= UMatrix; _}; _} as e1), e2)
228+
|LowerUpper (e1, ({meta= {type_= UMatrix; _}; _} as e2)) ->
229+
Expr.Helpers.matrix_from_rows
230+
(gen_ul_bounded mm (gen_row_vector mm n) e1 e2)
265231
| _ ->
266-
{ int_two with
267-
expr= RowVectorExpr (repeat_th m (fun () -> gen_row_vector mm n t)) }
268-
269-
(* TODO: do some proper random generation of these special matrices *)
270-
271-
let gen_array elt n _ = {int_two with expr= ArrayExpr (repeat_th n elt)}
272-
273-
let rec generate_value m st t =
232+
Expr.Helpers.matrix_from_rows
233+
(repeat_th m (fun () -> gen_row_vector mm n t))
234+
235+
let rec gen_array m st n t =
236+
let elt () = generate_value m st t in
237+
match (t : Expr.Typed.t Transformation.t) with
238+
| Lower ({meta= {type_= UArray _; _}; _} as e) ->
239+
Expr.Helpers.array_expr
240+
(gen_bounded m (fun x -> generate_value m st (Lower x)) e)
241+
| Upper ({meta= {type_= UArray _; _}; _} as e) ->
242+
Expr.Helpers.array_expr
243+
(gen_bounded m (fun x -> generate_value m st (Upper x)) e)
244+
| LowerUpper (({meta= {type_= UArray _; _}; _} as e1), e2)
245+
|LowerUpper (e1, ({meta= {type_= UArray _; _}; _} as e2)) ->
246+
Expr.Helpers.array_expr (gen_ul_bounded m (generate_value m st) e1 e2)
247+
| _ -> Expr.Helpers.array_expr (repeat_th n elt)
248+
249+
and generate_value m st t =
274250
match st with
275-
| SizedType.SInt -> gen_int m t
276-
| SReal -> gen_real m t
251+
| SizedType.SInt -> Expr.Helpers.int (gen_num_int m t)
252+
| SReal -> Expr.Helpers.float (gen_num_real m t)
277253
| SComplex ->
278254
(* when serialzied, a complex number looks just like a 2-array of reals *)
279-
generate_value m (SArray (SReal, wrap_int 2)) t
255+
generate_value m (SArray (SReal, Expr.Helpers.int 2)) t
280256
| SVector (_, e) -> gen_vector m (unwrap_int_exn m e) t
281257
| SRowVector (_, e) -> gen_row_vector m (unwrap_int_exn m e) t
282258
| SMatrix (_, e1, e2) ->
283259
gen_matrix m (unwrap_int_exn m e1) (unwrap_int_exn m e2) t
284-
| SArray (st, e) ->
285-
let element () = generate_value m st t in
286-
gen_array element (unwrap_int_exn m e) t
260+
| SArray (st, e) -> gen_array m st (unwrap_int_exn m e) t
287261

288262
let rec pp_value_json ppf e =
289-
match e.expr with
290-
| PostfixOp (e, Transpose) -> pp_value_json ppf e
291-
| IntNumeral s | RealNumeral s -> Fmt.string ppf s
292-
| ArrayExpr l | RowVectorExpr l ->
263+
match e.Expr.Fixed.pattern with
264+
| Lit ((Int | Real), s) -> Fmt.string ppf s
265+
| FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray), l) ->
293266
Fmt.(pf ppf "[@[<hov 1>%a@]]" (list ~sep:comma pp_value_json) l)
267+
| FunApp (StanLib (transpose, _, _), [e])
268+
when String.equal transpose (Operator.to_string Transpose) ->
269+
pp_value_json ppf e
294270
| _ ->
295271
Common.FatalError.fatal_error_msg
296-
[%message "Could not evaluate expression " (e : typed_expression)]
272+
[%message "Could not evaluate expression " (e : Expr.Typed.t)]
297273

298274
let print_data_prog s =
299-
let data = Ast.get_stmts s.datablock in
300275
let l, _ =
301-
List.fold data ~init:([], Map.Poly.empty) ~f:(fun (l, m) decl ->
302-
match decl.stmt with
303-
| VarDecl
304-
{ decl_type= Sized sizedtype
305-
; transformation
306-
; identifier= {name; _}
307-
; _ } ->
308-
let value = generate_value m sizedtype transformation in
309-
((name, value) :: l, Map.set m ~key:name ~data:value)
310-
| _ -> (l, m) ) in
276+
List.fold s ~init:([], Map.Poly.empty)
277+
~f:(fun (l, m) (sizedtype, transformation, name) ->
278+
let value = generate_value m sizedtype transformation in
279+
((name, value) :: l, Map.set m ~key:name ~data:value) ) in
311280
let pp ppf (id, value) =
312281
Fmt.pf ppf {|@[<hov 2>"%s":@ %a@]|} id pp_value_json value in
313282
Fmt.(str "{@ @[<hov>%a@]@ }" (list ~sep:comma pp) (List.rev l))
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
val print_data_prog : Frontend.Ast.typed_program -> string
1+
open Middle
2+
3+
val print_data_prog :
4+
(Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list
5+
-> string

src/analysis_and_optimization/Dependence_analysis.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ let mir_uninitialized_variables (mir : Program.Typed.t) :
203203
(Set.Poly.union arg_vars globals)
204204
fdbody ) ) ) ]
205205

206-
let build_dep_info_map (mir : Program.Typed.t)
207-
(stmt : (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t) :
206+
let build_dep_info_map (mir : Program.Typed.t) (stmt : Stmt.Located.t) :
208207
( label
209208
, (Expr.Typed.t, label) Stmt.Fixed.Pattern.t * node_dep_info )
210209
Map.Poly.t =

src/analysis_and_optimization/Dependence_analysis.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ val node_vars_dependencies :
6868

6969
val build_dep_info_map :
7070
Program.Typed.t
71-
-> (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t
71+
-> Stmt.Located.t
7272
-> ( label
7373
, (Expr.Typed.t, label) Stmt.Fixed.Pattern.t * node_dep_info )
7474
Map.Poly.t

0 commit comments

Comments
 (0)