11open Core_kernel
22open Core_kernel.Poly
33open Middle
4- open Middle.Expr
54
65(* *
76 * Return a Var expression of the name for each type
@@ -26,8 +25,7 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
2625 * Return a set of all types containing autodiffable Eigen matrices
2726 * in an expression.
2827 *)
29- let query_var_eigen_names (expr : Typed.Meta.t Expr.Fixed.t ) : string Set.Poly. t
30- =
28+ let query_var_eigen_names (expr : Expr.Typed.t ) : string Set.Poly.t =
3129 let get_expr_eigen_names
3230 (Dataflow_types. VVar s , Expr.Typed.Meta. {adlevel; type_; _} ) =
3331 if
@@ -52,7 +50,7 @@ let is_nonzero_subset ~set ~subset =
5250 *)
5351let rec count_single_idx_exprs (acc : int ) Expr.Fixed. {pattern; _} : int =
5452 match pattern with
55- | Expr.Fixed.Pattern. FunApp (_ , (exprs : Typed.Meta.t Expr.Fixed .t list )) ->
53+ | Expr.Fixed.Pattern. FunApp (_ , (exprs : Expr.Typed .t list )) ->
5654 List. fold_left ~init: acc ~f: count_single_idx_exprs exprs
5755 | TernaryIf (predicate , texpr , fexpr ) ->
5856 acc
@@ -79,8 +77,7 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
7977 * for a Single index. All and Between cannot be Single cell access
8078 * and so pass acc along.
8179 *)
82- and count_single_idx (acc : int ) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t )
83- =
80+ and count_single_idx (acc : int ) (idx : Expr.Typed.t Index.t ) =
8481 match idx with
8582 | Index. All | Between _ | Upfrom _ | MultiIndex _ -> acc
8683 | Single _ -> acc + 1
@@ -96,7 +93,7 @@ and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
9693 * either at the top level or within the `Index` types of the list.
9794 *)
9895let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t )
99- (index : Typed.Meta.t Expr.Fixed .t Index.t list ) =
96+ (index : Expr.Typed .t Index.t list ) =
10097 match in_loop with
10198 | false -> false
10299 | true -> (
@@ -155,7 +152,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
155152 let query_expr (accum : string Set.Poly.t ) =
156153 query_initial_demotable_expr in_loop ~acc: accum in
157154 match pattern with
158- | FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed. t list )) ->
155+ | FunApp (kind , (exprs : Expr.Typed.t list )) ->
159156 query_initial_demotable_funs in_loop acc kind exprs
160157 | Indexed ((Expr.Fixed. {meta = {type_; _} ; _} as expr ), indexed ) ->
161158 let index_set =
@@ -202,8 +199,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
202199 * exprs The expression list passed to the functions.
203200 *)
204201and query_initial_demotable_funs (in_loop : bool ) (acc : string Set.Poly.t )
205- (kind : 'a Fun_kind.t ) (exprs : Typed.Meta.t Expr.Fixed.t list ) :
206- string Set.Poly. t =
202+ (kind : 'a Fun_kind.t ) (exprs : Expr.Typed.t list ) : string Set.Poly.t =
207203 let query_expr accum = query_initial_demotable_expr in_loop ~acc: accum in
208204 let top_level_eigen_names =
209205 Set.Poly. union_list (List. map ~f: query_var_eigen_names exprs) in
@@ -236,10 +232,10 @@ let rec is_any_soa_supported_expr
236232 then true
237233 else
238234 match pattern with
239- | FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed. t list )) ->
235+ | FunApp (kind , (exprs : Expr.Typed.t list )) ->
240236 is_any_soa_supported_fun_expr kind exprs
241- | Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list ))
242- | Promotion ( expr , _ , _ ) ->
237+ | Indexed (expr, (_ : Expr. Typed.t Index.t list )) | Promotion (expr, _, _ )
238+ ->
243239 is_any_soa_supported_expr expr
244240 | Var (_ : string ) | Lit ((_ : Expr.Fixed.Pattern.litType ), (_ : string )) ->
245241 true
@@ -252,7 +248,7 @@ let rec is_any_soa_supported_expr
252248 * Return false if the `Fun_kind.t` does not support `SoA`
253249 *)
254250and is_any_soa_supported_fun_expr (kind : 'a Fun_kind.t )
255- (exprs : Typed.Meta.t Expr.Fixed .t list ) : bool =
251+ (exprs : Expr.Typed .t list ) : bool =
256252 match kind with
257253 | CompilerInternal (Internal_fun. FnMakeArray | FnMakeRowVec ) -> false
258254 | UserDefined ((_ : string ), (_ : bool Fun_kind.suffix )) -> false
@@ -273,7 +269,7 @@ let rec is_any_ad_real_data_matrix_expr
273269 if UnsizedType. is_dataonlytype adlevel then false
274270 else
275271 match pattern with
276- | FunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed. t list )) ->
272+ | FunApp (kind , (exprs : Expr.Typed.t list )) ->
277273 is_any_ad_real_data_matrix_expr_fun kind exprs
278274 | Indexed (expr , _ ) | Promotion (expr , _ , _ ) ->
279275 is_any_ad_real_data_matrix_expr expr
@@ -291,7 +287,7 @@ let rec is_any_ad_real_data_matrix_expr
291287 * combinations of AutoDiffable Reals and Data Matrices
292288 *)
293289and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t )
294- (exprs : Typed.Meta.t Expr.Fixed .t list ) : bool =
290+ (exprs : Expr.Typed .t list ) : bool =
295291 match kind with
296292 | Fun_kind. StanLib (name , (_ : bool Fun_kind.suffix ), _ ) -> (
297293 match name with
@@ -354,9 +350,7 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
354350 * `query_initial_demotable_expr` for an explanation of the logic.
355351 *)
356352let rec query_initial_demotable_stmt (in_loop : bool ) (acc : string Set.Poly.t )
357- (Stmt.Fixed. {pattern; _} :
358- (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t ) :
359- string Set.Poly. t =
353+ (Stmt.Fixed. {pattern; _} : Stmt.Located.t ) : string Set.Poly.t =
360354 let query_expr (accum : string Set.Poly.t ) =
361355 query_initial_demotable_expr in_loop ~acc: accum in
362356 match pattern with
@@ -446,7 +440,7 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
446440 * @param pattern The Stmt pattern to query.
447441 *)
448442let query_demotable_stmt (aos_exits : string Set.Poly.t )
449- (pattern : (Typed.t, int) Stmt.Fixed.Pattern.t ) : string Set.Poly.t =
443+ (pattern : (Expr. Typed.t, int) Stmt.Fixed.Pattern.t ) : string Set.Poly.t =
450444 match pattern with
451445 | Stmt.Fixed.Pattern. Assignment
452446 ( ( (assign_name : string )
@@ -480,7 +474,7 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
480474 **)
481475let rec modify_kind ?force_demotion :(force = false )
482476 (modifiable_set : string Set.Poly.t ) (kind : 'a Fun_kind.t )
483- (exprs : Expr.Typed.Meta.t Expr.Fixed. t list ) =
477+ (exprs : Expr.Typed.t list ) =
484478 let expr_names =
485479 Set.Poly. union_list (List. map ~f: query_var_eigen_names exprs) in
486480 let is_all_in_list =
@@ -518,12 +512,11 @@ let rec modify_kind ?force_demotion:(force = false)
518512 *)
519513and modify_expr_pattern ?force_demotion :(force = false )
520514 (modifiable_set : string Set.Poly.t )
521- (pattern : Expr.Typed.Meta.t Expr.Fixed. t Expr.Fixed.Pattern.t ) =
515+ (pattern : Expr.Typed.t Expr.Fixed.Pattern.t ) =
522516 let mod_expr ?force_demotion :(forced = false ) =
523517 modify_expr ~force_demotion: forced modifiable_set in
524518 match pattern with
525- | Expr.Fixed.Pattern. FunApp
526- (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list )) ->
519+ | Expr.Fixed.Pattern. FunApp (kind , (exprs : Expr.Typed.t list )) ->
527520 let kind', expr' =
528521 modify_kind ~force_demotion: force modifiable_set kind exprs in
529522 Expr.Fixed.Pattern. FunApp (kind', expr')
@@ -578,10 +571,8 @@ and modify_expr ?force_demotion:(force = false)
578571* @param modifiable_set The name of the variable we are searching for.
579572*)
580573let rec modify_stmt_pattern
581- (pattern :
582- ( Expr.Typed.Meta.t Expr.Fixed.t
583- , (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t )
584- Stmt.Fixed.Pattern. t ) (modifiable_set : string Core_kernel.Set.Poly.t ) =
574+ (pattern : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t )
575+ (modifiable_set : string Core_kernel.Set.Poly.t ) =
585576 let mod_expr force = modify_expr ~force_demotion: force modifiable_set in
586577 let mod_stmt stmt = modify_stmt stmt modifiable_set in
587578 match pattern with
@@ -597,7 +588,7 @@ let rec modify_stmt_pattern
597588 { decl with
598589 decl_type=
599590 Type. Sized (SizedType. modify_sizedtype_mem SoA sized_type) }
600- | NRFunApp (kind , (exprs : Expr.Typed.Meta.t Expr.Fixed. t list )) ->
591+ | NRFunApp (kind , (exprs : Expr.Typed.t list )) ->
601592 let kind', exprs' = modify_kind modifiable_set kind exprs in
602593 NRFunApp (kind', exprs')
603594 | Assignment
0 commit comments