Skip to content

Commit 918c2fd

Browse files
committed
use Expr.Typed.t and Stmt.Located.t type aliases
1 parent 07d3bba commit 918c2fd

11 files changed

Lines changed: 43 additions & 71 deletions

File tree

src/analysis_and_optimization/Debug_data_generation.ml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ let gen_bounded m gen e =
7878
| Some unpacked_e -> List.map ~f:gen unpacked_e
7979
| None ->
8080
Common.FatalError.fatal_error_msg
81-
[%message
82-
"Bad bounded (upper OR lower) expr: "
83-
(e : Expr.Typed.Meta.t Expr.Fixed.t)]
81+
[%message "Bad bounded (upper OR lower) expr: " (e : Expr.Typed.t)]
8482

8583
let gen_ul_bounded m gen e1 e2 =
8684
let create_bounds l u =

src/analysis_and_optimization/Dependence_analysis.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,7 @@ let mir_uninitialized_variables (mir : Program.Typed.t) :
203203
(Set.Poly.union arg_vars globals)
204204
fdbody ) ) ) ]
205205

206-
let build_dep_info_map (mir : Program.Typed.t)
207-
(stmt : (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t) :
206+
let build_dep_info_map (mir : Program.Typed.t) (stmt : Stmt.Located.t) :
208207
( label
209208
, (Expr.Typed.t, label) Stmt.Fixed.Pattern.t * node_dep_info )
210209
Map.Poly.t =

src/analysis_and_optimization/Dependence_analysis.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ val node_vars_dependencies :
6868

6969
val build_dep_info_map :
7070
Program.Typed.t
71-
-> (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t
71+
-> Stmt.Located.t
7272
-> ( label
7373
, (Expr.Typed.t, label) Stmt.Fixed.Pattern.t * node_dep_info )
7474
Map.Poly.t

src/analysis_and_optimization/Mem_pattern.ml

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
open Core_kernel
22
open Core_kernel.Poly
33
open Middle
4-
open Middle.Expr
54

65
(**
76
* Return a Var expression of the name for each type
@@ -26,8 +25,7 @@ let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta
2625
* Return a set of all types containing autodiffable Eigen matrices
2726
* in an expression.
2827
*)
29-
let query_var_eigen_names (expr : Typed.Meta.t Expr.Fixed.t) : string Set.Poly.t
30-
=
28+
let query_var_eigen_names (expr : Expr.Typed.t) : string Set.Poly.t =
3129
let get_expr_eigen_names
3230
(Dataflow_types.VVar s, Expr.Typed.Meta.{adlevel; type_; _}) =
3331
if
@@ -52,7 +50,7 @@ let is_nonzero_subset ~set ~subset =
5250
*)
5351
let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
5452
match pattern with
55-
| Expr.Fixed.Pattern.FunApp (_, (exprs : Typed.Meta.t Expr.Fixed.t list)) ->
53+
| Expr.Fixed.Pattern.FunApp (_, (exprs : Expr.Typed.t list)) ->
5654
List.fold_left ~init:acc ~f:count_single_idx_exprs exprs
5755
| TernaryIf (predicate, texpr, fexpr) ->
5856
acc
@@ -79,8 +77,7 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int =
7977
* for a Single index. All and Between cannot be Single cell access
8078
* and so pass acc along.
8179
*)
82-
and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
83-
=
80+
and count_single_idx (acc : int) (idx : Expr.Typed.t Index.t) =
8481
match idx with
8582
| Index.All | Between _ | Upfrom _ | MultiIndex _ -> acc
8683
| Single _ -> acc + 1
@@ -96,7 +93,7 @@ and count_single_idx (acc : int) (idx : Expr.Typed.Meta.t Expr.Fixed.t Index.t)
9693
* either at the top level or within the `Index` types of the list.
9794
*)
9895
let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t)
99-
(index : Typed.Meta.t Expr.Fixed.t Index.t list) =
96+
(index : Expr.Typed.t Index.t list) =
10097
match in_loop with
10198
| false -> false
10299
| true -> (
@@ -155,7 +152,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
155152
let query_expr (accum : string Set.Poly.t) =
156153
query_initial_demotable_expr in_loop ~acc:accum in
157154
match pattern with
158-
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
155+
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
159156
query_initial_demotable_funs in_loop acc kind exprs
160157
| Indexed ((Expr.Fixed.{meta= {type_; _}; _} as expr), indexed) ->
161158
let index_set =
@@ -202,8 +199,7 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
202199
* exprs The expression list passed to the functions.
203200
*)
204201
and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
205-
(kind : 'a Fun_kind.t) (exprs : Typed.Meta.t Expr.Fixed.t list) :
206-
string Set.Poly.t =
202+
(kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t =
207203
let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in
208204
let top_level_eigen_names =
209205
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
@@ -236,10 +232,10 @@ let rec is_any_soa_supported_expr
236232
then true
237233
else
238234
match pattern with
239-
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
235+
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
240236
is_any_soa_supported_fun_expr kind exprs
241-
| Indexed (expr, (_ : Typed.Meta.t Fixed.t Index.t list))
242-
|Promotion (expr, _, _) ->
237+
| Indexed (expr, (_ : Expr.Typed.t Index.t list)) | Promotion (expr, _, _)
238+
->
243239
is_any_soa_supported_expr expr
244240
| Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) ->
245241
true
@@ -252,7 +248,7 @@ let rec is_any_soa_supported_expr
252248
* Return false if the `Fun_kind.t` does not support `SoA`
253249
*)
254250
and is_any_soa_supported_fun_expr (kind : 'a Fun_kind.t)
255-
(exprs : Typed.Meta.t Expr.Fixed.t list) : bool =
251+
(exprs : Expr.Typed.t list) : bool =
256252
match kind with
257253
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> false
258254
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false
@@ -273,7 +269,7 @@ let rec is_any_ad_real_data_matrix_expr
273269
if UnsizedType.is_dataonlytype adlevel then false
274270
else
275271
match pattern with
276-
| FunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
272+
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
277273
is_any_ad_real_data_matrix_expr_fun kind exprs
278274
| Indexed (expr, _) | Promotion (expr, _, _) ->
279275
is_any_ad_real_data_matrix_expr expr
@@ -291,7 +287,7 @@ let rec is_any_ad_real_data_matrix_expr
291287
* combinations of AutoDiffable Reals and Data Matrices
292288
*)
293289
and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
294-
(exprs : Typed.Meta.t Expr.Fixed.t list) : bool =
290+
(exprs : Expr.Typed.t list) : bool =
295291
match kind with
296292
| Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> (
297293
match name with
@@ -354,9 +350,7 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t)
354350
* `query_initial_demotable_expr` for an explanation of the logic.
355351
*)
356352
let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
357-
(Stmt.Fixed.{pattern; _} :
358-
(Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t ) :
359-
string Set.Poly.t =
353+
(Stmt.Fixed.{pattern; _} : Stmt.Located.t) : string Set.Poly.t =
360354
let query_expr (accum : string Set.Poly.t) =
361355
query_initial_demotable_expr in_loop ~acc:accum in
362356
match pattern with
@@ -446,7 +440,7 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
446440
* @param pattern The Stmt pattern to query.
447441
*)
448442
let query_demotable_stmt (aos_exits : string Set.Poly.t)
449-
(pattern : (Typed.t, int) Stmt.Fixed.Pattern.t) : string Set.Poly.t =
443+
(pattern : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) : string Set.Poly.t =
450444
match pattern with
451445
| Stmt.Fixed.Pattern.Assignment
452446
( ( (assign_name : string)
@@ -480,7 +474,7 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
480474
**)
481475
let rec modify_kind ?force_demotion:(force = false)
482476
(modifiable_set : string Set.Poly.t) (kind : 'a Fun_kind.t)
483-
(exprs : Expr.Typed.Meta.t Expr.Fixed.t list) =
477+
(exprs : Expr.Typed.t list) =
484478
let expr_names =
485479
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
486480
let is_all_in_list =
@@ -518,12 +512,11 @@ let rec modify_kind ?force_demotion:(force = false)
518512
*)
519513
and modify_expr_pattern ?force_demotion:(force = false)
520514
(modifiable_set : string Set.Poly.t)
521-
(pattern : Expr.Typed.Meta.t Expr.Fixed.t Expr.Fixed.Pattern.t) =
515+
(pattern : Expr.Typed.t Expr.Fixed.Pattern.t) =
522516
let mod_expr ?force_demotion:(forced = false) =
523517
modify_expr ~force_demotion:forced modifiable_set in
524518
match pattern with
525-
| Expr.Fixed.Pattern.FunApp
526-
(kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
519+
| Expr.Fixed.Pattern.FunApp (kind, (exprs : Expr.Typed.t list)) ->
527520
let kind', expr' =
528521
modify_kind ~force_demotion:force modifiable_set kind exprs in
529522
Expr.Fixed.Pattern.FunApp (kind', expr')
@@ -578,10 +571,8 @@ and modify_expr ?force_demotion:(force = false)
578571
* @param modifiable_set The name of the variable we are searching for.
579572
*)
580573
let rec modify_stmt_pattern
581-
(pattern :
582-
( Expr.Typed.Meta.t Expr.Fixed.t
583-
, (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t )
584-
Stmt.Fixed.Pattern.t ) (modifiable_set : string Core_kernel.Set.Poly.t) =
574+
(pattern : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t)
575+
(modifiable_set : string Core_kernel.Set.Poly.t) =
585576
let mod_expr force = modify_expr ~force_demotion:force modifiable_set in
586577
let mod_stmt stmt = modify_stmt stmt modifiable_set in
587578
match pattern with
@@ -597,7 +588,7 @@ let rec modify_stmt_pattern
597588
{ decl with
598589
decl_type=
599590
Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type) }
600-
| NRFunApp (kind, (exprs : Expr.Typed.Meta.t Expr.Fixed.t list)) ->
591+
| NRFunApp (kind, (exprs : Expr.Typed.t list)) ->
601592
let kind', exprs' = modify_kind modifiable_set kind exprs in
602593
NRFunApp (kind', exprs')
603594
| Assignment

src/analysis_and_optimization/Mir_utils.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ let rec fold_expr ~take_expr ~(init : 'c) (expr : Expr.Typed.t) : 'c =
1212

1313
let fold_stmts ~take_expr ~take_stmt ~(init : 'c) (stmts : Stmt.Located.t List.t)
1414
: 'c =
15-
(* let rec fold_expr (state : 'c) (expr : Expr.Typed.Meta.t Expr.Fixed.t) =
15+
(* let rec fold_expr (state : 'c) (expr : Expr.t) =
1616
* Expr.Fixed.Pattern.fold_left
1717
* ~f:(fun a e -> fold_expr (take_expr a e) e)
1818
* ~init:state
@@ -313,7 +313,7 @@ let rec fn_subst_expr m e =
313313
match m e with
314314
| Some e' ->
315315
(* let print_expr (e:Expr.Typed.t) = *)
316-
(* [%sexp (e.pattern : Expr.Typed.Meta.t Expr.Fixed.t Expr.Fixed.Pattern.t)] |> Sexp.to_string *)
316+
(* [%sexp (e.pattern : Expr.Typed.t Expr.Fixed.Pattern.t)] |> Sexp.to_string *)
317317
(* in *)
318318
(* let _ = print_endline ("Replaced expr: " ^ print_expr e ^ " -> " ^ print_expr e') in *)
319319
e'

src/analysis_and_optimization/Mir_utils.mli

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,10 @@ val parameter_names_set :
2828
?include_transformed:bool -> Program.Typed.t -> string Set.Poly.t
2929

3030
val fold_expr :
31-
take_expr:('c -> Expr.Typed.Meta.t Expr.Fixed.t -> 'c)
32-
-> init:'c
33-
-> Expr.Typed.t
34-
-> 'c
31+
take_expr:('c -> Expr.Typed.t -> 'c) -> init:'c -> Expr.Typed.t -> 'c
3532

3633
val fold_stmts :
37-
take_expr:('c -> Expr.Typed.Meta.t Expr.Fixed.t -> 'c)
34+
take_expr:('c -> Expr.Typed.t -> 'c)
3835
-> take_stmt:('c -> Stmt.Located.t -> 'c)
3936
-> init:'c
4037
-> Stmt.Located.t List.t
@@ -137,8 +134,7 @@ val index_var_set :
137134
For use in RHS sets, not LHS assignment sets, except in a target term
138135
*)
139136

140-
val expr_var_names_set :
141-
Expr.Typed.Meta.t Expr.Fixed.t -> string Core_kernel.Set.Poly.t
137+
val expr_var_names_set : Expr.Typed.t -> string Core_kernel.Set.Poly.t
142138
(**
143139
Return the names of the variables in an expression.
144140
*)

src/analysis_and_optimization/Optimize.ml

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,7 @@ let rec find_assignment_idx (name : string) Stmt.Fixed.{pattern; _} =
765765
* in their first assignment and mark them as not needing to be
766766
* initialized.
767767
*)
768-
and unenforce_initialize
769-
(lst : (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t list) =
768+
and unenforce_initialize (lst : Stmt.Located.t list) =
770769
let rec unenforce_initialize_patt (Stmt.Fixed.{pattern; _} as stmt) sub_lst =
771770
match pattern with
772771
| Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl_pat) -> (
@@ -817,9 +816,8 @@ and unenforce_initialize
817816
* Stmts.
818817
*)
819818
let transform_mir_blocks (mir : (Expr.Typed.t, Stmt.Located.t) Program.t)
820-
(transformer :
821-
(Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t list
822-
-> Stmt.Located.t list ) : (Expr.Typed.t, Stmt.Located.t) Program.t =
819+
(transformer : Stmt.Located.t list -> Stmt.Located.t list) :
820+
(Expr.Typed.t, Stmt.Located.t) Program.t =
823821
let transformed_functions =
824822
List.map mir.functions_block ~f:(fun fs ->
825823
let new_body =
@@ -1049,11 +1047,11 @@ let optimize_minimal_variables
10491047
-> string Set.Poly.t )
10501048
~(update_expr : string Set.Poly.t -> Expr.Typed.t -> Expr.Typed.t)
10511049
~(update_stmt :
1052-
( Expr.Typed.Meta.t Expr.Fixed.t
1050+
( Expr.Typed.t
10531051
, (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t )
10541052
Stmt.Fixed.Pattern.t
10551053
-> string Core_kernel.Set.Poly.t
1056-
-> ( Expr.Typed.Meta.t Expr.Fixed.t
1054+
-> ( Expr.Typed.t
10571055
, (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t )
10581056
Stmt.Fixed.Pattern.t )
10591057
~(extra_variables : string -> string Set.Poly.t)
@@ -1169,10 +1167,7 @@ let optimize_soa (mir : Program.Typed.t) =
11691167
stmt ~extra_variables:(fun _ -> initial_variables) in
11701168
let transform' s =
11711169
match transform {pattern= SList s; meta= Location_span.empty} with
1172-
| { pattern=
1173-
SList (l : (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t list)
1174-
; _ } ->
1175-
l
1170+
| {pattern= SList (l : Stmt.Located.t list); _} -> l
11761171
| _ ->
11771172
raise
11781173
(Failure "Something went wrong with program transformation packing!")

src/analysis_and_optimization/Pedantic_analysis.ml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,7 @@ let expr_collect_exprs (expr : Expr.Typed.t) ~f : 'a Set.Poly.t =
206206
match f expr with Some a -> Set.Poly.add s a | _ -> s in
207207
fold_expr ~init:Set.Poly.empty ~take_expr:(fun s e -> collect_expr s e) expr
208208

209-
let stmts_collect_exprs
210-
(stmts : (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t List.t) ~f :
211-
'a Set.Poly.t =
209+
let stmts_collect_exprs (stmts : Stmt.Located.t List.t) ~f : 'a Set.Poly.t =
212210
let collect_expr s (expr : Expr.Typed.t) =
213211
match f expr with Some a -> Set.Poly.add s a | _ -> s in
214212
fold_stmts ~init:Set.Poly.empty

src/frontend/Ast_to_Mir.mli

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ open Middle
33

44
val trans_data :
55
Ast.typed_program
6-
-> ( Expr.Typed.Meta.t Expr.Fixed.t SizedType.t
7-
* Expr.Typed.Meta.t Expr.Fixed.t Transformation.t
8-
* string )
9-
list
6+
-> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list
107

118
val trans_prog : string -> Ast.typed_program -> Program.Typed.t

src/middle/Expr.ml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,7 @@ module Helpers = struct
267267
| _ ->
268268
(* These should go away with Ryan's LHS *)
269269
Common.FatalError.fatal_error_msg
270-
[%message
271-
"Expected Var or Indexed but found " (e : Typed.Meta.t Fixed.t)]
272-
in
270+
[%message "Expected Var or Indexed but found " (e : Typed.t)] in
273271
Fixed.{meta; pattern}
274272

275273
(** TODO: Make me tail recursive *)

0 commit comments

Comments
 (0)