-
-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathDebug_data_generation.ml
More file actions
312 lines (277 loc) · 11.4 KB
/
Debug_data_generation.ml
File metadata and controls
312 lines (277 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
open Core_kernel
open Middle
open Ast
let rec transpose = function
| [] :: _ -> []
| rows ->
let hd = List.map ~f:List.hd_exn rows in
let tl = List.map ~f:List.tl_exn rows in
hd :: transpose tl
let dotproduct xs ys =
List.fold2_exn xs ys ~init:0. ~f:(fun accum x y -> accum +. (x *. y))
let matprod x y =
let y_T = transpose y in
if List.length x <> List.length y_T then
Common.FatalError.fatal_error_msg
[%message "Matrix multiplication dim. mismatch"]
else List.map ~f:(fun row -> List.map ~f:(dotproduct row) y_T) x
let rec vect_to_mat l m =
let len = List.length l in
if len % m <> 0 then
Common.FatalError.fatal_error_msg
[%message "The length has to be a whole multiple of the partition size"]
else if len = m then [l]
else
let hd, tl = List.split_n l m in
hd :: vect_to_mat tl m
let unwrap_num_exn m e =
let e = Ast_to_Mir.trans_expr e in
let m = Map.Poly.map m ~f:Ast_to_Mir.trans_expr in
let e = Analysis_and_optimization.Mir_utils.subst_expr m e in
let e = Analysis_and_optimization.Partial_evaluator.eval_expr e in
let rec strip_promotions (e : Middle.Expr.Typed.t) =
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
in
let e = strip_promotions e in
match e.pattern with
| Lit (_, s) -> Float.of_string s
| _ ->
Common.FatalError.fatal_error_msg
[%message "Cannot convert size to number."]
let unwrap_int_exn m e = Int.of_float (unwrap_num_exn m e)
let gen_num_int m t =
let def_low, diff = (2, 4) in
let low, up =
match t with
| Transformation.Lower e -> (unwrap_int_exn m e, unwrap_int_exn m e + diff)
| Upper e -> (unwrap_int_exn m e - diff, unwrap_int_exn m e)
| LowerUpper (e1, e2) -> (unwrap_int_exn m e1, unwrap_int_exn m e2)
| _ -> (def_low, def_low + diff) in
let low = if low = 0 && up <> 1 then low + 1 else low in
Random.int (up - low + 1) + low
let gen_num_real m t =
let def_low, diff = (2., 5.) in
let low, up =
match t with
| Transformation.Lower e -> (unwrap_num_exn m e, unwrap_num_exn m e +. diff)
| Upper e -> (unwrap_num_exn m e -. diff, unwrap_num_exn m e)
| LowerUpper (e1, e2) -> (unwrap_num_exn m e1, unwrap_num_exn m e2)
| _ -> (def_low, def_low +. diff) in
Random.float_range low up
let rec repeat n e =
match n with n when n <= 0 -> [] | m -> e :: repeat (m - 1) e
let rec repeat_th n f =
match n with n when n <= 0 -> [] | m -> f () :: repeat_th (m - 1) f
let wrap_int n =
{ expr= IntNumeral (Int.to_string n)
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UInt} }
let int_two = wrap_int 2
let wrap_real r =
{ expr= RealNumeral (Float.to_string r)
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UReal} }
let wrap_row_vector l =
{ expr= RowVectorExpr l
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= URowVector} }
let wrap_vector l =
{ expr= PostfixOp (wrap_row_vector l, Transpose)
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UVector} }
let gen_int m t = wrap_int (gen_num_int m t)
let gen_real m t = wrap_real (gen_num_real m t)
let gen_row_vector m n t =
let extract_var e =
match e with {expr= Variable x; _} -> Map.find_exn m x.name | _ -> e in
let gen_bounded t e =
match e with
| {expr= RowVectorExpr unpacked_e; _}
|{expr= ArrayExpr unpacked_e; _}
|{expr= PostfixOp ({expr= RowVectorExpr unpacked_e; _}, Transpose); _} ->
wrap_row_vector (List.map ~f:(fun x -> gen_real m (t x)) unpacked_e)
| _ ->
Common.FatalError.fatal_error_msg
[%message
"Bad bounded (upper OR lower) expr: "
(e : (typed_expr_meta, fun_kind) expr_with)] in
let gen_ul_bounded e1 e2 =
let create_bounds l u =
wrap_row_vector
(List.map2_exn
~f:(fun x y -> gen_real m (Transformation.LowerUpper (x, y)))
l u ) in
match (e1, e2) with
| ( ( {expr= RowVectorExpr unpacked_e1 | ArrayExpr unpacked_e1; _}
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose); _}
)
, ( {expr= RowVectorExpr unpacked_e2 | ArrayExpr unpacked_e2; _}
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose); _}
) ) ->
(* | {expr= ArrayExpr unpacked_e1; _}, {expr= ArrayExpr unpacked_e2; _} -> *)
create_bounds unpacked_e1 unpacked_e2
| ( ({expr= RealNumeral _; _} | {expr= IntNumeral _; _})
, ( {expr= RowVectorExpr unpacked_e2; _}
| {expr= ArrayExpr unpacked_e2; _}
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose); _}
) ) ->
create_bounds
(List.init (List.length unpacked_e2) ~f:(fun _ -> e1))
unpacked_e2
| ( ( {expr= RowVectorExpr unpacked_e1; _}
| {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose); _}
| {expr= ArrayExpr unpacked_e1; _} )
, ({expr= RealNumeral _; _} | {expr= IntNumeral _; _}) ) ->
create_bounds unpacked_e1
(List.init (List.length unpacked_e1) ~f:(fun _ -> e2))
| _ ->
Common.FatalError.fatal_error_msg
[%message
"Bad bounded upper and lower expr: "
(e1 : (typed_expr_meta, fun_kind) expr_with)
" and "
(e2 : (typed_expr_meta, fun_kind) expr_with)] in
match t with
| Transformation.Lower ({emeta= {type_= UVector | URowVector; _}; _} as e) ->
gen_bounded (fun x -> Transformation.Lower x) (extract_var e)
| Transformation.Upper ({emeta= {type_= UVector | URowVector; _}; _} as e) ->
gen_bounded (fun x -> Transformation.Upper x) (extract_var e)
| Transformation.LowerUpper
( ({emeta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
, ({emeta= {type_= UVector | URowVector; _}; _} as e2) )
|Transformation.LowerUpper
( ({emeta= {type_= UVector | URowVector; _}; _} as e1)
, ({emeta= {type_= UReal | UInt; _}; _} as e2) ) ->
gen_ul_bounded (extract_var e1) (extract_var e2)
| _ ->
{ expr= RowVectorExpr (repeat_th n (fun _ -> gen_real m t))
; emeta= {loc= Location_span.empty; ad_level= DataOnly; type_= UMatrix} }
let gen_vector m n t =
let gen_ordered n =
let l = repeat_th n (fun _ -> Random.float 1.) in
let l =
List.fold (List.tl_exn l) ~init:[List.hd_exn l] ~f:(fun accum elt ->
(Float.exp elt +. List.hd_exn accum) :: accum ) in
l in
match t with
| Transformation.Simplex ->
let l = repeat_th n (fun _ -> Random.float 1.) in
let sum = List.fold l ~init:0. ~f:(fun accum elt -> accum +. elt) in
let l = List.map l ~f:(fun x -> x /. sum) in
wrap_vector (List.map ~f:wrap_real l)
| Ordered ->
let l = gen_ordered n in
let halfmax =
Option.value_exn (List.max_elt l ~compare:compare_float) /. 2. in
let l = List.map l ~f:(fun x -> (x -. halfmax) /. halfmax) in
wrap_vector (List.map ~f:wrap_real l)
| PositiveOrdered ->
let l = gen_ordered n in
let max = Option.value_exn (List.max_elt l ~compare:compare_float) in
let l = List.map l ~f:(fun x -> x /. max) in
wrap_vector (List.map ~f:wrap_real l)
| UnitVector ->
let l = repeat_th n (fun _ -> Random.float 1.) in
let sum =
Float.sqrt
(List.fold l ~init:0. ~f:(fun accum elt -> accum +. (elt ** 2.)))
in
let l = List.map l ~f:(fun x -> x /. sum) in
wrap_vector (List.map ~f:wrap_real l)
| _ -> {int_two with expr= PostfixOp (gen_row_vector m n t, Transpose)}
let gen_cov_unwrapped n =
let l = repeat_th (n * n) (fun _ -> Random.float 2.) in
let l_mat = vect_to_mat l n in
matprod l_mat (transpose l_mat)
let wrap_real_mat m =
let mat_wrapped =
List.map ~f:wrap_row_vector
(List.map ~f:(fun x -> List.map ~f:wrap_real x) m) in
{int_two with expr= RowVectorExpr mat_wrapped}
let gen_diag_mat l =
let n = List.length l in
List.map
(List.range 1 (n + 1))
~f:(fun k ->
repeat (min (k - 1) n) 0.
@ (if k <= n then [List.nth_exn l (k - 1)] else [])
@ repeat (n - k) 0. )
let fill_lower_triangular m =
let fill_row i l =
let _, tl = List.split_n l i in
List.init ~f:(fun _ -> Random.float 2.) i @ tl in
List.mapi ~f:fill_row m
let pad_mat mm m n =
let padding_mat =
List.init (m - n) ~f:(fun _ -> List.init n ~f:(fun _ -> Random.float 2.))
in
wrap_real_mat (mm @ padding_mat)
let gen_cov_cholesky m n =
let diag_mat = gen_diag_mat (List.init ~f:(fun _ -> Random.float 2.) n) in
let filled_mat = fill_lower_triangular diag_mat in
if m <= n then wrap_real_mat filled_mat else pad_mat filled_mat m n
let gen_corr_cholesky_unwrapped n =
let diag_mat = gen_diag_mat (List.init ~f:(fun _ -> Random.float 2.) n) in
let filled_mat = fill_lower_triangular diag_mat in
let row_normalizer l =
let row_norm =
Float.sqrt (List.fold ~init:0. ~f:(fun accum x -> accum +. (x *. x)) l)
in
List.map ~f:(fun x -> x /. row_norm) l in
List.map ~f:row_normalizer filled_mat
let gen_corr_cholesky n = wrap_real_mat (gen_corr_cholesky_unwrapped n)
(* let gen_identity_matrix m n =
let id_mat = gen_diag_mat (List.init ~f:(fun _ -> 1.) n) in
if m <= n then wrap_real_mat id_mat else pad_mat id_mat m n *)
let gen_cov_matrix n =
let cov = gen_cov_unwrapped n in
wrap_real_mat cov
let gen_corr_matrix n =
let corr_chol = gen_corr_cholesky_unwrapped n in
wrap_real_mat (matprod corr_chol (transpose corr_chol))
let gen_matrix mm m n t =
match t with
| Transformation.Covariance -> gen_cov_matrix m
| Correlation -> gen_corr_matrix m
| CholeskyCov -> gen_cov_cholesky m n
| CholeskyCorr -> gen_corr_cholesky m
| _ ->
{ int_two with
expr= RowVectorExpr (repeat_th m (fun () -> gen_row_vector mm n t)) }
(* TODO: do some proper random generation of these special matrices *)
let gen_array elt n _ = {int_two with expr= ArrayExpr (repeat_th n elt)}
let rec generate_value m st t =
match st with
| SizedType.SInt -> gen_int m t
| SReal -> gen_real m t
| SComplex ->
(* when serialzied, a complex number looks just like a 2-array of reals *)
generate_value m (SArray (SReal, wrap_int 2)) t
| SVector (_, e) -> gen_vector m (unwrap_int_exn m e) t
| SRowVector (_, e) -> gen_row_vector m (unwrap_int_exn m e) t
| SMatrix (_, e1, e2) ->
gen_matrix m (unwrap_int_exn m e1) (unwrap_int_exn m e2) t
| SArray (st, e) ->
let element () = generate_value m st t in
gen_array element (unwrap_int_exn m e) t
let rec pp_value_json ppf e =
match e.expr with
| PostfixOp (e, Transpose) -> pp_value_json ppf e
| IntNumeral s | RealNumeral s -> Fmt.string ppf s
| ArrayExpr l | RowVectorExpr l ->
Fmt.(pf ppf "[@[<hov 1>%a@]]" (list ~sep:comma pp_value_json) l)
| _ ->
Common.FatalError.fatal_error_msg
[%message "Could not evaluate expression " (e : typed_expression)]
let print_data_prog s =
let data = Ast.get_stmts s.datablock in
let l, _ =
List.fold data ~init:([], Map.Poly.empty) ~f:(fun (l, m) decl ->
match decl.stmt with
| VarDecl
{ decl_type= Sized sizedtype
; transformation
; identifier= {name; _}
; _ } ->
let value = generate_value m sizedtype transformation in
((name, value) :: l, Map.set m ~key:name ~data:value)
| _ -> (l, m) ) in
let pp ppf (id, value) =
Fmt.pf ppf {|@[<hov 2>"%s":@ %a@]|} id pp_value_json value in
Fmt.(str "{@ @[<hov>%a@]@ }" (list ~sep:comma pp) (List.rev l))