Skip to content

Commit 07d3bba

Browse files
committed
fix debug data ganeration for constrained arrays and matrices
1 parent 6321198 commit 07d3bba

3 files changed

Lines changed: 74 additions & 49 deletions

File tree

src/analysis_and_optimization/Debug_data_generation.ml

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -73,56 +73,53 @@ let rec repeat n e =
7373
let rec repeat_th n f =
7474
match n with n when n <= 0 -> [] | m -> f () :: repeat_th (m - 1) f
7575

76+
let gen_bounded m gen e =
77+
match Expr.Helpers.try_unpack (eval_expr m e) with
78+
| Some unpacked_e -> List.map ~f:gen unpacked_e
79+
| None ->
80+
Common.FatalError.fatal_error_msg
81+
[%message
82+
"Bad bounded (upper OR lower) expr: "
83+
(e : Expr.Typed.Meta.t Expr.Fixed.t)]
84+
85+
let gen_ul_bounded m gen e1 e2 =
86+
let create_bounds l u =
87+
List.map2_exn ~f:(fun x y -> gen (Transformation.LowerUpper (x, y))) l u
88+
in
89+
let e1, e2 = (eval_expr m e1, eval_expr m e2) in
90+
match Expr.Helpers.(try_unpack e1, try_unpack e2) with
91+
| Some unpacked_e1, Some unpacked_e2 -> create_bounds unpacked_e1 unpacked_e2
92+
| None, Some unpacked_e2 ->
93+
create_bounds
94+
(List.init (List.length unpacked_e2) ~f:(fun _ -> e1))
95+
unpacked_e2
96+
| Some unpacked_e1, None ->
97+
create_bounds unpacked_e1
98+
(List.init (List.length unpacked_e1) ~f:(fun _ -> e2))
99+
| _ ->
100+
Common.FatalError.fatal_error_msg
101+
[%message
102+
"Bad bounded upper and lower expr: "
103+
(e1 : Expr.Typed.t)
104+
" and "
105+
(e2 : Expr.Typed.t)]
106+
76107
let gen_row_vector m n t =
77-
let gen_bounded t e =
78-
match Expr.Helpers.try_unpack e with
79-
| Some unpacked_e ->
80-
Expr.Helpers.row_vector
81-
(List.map ~f:(fun x -> gen_num_real m (t x)) unpacked_e)
82-
| None ->
83-
Common.FatalError.fatal_error_msg
84-
[%message
85-
"Bad bounded (upper OR lower) expr: "
86-
(e : Expr.Typed.Meta.t Expr.Fixed.t)] in
87-
let gen_ul_bounded e1 e2 =
88-
let create_bounds l u =
89-
Expr.Helpers.row_vector
90-
(List.map2_exn
91-
~f:(fun x y -> gen_num_real m (Transformation.LowerUpper (x, y)))
92-
l u ) in
93-
match Expr.Helpers.(try_unpack e1, try_unpack e2) with
94-
| Some unpacked_e1, Some unpacked_e2 ->
95-
create_bounds unpacked_e1 unpacked_e2
96-
| None, Some unpacked_e2 ->
97-
create_bounds
98-
(List.init (List.length unpacked_e2) ~f:(fun _ -> e1))
99-
unpacked_e2
100-
| Some unpacked_e1, None ->
101-
create_bounds unpacked_e1
102-
(List.init (List.length unpacked_e1) ~f:(fun _ -> e2))
103-
| _ ->
104-
Common.FatalError.fatal_error_msg
105-
[%message
106-
"Bad bounded upper and lower expr: "
107-
(e1 : Expr.Typed.t)
108-
" and "
109-
(e2 : Expr.Typed.t)] in
110108
match (t : Expr.Typed.t Transformation.t) with
111109
| Transformation.Lower ({meta= {type_= UVector | URowVector; _}; _} as e) ->
112-
gen_bounded (fun x -> Transformation.Lower x) (eval_expr m e)
110+
gen_bounded m (fun x -> gen_num_real m (Transformation.Lower x)) e
111+
|> Expr.Helpers.row_vector
113112
| Transformation.Upper ({meta= {type_= UVector | URowVector; _}; _} as e) ->
114-
gen_bounded (fun x -> Transformation.Upper x) (eval_expr m e)
113+
gen_bounded m (fun x -> gen_num_real m (Transformation.Upper x)) e
114+
|> Expr.Helpers.row_vector
115115
| Transformation.LowerUpper
116116
( ({meta= {type_= UVector | URowVector | UReal | UInt; _}; _} as e1)
117117
, ({meta= {type_= UVector | URowVector; _}; _} as e2) )
118118
|Transformation.LowerUpper
119119
( ({meta= {type_= UVector | URowVector; _}; _} as e1)
120120
, ({meta= {type_= UReal | UInt; _}; _} as e2) ) ->
121-
gen_ul_bounded (eval_expr m e1) (eval_expr m e2)
122-
| _ ->
123-
let e =
124-
Expr.Helpers.row_vector (repeat_th n (fun _ -> gen_num_real m t)) in
125-
{e with meta= {e.meta with type_= UMatrix}}
121+
gen_ul_bounded m (gen_num_real m) e1 e2 |> Expr.Helpers.row_vector
122+
| _ -> Expr.Helpers.row_vector (repeat_th n (fun _ -> gen_num_real m t))
126123

127124
let gen_vector m n t =
128125
let gen_ordered n =
@@ -218,18 +215,40 @@ let gen_corr_matrix n =
218215
wrap_real_mat (matprod corr_chol (transpose corr_chol))
219216

220217
let gen_matrix mm m n t =
221-
match t with
222-
| Transformation.Covariance -> gen_cov_matrix m
218+
match (t : Expr.Typed.t Transformation.t) with
219+
| Covariance -> gen_cov_matrix m
223220
| Correlation -> gen_corr_matrix m
224221
| CholeskyCov -> gen_cov_cholesky m n
225222
| CholeskyCorr -> gen_corr_cholesky m
223+
| Lower ({meta= {type_= UMatrix; _}; _} as e) ->
224+
Expr.Helpers.matrix_from_rows
225+
(gen_bounded mm (fun x -> gen_row_vector mm n (Lower x)) e)
226+
| Upper ({meta= {type_= UMatrix; _}; _} as e) ->
227+
Expr.Helpers.matrix_from_rows
228+
(gen_bounded mm (fun x -> gen_row_vector mm n (Upper x)) e)
229+
| LowerUpper (({meta= {type_= UMatrix; _}; _} as e1), e2)
230+
|LowerUpper (e1, ({meta= {type_= UMatrix; _}; _} as e2)) ->
231+
Expr.Helpers.matrix_from_rows
232+
(gen_ul_bounded mm (gen_row_vector mm n) e1 e2)
226233
| _ ->
227234
Expr.Helpers.matrix_from_rows
228235
(repeat_th m (fun () -> gen_row_vector mm n t))
229236

230-
let gen_array elt n _ = Expr.Helpers.array_expr (repeat_th n elt)
237+
let rec gen_array m st n t =
238+
let elt () = generate_value m st t in
239+
match (t : Expr.Typed.t Transformation.t) with
240+
| Lower ({meta= {type_= UArray _; _}; _} as e) ->
241+
Expr.Helpers.array_expr
242+
(gen_bounded m (fun x -> generate_value m st (Lower x)) e)
243+
| Upper ({meta= {type_= UArray _; _}; _} as e) ->
244+
Expr.Helpers.array_expr
245+
(gen_bounded m (fun x -> generate_value m st (Upper x)) e)
246+
| LowerUpper (({meta= {type_= UArray _; _}; _} as e1), e2)
247+
|LowerUpper (e1, ({meta= {type_= UArray _; _}; _} as e2)) ->
248+
Expr.Helpers.array_expr (gen_ul_bounded m (generate_value m st) e1 e2)
249+
| _ -> Expr.Helpers.array_expr (repeat_th n elt)
231250

232-
let rec generate_value m st t =
251+
and generate_value m st t =
233252
match st with
234253
| SizedType.SInt -> Expr.Helpers.int (gen_num_int m t)
235254
| SReal -> Expr.Helpers.float (gen_num_real m t)
@@ -240,9 +259,7 @@ let rec generate_value m st t =
240259
| SRowVector (_, e) -> gen_row_vector m (unwrap_int_exn m e) t
241260
| SMatrix (_, e1, e2) ->
242261
gen_matrix m (unwrap_int_exn m e1) (unwrap_int_exn m e2) t
243-
| SArray (st, e) ->
244-
let element () = generate_value m st t in
245-
gen_array element (unwrap_int_exn m e) t
262+
| SArray (st, e) -> gen_array m st (unwrap_int_exn m e) t
246263

247264
let rec pp_value_json ppf e =
248265
match e.Expr.Fixed.pattern with

src/middle/Expr.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ module Helpers = struct
210210
; pattern= FunApp (CompilerInternal FnMakeArray, l) }
211211

212212
let try_unpack e =
213-
(* FIXME: what about matrices? *)
214213
match e.Fixed.pattern with
215214
| FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray), l) -> Some l
216215
| FunApp

test/unit/Debug_data_generation_tests.ml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ let%expect_test "whole program data generation check" =
341341
row_vector<lower=[1,2,3,4,5],upper=[2,3,4,5,6]>[5] y_row_lu_given_bound;
342342
row_vector<lower=x_row_vect,upper=x_row_vect_up>[5] y_row_lu_vector_bound;
343343
row_vector<lower=[1,2,3,4,5],upper=x_row_vect_up>[5] y_row_lu_vector_bound_mixed;
344+
345+
matrix<upper=2.0>[2,2] upper_matrix;
346+
matrix<lower=upper_matrix, upper=5>[2,2] lower_upper_matrix;
344347
}
345348
|}
346349
in
@@ -426,7 +429,13 @@ let%expect_test "whole program data generation check" =
426429
\n 19.572734365483463, 12.348137644499662],\
427430
\n\"y_row_lu_vector_bound_mixed\":\
428431
\n [23.68117004588095, 21.859463377110153, 16.93731450392735,\
429-
\n 19.994339965018462, 16.065041740638126]\
432+
\n 19.994339965018462, 16.065041740638126],\
433+
\n\"upper_matrix\":\
434+
\n [[1.0341237330520263, 1.5089128377972285],\
435+
\n [-2.334114450641863, 1.2726909686498331]],\
436+
\n\"lower_upper_matrix\":\
437+
\n [[2.700035909998304, 3.3701108882979232],\
438+
\n [4.412857964089933, 1.5133491301453186]]\
430439
\n}" |}]
431440

432441
let%expect_test "whole program data generation check" =

0 commit comments

Comments
 (0)