11open Core_kernel
22open Middle
3- open Frontend
4- open Ast
53
64let rec transpose = function
75 | [] :: _ -> []
@@ -30,15 +28,16 @@ let rec vect_to_mat l m =
3028 let hd, tl = List. split_n l m in
3129 hd :: vect_to_mat tl m
3230
33- let unwrap_num_exn m e =
34- let e = Ast_to_Mir. trans_expr e in
35- let m = Map.Poly. map m ~f: Ast_to_Mir. trans_expr in
31+ let eval_expr m e =
3632 let e = Mir_utils. subst_expr m e in
3733 let e = Partial_evaluator. eval_expr e in
3834 let rec strip_promotions (e : Middle.Expr.Typed.t ) =
3935 match e.pattern with Promotion (e , _ , _ ) -> strip_promotions e | _ -> e
4036 in
41- let e = strip_promotions e in
37+ strip_promotions e
38+
39+ let unwrap_num_exn m e =
40+ let e = eval_expr m e in
4241 match e.pattern with
4342 | Lit (_ , s ) -> Float. of_string s
4443 | _ ->
@@ -74,92 +73,51 @@ let rec repeat n e =
7473let rec repeat_th n f =
7574 match n with n when n < = 0 -> [] | m -> f () :: repeat_th (m - 1 ) f
7675
77- let wrap_int n =
78- { expr= IntNumeral (Int. to_string n)
79- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UInt } }
80-
81- let int_two = wrap_int 2
82-
83- let wrap_real r =
84- { expr= RealNumeral (Float. to_string r)
85- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UReal } }
86-
87- let wrap_row_vector l =
88- { expr= RowVectorExpr l
89- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= URowVector } }
90-
91- let wrap_vector l =
92- { expr= PostfixOp (wrap_row_vector l, Transpose )
93- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UVector } }
76+ let gen_bounded m gen e =
77+ match Expr.Helpers. try_unpack (eval_expr m e) with
78+ | Some unpacked_e -> List. map ~f: gen unpacked_e
79+ | None ->
80+ Common.FatalError. fatal_error_msg
81+ [% message " Bad bounded (upper OR lower) expr: " (e : Expr.Typed.t )]
9482
95- let gen_int m t = wrap_int (gen_num_int m t)
96- let gen_real m t = wrap_real (gen_num_real m t)
83+ let gen_ul_bounded m gen e1 e2 =
84+ let create_bounds l u =
85+ List. map2_exn ~f: (fun x y -> gen (Transformation. LowerUpper (x, y))) l u
86+ in
87+ let e1, e2 = (eval_expr m e1, eval_expr m e2) in
88+ match Expr.Helpers. (try_unpack e1, try_unpack e2) with
89+ | Some unpacked_e1 , Some unpacked_e2 -> create_bounds unpacked_e1 unpacked_e2
90+ | None , Some unpacked_e2 ->
91+ create_bounds
92+ (List. init (List. length unpacked_e2) ~f: (fun _ -> e1))
93+ unpacked_e2
94+ | Some unpacked_e1 , None ->
95+ create_bounds unpacked_e1
96+ (List. init (List. length unpacked_e1) ~f: (fun _ -> e2))
97+ | _ ->
98+ Common.FatalError. fatal_error_msg
99+ [% message
100+ " Bad bounded upper and lower expr: "
101+ (e1 : Expr.Typed.t )
102+ " and "
103+ (e2 : Expr.Typed.t )]
97104
98105let gen_row_vector m n t =
99- let extract_var e =
100- match e with {expr = Variable x ; _} -> Map. find_exn m x.name | _ -> e in
101- let gen_bounded t e =
102- match e with
103- | {expr= RowVectorExpr unpacked_e; _}
104- | {expr= ArrayExpr unpacked_e; _}
105- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e; _}, Transpose ); _} ->
106- wrap_row_vector (List. map ~f: (fun x -> gen_real m (t x)) unpacked_e)
107- | _ ->
108- Common.FatalError. fatal_error_msg
109- [% message
110- " Bad bounded (upper OR lower) expr: "
111- (e : (typed_expr_meta, fun_kind ) expr_with)] in
112- let gen_ul_bounded e1 e2 =
113- let create_bounds l u =
114- wrap_row_vector
115- (List. map2_exn
116- ~f: (fun x y -> gen_real m (Transformation. LowerUpper (x, y)))
117- l u ) in
118- match (e1, e2) with
119- | ( ( {expr= RowVectorExpr unpacked_e1 | ArrayExpr unpacked_e1; _}
120- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose ); _}
121- )
122- , ( {expr= RowVectorExpr unpacked_e2 | ArrayExpr unpacked_e2; _}
123- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose ); _}
124- ) ) ->
125- (* | {expr= ArrayExpr unpacked_e1; _}, {expr= ArrayExpr unpacked_e2; _} -> *)
126- create_bounds unpacked_e1 unpacked_e2
127- | ( ({expr= RealNumeral _; _} | {expr= IntNumeral _; _})
128- , ( {expr= RowVectorExpr unpacked_e2; _}
129- | {expr= ArrayExpr unpacked_e2; _}
130- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose ); _}
131- ) ) ->
132- create_bounds
133- (List. init (List. length unpacked_e2) ~f: (fun _ -> e1))
134- unpacked_e2
135- | ( ( {expr= RowVectorExpr unpacked_e1; _}
136- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose ); _}
137- | {expr= ArrayExpr unpacked_e1; _} )
138- , ({expr= RealNumeral _; _} | {expr = IntNumeral _ ; _} ) ) ->
139- create_bounds unpacked_e1
140- (List. init (List. length unpacked_e1) ~f: (fun _ -> e2))
141- | _ ->
142- Common.FatalError. fatal_error_msg
143- [% message
144- " Bad bounded upper and lower expr: "
145- (e1 : (typed_expr_meta, fun_kind ) expr_with)
146- " and "
147- (e2 : (typed_expr_meta, fun_kind ) expr_with)] in
148- match t with
149- | Transformation. Lower ({emeta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
150- gen_bounded (fun x -> Transformation. Lower x) (extract_var e)
151- | Transformation. Upper ({emeta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
152- gen_bounded (fun x -> Transformation. Upper x) (extract_var e)
106+ match (t : Expr.Typed.t Transformation.t ) with
107+ | Transformation. Lower ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
108+ gen_bounded m (fun x -> gen_num_real m (Transformation. Lower x)) e
109+ |> Expr.Helpers. row_vector
110+ | Transformation. Upper ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
111+ gen_bounded m (fun x -> gen_num_real m (Transformation. Upper x)) e
112+ |> Expr.Helpers. row_vector
153113 | Transformation. LowerUpper
154- ( ({emeta = {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
155- , ({emeta = {type_= UVector | URowVector ; _}; _} as e2) )
114+ ( ({meta = {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
115+ , ({meta = {type_= UVector | URowVector ; _}; _} as e2) )
156116 | Transformation. LowerUpper
157- ( ({emeta= {type_= UVector | URowVector ; _}; _} as e1)
158- , ({emeta= {type_= UReal | UInt ; _ }; _ } as e2 ) ) ->
159- gen_ul_bounded (extract_var e1) (extract_var e2)
160- | _ ->
161- { expr= RowVectorExpr (repeat_th n (fun _ -> gen_real m t))
162- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UMatrix } }
117+ ( ({meta= {type_= UVector | URowVector ; _}; _} as e1)
118+ , ({meta= {type_= UReal | UInt ; _ }; _ } as e2 ) ) ->
119+ gen_ul_bounded m (gen_num_real m) e1 e2 |> Expr.Helpers. row_vector
120+ | _ -> Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t))
163121
164122let gen_vector m n t =
165123 let gen_ordered n =
@@ -173,38 +131,36 @@ let gen_vector m n t =
173131 let l = repeat_th n (fun _ -> Random. float 1. ) in
174132 let sum = List. fold l ~init: 0. ~f: (fun accum elt -> accum +. elt) in
175133 let l = List. map l ~f: (fun x -> x /. sum) in
176- wrap_vector ( List. map ~f: wrap_real l)
134+ Expr.Helpers. vector l
177135 | Ordered ->
178136 let l = gen_ordered n in
179137 let halfmax =
180138 Option. value_exn (List. max_elt l ~compare: compare_float) /. 2. in
181139 let l = List. map l ~f: (fun x -> (x -. halfmax) /. halfmax) in
182- wrap_vector ( List. map ~f: wrap_real l)
140+ Expr.Helpers. vector l
183141 | PositiveOrdered ->
184142 let l = gen_ordered n in
185143 let max = Option. value_exn (List. max_elt l ~compare: compare_float) in
186144 let l = List. map l ~f: (fun x -> x /. max) in
187- wrap_vector ( List. map ~f: wrap_real l)
145+ Expr.Helpers. vector l
188146 | UnitVector ->
189147 let l = repeat_th n (fun _ -> Random. float 1. ) in
190148 let sum =
191149 Float. sqrt
192150 (List. fold l ~init: 0. ~f: (fun accum elt -> accum +. (elt ** 2. )))
193151 in
194152 let l = List. map l ~f: (fun x -> x /. sum) in
195- wrap_vector (List. map ~f: wrap_real l)
196- | _ -> {int_two with expr= PostfixOp (gen_row_vector m n t, Transpose )}
153+ Expr.Helpers. vector l
154+ | _ ->
155+ let v = Expr.Helpers. unary_op Transpose (gen_row_vector m n t) in
156+ {v with meta= {v.meta with type_= UVector }}
197157
198158let gen_cov_unwrapped n =
199159 let l = repeat_th (n * n) (fun _ -> Random. float 2. ) in
200160 let l_mat = vect_to_mat l n in
201161 matprod l_mat (transpose l_mat)
202162
203- let wrap_real_mat m =
204- let mat_wrapped =
205- List. map ~f: wrap_row_vector
206- (List. map ~f: (fun x -> List. map ~f: wrap_real x) m) in
207- {int_two with expr= RowVectorExpr mat_wrapped}
163+ let wrap_real_mat m = Expr.Helpers. matrix m
208164
209165let gen_diag_mat l =
210166 let n = List. length l in
@@ -257,57 +213,70 @@ let gen_corr_matrix n =
257213 wrap_real_mat (matprod corr_chol (transpose corr_chol))
258214
259215let gen_matrix mm m n t =
260- match t with
261- | Transformation. Covariance -> gen_cov_matrix m
216+ match (t : Expr.Typed.t Transformation.t ) with
217+ | Covariance -> gen_cov_matrix m
262218 | Correlation -> gen_corr_matrix m
263219 | CholeskyCov -> gen_cov_cholesky m n
264220 | CholeskyCorr -> gen_corr_cholesky m
221+ | Lower ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
222+ Expr.Helpers. matrix_from_rows
223+ (gen_bounded mm (fun x -> gen_row_vector mm n (Lower x)) e)
224+ | Upper ({meta = {type_ = UMatrix ; _} ; _} as e ) ->
225+ Expr.Helpers. matrix_from_rows
226+ (gen_bounded mm (fun x -> gen_row_vector mm n (Upper x)) e)
227+ | LowerUpper (({meta= {type_= UMatrix ; _}; _} as e1), e2)
228+ | LowerUpper (e1 , ({meta = {type_ = UMatrix ; _} ; _} as e2 )) ->
229+ Expr.Helpers. matrix_from_rows
230+ (gen_ul_bounded mm (gen_row_vector mm n) e1 e2)
265231 | _ ->
266- { int_two with
267- expr= RowVectorExpr (repeat_th m (fun () -> gen_row_vector mm n t)) }
268-
269- (* TODO: do some proper random generation of these special matrices *)
270-
271- let gen_array elt n _ = {int_two with expr= ArrayExpr (repeat_th n elt)}
272-
273- let rec generate_value m st t =
232+ Expr.Helpers. matrix_from_rows
233+ (repeat_th m (fun () -> gen_row_vector mm n t))
234+
235+ let rec gen_array m st n t =
236+ let elt () = generate_value m st t in
237+ match (t : Expr.Typed.t Transformation.t ) with
238+ | Lower ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
239+ Expr.Helpers. array_expr
240+ (gen_bounded m (fun x -> generate_value m st (Lower x)) e)
241+ | Upper ({meta = {type_ = UArray _ ; _} ; _} as e ) ->
242+ Expr.Helpers. array_expr
243+ (gen_bounded m (fun x -> generate_value m st (Upper x)) e)
244+ | LowerUpper (({meta= {type_= UArray _; _}; _} as e1), e2)
245+ | LowerUpper (e1 , ({meta = {type_ = UArray _ ; _} ; _} as e2 )) ->
246+ Expr.Helpers. array_expr (gen_ul_bounded m (generate_value m st) e1 e2)
247+ | _ -> Expr.Helpers. array_expr (repeat_th n elt)
248+
249+ and generate_value m st t =
274250 match st with
275- | SizedType. SInt -> gen_int m t
276- | SReal -> gen_real m t
251+ | SizedType. SInt -> Expr.Helpers. int (gen_num_int m t)
252+ | SReal -> Expr.Helpers. float (gen_num_real m t)
277253 | SComplex ->
278254 (* when serialzied, a complex number looks just like a 2-array of reals *)
279- generate_value m (SArray (SReal , wrap_int 2 )) t
255+ generate_value m (SArray (SReal , Expr.Helpers. int 2 )) t
280256 | SVector (_ , e ) -> gen_vector m (unwrap_int_exn m e) t
281257 | SRowVector (_ , e ) -> gen_row_vector m (unwrap_int_exn m e) t
282258 | SMatrix (_ , e1 , e2 ) ->
283259 gen_matrix m (unwrap_int_exn m e1) (unwrap_int_exn m e2) t
284- | SArray (st , e ) ->
285- let element () = generate_value m st t in
286- gen_array element (unwrap_int_exn m e) t
260+ | SArray (st , e ) -> gen_array m st (unwrap_int_exn m e) t
287261
288262let rec pp_value_json ppf e =
289- match e.expr with
290- | PostfixOp (e , Transpose) -> pp_value_json ppf e
291- | IntNumeral s | RealNumeral s -> Fmt. string ppf s
292- | ArrayExpr l | RowVectorExpr l ->
263+ match e.Expr.Fixed. pattern with
264+ | Lit ((Int | Real ), s ) -> Fmt. string ppf s
265+ | FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray ), l ) ->
293266 Fmt. (pf ppf " [@[<hov 1>%a@]]" (list ~sep: comma pp_value_json) l)
267+ | FunApp (StanLib (transpose, _, _), [e])
268+ when String. equal transpose (Operator. to_string Transpose ) ->
269+ pp_value_json ppf e
294270 | _ ->
295271 Common.FatalError. fatal_error_msg
296- [% message " Could not evaluate expression " (e : typed_expression )]
272+ [% message " Could not evaluate expression " (e : Expr.Typed.t )]
297273
298274let print_data_prog s =
299- let data = Ast. get_stmts s.datablock in
300275 let l, _ =
301- List. fold data ~init: ([] , Map.Poly. empty) ~f: (fun (l , m ) decl ->
302- match decl.stmt with
303- | VarDecl
304- { decl_type= Sized sizedtype
305- ; transformation
306- ; identifier= {name; _}
307- ; _ } ->
308- let value = generate_value m sizedtype transformation in
309- ((name, value) :: l, Map. set m ~key: name ~data: value)
310- | _ -> (l, m) ) in
276+ List. fold s ~init: ([] , Map.Poly. empty)
277+ ~f: (fun (l , m ) (sizedtype , transformation , name ) ->
278+ let value = generate_value m sizedtype transformation in
279+ ((name, value) :: l, Map. set m ~key: name ~data: value) ) in
311280 let pp ppf (id , value ) =
312281 Fmt. pf ppf {|@ [< hov 2 > " %s" :@ % a@ ]| } id pp_value_json value in
313282 Fmt. (str " {@ @[<hov>%a@]@ }" (list ~sep: comma pp) (List. rev l))
0 commit comments