@@ -73,56 +73,53 @@ let rec repeat n e =
7373let rec repeat_th n f =
7474 match n with n when n < = 0 -> [] | m -> f () :: repeat_th (m - 1 ) f
7575
76+ let gen_bounded m gen e =
77+ match Expr.Helpers. try_unpack (eval_expr m e) with
78+ | Some unpacked_e -> List. map ~f: gen unpacked_e
79+ | None ->
80+ Common.FatalError. fatal_error_msg
81+ [% message
82+ " Bad bounded (upper OR lower) expr: "
83+ (e : Expr.Typed.Meta.t Expr.Fixed.t )]
84+
85+ let gen_ul_bounded m gen e1 e2 =
86+ let create_bounds l u =
87+ List. map2_exn ~f: (fun x y -> gen (Transformation. LowerUpper (x, y))) l u
88+ in
89+ let e1, e2 = (eval_expr m e1, eval_expr m e2) in
90+ match Expr.Helpers. (try_unpack e1, try_unpack e2) with
91+ | Some unpacked_e1 , Some unpacked_e2 -> create_bounds unpacked_e1 unpacked_e2
92+ | None , Some unpacked_e2 ->
93+ create_bounds
94+ (List. init (List. length unpacked_e2) ~f: (fun _ -> e1))
95+ unpacked_e2
96+ | Some unpacked_e1 , None ->
97+ create_bounds unpacked_e1
98+ (List. init (List. length unpacked_e1) ~f: (fun _ -> e2))
99+ | _ ->
100+ Common.FatalError. fatal_error_msg
101+ [% message
102+ " Bad bounded upper and lower expr: "
103+ (e1 : Expr.Typed.t )
104+ " and "
105+ (e2 : Expr.Typed.t )]
106+
76107let gen_row_vector m n t =
77- let gen_bounded t e =
78- match Expr.Helpers. try_unpack e with
79- | Some unpacked_e ->
80- Expr.Helpers. row_vector
81- (List. map ~f: (fun x -> gen_num_real m (t x)) unpacked_e)
82- | None ->
83- Common.FatalError. fatal_error_msg
84- [% message
85- " Bad bounded (upper OR lower) expr: "
86- (e : Expr.Typed.Meta.t Expr.Fixed.t )] in
87- let gen_ul_bounded e1 e2 =
88- let create_bounds l u =
89- Expr.Helpers. row_vector
90- (List. map2_exn
91- ~f: (fun x y -> gen_num_real m (Transformation. LowerUpper (x, y)))
92- l u ) in
93- match Expr.Helpers. (try_unpack e1, try_unpack e2) with
94- | Some unpacked_e1 , Some unpacked_e2 ->
95- create_bounds unpacked_e1 unpacked_e2
96- | None , Some unpacked_e2 ->
97- create_bounds
98- (List. init (List. length unpacked_e2) ~f: (fun _ -> e1))
99- unpacked_e2
100- | Some unpacked_e1 , None ->
101- create_bounds unpacked_e1
102- (List. init (List. length unpacked_e1) ~f: (fun _ -> e2))
103- | _ ->
104- Common.FatalError. fatal_error_msg
105- [% message
106- " Bad bounded upper and lower expr: "
107- (e1 : Expr.Typed.t )
108- " and "
109- (e2 : Expr.Typed.t )] in
110108 match (t : Expr.Typed.t Transformation.t ) with
111109 | Transformation. Lower ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
112- gen_bounded (fun x -> Transformation. Lower x) (eval_expr m e)
110+ gen_bounded m (fun x -> gen_num_real m (Transformation. Lower x)) e
111+ |> Expr.Helpers. row_vector
113112 | Transformation. Upper ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
114- gen_bounded (fun x -> Transformation. Upper x) (eval_expr m e)
113+ gen_bounded m (fun x -> gen_num_real m (Transformation. Upper x)) e
114+ |> Expr.Helpers. row_vector
115115 | Transformation. LowerUpper
116116 ( ({meta= {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
117117 , ({meta= {type_= UVector | URowVector ; _}; _} as e2) )
118118 | Transformation. LowerUpper
119119 ( ({meta= {type_= UVector | URowVector ; _}; _} as e1)
120120 , ({meta= {type_= UReal | UInt ; _ }; _ } as e2 ) ) ->
121- gen_ul_bounded (eval_expr m e1) (eval_expr m e2)
122- | _ ->
123- let e =
124- Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t)) in
125- {e with meta= {e.meta with type_= UMatrix }}
121+ gen_ul_bounded m (gen_num_real m) e1 e2 |> Expr.Helpers. row_vector
122+ | _ -> Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t))
126123
127124let gen_vector m n t =
128125 let gen_ordered n =
@@ -218,18 +215,40 @@ let gen_corr_matrix n =
218215 wrap_real_mat (matprod corr_chol (transpose corr_chol))
219216
220217let gen_matrix mm m n t =
221- match t with
222- | Transformation. Covariance -> gen_cov_matrix m
218+ match (t : Expr.Typed.t Transformation.t ) with
219+ | Covariance -> gen_cov_matrix m
223220 | Correlation -> gen_corr_matrix m
224221 | CholeskyCov -> gen_cov_cholesky m n
225222 | CholeskyCorr -> gen_corr_cholesky m
223+ | Lower ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
224+ Expr.Helpers. matrix_from_rows
225+ (gen_bounded mm (fun x -> gen_row_vector mm n (Lower x)) e)
226+ | Upper ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
227+ Expr.Helpers. matrix_from_rows
228+ (gen_bounded mm (fun x -> gen_row_vector mm n (Upper x)) e)
229+ | LowerUpper (({meta= {type_= UMatrix ; _}; _} as e1), e2)
230+ | LowerUpper (e1 , ({meta = {type_ = UMatrix ; _} ; _} as e2 )) ->
231+ Expr.Helpers. matrix_from_rows
232+ (gen_ul_bounded mm (gen_row_vector mm n) e1 e2)
226233 | _ ->
227234 Expr.Helpers. matrix_from_rows
228235 (repeat_th m (fun () -> gen_row_vector mm n t))
229236
230- let gen_array elt n _ = Expr.Helpers. array_expr (repeat_th n elt)
237+ let rec gen_array m st n t =
238+ let elt () = generate_value m st t in
239+ match (t : Expr.Typed.t Transformation.t ) with
240+ | Lower ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
241+ Expr.Helpers. array_expr
242+ (gen_bounded m (fun x -> generate_value m st (Lower x)) e)
243+ | Upper ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
244+ Expr.Helpers. array_expr
245+ (gen_bounded m (fun x -> generate_value m st (Upper x)) e)
246+ | LowerUpper (({meta= {type_= UArray _; _}; _} as e1), e2)
247+ | LowerUpper (e1 , ({meta = {type_ = UArray _ ; _} ; _} as e2 )) ->
248+ Expr.Helpers. array_expr (gen_ul_bounded m (generate_value m st) e1 e2)
249+ | _ -> Expr.Helpers. array_expr (repeat_th n elt)
231250
232- let rec generate_value m st t =
251+ and generate_value m st t =
233252 match st with
234253 | SizedType. SInt -> Expr.Helpers. int (gen_num_int m t)
235254 | SReal -> Expr.Helpers. float (gen_num_real m t)
@@ -240,9 +259,7 @@ let rec generate_value m st t =
240259 | SRowVector (_ , e ) -> gen_row_vector m (unwrap_int_exn m e) t
241260 | SMatrix (_ , e1 , e2 ) ->
242261 gen_matrix m (unwrap_int_exn m e1) (unwrap_int_exn m e2) t
243- | SArray (st , e ) ->
244- let element () = generate_value m st t in
245- gen_array element (unwrap_int_exn m e) t
262+ | SArray (st , e ) -> gen_array m st (unwrap_int_exn m e) t
246263
247264let rec pp_value_json ppf e =
248265 match e.Expr.Fixed. pattern with
0 commit comments