11open Core_kernel
22open Middle
3- open Frontend
4- open Ast
53
64let rec transpose = function
75 | [] :: _ -> []
@@ -30,15 +28,16 @@ let rec vect_to_mat l m =
3028 let hd, tl = List. split_n l m in
3129 hd :: vect_to_mat tl m
3230
33- let unwrap_num_exn m e =
34- let e = Ast_to_Mir. trans_expr e in
35- let m = Map.Poly. map m ~f: Ast_to_Mir. trans_expr in
31+ let eval_expr m e =
3632 let e = Mir_utils. subst_expr m e in
3733 let e = Partial_evaluator. eval_expr e in
3834 let rec strip_promotions (e : Middle.Expr.Typed.t ) =
3935 match e.pattern with Promotion (e , _ , _ ) -> strip_promotions e | _ -> e
4036 in
41- let e = strip_promotions e in
37+ strip_promotions e
38+
39+ let unwrap_num_exn m e =
40+ let e = eval_expr m e in
4241 match e.pattern with
4342 | Lit (_ , s ) -> Float. of_string s
4443 | _ ->
@@ -74,92 +73,56 @@ let rec repeat n e =
7473let rec repeat_th n f =
7574 match n with n when n < = 0 -> [] | m -> f () :: repeat_th (m - 1 ) f
7675
77- let wrap_int n =
78- { expr= IntNumeral (Int. to_string n)
79- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UInt } }
80-
81- let int_two = wrap_int 2
82-
83- let wrap_real r =
84- { expr= RealNumeral (Float. to_string r)
85- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UReal } }
86-
87- let wrap_row_vector l =
88- { expr= RowVectorExpr l
89- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= URowVector } }
90-
91- let wrap_vector l =
92- { expr= PostfixOp (wrap_row_vector l, Transpose )
93- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UVector } }
94-
95- let gen_int m t = wrap_int (gen_num_int m t)
96- let gen_real m t = wrap_real (gen_num_real m t)
97-
9876let gen_row_vector m n t =
99- let extract_var e =
100- match e with {expr = Variable x ; _} -> Map. find_exn m x.name | _ -> e in
10177 let gen_bounded t e =
102- match e with
103- | {expr= RowVectorExpr unpacked_e; _}
104- | {expr= ArrayExpr unpacked_e; _}
105- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e; _}, Transpose ); _} ->
106- wrap_row_vector (List. map ~f: (fun x -> gen_real m (t x)) unpacked_e)
107- | _ ->
78+ match Expr.Helpers. try_unpack e with
79+ | Some unpacked_e ->
80+ Expr.Helpers. row_vector
81+ (List. map ~f: (fun x -> gen_num_real m (t x)) unpacked_e)
82+ | None ->
10883 Common.FatalError. fatal_error_msg
10984 [% message
11085 " Bad bounded (upper OR lower) expr: "
111- (e : (typed_expr_meta, fun_kind ) expr_with )] in
86+ (e : Expr.Typed.Meta.t Expr.Fixed.t )] in
11287 let gen_ul_bounded e1 e2 =
11388 let create_bounds l u =
114- wrap_row_vector
89+ Expr.Helpers. row_vector
11590 (List. map2_exn
116- ~f: (fun x y -> gen_real m (Transformation. LowerUpper (x, y)))
91+ ~f: (fun x y -> gen_num_real m (Transformation. LowerUpper (x, y)))
11792 l u ) in
118- match (e1, e2) with
119- | ( ( {expr= RowVectorExpr unpacked_e1 | ArrayExpr unpacked_e1; _}
120- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose ); _}
121- )
122- , ( {expr= RowVectorExpr unpacked_e2 | ArrayExpr unpacked_e2; _}
123- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose ); _}
124- ) ) ->
125- (* | {expr= ArrayExpr unpacked_e1; _}, {expr= ArrayExpr unpacked_e2; _} -> *)
93+ match Expr.Helpers. (try_unpack e1, try_unpack e2) with
94+ | Some unpacked_e1 , Some unpacked_e2 ->
12695 create_bounds unpacked_e1 unpacked_e2
127- | ( ({expr= RealNumeral _; _} | {expr= IntNumeral _; _})
128- , ( {expr= RowVectorExpr unpacked_e2; _}
129- | {expr= ArrayExpr unpacked_e2; _}
130- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e2; _}, Transpose ); _}
131- ) ) ->
96+ | None , Some unpacked_e2 ->
13297 create_bounds
13398 (List. init (List. length unpacked_e2) ~f: (fun _ -> e1))
13499 unpacked_e2
135- | ( ( {expr= RowVectorExpr unpacked_e1; _}
136- | {expr= PostfixOp ({expr= RowVectorExpr unpacked_e1; _}, Transpose ); _}
137- | {expr= ArrayExpr unpacked_e1; _} )
138- , ({expr= RealNumeral _; _} | {expr = IntNumeral _ ; _} ) ) ->
100+ | Some unpacked_e1 , None ->
139101 create_bounds unpacked_e1
140102 (List. init (List. length unpacked_e1) ~f: (fun _ -> e2))
141103 | _ ->
142104 Common.FatalError. fatal_error_msg
143105 [% message
144106 " Bad bounded upper and lower expr: "
145- (e1 : (typed_expr_meta, fun_kind ) expr_with )
107+ (e1 : Expr.Typed.t )
146108 " and "
147- (e2 : (typed_expr_meta, fun_kind ) expr_with )] in
148- match t with
149- | Transformation. Lower ({emeta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
150- gen_bounded (fun x -> Transformation. Lower x) (extract_var e)
151- | Transformation. Upper ({emeta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
152- gen_bounded (fun x -> Transformation. Upper x) (extract_var e)
109+ (e2 : Expr.Typed.t )] in
110+ match (t : Expr.Typed.t Transformation.t ) with
111+ | Transformation. Lower ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
112+ gen_bounded (fun x -> Transformation. Lower x) (eval_expr m e)
113+ | Transformation. Upper ({meta = {type_ = UVector | URowVector ; _} ; _} as e ) ->
114+ gen_bounded (fun x -> Transformation. Upper x) (eval_expr m e)
153115 | Transformation. LowerUpper
154- ( ({emeta = {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
155- , ({emeta = {type_= UVector | URowVector ; _}; _} as e2) )
116+ ( ({meta = {type_= UVector | URowVector | UReal | UInt ; _}; _} as e1)
117+ , ({meta = {type_= UVector | URowVector ; _}; _} as e2) )
156118 | Transformation. LowerUpper
157- ( ({emeta = {type_= UVector | URowVector ; _}; _} as e1)
158- , ({emeta = {type_= UReal | UInt ; _ }; _ } as e2 ) ) ->
159- gen_ul_bounded (extract_var e1) (extract_var e2)
119+ ( ({meta = {type_= UVector | URowVector ; _}; _} as e1)
120+ , ({meta = {type_= UReal | UInt ; _ }; _ } as e2 ) ) ->
121+ gen_ul_bounded (eval_expr m e1) (eval_expr m e2)
160122 | _ ->
161- { expr= RowVectorExpr (repeat_th n (fun _ -> gen_real m t))
162- ; emeta= {loc= Location_span. empty; ad_level= DataOnly ; type_= UMatrix } }
123+ let e =
124+ Expr.Helpers. row_vector (repeat_th n (fun _ -> gen_num_real m t)) in
125+ {e with meta= {e.meta with type_= UMatrix }}
163126
164127let gen_vector m n t =
165128 let gen_ordered n =
@@ -173,38 +136,36 @@ let gen_vector m n t =
173136 let l = repeat_th n (fun _ -> Random. float 1. ) in
174137 let sum = List. fold l ~init: 0. ~f: (fun accum elt -> accum +. elt) in
175138 let l = List. map l ~f: (fun x -> x /. sum) in
176- wrap_vector ( List. map ~f: wrap_real l)
139+ Expr.Helpers. vector l
177140 | Ordered ->
178141 let l = gen_ordered n in
179142 let halfmax =
180143 Option. value_exn (List. max_elt l ~compare: compare_float) /. 2. in
181144 let l = List. map l ~f: (fun x -> (x -. halfmax) /. halfmax) in
182- wrap_vector ( List. map ~f: wrap_real l)
145+ Expr.Helpers. vector l
183146 | PositiveOrdered ->
184147 let l = gen_ordered n in
185148 let max = Option. value_exn (List. max_elt l ~compare: compare_float) in
186149 let l = List. map l ~f: (fun x -> x /. max) in
187- wrap_vector ( List. map ~f: wrap_real l)
150+ Expr.Helpers. vector l
188151 | UnitVector ->
189152 let l = repeat_th n (fun _ -> Random. float 1. ) in
190153 let sum =
191154 Float. sqrt
192155 (List. fold l ~init: 0. ~f: (fun accum elt -> accum +. (elt ** 2. )))
193156 in
194157 let l = List. map l ~f: (fun x -> x /. sum) in
195- wrap_vector (List. map ~f: wrap_real l)
196- | _ -> {int_two with expr= PostfixOp (gen_row_vector m n t, Transpose )}
158+ Expr.Helpers. vector l
159+ | _ ->
160+ let v = Expr.Helpers. unary_op Transpose (gen_row_vector m n t) in
161+ {v with meta= {v.meta with type_= UVector }}
197162
198163let gen_cov_unwrapped n =
199164 let l = repeat_th (n * n) (fun _ -> Random. float 2. ) in
200165 let l_mat = vect_to_mat l n in
201166 matprod l_mat (transpose l_mat)
202167
203- let wrap_real_mat m =
204- let mat_wrapped =
205- List. map ~f: wrap_row_vector
206- (List. map ~f: (fun x -> List. map ~f: wrap_real x) m) in
207- {int_two with expr= RowVectorExpr mat_wrapped}
168+ let wrap_real_mat m = Expr.Helpers. matrix m
208169
209170let gen_diag_mat l =
210171 let n = List. length l in
@@ -263,20 +224,18 @@ let gen_matrix mm m n t =
263224 | CholeskyCov -> gen_cov_cholesky m n
264225 | CholeskyCorr -> gen_corr_cholesky m
265226 | _ ->
266- { int_two with
267- expr= RowVectorExpr (repeat_th m (fun () -> gen_row_vector mm n t)) }
268-
269- (* TODO: do some proper random generation of these special matrices *)
227+ Expr.Helpers. matrix_from_rows
228+ (repeat_th m (fun () -> gen_row_vector mm n t))
270229
271- let gen_array elt n _ = {int_two with expr = ArrayExpr (repeat_th n elt)}
230+ let gen_array elt n _ = Expr.Helpers. array_expr (repeat_th n elt)
272231
273232let rec generate_value m st t =
274233 match st with
275- | SizedType. SInt -> gen_int m t
276- | SReal -> gen_real m t
234+ | SizedType. SInt -> Expr.Helpers. int (gen_num_int m t)
235+ | SReal -> Expr.Helpers. float (gen_num_real m t)
277236 | SComplex ->
278237 (* when serialzied, a complex number looks just like a 2-array of reals *)
279- generate_value m (SArray (SReal , wrap_int 2 )) t
238+ generate_value m (SArray (SReal , Expr.Helpers. int 2 )) t
280239 | SVector (_ , e ) -> gen_vector m (unwrap_int_exn m e) t
281240 | SRowVector (_ , e ) -> gen_row_vector m (unwrap_int_exn m e) t
282241 | SMatrix (_ , e1 , e2 ) ->
@@ -286,28 +245,23 @@ let rec generate_value m st t =
286245 gen_array element (unwrap_int_exn m e) t
287246
288247let rec pp_value_json ppf e =
289- match e.expr with
290- | PostfixOp (e , Transpose) -> pp_value_json ppf e
291- | IntNumeral s | RealNumeral s -> Fmt. string ppf s
292- | ArrayExpr l | RowVectorExpr l ->
248+ match e.Expr.Fixed. pattern with
249+ | Lit ((Int | Real ), s ) -> Fmt. string ppf s
250+ | FunApp (CompilerInternal (FnMakeRowVec | FnMakeArray ), l ) ->
293251 Fmt. (pf ppf " [@[<hov 1>%a@]]" (list ~sep: comma pp_value_json) l)
252+ | FunApp (StanLib (transpose, _, _), [e])
253+ when String. equal transpose (Operator. to_string Transpose ) ->
254+ pp_value_json ppf e
294255 | _ ->
295256 Common.FatalError. fatal_error_msg
296- [% message " Could not evaluate expression " (e : typed_expression )]
257+ [% message " Could not evaluate expression " (e : Expr.Typed.t )]
297258
298259let print_data_prog s =
299- let data = Ast. get_stmts s.datablock in
300260 let l, _ =
301- List. fold data ~init: ([] , Map.Poly. empty) ~f: (fun (l , m ) decl ->
302- match decl.stmt with
303- | VarDecl
304- { decl_type= Sized sizedtype
305- ; transformation
306- ; identifier= {name; _}
307- ; _ } ->
308- let value = generate_value m sizedtype transformation in
309- ((name, value) :: l, Map. set m ~key: name ~data: value)
310- | _ -> (l, m) ) in
261+ List. fold s ~init: ([] , Map.Poly. empty)
262+ ~f: (fun (l , m ) (sizedtype , transformation , name ) ->
263+ let value = generate_value m sizedtype transformation in
264+ ((name, value) :: l, Map. set m ~key: name ~data: value) ) in
311265 let pp ppf (id , value ) =
312266 Fmt. pf ppf {|@ [< hov 2 > " %s" :@ % a@ ]| } id pp_value_json value in
313267 Fmt. (str " {@ @[<hov>%a@]@ }" (list ~sep: comma pp) (List. rev l))
0 commit comments