diff --git a/docs/core_ideas.mld b/docs/core_ideas.mld index 111cd474c0..b05bfec243 100644 --- a/docs/core_ideas.mld +++ b/docs/core_ideas.mld @@ -65,6 +65,31 @@ This takes some getting used to, and also can lead to some unhelpful type signat VSCode, because abbreviations are not always used in hover-over text. For example, [Expr.Typed.t], the MIR's typed expression type, actually has a signature of [Expr.Typed.Meta.t Expr.Fixed.t]. +{1 The [Library] interface and functors} + +Many modules of stanc are modeled as OCaml {{:https://ocaml.org/learn/tutorials/functors.html}functors}, +which take in another module as input and produce a module as output. For the most part, +these functors expect an instance of the [Library] interface defined in +[src/frontend/Std_library_utils.ml]. + +This module primarily contains signatures for the Stan standard library. For most users, +you can assume this will be filled in with [src/stan_math_backend/Stan_math_library.ml], +the object representing the {{:https://github.com/stan-dev/math}stan-dev/math} C++ library. + +Usages of these functors are rather simple, e.g. in the core stanc driver the line + +{[ +module Typechecker = Typechecking.Make (Stan_math_library) +]} + +defines a module [Typechecker] by supplying the functor [Typechecking.Make] with +the Stan C++ library module. After this, [Typechecker.check_program] will typecheck +an AST against those specific functions. + +As noted in the above tutorial link, the syntax of functors is often the hardest part +of using and understanding them. The functors which accept [Library] are all relatively +simple, and should serve as good examples to beginners with the concept. + {1 The [Fmt] library and pretty-printing} We extensively use the {{:https://erratique.ch/software/fmt}Fmt} library for our pretty-printing and code diff --git a/docs/exposing_new_functions.mld b/docs/exposing_new_functions.mld index 7194ccc03c..b7905a989d 100644 --- a/docs/exposing_new_functions.mld +++ b/docs/exposing_new_functions.mld @@ -7,7 +7,7 @@ For a function to be built into Stan, it has to be included in the Stan Math library and its signature has to be exposed to the compiler. -To do the latter, we have to add a corresponding line in [src/middle/Stan_math_signatures.ml]. +To do the latter, we have to add a corresponding line in [src/stan_math_backend/Stan_math_library.ml]. The compiler uses the signatures defined there to do type checking. @@ -127,15 +127,14 @@ For example, the following line defines the signature [add(real, matrix) => matr {1 Higher-Order Variadic functions} -Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length -list of arguments, are {b NOT} added to this list. -"Nice" variadic functions are added to the hashtable [Stan_math_signatures.stan_math_variadic_signatures]. +"Nice" variadic functions are added to the hashtable [StdLibrary.variadic_signatures]. This is probably sufficient for most variadic functions, e.g. all the ODE solvers and DAE solvers are done via this method. [reduce_sum] is not "nice", since it is both variadic and {e polymorphic}, requiring certain arguments to have the same (but {e not predetermined}) type. Therefore, [reduce_sum] is treated as special case in the [Typechecker] -module in the frontend folder. +module in the frontend folder. These are instead handled by special functions like [is_special_function_name]. They +must also be given custom typechecking rules in the private sub-module [Special_typechecking]. Note that higher-order functions also usually require changes to the C++ code generation to work properly. It is best to consult an existing example of how these are done before proceeding. diff --git a/src/analysis_and_optimization/Debug_data_generation.ml b/src/analysis_and_optimization/Debug_data_generation.ml index ac83672212..e80a1016b2 100644 --- a/src/analysis_and_optimization/Debug_data_generation.ml +++ b/src/analysis_and_optimization/Debug_data_generation.ml @@ -1,6 +1,9 @@ open Core_kernel open Middle +module Partial_evaluator = + Partial_evaluation.Make (Frontend.Std_library_utils.NullLibrary) + let rec transpose = function | [] :: _ -> [] | rows -> @@ -8,7 +11,7 @@ let rec transpose = function let tl = List.map ~f:List.tl_exn rows in hd :: transpose tl -let reject loc msg = raise (Partial_evaluator.Rejected (loc, msg)) +let reject loc msg = raise (Partial_evaluation.Rejected (loc, msg)) let dotproduct xs ys = List.fold2_exn xs ys ~init:0. ~f:(fun accum x y -> accum +. (x *. y)) @@ -32,7 +35,7 @@ let rec vect_to_mat l m = let eval_expr m e = let e = Mir_utils.subst_expr m e in - let e = Partial_evaluator.eval_expr e in + let e = Partial_evaluator.try_eval_expr e in let rec strip_promotions (e : Middle.Expr.Typed.t) = match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e in @@ -361,4 +364,4 @@ let gen_values_json_exn ?(new_only = false) ?(context = Map.Poly.empty) decls = let gen_values_json ?(new_only = false) ?(context = Map.Poly.empty) decls = try Ok (gen_values_json_exn ~new_only ~context decls) - with Partial_evaluator.Rejected (loc, msg) -> Error (loc, msg) + with Partial_evaluation.Rejected (loc, msg) -> Error (loc, msg) diff --git a/src/analysis_and_optimization/Dependence_analysis.ml b/src/analysis_and_optimization/Dependence_analysis.ml index dd1907009a..6ec3b862b7 100644 --- a/src/analysis_and_optimization/Dependence_analysis.ml +++ b/src/analysis_and_optimization/Dependence_analysis.ml @@ -4,7 +4,7 @@ open Middle open Dataflow_types open Mir_utils open Dataflow_utils -open Monotone_framework_sigs +open Monotone_framework_intf open Monotone_framework (***********************************) @@ -119,9 +119,8 @@ let mir_reaching_definitions (mir : Program.Typed.t) (stmt : Stmt.Located.t) : Map.Poly.map rd_map ~f:(fun {entry; exit} -> {entry= to_rd_set entry; exit= to_rd_set exit} ) -let all_labels - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - : int Set.Poly.t = +let all_labels (module Flowgraph : FLOWGRAPH with type labels = int) : + int Set.Poly.t = let step set = Set.Poly.union set (union_map set ~f:(fun l -> Map.Poly.find_exn Flowgraph.successors l)) diff --git a/src/analysis_and_optimization/Memory_patterns.ml b/src/analysis_and_optimization/Memory_patterns.ml index 9af186e0cc..58604be6c9 100644 --- a/src/analysis_and_optimization/Memory_patterns.ml +++ b/src/analysis_and_optimization/Memory_patterns.ml @@ -2,48 +2,49 @@ open Core_kernel open Core_kernel.Poly open Middle -(** +module Make (StdLibrary : Frontend.Std_library_utils.Library) = struct + (** Return a Var expression of the name for each type containing an eigen matrix *) -let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta} - = - let union_recur exprs = Set.Poly.union_list (List.map exprs ~f:matrix_set) in - if UnsizedType.contains_eigen_type type_ then - match pattern with - | Var s -> Set.Poly.singleton (Dataflow_types.VVar s, meta) - | Lit _ -> Set.Poly.empty - | FunApp (_, exprs) -> - if UnsizedType.contains_eigen_type type_ then union_recur exprs - else Set.Poly.empty - | TernaryIf (_, expr2, expr3) -> union_recur [expr2; expr3] - | Indexed (expr, _) | Promotion (expr, _, _) -> matrix_set expr - | EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2] - else Set.Poly.empty + let rec matrix_set + Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta} = + let union_recur exprs = Set.Poly.union_list (List.map exprs ~f:matrix_set) in + if UnsizedType.contains_eigen_type type_ then + match pattern with + | Var s -> Set.Poly.singleton (Dataflow_types.VVar s, meta) + | Lit _ -> Set.Poly.empty + | FunApp (_, exprs) -> + if UnsizedType.contains_eigen_type type_ then union_recur exprs + else Set.Poly.empty + | TernaryIf (_, expr2, expr3) -> union_recur [expr2; expr3] + | Indexed (expr, _) | Promotion (expr, _, _) -> matrix_set expr + | EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2] + else Set.Poly.empty -(** + (** Return a set of all types containing autodiffable Eigen matrices in an expression. *) -let query_var_eigen_names (expr : Expr.Typed.t) : string Set.Poly.t = - let get_expr_eigen_names - (Dataflow_types.VVar s, Expr.Typed.Meta.{adlevel; type_; _}) = - if - UnsizedType.contains_eigen_type type_ - && UnsizedType.is_autodifftype adlevel - then Some s - else None in - Set.Poly.filter_map ~f:get_expr_eigen_names (matrix_set expr) + let query_var_eigen_names (expr : Expr.Typed.t) : string Set.Poly.t = + let get_expr_eigen_names + (Dataflow_types.VVar s, Expr.Typed.Meta.{adlevel; type_; _}) = + if + UnsizedType.contains_eigen_type type_ + && UnsizedType.is_autodifftype adlevel + then Some s + else None in + Set.Poly.filter_map ~f:get_expr_eigen_names (matrix_set expr) -(** + (** Check whether one set is a nonzero subset of another set. *) -let is_nonzero_subset ~set ~subset = - Set.Poly.is_subset subset ~of_:set - && (not (Set.Poly.is_empty set)) - && not (Set.Poly.is_empty subset) + let is_nonzero_subset ~set ~subset = + Set.Poly.is_subset subset ~of_:set + && (not (Set.Poly.is_empty set)) + && not (Set.Poly.is_empty subset) -(** + (** Check an Index to count how many times we see a single index. @param acc An accumulator from previous folds of multiple expressions. @param idx An Index to match. For Single types this adds 1 to the @@ -51,12 +52,12 @@ let is_nonzero_subset ~set ~subset = for a Single index. All and Between cannot be Single cell access and so pass acc along. *) -and count_single_idx (acc : int) (idx : Expr.Typed.t Index.t) = - match idx with - | Index.All | Between _ | Upfrom _ | MultiIndex _ -> acc - | Single _ -> acc + 1 + and count_single_idx (acc : int) (idx : Expr.Typed.t Index.t) = + match idx with + | Index.All | Between _ | Upfrom _ | MultiIndex _ -> acc + | Single _ -> acc + 1 -(** + (** Find indices on Matrix and Vector types that perform single cell access. Returns true if it finds a vector, row vector, matrix, or matrix with single cell access @@ -66,51 +67,50 @@ and count_single_idx (acc : int) (idx : Expr.Typed.t Index.t) = @param index This list is checked for Single cell access either at the top level or within the [Index] types of the list. *) -let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t) - (index : Expr.Typed.t Index.t list) = - match in_loop with - | false -> false - | true -> ( - let contains_single_idx = - List.fold_left ~init:0 ~f:count_single_idx index in - match (ut, index) with - | (UnsizedType.UVector | URowVector), _ when contains_single_idx > 0 -> - true - | UMatrix, _ when contains_single_idx > 1 -> true - | (UArray t | UFun (_, ReturnType t, _, _)), index -> ( - match List.tl index with - | Some cut_list -> is_uni_eigen_loop_indexing in_loop t cut_list - | None -> false ) - | _ -> false ) + let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t) + (index : Expr.Typed.t Index.t list) = + match in_loop with + | false -> false + | true -> ( + let contains_single_idx = + List.fold_left ~init:0 ~f:count_single_idx index in + match (ut, index) with + | (UnsizedType.UVector | URowVector), _ when contains_single_idx > 0 -> + true + | UMatrix, _ when contains_single_idx > 1 -> true + | (UArray t | UFun (_, ReturnType t, _, _)), index -> ( + match List.tl index with + | Some cut_list -> is_uni_eigen_loop_indexing in_loop t cut_list + | None -> false ) + | _ -> false ) -let query_stan_math_mem_pattern_support (name : string) - (args : (UnsizedType.autodifftype * UnsizedType.t) list) = - let open Stan_math_signatures in - match name with - | x when is_stan_math_variadic_function_name x -> false - | x when is_reduce_sum_fn x -> false - | _ -> - let name = - string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name) - in - let namematches = Hashtbl.find_multi stan_math_signatures name in - let filteredmatches = - List.filter - ~f:(fun x -> - Frontend.SignatureMismatch.check_compatible_arguments_mod_conv - (snd3 x) args - |> Result.is_ok ) - namematches in - let is_soa ((_ : UnsizedType.returntype), _, mem) = - mem = Mem_pattern.SoA in - List.exists ~f:is_soa filteredmatches + let query_stan_math_mem_pattern_support (name : string) + (args : (UnsizedType.autodifftype * UnsizedType.t) list) = + match name with + | x when StdLibrary.is_variadic_function_name x -> false + | x when StdLibrary.is_special_function_name x -> false + | _ -> + let name = + StdLibrary.string_operator_to_function_name + (Utils.stdlib_distribution_name name) in + let namematches = StdLibrary.get_signatures name in + let filteredmatches = + List.filter + ~f:(fun x -> + Frontend.SignatureMismatch.check_compatible_arguments_mod_conv + (snd3 x) args + |> Result.is_ok ) + namematches in + let is_soa ((_ : UnsizedType.returntype), _, mem) = + mem = Mem_pattern.SoA in + List.exists ~f:is_soa filteredmatches -(*Validate whether a function can support SoA matrices*) -let is_fun_soa_supported name exprs = - let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in - query_stan_math_mem_pattern_support name fun_args + (*Validate whether a function can support SoA matrices*) + let is_fun_soa_supported name exprs = + let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in + query_stan_math_mem_pattern_support name fun_args -(** + (** Query to find the initial set of objects that cannot be SoA. This is mostly recursing over expressions, with the exceptions being functions and indexing expressions. For the logic on functions @@ -120,41 +120,43 @@ let is_fun_soa_supported name exprs = will be returned if the matrix or vector is accessed by single cell indexing. *) -let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t) - Expr.Fixed.{pattern; _} : string Set.Poly.t = - let query_expr (accum : string Set.Poly.t) = - query_initial_demotable_expr in_loop ~acc:accum in - match pattern with - | FunApp (kind, (exprs : Expr.Typed.t list)) -> - query_initial_demotable_funs in_loop acc kind exprs - | Indexed ((Expr.Fixed.{meta= {type_; _}; _} as expr), indexed) -> - let index_set = - Set.Poly.union_list - (List.map - ~f: - (Index.apply ~default:Set.Poly.empty ~merge:Set.Poly.union - (query_expr acc) ) - indexed ) in - let index_demotes = - if is_uni_eigen_loop_indexing in_loop type_ indexed then - Set.Poly.union (query_var_eigen_names expr) index_set - else Set.Poly.union (query_expr acc expr) index_set in - Set.Poly.union acc index_demotes - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - acc - | Promotion (expr, _, _) -> query_expr acc expr - | TernaryIf (predicate, texpr, fexpr) -> - let predicate_demotes = query_expr acc predicate in - Set.Poly.union - (Set.Poly.union predicate_demotes (query_var_eigen_names texpr)) - (query_var_eigen_names fexpr) - | EAnd (lhs, rhs) | EOr (lhs, rhs) -> - (*We need to get the demotes from both sides*) - let full_lhs_rhs = - Set.Poly.union (query_expr acc lhs) (query_expr acc rhs) in - Set.Poly.union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs) + let rec query_initial_demotable_expr (in_loop : bool) + ~(acc : string Set.Poly.t) Expr.Fixed.{pattern; _} : string Set.Poly.t = + let query_expr (accum : string Set.Poly.t) = + query_initial_demotable_expr in_loop ~acc:accum in + match pattern with + | FunApp (kind, (exprs : Expr.Typed.t list)) -> + query_initial_demotable_funs in_loop acc kind exprs + | Indexed ((Expr.Fixed.{meta= {type_; _}; _} as expr), indexed) -> + let index_set = + Set.Poly.union_list + (List.map + ~f: + (Index.apply ~default:Set.Poly.empty ~merge:Set.Poly.union + (query_expr acc) ) + indexed ) in + let index_demotes = + if is_uni_eigen_loop_indexing in_loop type_ indexed then + Set.Poly.union (query_var_eigen_names expr) index_set + else Set.Poly.union (query_expr acc expr) index_set in + Set.Poly.union acc index_demotes + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> + acc + | Promotion (expr, _, _) -> query_expr acc expr + | TernaryIf (predicate, texpr, fexpr) -> + let predicate_demotes = query_expr acc predicate in + Set.Poly.union + (Set.Poly.union predicate_demotes (query_var_eigen_names texpr)) + (query_var_eigen_names fexpr) + | EAnd (lhs, rhs) | EOr (lhs, rhs) -> + (*We need to get the demotes from both sides*) + let full_lhs_rhs = + Set.Poly.union (query_expr acc lhs) (query_expr acc rhs) in + Set.Poly.union + (query_expr full_lhs_rhs lhs) + (query_expr full_lhs_rhs rhs) -(** + (** Query a function to detect if it or any of its used expression's objects or expressions should be demoted to AoS. * @@ -171,134 +173,136 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t) to the UDF. exprs The expression list passed to the functions. *) -and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t) - (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t = - let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in - let top_level_eigen_names = - Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in - let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in - let demoted_and_top_level_names = - Set.Poly.union demoted_eigen_names top_level_eigen_names in - match kind with - | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( - match name with - | "check_matching_dims" -> acc - | name -> ( - match is_fun_soa_supported name exprs with - | true -> Set.Poly.union acc demoted_eigen_names - | false -> Set.Poly.union acc demoted_and_top_level_names ) ) - | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> - Set.Poly.union acc demoted_and_top_level_names - | CompilerInternal (_ : 'a Internal_fun.t) -> acc - | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> - Set.Poly.union acc demoted_and_top_level_names + and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t) + (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t = + let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in + let top_level_eigen_names = + Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in + let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in + let demoted_and_top_level_names = + Set.Poly.union demoted_eigen_names top_level_eigen_names in + match kind with + | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( + match name with + | "check_matching_dims" -> acc + | name -> ( + match is_fun_soa_supported name exprs with + | true -> Set.Poly.union acc demoted_eigen_names + | false -> Set.Poly.union acc demoted_and_top_level_names ) ) + | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> + Set.Poly.union acc demoted_and_top_level_names + | CompilerInternal (_ : 'a Internal_fun.t) -> acc + | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> + Set.Poly.union acc demoted_and_top_level_names -(** + (** Check whether any functions in the right hand side expression of an assignment support SoA. If so then return true, otherwise return false. *) -let rec is_any_soa_supported_expr - Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; type_; _}} : bool = - if - UnsizedType.is_dataonlytype adlevel - || not (UnsizedType.contains_eigen_type type_) - then true - else - match pattern with - | FunApp (kind, (exprs : Expr.Typed.t list)) -> - is_any_soa_supported_fun_expr kind exprs - | Indexed (expr, (_ : Expr.Typed.t Index.t list)) | Promotion (expr, _, _) - -> - is_any_soa_supported_expr expr - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - true - | TernaryIf (_, texpr, fexpr) -> - is_any_soa_supported_expr texpr && is_any_soa_supported_expr fexpr - | EAnd (lhs, rhs) | EOr (lhs, rhs) -> - is_any_soa_supported_expr lhs && is_any_soa_supported_expr rhs + let rec is_any_soa_supported_expr + Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; type_; _}} : bool = + if + UnsizedType.is_dataonlytype adlevel + || not (UnsizedType.contains_eigen_type type_) + then true + else + match pattern with + | FunApp (kind, (exprs : Expr.Typed.t list)) -> + is_any_soa_supported_fun_expr kind exprs + | Indexed (expr, (_ : Expr.Typed.t Index.t list)) | Promotion (expr, _, _) + -> + is_any_soa_supported_expr expr + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) + -> + true + | TernaryIf (_, texpr, fexpr) -> + is_any_soa_supported_expr texpr && is_any_soa_supported_expr fexpr + | EAnd (lhs, rhs) | EOr (lhs, rhs) -> + is_any_soa_supported_expr lhs && is_any_soa_supported_expr rhs -(** + (** Return false if the [Fun_kind.t] does not support [SoA] *) -and is_any_soa_supported_fun_expr (kind : 'a Fun_kind.t) - (exprs : Expr.Typed.t list) : bool = - match kind with - | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> false - | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false - | CompilerInternal (_ : 'a Internal_fun.t) -> true - | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( - match name with - | "check_matching_dims" -> true - | _ -> - is_fun_soa_supported name exprs - && List.exists ~f:is_any_soa_supported_expr exprs ) + and is_any_soa_supported_fun_expr (kind : 'a Fun_kind.t) + (exprs : Expr.Typed.t list) : bool = + match kind with + | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> false + | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false + | CompilerInternal (_ : 'a Internal_fun.t) -> true + | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( + match name with + | "check_matching_dims" -> true + | _ -> + is_fun_soa_supported name exprs + && List.exists ~f:is_any_soa_supported_expr exprs ) -(** + (** Return true if the rhs expression of an assignment contains only combinations of AutoDiffable Reals and Data Matrices *) -let rec is_any_ad_real_data_matrix_expr - Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; _}} : bool = - if UnsizedType.is_dataonlytype adlevel then false - else - match pattern with - | FunApp (kind, (exprs : Expr.Typed.t list)) -> - is_any_ad_real_data_matrix_expr_fun kind exprs - | Indexed (expr, _) | Promotion (expr, _, _) -> - is_any_ad_real_data_matrix_expr expr - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - false - | TernaryIf (_, texpr, fexpr) -> - is_any_ad_real_data_matrix_expr texpr - || is_any_ad_real_data_matrix_expr fexpr - | EAnd (lhs, rhs) | EOr (lhs, rhs) -> - is_any_ad_real_data_matrix_expr lhs - && is_any_ad_real_data_matrix_expr rhs + let rec is_any_ad_real_data_matrix_expr + Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; _}} : bool = + if UnsizedType.is_dataonlytype adlevel then false + else + match pattern with + | FunApp (kind, (exprs : Expr.Typed.t list)) -> + is_any_ad_real_data_matrix_expr_fun kind exprs + | Indexed (expr, _) | Promotion (expr, _, _) -> + is_any_ad_real_data_matrix_expr expr + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) + -> + false + | TernaryIf (_, texpr, fexpr) -> + is_any_ad_real_data_matrix_expr texpr + || is_any_ad_real_data_matrix_expr fexpr + | EAnd (lhs, rhs) | EOr (lhs, rhs) -> + is_any_ad_real_data_matrix_expr lhs + && is_any_ad_real_data_matrix_expr rhs -(** + (** Return true if the expressions in a function call are all combinations of AutoDiffable Reals and Data Matrices *) -and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t) - (exprs : Expr.Typed.t list) : bool = - match kind with - | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( - match name with - | "check_matching_dims" -> false - | _ -> ( - let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in - (*Right now we can't handle AD real and data matrix funcs - that return a matrix :-/*) - let is_args_autodiff_real_data_matrix = - (*If there are any autodiffable vars*) - List.exists - ~f:(fun (x, y) -> - match (x, y) with - | UnsizedType.AutoDiffable, UnsizedType.UReal -> true - | _ -> false ) - fun_args - (*And there are any data matrices*) - && List.exists - ~f:(fun (x, y) -> - match (x, UnsizedType.is_container y) with - | UnsizedType.DataOnly, true -> true - | _ -> false ) - fun_args - (*And there are no Autodiffable matrices*) - && List.exists - ~f:(fun (x, y) -> - match (x, UnsizedType.contains_eigen_type y) with - | UnsizedType.AutoDiffable, true -> false - | _ -> true ) - fun_args in - match is_args_autodiff_real_data_matrix with - | true -> true - | false -> List.exists ~f:is_any_ad_real_data_matrix_expr exprs ) ) - | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> true - | CompilerInternal (_ : 'a Internal_fun.t) -> false - | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false + and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t) + (exprs : Expr.Typed.t list) : bool = + match kind with + | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( + match name with + | "check_matching_dims" -> false + | _ -> ( + let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in + (*Right now we can't handle AD real and data matrix funcs + that return a matrix :-/*) + let is_args_autodiff_real_data_matrix = + (*If there are any autodiffable vars*) + List.exists + ~f:(fun (x, y) -> + match (x, y) with + | UnsizedType.AutoDiffable, UnsizedType.UReal -> true + | _ -> false ) + fun_args + (*And there are any data matrices*) + && List.exists + ~f:(fun (x, y) -> + match (x, UnsizedType.is_container y) with + | UnsizedType.DataOnly, true -> true + | _ -> false ) + fun_args + (*And there are no Autodiffable matrices*) + && List.exists + ~f:(fun (x, y) -> + match (x, UnsizedType.contains_eigen_type y) with + | UnsizedType.AutoDiffable, true -> false + | _ -> true ) + fun_args in + match is_args_autodiff_real_data_matrix with + | true -> true + | false -> List.exists ~f:is_any_ad_real_data_matrix_expr exprs ) ) + | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> true + | CompilerInternal (_ : 'a Internal_fun.t) -> false + | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false -(** + (** Query to find the initial set of objects in statements that cannot be SoA. This is mostly recursive over expressions and statements, with the exception of functions and Assignments. @@ -322,97 +326,100 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t) @param in_loop A boolean to specify the logic of indexing expressions. See [query_initial_demotable_expr] for an explanation of the logic. *) -let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t) - (Stmt.Fixed.{pattern; _} : Stmt.Located.t) : string Set.Poly.t = - let query_expr (accum : string Set.Poly.t) = - query_initial_demotable_expr in_loop ~acc:accum in - match pattern with - | Stmt.Fixed.Pattern.Assignment - ( ((name : string), (ut : UnsizedType.t), idx) - , (Expr.Fixed.{meta= Expr.Typed.Meta.{type_; adlevel; _}; _} as rhs) ) -> - let idx_list = - List.fold ~init:acc - ~f:(fun accum x -> - Index.folder accum - (fun acc -> query_initial_demotable_expr in_loop ~acc) - x ) - idx in - let idx_demotable = - (* RHS (2)*) - match is_uni_eigen_loop_indexing in_loop ut idx with - | true -> Set.Poly.add idx_list name - | false -> idx_list in - let rhs_demotable_names = query_expr acc rhs in - (* RHS (3)*) - let check_if_rhs_ad_real_data_matrix_expr = - match (UnsizedType.contains_eigen_type type_, adlevel) with - | true, UnsizedType.AutoDiffable -> - is_any_ad_real_data_matrix_expr rhs - || not (is_any_soa_supported_expr rhs) - | _ -> false in - (* RHS (1)*) - let is_all_rhs_aos = - let all_rhs_eigen_names = query_var_eigen_names rhs in - is_nonzero_subset ~subset:all_rhs_eigen_names ~set:rhs_demotable_names - in - let is_not_supported_func = - match rhs.pattern with - | FunApp (CompilerInternal _, _) -> false - | FunApp (UserDefined _, _) -> true - | _ -> false in - let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in - let assign_demotes = - if - is_eigen_stmt - && ( is_all_rhs_aos || check_if_rhs_ad_real_data_matrix_expr - || is_not_supported_func ) - then - let base_set = Set.Poly.union idx_demotable rhs_demotable_names in - Set.Poly.add - (Set.Poly.union base_set (query_var_eigen_names rhs)) - name - else Set.Poly.union idx_demotable rhs_demotable_names in - Set.Poly.union acc assign_demotes - | NRFunApp (kind, exprs) -> - query_initial_demotable_funs in_loop acc kind exprs - | IfElse (predicate, true_stmt, op_false_stmt) -> - let predicate_acc = query_expr acc predicate in - Set.Poly.union acc - (Set.Poly.union_list - [ predicate_acc - ; query_initial_demotable_stmt in_loop predicate_acc true_stmt - ; Option.value_map - ~f:(query_initial_demotable_stmt in_loop predicate_acc) - ~default:Set.Poly.empty op_false_stmt ] ) - | Return optional_expr -> - Option.value_map ~f:(query_expr acc) ~default:Set.Poly.empty optional_expr - | SList lst | Profile (_, lst) | Block lst -> - Set.Poly.union_list - (List.map ~f:(query_initial_demotable_stmt in_loop acc) lst) - | TargetPE expr -> query_expr acc expr - (* NOTE: loops generated by inlining are not actually loops; - we do not unconditionally set "in_loop" *) - | For - { lower= Expr.Fixed.{pattern= Lit (Int, lb); _} - ; upper= Expr.Fixed.{pattern= Lit (Int, ub); _} - ; body - ; _ } - when lb = "1" && ub = "1" -> - query_initial_demotable_stmt in_loop acc body - | For {lower; upper; body; _} -> - Set.Poly.union - (Set.Poly.union (query_expr acc lower) (query_expr acc upper)) - (query_initial_demotable_stmt true acc body) - | While (predicate, body) -> - Set.Poly.union_list - [ acc; query_expr acc predicate - ; query_initial_demotable_stmt true acc body ] - | Decl {decl_type= Type.Sized st; decl_id; _} - when SizedType.is_complex_type st -> - Set.Poly.add acc decl_id - | Skip | Break | Continue | Decl _ -> acc + let rec query_initial_demotable_stmt (in_loop : bool) + (acc : string Set.Poly.t) (Stmt.Fixed.{pattern; _} : Stmt.Located.t) : + string Set.Poly.t = + let query_expr (accum : string Set.Poly.t) = + query_initial_demotable_expr in_loop ~acc:accum in + match pattern with + | Stmt.Fixed.Pattern.Assignment + ( ((name : string), (ut : UnsizedType.t), idx) + , (Expr.Fixed.{meta= Expr.Typed.Meta.{type_; adlevel; _}; _} as rhs) ) + -> + let idx_list = + List.fold ~init:acc + ~f:(fun accum x -> + Index.folder accum + (fun acc -> query_initial_demotable_expr in_loop ~acc) + x ) + idx in + let idx_demotable = + (* RHS (2)*) + match is_uni_eigen_loop_indexing in_loop ut idx with + | true -> Set.Poly.add idx_list name + | false -> idx_list in + let rhs_demotable_names = query_expr acc rhs in + (* RHS (3)*) + let check_if_rhs_ad_real_data_matrix_expr = + match (UnsizedType.contains_eigen_type type_, adlevel) with + | true, UnsizedType.AutoDiffable -> + is_any_ad_real_data_matrix_expr rhs + || not (is_any_soa_supported_expr rhs) + | _ -> false in + (* RHS (1)*) + let is_all_rhs_aos = + let all_rhs_eigen_names = query_var_eigen_names rhs in + is_nonzero_subset ~subset:all_rhs_eigen_names ~set:rhs_demotable_names + in + let is_not_supported_func = + match rhs.pattern with + | FunApp (CompilerInternal _, _) -> false + | FunApp (UserDefined _, _) -> true + | _ -> false in + let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in + let assign_demotes = + if + is_eigen_stmt + && ( is_all_rhs_aos || check_if_rhs_ad_real_data_matrix_expr + || is_not_supported_func ) + then + let base_set = Set.Poly.union idx_demotable rhs_demotable_names in + Set.Poly.add + (Set.Poly.union base_set (query_var_eigen_names rhs)) + name + else Set.Poly.union idx_demotable rhs_demotable_names in + Set.Poly.union acc assign_demotes + | NRFunApp (kind, exprs) -> + query_initial_demotable_funs in_loop acc kind exprs + | IfElse (predicate, true_stmt, op_false_stmt) -> + let predicate_acc = query_expr acc predicate in + Set.Poly.union acc + (Set.Poly.union_list + [ predicate_acc + ; query_initial_demotable_stmt in_loop predicate_acc true_stmt + ; Option.value_map + ~f:(query_initial_demotable_stmt in_loop predicate_acc) + ~default:Set.Poly.empty op_false_stmt ] ) + | Return optional_expr -> + Option.value_map ~f:(query_expr acc) ~default:Set.Poly.empty + optional_expr + | SList lst | Profile (_, lst) | Block lst -> + Set.Poly.union_list + (List.map ~f:(query_initial_demotable_stmt in_loop acc) lst) + | TargetPE expr -> query_expr acc expr + (* NOTE: loops generated by inlining are not actually loops; + we do not unconditionally set "in_loop" *) + | For + { lower= Expr.Fixed.{pattern= Lit (Int, lb); _} + ; upper= Expr.Fixed.{pattern= Lit (Int, ub); _} + ; body + ; _ } + when lb = "1" && ub = "1" -> + query_initial_demotable_stmt in_loop acc body + | For {lower; upper; body; _} -> + Set.Poly.union + (Set.Poly.union (query_expr acc lower) (query_expr acc upper)) + (query_initial_demotable_stmt true acc body) + | While (predicate, body) -> + Set.Poly.union_list + [ acc; query_expr acc predicate + ; query_initial_demotable_stmt true acc body ] + | Decl {decl_type= Type.Sized st; decl_id; _} + when SizedType.is_complex_type st -> + Set.Poly.add acc decl_id + | Skip | Break | Continue | Decl _ -> acc -(** Look through a statement to see whether the objects used in it need to be + (** Look through a statement to see whether the objects used in it need to be modified from SoA to AoS. Returns the set of object names that need demoted in a statement, if any. This function looks at Assignment statements, and returns back the @@ -424,25 +431,27 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t) @param aos_exits A set of variables that can be demoted. @param pattern The Stmt pattern to query. *) -let query_demotable_stmt (aos_exits : string Set.Poly.t) - (pattern : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) : string Set.Poly.t = - match pattern with - | Stmt.Fixed.Pattern.Assignment - ( ( (assign_name : string) - , (_ : UnsizedType.t) - , (_ : Expr.Typed.t Index.t list) ) - , (rhs : Expr.Typed.t) ) -> ( - let all_rhs_eigen_names = query_var_eigen_names rhs in - if Set.Poly.mem aos_exits assign_name then - Set.Poly.add all_rhs_eigen_names assign_name - else - match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with - | true -> Set.Poly.add all_rhs_eigen_names assign_name - | false -> Set.Poly.empty ) - (* All other statements do not need logic here*) - | _ -> Set.Poly.empty + let query_demotable_stmt (aos_exits : string Set.Poly.t) + (pattern : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) : string Set.Poly.t = + match pattern with + | Stmt.Fixed.Pattern.Assignment + ( ( (assign_name : string) + , (_ : UnsizedType.t) + , (_ : Expr.Typed.t Index.t list) ) + , (rhs : Expr.Typed.t) ) -> ( + let all_rhs_eigen_names = query_var_eigen_names rhs in + if Set.Poly.mem aos_exits assign_name then + Set.Poly.add all_rhs_eigen_names assign_name + else + match + is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names + with + | true -> Set.Poly.add all_rhs_eigen_names assign_name + | false -> Set.Poly.empty ) + (* All other statements do not need logic here*) + | _ -> Set.Poly.empty -(** + (** Modify a function and it's subexpressions from SoA <-> AoS and vice versa. This performs demotion for sub expressions recursively. The top level expression and it's sub expressions are demoted to SoA if @@ -457,31 +466,36 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t) @param kind A [Fun_kind.t] @param exprs A list of expressions going into the function. **) -let rec modify_kind ?force_demotion:(force = false) - (modifiable_set : string Set.Poly.t) (kind : 'a Fun_kind.t) - (exprs : Expr.Typed.t list) = - let expr_names = - Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in - let is_all_in_list = - is_nonzero_subset ~set:modifiable_set ~subset:expr_names in - match kind with - | Fun_kind.StanLib (name, sfx, (_ : Mem_pattern.t)) -> - if is_all_in_list || (not (is_fun_soa_supported name exprs)) || force then - (*Force demotion of all subexprs*) - let exprs' = - List.map ~f:(modify_expr ~force_demotion:true expr_names) exprs in - (Fun_kind.StanLib (name, sfx, Mem_pattern.AoS), exprs') - else - ( Fun_kind.StanLib (name, sfx, SoA) + let rec modify_kind ?force_demotion:(force = false) + (modifiable_set : string Set.Poly.t) (kind : 'a Fun_kind.t) + (exprs : Expr.Typed.t list) = + let expr_names = + Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in + let is_all_in_list = + is_nonzero_subset ~set:modifiable_set ~subset:expr_names in + match kind with + | Fun_kind.StanLib (name, sfx, (_ : Mem_pattern.t)) -> + if is_all_in_list || (not (is_fun_soa_supported name exprs)) || force + then + (*Force demotion of all subexprs*) + let exprs' = + List.map ~f:(modify_expr ~force_demotion:true expr_names) exprs + in + (Fun_kind.StanLib (name, sfx, Mem_pattern.AoS), exprs') + else + ( Fun_kind.StanLib (name, sfx, SoA) + , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs + ) + | UserDefined _ as udf -> + ( udf + , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs + ) + | (_ : 'a Fun_kind.t) -> + ( kind , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs ) - | UserDefined _ as udf -> - (udf, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs) - | (_ : 'a Fun_kind.t) -> - ( kind - , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs ) -(** + (** Modify an expression and it's subexpressions from SoA <-> AoS and vice versa. The only real paths in the below is on the functions and ternary expressions. @@ -495,42 +509,42 @@ let rec modify_kind ?force_demotion:(force = false) associated expressions we want to modify. @param pattern The expression to modify. *) -and modify_expr_pattern ?force_demotion:(force = false) - (modifiable_set : string Set.Poly.t) - (pattern : Expr.Typed.t Expr.Fixed.Pattern.t) = - let mod_expr ?force_demotion:(forced = false) = - modify_expr ~force_demotion:forced modifiable_set in - match pattern with - | Expr.Fixed.Pattern.FunApp (kind, (exprs : Expr.Typed.t list)) -> - let kind', expr' = - modify_kind ~force_demotion:force modifiable_set kind exprs in - Expr.Fixed.Pattern.FunApp (kind', expr') - | TernaryIf (predicate, texpr, fexpr) -> - let is_eigen_return = - UnsizedType.contains_eigen_type fexpr.meta.type_ - || UnsizedType.contains_eigen_type texpr.meta.type_ in - if is_eigen_return then - TernaryIf - ( mod_expr ~force_demotion:force predicate - , mod_expr ~force_demotion:true texpr - , mod_expr ~force_demotion:true fexpr ) - else - TernaryIf - ( mod_expr ~force_demotion:force predicate - , mod_expr ~force_demotion:force texpr - , mod_expr ~force_demotion:force fexpr ) - | Indexed (idx_expr, indexed) -> - Indexed - ( mod_expr idx_expr - , List.map ~f:(Index.map (mod_expr ~force_demotion:force)) indexed ) - | EAnd (lhs, rhs) -> EAnd (mod_expr lhs, mod_expr rhs) - | EOr (lhs, rhs) -> EOr (mod_expr lhs, mod_expr rhs) - | Promotion (expr, type_, ad_level) -> - Promotion (mod_expr expr, type_, ad_level) - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - pattern + and modify_expr_pattern ?force_demotion:(force = false) + (modifiable_set : string Set.Poly.t) + (pattern : Expr.Typed.t Expr.Fixed.Pattern.t) = + let mod_expr ?force_demotion:(forced = false) = + modify_expr ~force_demotion:forced modifiable_set in + match pattern with + | Expr.Fixed.Pattern.FunApp (kind, (exprs : Expr.Typed.t list)) -> + let kind', expr' = + modify_kind ~force_demotion:force modifiable_set kind exprs in + Expr.Fixed.Pattern.FunApp (kind', expr') + | TernaryIf (predicate, texpr, fexpr) -> + let is_eigen_return = + UnsizedType.contains_eigen_type fexpr.meta.type_ + || UnsizedType.contains_eigen_type texpr.meta.type_ in + if is_eigen_return then + TernaryIf + ( mod_expr ~force_demotion:force predicate + , mod_expr ~force_demotion:true texpr + , mod_expr ~force_demotion:true fexpr ) + else + TernaryIf + ( mod_expr ~force_demotion:force predicate + , mod_expr ~force_demotion:force texpr + , mod_expr ~force_demotion:force fexpr ) + | Indexed (idx_expr, indexed) -> + Indexed + ( mod_expr idx_expr + , List.map ~f:(Index.map (mod_expr ~force_demotion:force)) indexed ) + | EAnd (lhs, rhs) -> EAnd (mod_expr lhs, mod_expr rhs) + | EOr (lhs, rhs) -> EOr (mod_expr lhs, mod_expr rhs) + | Promotion (expr, type_, ad_level) -> + Promotion (mod_expr expr, type_, ad_level) + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> + pattern -(** + (** Given a Set of strings containing the names of objects that can be modified from AoS <-> SoA and vice versa, modify them within the expression. @param mem_pattern The memory pattern to change expressions to. @@ -538,12 +552,13 @@ and modify_expr_pattern ?force_demotion:(force = false) associated expressions we want to modify. @param expr the expression to modify. *) -and modify_expr ?force_demotion:(force = false) - (modifiable_set : string Set.Poly.t) (Expr.Fixed.{pattern; _} as expr) = - { expr with - pattern= modify_expr_pattern ~force_demotion:force modifiable_set pattern } + and modify_expr ?force_demotion:(force = false) + (modifiable_set : string Set.Poly.t) (Expr.Fixed.{pattern; _} as expr) = + { expr with + pattern= modify_expr_pattern ~force_demotion:force modifiable_set pattern + } -(** + (** Modify statement patterns in the MIR from AoS <-> SoA and vice versa For [Decl] and [Assignment]'s reading in parameters, we demote to AoS if the [decl_id] (or assign name) is in the modifiable set and @@ -555,81 +570,82 @@ and modify_expr ?force_demotion:(force = false) @param pattern The statement pattern to modify @param modifiable_set The name of the variable we are searching for. *) -let rec modify_stmt_pattern - (pattern : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) - (modifiable_set : string Core_kernel.Set.Poly.t) = - let mod_expr force = modify_expr ~force_demotion:force modifiable_set in - let mod_stmt stmt = modify_stmt stmt modifiable_set in - match pattern with - | Stmt.Fixed.Pattern.Decl - ({decl_id; decl_type= Type.Sized sized_type; _} as decl) -> - if Set.Poly.mem modifiable_set decl_id then - Stmt.Fixed.Pattern.Decl - { decl with - decl_type= - Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type) } - else - Decl - { decl with - decl_type= - Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type) } - | NRFunApp (kind, (exprs : Expr.Typed.t list)) -> - let kind', exprs' = modify_kind modifiable_set kind exprs in - NRFunApp (kind', exprs') - | Assignment - ( (name, ut, lhs) - , ( {pattern= FunApp (CompilerInternal (FnReadParam read_param), args); _} - as assigner ) ) -> - if Set.Poly.mem modifiable_set name then - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) - , { assigner with - pattern= - FunApp - ( CompilerInternal - (FnReadParam {read_param with mem_pattern= AoS}) - , List.map ~f:(mod_expr true) args ) } ) - else - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) - , { assigner with - pattern= - FunApp - ( CompilerInternal - (FnReadParam {read_param with mem_pattern= SoA}) - , List.map ~f:(mod_expr false) args ) } ) - | Assignment (((name : string), (ut : UnsizedType.t), idx), rhs) -> - if Set.Poly.mem modifiable_set name then - (*If assignee is in bad set, force demotion of rhs functions*) - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) - , mod_expr true rhs ) - else - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) - , (mod_expr false) rhs ) - | IfElse (predicate, true_stmt, op_false_stmt) -> - IfElse - ( (mod_expr false) predicate - , mod_stmt true_stmt - , Option.map ~f:mod_stmt op_false_stmt ) - | Block stmts -> Block (List.map ~f:mod_stmt stmts) - | SList stmts -> SList (List.map ~f:mod_stmt stmts) - | For ({lower; upper; body; _} as loop) -> - Stmt.Fixed.Pattern.For - { loop with - lower= mod_expr false lower - ; upper= mod_expr false upper - ; body= mod_stmt body } - | TargetPE expr -> TargetPE ((mod_expr false) expr) - | Return optional_expr -> - Return (Option.map ~f:(mod_expr false) optional_expr) - | Profile ((p_name : string), stmt) -> - Profile (p_name, List.map ~f:mod_stmt stmt) - | While (predicate, body) -> While ((mod_expr false) predicate, mod_stmt body) - | Skip | Break | Continue | Decl _ -> pattern + let rec modify_stmt_pattern + (pattern : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) + (modifiable_set : string Core_kernel.Set.Poly.t) = + let mod_expr force = modify_expr ~force_demotion:force modifiable_set in + let mod_stmt stmt = modify_stmt stmt modifiable_set in + match pattern with + | Stmt.Fixed.Pattern.Decl + ({decl_id; decl_type= Type.Sized sized_type; _} as decl) -> + if Set.Poly.mem modifiable_set decl_id then + Stmt.Fixed.Pattern.Decl + { decl with + decl_type= + Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type) } + else + Decl + { decl with + decl_type= + Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type) } + | NRFunApp (kind, (exprs : Expr.Typed.t list)) -> + let kind', exprs' = modify_kind modifiable_set kind exprs in + NRFunApp (kind', exprs') + | Assignment + ( (name, ut, lhs) + , ( { pattern= FunApp (CompilerInternal (FnReadParam read_param), args) + ; _ } as assigner ) ) -> + if Set.Poly.mem modifiable_set name then + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) + , { assigner with + pattern= + FunApp + ( CompilerInternal + (FnReadParam {read_param with mem_pattern= AoS}) + , List.map ~f:(mod_expr true) args ) } ) + else + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) + , { assigner with + pattern= + FunApp + ( CompilerInternal + (FnReadParam {read_param with mem_pattern= SoA}) + , List.map ~f:(mod_expr false) args ) } ) + | Assignment (((name : string), (ut : UnsizedType.t), idx), rhs) -> + if Set.Poly.mem modifiable_set name then + (*If assignee is in bad set, force demotion of rhs functions*) + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) + , mod_expr true rhs ) + else + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) + , (mod_expr false) rhs ) + | IfElse (predicate, true_stmt, op_false_stmt) -> + IfElse + ( (mod_expr false) predicate + , mod_stmt true_stmt + , Option.map ~f:mod_stmt op_false_stmt ) + | Block stmts -> Block (List.map ~f:mod_stmt stmts) + | SList stmts -> SList (List.map ~f:mod_stmt stmts) + | For ({lower; upper; body; _} as loop) -> + Stmt.Fixed.Pattern.For + { loop with + lower= mod_expr false lower + ; upper= mod_expr false upper + ; body= mod_stmt body } + | TargetPE expr -> TargetPE ((mod_expr false) expr) + | Return optional_expr -> + Return (Option.map ~f:(mod_expr false) optional_expr) + | Profile ((p_name : string), stmt) -> + Profile (p_name, List.map ~f:mod_stmt stmt) + | While (predicate, body) -> + While ((mod_expr false) predicate, mod_stmt body) + | Skip | Break | Continue | Decl _ -> pattern -(** + (** Modify statement patterns in the MIR from AoS <-> SoA and vice versa @param mem_pattern A mem_pattern to modify expressions to. For the given memory pattern, this modifies @@ -637,9 +653,10 @@ let rec modify_stmt_pattern @param stmt The statement to modify. @param modifiable_set The name of the variable we are searching for. *) -and modify_stmt (Stmt.Fixed.{pattern; _} as stmt) - (modifiable_set : string Set.Poly.t) = - {stmt with pattern= modify_stmt_pattern pattern modifiable_set} + and modify_stmt (Stmt.Fixed.{pattern; _} as stmt) + (modifiable_set : string Set.Poly.t) = + {stmt with pattern= modify_stmt_pattern pattern modifiable_set} +end let collect_mem_pattern_variables stmts = let take_stmt acc = function diff --git a/src/analysis_and_optimization/Monotone_framework.ml b/src/analysis_and_optimization/Monotone_framework.ml index 135a2dc247..ba8cabd3a0 100644 --- a/src/analysis_and_optimization/Monotone_framework.ml +++ b/src/analysis_and_optimization/Monotone_framework.ml @@ -2,7 +2,7 @@ open Core_kernel open Core_kernel.Poly -open Monotone_framework_sigs +open Monotone_framework_intf open Mir_utils open Middle @@ -309,7 +309,9 @@ let minimal_variables_lattice initial_variables = end ) (* The transfer function for a constant propagation analysis *) -let constant_propagation_transfer ?(preserve_stability = false) +let constant_propagation_transfer + (module Partial_evaluator : Partial_evaluation.PARTIAL_EVALUATOR) + ?(preserve_stability = false) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = ( module struct type labels = int @@ -864,7 +866,7 @@ let rec declared_variables_stmt (List.map ~f:(fun x -> declared_variables_stmt x.pattern) l) let propagation_mfp (prog : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (propagation_transfer : (int, Stmt.Located.Non_recursive.t) Map.Poly.t @@ -897,7 +899,7 @@ let propagation_mfp (prog : Program.Typed.t) Mf.mfp () let reaching_definitions_mfp (mir : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let variables = ( module struct @@ -918,7 +920,7 @@ let reaching_definitions_mfp (mir : Program.Typed.t) Mf.mfp () let initialized_vars_mfp (total : string Set.Poly.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let (module Lattice) = dual_powerset_lattice_empty_initial @@ -949,8 +951,7 @@ let globals (prog : Program.Typed.t) = (** Monotone framework instance for live_variables analysis. Expects reverse flowgraph. *) let live_variables_mfp (prog : Program.Typed.t) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let never_kill = globals prog in let variables = @@ -970,10 +971,8 @@ let live_variables_mfp (prog : Program.Typed.t) (** Instantiate all four instances of the monotone framework for lazy code motion, reusing code between them *) -let lazy_expressions_mfp - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) +let lazy_expressions_mfp (module Flowgraph : FLOWGRAPH with type labels = int) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let all_expressions = used_subexpressions_stmt @@ -1031,8 +1030,7 @@ let lazy_expressions_mfp * *) let minimal_variables_mfp - (module Circular_Fwd_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Circular_Fwd_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (initial_variables : string Set.Poly.t) (gen_variable : diff --git a/src/analysis_and_optimization/Monotone_framework_sigs.mli b/src/analysis_and_optimization/Monotone_framework_intf.ml similarity index 100% rename from src/analysis_and_optimization/Monotone_framework_sigs.mli rename to src/analysis_and_optimization/Monotone_framework_intf.ml diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index bee45c2afa..3cc51f9076 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -5,1067 +5,1146 @@ open Core_kernel.Poly open Common open Middle open Mir_utils +open Optimize_intf -(** +let settings_const b = + { function_inlining= b + ; static_loop_unrolling= b + ; one_step_loop_unrolling= b + ; list_collapsing= b + ; block_fixing= b + ; allow_uninitialized_decls= b + ; constant_propagation= b + ; expression_propagation= b + ; copy_propagation= b + ; dead_code_elimination= b + ; partial_evaluation= b + ; lazy_code_motion= b + ; optimize_ad_levels= b + ; preserve_stability= not b + ; optimize_soa= b } + +let all_optimizations : optimization_settings = settings_const true +let no_optimizations : optimization_settings = settings_const false + +type optimization_level = O0 | O1 | Oexperimental + +let level_optimizations (lvl : optimization_level) : optimization_settings = + match lvl with + | O0 -> no_optimizations + | O1 -> + { function_inlining= true + ; static_loop_unrolling= false + ; one_step_loop_unrolling= false + ; list_collapsing= true + ; block_fixing= true + ; constant_propagation= true + ; expression_propagation= false + ; copy_propagation= true + ; dead_code_elimination= true + ; partial_evaluation= true + ; lazy_code_motion= false + ; allow_uninitialized_decls= true + ; optimize_ad_levels= false + ; preserve_stability= false + ; optimize_soa= true } + | Oexperimental -> all_optimizations + +module Make (StdLibrary : Frontend.Std_library_utils.Library) : OPTIMIZER = +struct + module Mem = Memory_patterns.Make (StdLibrary) + module Partial_evaluator = Partial_evaluation.Make (StdLibrary) + + (** Apply the transformation to each function body and to the rest of the program as one block. *) -let transform_program (mir : Program.Typed.t) - (transform : Stmt.Located.t -> Stmt.Located.t) : Program.Typed.t = - let packed_prog_body = - transform - { pattern= + let transform_program (mir : Program.Typed.t) + (transform : Stmt.Located.t -> Stmt.Located.t) : Program.Typed.t = + let packed_prog_body = + transform + { pattern= + SList + (List.map + ~f:(fun x -> + Stmt.Fixed.{pattern= SList x; meta= Location_span.empty} ) + [ mir.prepare_data; mir.transform_inits; mir.log_prob + ; mir.generate_quantities ] ) + ; meta= Location_span.empty } in + let transformed_prog_body = transform packed_prog_body in + let transformed_functions = + List.map mir.functions_block ~f:(fun fs -> + {fs with fdbody= Option.map ~f:transform fs.fdbody} ) in + match transformed_prog_body with + | { pattern= SList - (List.map - ~f:(fun x -> - Stmt.Fixed.{pattern= SList x; meta= Location_span.empty} ) - [ mir.prepare_data; mir.transform_inits; mir.log_prob - ; mir.generate_quantities ] ) - ; meta= Location_span.empty } in - let transformed_prog_body = transform packed_prog_body in - let transformed_functions = - List.map mir.functions_block ~f:(fun fs -> - {fs with fdbody= Option.map ~f:transform fs.fdbody} ) in - match transformed_prog_body with - | { pattern= - SList - [ {pattern= SList prepare_data'; _} - ; {pattern= SList transform_inits'; _}; {pattern= SList log_prob'; _} - ; {pattern= SList generate_quantities'; _} ] - ; _ } -> - { mir with - functions_block= transformed_functions - ; prepare_data= prepare_data' - ; transform_inits= transform_inits' - ; log_prob= log_prob' - ; generate_quantities= generate_quantities' } - | _ -> - Common.FatalError.fatal_error_msg - [%message "Something went wrong with program transformation packing!"] + [ {pattern= SList prepare_data'; _} + ; {pattern= SList transform_inits'; _}; {pattern= SList log_prob'; _} + ; {pattern= SList generate_quantities'; _} ] + ; _ } -> + { mir with + functions_block= transformed_functions + ; prepare_data= prepare_data' + ; transform_inits= transform_inits' + ; log_prob= log_prob' + ; generate_quantities= generate_quantities' } + | _ -> + raise + (Failure "Something went wrong with program transformation packing!") -(** + (** Apply the transformation to each function body and to each program block separately. *) -let transform_program_blockwise (mir : Program.Typed.t) - (transform : - Stmt.Located.t Program.fun_def option -> Stmt.Located.t -> Stmt.Located.t - ) : Program.Typed.t = - let transform' fd s = - match transform fd {pattern= SList s; meta= Location_span.empty} with - | {pattern= SList l; _} -> l - | _ -> - Common.FatalError.fatal_error_msg - [%message "Something went wrong with program transformation packing!"] - in - let transformed_functions = - List.map mir.functions_block ~f:(fun fs -> - {fs with fdbody= Option.map ~f:(transform (Some fs)) fs.fdbody} ) in - { mir with - functions_block= transformed_functions - ; prepare_data= transform' None mir.prepare_data - ; transform_inits= transform' None mir.transform_inits - ; log_prob= transform' None mir.log_prob - ; generate_quantities= transform' None mir.generate_quantities } + let transform_program_blockwise (mir : Program.Typed.t) + (transform : + Stmt.Located.t Program.fun_def option + -> Stmt.Located.t + -> Stmt.Located.t ) : Program.Typed.t = + let transform' fd s = + match transform fd {pattern= SList s; meta= Location_span.empty} with + | {pattern= SList l; _} -> l + | _ -> + raise + (Failure "Something went wrong with program transformation packing!") + in + let transformed_functions = + List.map mir.functions_block ~f:(fun fs -> + {fs with fdbody= Option.map ~f:(transform (Some fs)) fs.fdbody} ) + in + { mir with + functions_block= transformed_functions + ; prepare_data= transform' None mir.prepare_data + ; transform_inits= transform' None mir.transform_inits + ; log_prob= transform' None mir.log_prob + ; generate_quantities= transform' None mir.generate_quantities } -let map_no_loc l = - List.map ~f:(fun s -> Stmt.Fixed.{pattern= s; meta= Location_span.empty}) l + let map_no_loc l = + List.map ~f:(fun s -> Stmt.Fixed.{pattern= s; meta= Location_span.empty}) l -let slist_no_loc l = Stmt.Fixed.Pattern.SList (map_no_loc l) -let block_no_loc l = Stmt.Fixed.Pattern.Block (map_no_loc l) + let slist_no_loc l = Stmt.Fixed.Pattern.SList (map_no_loc l) + let block_no_loc l = Stmt.Fixed.Pattern.Block (map_no_loc l) -let slist_concat_no_loc l stmt = - match l with [] -> stmt | l -> slist_no_loc (l @ [stmt]) + let slist_concat_no_loc l stmt = + match l with [] -> stmt | l -> slist_no_loc (l @ [stmt]) -let gen_inline_var (name : string) (id_var : string) = - Gensym.generate ~prefix:("inline_" ^ name ^ "_" ^ id_var ^ "_") () + let gen_inline_var (name : string) (id_var : string) = + Gensym.generate ~prefix:("inline_" ^ name ^ "_" ^ id_var ^ "_") () -let replace_fresh_local_vars (fname : string) stmt = - let f (m : (string, string) Core_kernel.Map.Poly.t) = function - | Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} -> - let new_name = - match Map.Poly.find m decl_id with - | Some existing -> existing - | None -> gen_inline_var fname decl_id in - ( Stmt.Fixed.Pattern.Decl - {decl_adtype; decl_id= new_name; decl_type; initialize} - , Map.Poly.set m ~key:decl_id ~data:new_name ) - | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> - let new_name = - match Map.Poly.find m loopvar with - | Some existing -> existing - | None -> gen_inline_var fname loopvar in - ( Stmt.Fixed.Pattern.For {loopvar= new_name; lower; upper; body} - , Map.Poly.set m ~key:loopvar ~data:new_name ) - | Assignment ((var_name, ut, l), e) -> - let var_name = - match Map.Poly.find m var_name with - | None -> var_name - | Some var_name -> var_name in - (Stmt.Fixed.Pattern.Assignment ((var_name, ut, l), e), m) - | x -> (x, m) in - let s, m = map_rec_state_stmt_loc f Map.Poly.empty stmt in - name_subst_stmt m s + let replace_fresh_local_vars (fname : string) stmt = + let f (m : (string, string) Core_kernel.Map.Poly.t) = function + | Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} -> + let new_name = + match Map.Poly.find m decl_id with + | Some existing -> existing + | None -> gen_inline_var fname decl_id in + ( Stmt.Fixed.Pattern.Decl + {decl_adtype; decl_id= new_name; decl_type; initialize} + , Map.Poly.set m ~key:decl_id ~data:new_name ) + | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> + let new_name = + match Map.Poly.find m loopvar with + | Some existing -> existing + | None -> gen_inline_var fname loopvar in + ( Stmt.Fixed.Pattern.For {loopvar= new_name; lower; upper; body} + , Map.Poly.set m ~key:loopvar ~data:new_name ) + | Assignment ((var_name, ut, l), e) -> + let var_name = + match Map.Poly.find m var_name with + | None -> var_name + | Some var_name -> var_name in + (Stmt.Fixed.Pattern.Assignment ((var_name, ut, l), e), m) + | x -> (x, m) in + let s, m = map_rec_state_stmt_loc f Map.Poly.empty stmt in + name_subst_stmt m s -let subst_args_stmt args es = - let m = Map.Poly.of_alist_exn (List.zip_exn args es) in - subst_stmt m + let subst_args_stmt args es = + let m = Map.Poly.of_alist_exn (List.zip_exn args es) in + subst_stmt m -(** + (** * Count the number of returns that happen in a statement *) -let rec count_returns Stmt.Fixed.{pattern; _} : int = - Stmt.Fixed.Pattern.fold - (fun acc _ -> acc) - (fun acc -> function - | Stmt.Fixed.{pattern= Return _; _} -> acc + 1 - | stmt -> acc + count_returns stmt ) - 0 pattern + let rec count_returns Stmt.Fixed.{pattern; _} : int = + Stmt.Fixed.Pattern.fold + (fun acc _ -> acc) + (fun acc -> function + | Stmt.Fixed.{pattern= Return _; _} -> acc + 1 + | stmt -> acc + count_returns stmt ) + 0 pattern -(* The strategy here is to wrap the function body in a dummy loop, then replace - returns with breaks. One issue is early return from internal loops - in - those cases, a break would only break out of the inner loop. The solution is - a flag variable to indicate whether a 'return' break has been called, and - then to check if that flag is set after each loop. Then, if a 'return' break - is called from an inner loop, there's a cascade of breaks all the way out of - the dummy loop. *) -let handle_early_returns (fname : string) opt_var stmt = - let returned = gen_inline_var fname "early_ret_check" in - let generate_inner_breaks num_returns stmt_pattern = - match stmt_pattern with - | Stmt.Fixed.Pattern.Return opt_ret -> ( - match (opt_var, opt_ret) with - | None, None when num_returns > 1 -> Stmt.Fixed.Pattern.Break - | None, None -> Stmt.Fixed.Pattern.Block [] - | Some name, Some e when num_returns > 1 -> - SList - [ Stmt.Fixed. + (* The strategy here is to wrap the function body in a dummy loop, then replace + returns with breaks. One issue is early return from internal loops - in + those cases, a break would only break out of the inner loop. The solution is + a flag variable to indicate whether a 'return' break has been called, and + then to check if that flag is set after each loop. Then, if a 'return' break + is called from an inner loop, there's a cascade of breaks all the way out of + the dummy loop. *) + let handle_early_returns (fname : string) opt_var stmt = + let returned = gen_inline_var fname "early_ret_check" in + let generate_inner_breaks num_returns stmt_pattern = + match stmt_pattern with + | Stmt.Fixed.Pattern.Return opt_ret -> ( + match (opt_var, opt_ret) with + | None, None when num_returns > 1 -> Stmt.Fixed.Pattern.Break + | None, None -> Stmt.Fixed.Pattern.Block [] + | Some name, Some e when num_returns > 1 -> + SList + [ Stmt.Fixed. + { pattern= + Assignment + ( (returned, UInt, []) + , Expr.Fixed. + { pattern= Lit (Int, "1") + ; meta= + Expr.Typed.Meta. + { type_= UInt + ; adlevel= DataOnly + ; loc= Location_span.empty } } ) + ; meta= Location_span.empty } + ; Stmt.Fixed. + { pattern= Assignment ((name, Expr.Typed.type_of e, []), e) + ; meta= Location_span.empty } + ; {pattern= Break; meta= Location_span.empty} ] + | Some name, Some e -> Assignment ((name, Expr.Typed.type_of e, []), e) + | Some _, None -> + Common.FatalError.fatal_error_msg + [%message + ( "Function should return a value but found an empty return \ + statement." + : string )] + | None, Some _ -> + Common.FatalError.fatal_error_msg + [%message + ( "Expected a void function but found a non-empty return \ + statement." + : string )] ) + | Stmt.Fixed.Pattern.For _ as loop when num_returns > 1 -> + Stmt.Fixed.Pattern.SList + [ Stmt.Fixed.{pattern= loop; meta= Location_span.empty} + ; Stmt.Fixed. { pattern= - Assignment - ( (returned, UInt, []) - , Expr.Fixed. - { pattern= Lit (Int, "1") + IfElse + ( Expr.Fixed. + { pattern= Var returned ; meta= Expr.Typed.Meta. { type_= UInt ; adlevel= DataOnly - ; loc= Location_span.empty } } ) - ; meta= Location_span.empty } - ; Stmt.Fixed. - { pattern= Assignment ((name, Expr.Typed.type_of e, []), e) - ; meta= Location_span.empty } - ; {pattern= Break; meta= Location_span.empty} ] - | Some name, Some e -> Assignment ((name, Expr.Typed.type_of e, []), e) - | Some _, None -> - Common.FatalError.fatal_error_msg - [%message - ( "Function should return a value but found an empty return \ - statement." - : string )] - | None, Some _ -> - Common.FatalError.fatal_error_msg - [%message - ( "Expected a void function but found a non-empty return \ - statement." - : string )] ) - | Stmt.Fixed.Pattern.For _ as loop when num_returns > 1 -> - Stmt.Fixed.Pattern.SList - [ Stmt.Fixed.{pattern= loop; meta= Location_span.empty} - ; Stmt.Fixed. - { pattern= - IfElse - ( Expr.Fixed. - { pattern= Var returned + ; loc= Location_span.empty } } + , {pattern= Break; meta= Location_span.empty} + , None ) + ; meta= Location_span.empty } ] + | x -> x in + let num_returns = count_returns stmt in + if num_returns > 1 then + Stmt.Fixed.Pattern.SList + [ Stmt.Fixed. + { pattern= + Decl + { decl_adtype= DataOnly + ; decl_id= returned + ; decl_type= Sized SInt + ; initialize= true } + ; meta= Location_span.empty } + ; Stmt.Fixed. + { pattern= + Assignment + ( (returned, UInt, []) + , Expr.Fixed. + { pattern= Lit (Int, "0") + ; meta= + Expr.Typed.Meta. + { type_= UInt + ; adlevel= DataOnly + ; loc= Location_span.empty } } ) + ; meta= Location_span.empty } + ; Stmt.Fixed. + { pattern= + Stmt.Fixed.Pattern.For + { loopvar= gen_inline_var fname "iterator" + ; lower= + Expr.Fixed. + { pattern= Lit (Int, "1") ; meta= Expr.Typed.Meta. { type_= UInt ; adlevel= DataOnly ; loc= Location_span.empty } } - , {pattern= Break; meta= Location_span.empty} - , None ) - ; meta= Location_span.empty } ] - | x -> x in - let num_returns = count_returns stmt in - if num_returns > 1 then - Stmt.Fixed.Pattern.SList - [ Stmt.Fixed. - { pattern= - Decl - { decl_adtype= DataOnly - ; decl_id= returned - ; decl_type= Sized SInt - ; initialize= true } - ; meta= Location_span.empty } - ; Stmt.Fixed. - { pattern= - Assignment - ( (returned, UInt, []) - , Expr.Fixed. - { pattern= Lit (Int, "0") - ; meta= - Expr.Typed.Meta. - { type_= UInt - ; adlevel= DataOnly - ; loc= Location_span.empty } } ) - ; meta= Location_span.empty } - ; Stmt.Fixed. - { pattern= - Stmt.Fixed.Pattern.For - { loopvar= gen_inline_var fname "iterator" - ; lower= - Expr.Fixed. + ; upper= { pattern= Lit (Int, "1") ; meta= - Expr.Typed.Meta. - { type_= UInt - ; adlevel= DataOnly - ; loc= Location_span.empty } } - ; upper= - { pattern= Lit (Int, "1") - ; meta= - { type_= UInt - ; adlevel= DataOnly - ; loc= Location_span.empty } } - ; body= - map_rec_stmt_loc (generate_inner_breaks num_returns) stmt } - ; meta= Location_span.empty } ] - else (map_rec_stmt_loc (generate_inner_breaks num_returns) stmt).pattern + { type_= UInt + ; adlevel= DataOnly + ; loc= Location_span.empty } } + ; body= + map_rec_stmt_loc (generate_inner_breaks num_returns) stmt + } + ; meta= Location_span.empty } ] + else (map_rec_stmt_loc (generate_inner_breaks num_returns) stmt).pattern -let inline_list f es = - let dse_list = List.map ~f es in - (* function arguments are evaluated from right to left in C++, so we need to reverse *) - let d_list = - List.concat (List.rev (List.map ~f:(function x, _, _ -> x) dse_list)) in - let s_list = - List.concat (List.rev (List.map ~f:(function _, x, _ -> x) dse_list)) in - let es = List.map ~f:(function _, _, x -> x) dse_list in - (d_list, s_list, es) + let inline_list f es = + let dse_list = List.map ~f es in + (* function arguments are evaluated from right to left in C++, so we need to reverse *) + let d_list = + List.concat (List.rev (List.map ~f:(function x, _, _ -> x) dse_list)) + in + let s_list = + List.concat (List.rev (List.map ~f:(function _, x, _ -> x) dse_list)) + in + let es = List.map ~f:(function _, _, x -> x) dse_list in + (d_list, s_list, es) -(* Triple is (declaration list, statement list, return expression) *) -let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e) - = - match pattern with - | Var _ -> ([], [], e) - | Lit (_, _) -> ([], [], e) - | Promotion (expr, ut, ad) -> - let d, sl, expr' = inline_function_expression propto adt fim expr in - (d, sl, {e with pattern= Promotion (expr', ut, ad)}) - | FunApp (kind, es) -> ( - let d_list, s_list, es = - inline_list (inline_function_expression propto adt fim) es in - match kind with - | CompilerInternal _ -> - (d_list, s_list, {e with pattern= FunApp (kind, es)}) - | UserDefined (fname, suffix) | StanLib (fname, suffix, _) -> ( - let suffix, fname' = - match suffix with - | FnLpdf propto' when propto' && propto -> - ( Fun_kind.FnLpdf true - , Utils.with_unnormalized_suffix fname - |> Option.value ~default:fname ) - | FnLpdf _ -> (Fun_kind.FnLpdf false, fname) - | _ -> (suffix, fname) in - match Map.find fim fname' with - | None -> - let fun_kind = - match kind with - | Fun_kind.UserDefined _ -> Fun_kind.UserDefined (fname, suffix) - | _ -> StanLib (fname, suffix, AoS) in - (d_list, s_list, {e with pattern= FunApp (fun_kind, es)}) - | Some (rt, args, body) -> - let inline_return_name = gen_inline_var fname "return" in - let handle = - handle_early_returns fname (Some inline_return_name) in - let d_list2, s_list2, (e : Expr.Typed.t) = - let decl_type = - Option.map ~f:Mir_utils.unsafe_unsized_to_sized_type rt - |> Option.value_exn in - ( [ Stmt.Fixed.Pattern.Decl - { decl_adtype= adt - ; decl_id= inline_return_name - ; decl_type - ; initialize= false } ] - (* We should minimize the code that's having its variables - replaced to avoid conflict with the (two) new dummy - variables introduced by inlining *) - , [ handle - (subst_args_stmt args es - (replace_fresh_local_vars fname body) ) ] - , { pattern= Var inline_return_name - ; meta= - Expr.Typed.Meta. - { type_= Type.to_unsized decl_type - ; adlevel= adt - ; loc= Location_span.empty } } ) in - let d_list = d_list @ d_list2 in - let s_list = s_list @ s_list2 in - (d_list, s_list, e) ) ) - | TernaryIf (e1, e2, e3) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - let dl3, sl3, e3 = inline_function_expression propto adt fim e3 in - ( dl1 @ dl2 @ dl3 - , sl1 - @ [ Stmt.Fixed.( + (* Triple is (declaration list, statement list, return expression) *) + let rec inline_function_expression propto adt fim + (Expr.Fixed.{pattern; _} as e) = + match pattern with + | Var _ -> ([], [], e) + | Lit (_, _) -> ([], [], e) + | Promotion (expr, ut, ad) -> + let d, sl, expr' = inline_function_expression propto adt fim expr in + (d, sl, {e with pattern= Promotion (expr', ut, ad)}) + | FunApp (kind, es) -> ( + let d_list, s_list, es = + inline_list (inline_function_expression propto adt fim) es in + match kind with + | CompilerInternal _ -> + (d_list, s_list, {e with pattern= FunApp (kind, es)}) + | UserDefined (fname, suffix) | StanLib (fname, suffix, _) -> ( + let suffix, fname' = + match suffix with + | FnLpdf propto' when propto' && propto -> + ( Fun_kind.FnLpdf true + , Utils.with_unnormalized_suffix fname + |> Option.value ~default:fname ) + | FnLpdf _ -> (Fun_kind.FnLpdf false, fname) + | _ -> (suffix, fname) in + match Map.find fim fname' with + | None -> + let fun_kind = + match kind with + | Fun_kind.UserDefined _ -> + Fun_kind.UserDefined (fname, suffix) + | _ -> StanLib (fname, suffix, AoS) in + (d_list, s_list, {e with pattern= FunApp (fun_kind, es)}) + | Some (rt, args, body) -> + let inline_return_name = gen_inline_var fname "return" in + let handle = + handle_early_returns fname (Some inline_return_name) in + let d_list2, s_list2, (e : Expr.Typed.t) = + let decl_type = + Option.map ~f:Mir_utils.unsafe_unsized_to_sized_type rt + |> Option.value_exn in + ( [ Stmt.Fixed.Pattern.Decl + { decl_adtype= adt + ; decl_id= inline_return_name + ; decl_type + ; initialize= false } ] + (* We should minimize the code that's having its variables + replaced to avoid conflict with the (two) new dummy + variables introduced by inlining *) + , [ handle + (subst_args_stmt args es + (replace_fresh_local_vars fname body) ) ] + , { pattern= Var inline_return_name + ; meta= + Expr.Typed.Meta. + { type_= Type.to_unsized decl_type + ; adlevel= adt + ; loc= Location_span.empty } } ) in + let d_list = d_list @ d_list2 in + let s_list = s_list @ s_list2 in + (d_list, s_list, e) ) ) + | TernaryIf (e1, e2, e3) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + let dl3, sl3, e3 = inline_function_expression propto adt fim e3 in + ( dl1 @ dl2 @ dl3 + , sl1 + @ [ Stmt.Fixed.( + Pattern.IfElse + ( e1 + , {pattern= block_no_loc sl2; meta= Location_span.empty} + , Some {pattern= block_no_loc sl3; meta= Location_span.empty} + )) ] + , {e with pattern= TernaryIf (e1, e2, e3)} ) + | Indexed (e', i_list) -> + let dl, sl, e' = inline_function_expression propto adt fim e' in + let d_list, s_list, i_list = + inline_list (inline_function_index propto adt fim) i_list in + (d_list @ dl, s_list @ sl, {e with pattern= Indexed (e', i_list)}) + | EAnd (e1, e2) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + let sl2 = + [ Stmt.Fixed.( Pattern.IfElse ( e1 - , {pattern= block_no_loc sl2; meta= Location_span.empty} - , Some {pattern= block_no_loc sl3; meta= Location_span.empty} )) - ] - , {e with pattern= TernaryIf (e1, e2, e3)} ) - | Indexed (e', i_list) -> - let dl, sl, e' = inline_function_expression propto adt fim e' in - let d_list, s_list, i_list = - inline_list (inline_function_index propto adt fim) i_list in - (d_list @ dl, s_list @ sl, {e with pattern= Indexed (e', i_list)}) - | EAnd (e1, e2) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - let sl2 = - [ Stmt.Fixed.( - Pattern.IfElse - ( e1 - , {pattern= Block (map_no_loc sl2); meta= Location_span.empty} - , None )) ] in - (dl1 @ dl2, sl1 @ sl2, {e with pattern= EAnd (e1, e2)}) - | EOr (e1, e2) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - let sl2 = - [ Stmt.Fixed.( - Pattern.IfElse - ( e1 - , {pattern= Skip; meta= Location_span.empty} - , Some {pattern= Block (map_no_loc sl2); meta= Location_span.empty} - )) ] in - (dl1 @ dl2, sl1 @ sl2, {e with pattern= EOr (e1, e2)}) + , {pattern= Block (map_no_loc sl2); meta= Location_span.empty} + , None )) ] in + (dl1 @ dl2, sl1 @ sl2, {e with pattern= EAnd (e1, e2)}) + | EOr (e1, e2) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + let sl2 = + [ Stmt.Fixed.( + Pattern.IfElse + ( e1 + , {pattern= Skip; meta= Location_span.empty} + , Some + {pattern= Block (map_no_loc sl2); meta= Location_span.empty} + )) ] in + (dl1 @ dl2, sl1 @ sl2, {e with pattern= EOr (e1, e2)}) -and inline_function_index propto adt fim i = - match i with - | All -> ([], [], All) - | Single e -> - let dl, sl, e = inline_function_expression propto adt fim e in - (dl, sl, Single e) - | Upfrom e -> - let dl, sl, e = inline_function_expression propto adt fim e in - (dl, sl, Upfrom e) - | Between (e1, e2) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - (dl1 @ dl2, sl1 @ sl2, Between (e1, e2)) - | MultiIndex e -> - let dl, sl, e = inline_function_expression propto adt fim e in - (dl, sl, MultiIndex e) + and inline_function_index propto adt fim i = + match i with + | All -> ([], [], All) + | Single e -> + let dl, sl, e = inline_function_expression propto adt fim e in + (dl, sl, Single e) + | Upfrom e -> + let dl, sl, e = inline_function_expression propto adt fim e in + (dl, sl, Upfrom e) + | Between (e1, e2) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + (dl1 @ dl2, sl1 @ sl2, Between (e1, e2)) + | MultiIndex e -> + let dl, sl, e = inline_function_expression propto adt fim e in + (dl, sl, MultiIndex e) -let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} = - Stmt.Fixed. - { pattern= - ( match pattern with - | Assignment ((assignee, ut, idx_lst), rhs) -> - let dl1, sl1, new_idx_lst = - inline_list (inline_function_index propto adt fim) idx_lst in - let dl2, sl2, new_rhs = - inline_function_expression propto adt fim rhs in - slist_concat_no_loc - (dl2 @ dl1 @ sl2 @ sl1) - (Assignment ((assignee, ut, new_idx_lst), new_rhs)) - | TargetPE e -> - let d, s, e = inline_function_expression propto adt fim e in - slist_concat_no_loc (d @ s) (TargetPE e) - | NRFunApp (kind, exprs) -> - let d_list, s_list, es = - inline_list (inline_function_expression propto adt fim) exprs - in - slist_concat_no_loc (d_list @ s_list) - ( match kind with - | CompilerInternal _ -> NRFunApp (kind, es) - | UserDefined (s, _) | StanLib (s, _, _) -> ( - match Map.find fim s with - | None -> NRFunApp (kind, es) - | Some (_, args, b) -> - let b = replace_fresh_local_vars s b in - let b = handle_early_returns s None b in - (subst_args_stmt args es - {pattern= b; meta= Location_span.empty} ) - .pattern ) ) - | Return e -> ( - match e with - | None -> Return None - | Some expr -> + let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} = + Stmt.Fixed. + { pattern= + ( match pattern with + | Assignment ((assignee, ut, idx_lst), rhs) -> + let dl1, sl1, new_idx_lst = + inline_list (inline_function_index propto adt fim) idx_lst in + let dl2, sl2, new_rhs = + inline_function_expression propto adt fim rhs in + slist_concat_no_loc + (dl2 @ dl1 @ sl2 @ sl1) + (Assignment ((assignee, ut, new_idx_lst), new_rhs)) + | TargetPE e -> + let d, s, e = inline_function_expression propto adt fim e in + slist_concat_no_loc (d @ s) (TargetPE e) + | NRFunApp (kind, exprs) -> + let d_list, s_list, es = + inline_list (inline_function_expression propto adt fim) exprs + in + slist_concat_no_loc (d_list @ s_list) + ( match kind with + | CompilerInternal _ -> NRFunApp (kind, es) + | UserDefined (s, _) | StanLib (s, _, _) -> ( + match Map.find fim s with + | None -> NRFunApp (kind, es) + | Some (_, args, b) -> + let b = replace_fresh_local_vars s b in + let b = handle_early_returns s None b in + (subst_args_stmt args es + {pattern= b; meta= Location_span.empty} ) + .pattern ) ) + | Return e -> ( + match e with + | None -> Return None + | Some expr -> + let d, s, e = inline_function_expression propto adt fim expr in + slist_concat_no_loc (d @ s) (Return (Some e)) ) + | IfElse (expr, s1, s2) -> let d, s, e = inline_function_expression propto adt fim expr in - slist_concat_no_loc (d @ s) (Return (Some e)) ) - | IfElse (expr, s1, s2) -> - let d, s, e = inline_function_expression propto adt fim expr in - slist_concat_no_loc (d @ s) - (IfElse - ( e - , inline_function_statement propto adt fim s1 - , Option.map ~f:(inline_function_statement propto adt fim) s2 - ) ) - | While (expr, stmt) -> - let d', s', e = inline_function_expression propto adt fim expr in - slist_concat_no_loc (d' @ s') - (While - ( e - , match s' with - | [] -> inline_function_statement propto adt fim stmt - | _ -> - { pattern= - Block - ( [inline_function_statement propto adt fim stmt] - @ map_no_loc s' ) - ; meta= Location_span.empty } ) ) - | For {loopvar; lower; upper; body} -> - let d_lower, s_lower, lower = - inline_function_expression propto adt fim lower in - let d_upper, s_upper, upper = - inline_function_expression propto adt fim upper in - slist_concat_no_loc - (d_lower @ d_upper @ s_lower @ s_upper) - (For - { loopvar - ; lower - ; upper - ; body= - ( match s_upper with - | [] -> inline_function_statement propto adt fim body + slist_concat_no_loc (d @ s) + (IfElse + ( e + , inline_function_statement propto adt fim s1 + , Option.map ~f:(inline_function_statement propto adt fim) s2 + ) ) + | While (expr, stmt) -> + let d', s', e = inline_function_expression propto adt fim expr in + slist_concat_no_loc (d' @ s') + (While + ( e + , match s' with + | [] -> inline_function_statement propto adt fim stmt | _ -> { pattern= Block - ( [inline_function_statement propto adt fim body] - @ map_no_loc s_upper ) - ; meta= Location_span.empty } ) } ) - | Profile (name, l) -> - Profile - (name, List.map l ~f:(inline_function_statement propto adt fim)) - | Block l -> - Block (List.map l ~f:(inline_function_statement propto adt fim)) - | SList l -> - SList (List.map l ~f:(inline_function_statement propto adt fim)) - | Decl r -> Decl r - | Skip -> Skip - | Break -> Break - | Continue -> Continue ) - ; meta } + ( [inline_function_statement propto adt fim stmt] + @ map_no_loc s' ) + ; meta= Location_span.empty } ) ) + | For {loopvar; lower; upper; body} -> + let d_lower, s_lower, lower = + inline_function_expression propto adt fim lower in + let d_upper, s_upper, upper = + inline_function_expression propto adt fim upper in + slist_concat_no_loc + (d_lower @ d_upper @ s_lower @ s_upper) + (For + { loopvar + ; lower + ; upper + ; body= + ( match s_upper with + | [] -> inline_function_statement propto adt fim body + | _ -> + { pattern= + Block + ( [ inline_function_statement propto adt fim + body ] + @ map_no_loc s_upper ) + ; meta= Location_span.empty } ) } ) + | Profile (name, l) -> + Profile + (name, List.map l ~f:(inline_function_statement propto adt fim)) + | Block l -> + Block (List.map l ~f:(inline_function_statement propto adt fim)) + | SList l -> + SList (List.map l ~f:(inline_function_statement propto adt fim)) + | Decl r -> Decl r + | Skip -> Skip + | Break -> Break + | Continue -> Continue ) + ; meta } -let create_function_inline_map adt l = - let f accum Program.{fdname; fdargs; fdbody; fdrt; _} = - match fdbody with - | None -> accum - | Some fdbody -> ( - let create_data propto = - ( Option.map - ~f:(fun x -> Type.Unsized x) - (UnsizedType.returntype_to_type_opt fdrt) - , List.map ~f:(fun (_, name, _) -> name) fdargs - , inline_function_statement propto adt accum fdbody ) in - match Middle.Utils.with_unnormalized_suffix fdname with - | None -> ( - let data = create_data true in - match Map.add accum ~key:fdname ~data with - | `Ok m -> m - | `Duplicate -> accum ) - | Some fdname' -> - let data = create_data false in - let data' = create_data true in - let m = Map.Poly.of_alist_exn [(fdname, data); (fdname', data')] in - Map.merge_skewed accum m ~combine:(fun ~key:_ f _ -> f) ) in - List.fold l ~init:Map.Poly.empty ~f + let create_function_inline_map adt l = + let f accum Program.{fdname; fdargs; fdbody; fdrt; _} = + match fdbody with + | None -> accum + | Some fdbody -> ( + let create_data propto = + ( Option.map + ~f:(fun x -> Type.Unsized x) + (UnsizedType.returntype_to_type_opt fdrt) + , List.map ~f:(fun (_, name, _) -> name) fdargs + , inline_function_statement propto adt accum fdbody ) in + match Middle.Utils.with_unnormalized_suffix fdname with + | None -> ( + let data = create_data true in + match Map.add accum ~key:fdname ~data with + | `Ok m -> m + | `Duplicate -> accum ) + | Some fdname' -> + let data = create_data false in + let data' = create_data true in + let m = Map.Poly.of_alist_exn [(fdname, data); (fdname', data')] in + Map.merge_skewed accum m ~combine:(fun ~key:_ f _ -> f) ) in + List.fold l ~init:Map.Poly.empty ~f -let function_inlining (mir : Program.Typed.t) = - (* We add only the functions with a single definition to the inline map. - Overloaded functions cannot be inlined. *) - let can_inline = - List.fold mir.functions_block ~init:String.Map.empty - ~f:(fun accum Program.{fdname; _} -> - Map.update accum fdname - ~f:(Option.value_map ~default:true ~f:(fun _ -> false)) ) in - let inlineable_functions = - List.filter mir.functions_block ~f:(fun Program.{fdname; _} -> - Map.find_exn can_inline fdname ) in - let dataonly_inline_map = - create_function_inline_map UnsizedType.DataOnly inlineable_functions in - let autodiff_inline_map = - create_function_inline_map UnsizedType.AutoDiffable inlineable_functions - in - let dataonly_inline_function_statements = - List.map - ~f: - (inline_function_statement true UnsizedType.DataOnly dataonly_inline_map) - in - let autodiffable_inline_function_statements = - List.map - ~f: - (inline_function_statement true UnsizedType.AutoDiffable - autodiff_inline_map ) in - { mir with - transform_inits= autodiffable_inline_function_statements mir.transform_inits - ; unconstrain_array= - autodiffable_inline_function_statements mir.unconstrain_array - ; log_prob= autodiffable_inline_function_statements mir.log_prob - ; generate_quantities= - dataonly_inline_function_statements mir.generate_quantities } + let function_inlining (mir : Program.Typed.t) = + (* We add only the functions with a single definition to the inline map. + Overloaded functions cannot be inlined. *) + let can_inline = + List.fold mir.functions_block ~init:String.Map.empty + ~f:(fun accum Program.{fdname; _} -> + Map.update accum fdname + ~f:(Option.value_map ~default:true ~f:(fun _ -> false)) ) in + let inlineable_functions = + List.filter mir.functions_block ~f:(fun Program.{fdname; _} -> + Map.find_exn can_inline fdname ) in + let dataonly_inline_map = + create_function_inline_map UnsizedType.DataOnly inlineable_functions in + let autodiff_inline_map = + create_function_inline_map UnsizedType.AutoDiffable inlineable_functions + in + let dataonly_inline_function_statements = + List.map + ~f: + (inline_function_statement true UnsizedType.DataOnly + dataonly_inline_map ) in + let autodiffable_inline_function_statements = + List.map + ~f: + (inline_function_statement true UnsizedType.AutoDiffable + autodiff_inline_map ) in + { mir with + transform_inits= + autodiffable_inline_function_statements mir.transform_inits + ; unconstrain_array= + autodiffable_inline_function_statements mir.unconstrain_array + ; log_prob= autodiffable_inline_function_statements mir.log_prob + ; generate_quantities= + dataonly_inline_function_statements mir.generate_quantities } -let rec contains_top_break_or_continue Stmt.Fixed.{pattern; _} = - match pattern with - | Break | Continue -> true - | Assignment (_, _) - |TargetPE _ - |NRFunApp (_, _) - |Return _ | Decl _ - |While (_, _) - |For _ | Skip -> - false - | Profile (_, l) | Block l | SList l -> - List.exists l ~f:contains_top_break_or_continue - | IfElse (_, b1, b2) -> ( - contains_top_break_or_continue b1 - || - match b2 with None -> false | Some b -> contains_top_break_or_continue b ) + let rec contains_top_break_or_continue Stmt.Fixed.{pattern; _} = + match pattern with + | Break | Continue -> true + | Assignment (_, _) + |TargetPE _ + |NRFunApp (_, _) + |Return _ | Decl _ + |While (_, _) + |For _ | Skip -> + false + | Profile (_, l) | Block l | SList l -> + List.exists l ~f:contains_top_break_or_continue + | IfElse (_, b1, b2) -> ( + contains_top_break_or_continue b1 + || + match b2 with + | None -> false + | Some b -> contains_top_break_or_continue b ) -let unroll_static_limit = 32 + let unroll_static_limit = 32 -let unroll_static_loops_statement _ = - let f stmt = - match stmt with - | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> ( - let lower = Partial_evaluator.try_eval_expr lower in - let upper = Partial_evaluator.try_eval_expr upper in - match - (contains_top_break_or_continue body, lower.pattern, upper.pattern) - with - | false, Lit (Int, low_str), Lit (Int, up_str) -> - let low = Int.of_string low_str in - let up = Int.of_string up_str in - if up - low > unroll_static_limit then stmt - else - let range = - List.map - ~f:(fun i -> - Expr.Fixed. - { pattern= Lit (Int, Int.to_string i) - ; meta= - Expr.Typed.Meta. - { type_= UInt - ; loc= Location_span.empty - ; adlevel= DataOnly } } ) - (List.range ~start:`inclusive ~stop:`inclusive low up) in - let stmts = - List.map - ~f:(fun i -> - subst_args_stmt [loopvar] [i] - {pattern= body.pattern; meta= Location_span.empty} ) - range in - Stmt.Fixed.Pattern.SList stmts - | _ -> stmt ) - | _ -> stmt in - top_down_map_rec_stmt_loc f + let unroll_static_loops_statement _ = + let f stmt = + match stmt with + | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> ( + let lower = Partial_evaluator.try_eval_expr lower in + let upper = Partial_evaluator.try_eval_expr upper in + match + (contains_top_break_or_continue body, lower.pattern, upper.pattern) + with + | false, Lit (Int, low_str), Lit (Int, up_str) -> + let low = Int.of_string low_str in + let up = Int.of_string up_str in + if up - low > unroll_static_limit then stmt + else + let range = + List.map + ~f:(fun i -> + Expr.Fixed. + { pattern= Lit (Int, Int.to_string i) + ; meta= + Expr.Typed.Meta. + { type_= UInt + ; loc= Location_span.empty + ; adlevel= DataOnly } } ) + (List.range ~start:`inclusive ~stop:`inclusive low up) in + let stmts = + List.map + ~f:(fun i -> + subst_args_stmt [loopvar] [i] + {pattern= body.pattern; meta= Location_span.empty} ) + range in + Stmt.Fixed.Pattern.SList stmts + | _ -> stmt ) + | _ -> stmt in + top_down_map_rec_stmt_loc f -let static_loop_unrolling mir = - transform_program_blockwise mir unroll_static_loops_statement + let static_loop_unrolling mir = + transform_program_blockwise mir unroll_static_loops_statement -let unroll_loop_one_step_statement _ = - let f stmt = - match stmt with - | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> - if contains_top_break_or_continue body then stmt - else - IfElse - ( Expr.Fixed. - { lower with - pattern= - FunApp (StanLib ("Geq__", FnPlain, AoS), [upper; lower]) } - , { pattern= - (let body_unrolled = - subst_args_stmt [loopvar] [lower] - {pattern= body.pattern; meta= Location_span.empty} in - let (body' : Stmt.Located.t) = - { pattern= - Stmt.Fixed.Pattern.For - { loopvar - ; upper - ; body - ; lower= - { lower with - pattern= - FunApp - ( StanLib ("Plus__", FnPlain, AoS) - , [lower; Expr.Helpers.loop_bottom] ) } } - ; meta= Location_span.empty } in - match body_unrolled.pattern with - | Block stmts -> Block (stmts @ [body']) - | _ -> Stmt.Fixed.Pattern.Block [body_unrolled; body'] ) - ; meta= Location_span.empty } - , None ) - | While (e, body) -> - if contains_top_break_or_continue body then stmt - else - IfElse - ( e - , { pattern= Block [body; {body with pattern= While (e, body)}] - ; meta= Location_span.empty } - , None ) - | _ -> stmt in - map_rec_stmt_loc f + let unroll_loop_one_step_statement _ = + let f stmt = + match stmt with + | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> + if contains_top_break_or_continue body then stmt + else + IfElse + ( Expr.Fixed. + { lower with + pattern= + FunApp (StanLib ("Geq__", FnPlain, AoS), [upper; lower]) + } + , { pattern= + (let body_unrolled = + subst_args_stmt [loopvar] [lower] + {pattern= body.pattern; meta= Location_span.empty} + in + let (body' : Stmt.Located.t) = + { pattern= + Stmt.Fixed.Pattern.For + { loopvar + ; upper + ; body + ; lower= + { lower with + pattern= + FunApp + ( StanLib ("Plus__", FnPlain, AoS) + , [lower; Expr.Helpers.loop_bottom] ) } + } + ; meta= Location_span.empty } in + match body_unrolled.pattern with + | Block stmts -> Block (stmts @ [body']) + | _ -> Stmt.Fixed.Pattern.Block [body_unrolled; body'] ) + ; meta= Location_span.empty } + , None ) + | While (e, body) -> + if contains_top_break_or_continue body then stmt + else + IfElse + ( e + , { pattern= Block [body; {body with pattern= While (e, body)}] + ; meta= Location_span.empty } + , None ) + | _ -> stmt in + map_rec_stmt_loc f -let one_step_loop_unrolling mir = - transform_program_blockwise mir unroll_loop_one_step_statement + let one_step_loop_unrolling mir = + transform_program_blockwise mir unroll_loop_one_step_statement -let collapse_lists_statement _ = - let rec collapse_lists l = - match l with - | [] -> [] - | Stmt.Fixed.{pattern= SList l'; _} :: rest -> l' @ collapse_lists rest - | x :: rest -> x :: collapse_lists rest in - let f = function - | Stmt.Fixed.Pattern.Block l -> Stmt.Fixed.Pattern.Block (collapse_lists l) - | SList l -> SList (collapse_lists l) - | x -> x in - map_rec_stmt_loc f + let collapse_lists_statement _ = + let rec collapse_lists l = + match l with + | [] -> [] + | Stmt.Fixed.{pattern= SList l'; _} :: rest -> l' @ collapse_lists rest + | x :: rest -> x :: collapse_lists rest in + let f = function + | Stmt.Fixed.Pattern.Block l -> + Stmt.Fixed.Pattern.Block (collapse_lists l) + | SList l -> SList (collapse_lists l) + | x -> x in + map_rec_stmt_loc f -let list_collapsing (mir : Program.Typed.t) = - transform_program_blockwise mir collapse_lists_statement + let list_collapsing (mir : Program.Typed.t) = + transform_program_blockwise mir collapse_lists_statement -let propagation - (propagation_transfer : - (int, Stmt.Located.Non_recursive.t) Map.Poly.t - -> (module Monotone_framework_sigs.TRANSFER_FUNCTION - with type labels = int - and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t - option ) ) (mir : Program.Typed.t) = - let transform stmt = - let flowgraph, flowgraph_to_mir = - Monotone_framework.forward_flowgraph_of_stmt stmt in - let (module Flowgraph) = flowgraph in - let values = - Monotone_framework.propagation_mfp mir - (module Flowgraph) - flowgraph_to_mir propagation_transfer in - let propagate_stmt = - map_rec_stmt_loc_num flowgraph_to_mir (fun i -> - subst_stmt_base - (Option.value ~default:Map.Poly.empty (Map.find_exn values i).entry) ) - in - propagate_stmt (Map.find_exn flowgraph_to_mir 1) in - transform_program mir transform + let propagation + (propagation_transfer : + (int, Stmt.Located.Non_recursive.t) Map.Poly.t + -> (module Monotone_framework_intf.TRANSFER_FUNCTION + with type labels = int + and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t + option ) ) (mir : Program.Typed.t) = + let transform stmt = + let flowgraph, flowgraph_to_mir = + Monotone_framework.forward_flowgraph_of_stmt stmt in + let (module Flowgraph) = flowgraph in + let values = + Monotone_framework.propagation_mfp mir + (module Flowgraph) + flowgraph_to_mir propagation_transfer in + let propagate_stmt = + map_rec_stmt_loc_num flowgraph_to_mir (fun i -> + subst_stmt_base + (Option.value ~default:Map.Poly.empty + (Map.find_exn values i).entry ) ) in + propagate_stmt (Map.find_exn flowgraph_to_mir 1) in + transform_program mir transform -let constant_propagation ?(preserve_stability = false) = - propagation - (Monotone_framework.constant_propagation_transfer ~preserve_stability) + let constant_propagation ?(preserve_stability = false) = + propagation + (Monotone_framework.constant_propagation_transfer + (module Partial_evaluator) + ~preserve_stability ) -let rec expr_any pred (e : Expr.Typed.t) = - match e.pattern with - | Indexed (e, is) -> expr_any pred e || List.exists ~f:(idx_any pred) is - | _ -> pred e || Expr.Fixed.Pattern.fold (accum_any pred) false e.pattern + let rec expr_any pred (e : Expr.Typed.t) = + match e.pattern with + | Indexed (e, is) -> expr_any pred e || List.exists ~f:(idx_any pred) is + | _ -> pred e || Expr.Fixed.Pattern.fold (accum_any pred) false e.pattern -and idx_any pred (i : Expr.Typed.t Index.t) = - Index.fold (accum_any pred) false i + and idx_any pred (i : Expr.Typed.t Index.t) = + Index.fold (accum_any pred) false i -and accum_any pred b e = b || expr_any pred e + and accum_any pred b e = b || expr_any pred e -let can_side_effect_top_expr (e : Expr.Typed.t) = - match e.pattern with - | FunApp ((UserDefined (_, FnTarget) | StanLib (_, FnTarget, _)), _) -> true - | FunApp (CompilerInternal internal_fn, _) -> - Internal_fun.can_side_effect internal_fn - | _ -> false + let can_side_effect_top_expr (e : Expr.Typed.t) = + match e.pattern with + | FunApp ((UserDefined (_, FnTarget) | StanLib (_, FnTarget, _)), _) -> true + | FunApp (CompilerInternal internal_fn, _) -> + Internal_fun.can_side_effect internal_fn + | _ -> false -let cannot_duplicate_expr ?(preserve_stability = false) (e : Expr.Typed.t) = - let pred e = - can_side_effect_top_expr e - || ( match e.pattern with - | FunApp ((UserDefined (_, FnRng) | StanLib (_, FnRng, _)), _) -> true - | _ -> false ) - || (preserve_stability && UnsizedType.is_autodiffable e.meta.type_) in - expr_any pred e + let cannot_duplicate_expr ?(preserve_stability = false) (e : Expr.Typed.t) = + let pred e = + can_side_effect_top_expr e + || ( match e.pattern with + | FunApp ((UserDefined (_, FnRng) | StanLib (_, FnRng, _)), _) -> true + | _ -> false ) + || (preserve_stability && UnsizedType.is_autodiffable e.meta.type_) in + expr_any pred e -let cannot_remove_expr (e : Expr.Typed.t) = expr_any can_side_effect_top_expr e + let cannot_remove_expr (e : Expr.Typed.t) = + expr_any can_side_effect_top_expr e -let expression_propagation ?(preserve_stability = false) mir = - propagation - (Monotone_framework.expression_propagation_transfer ~preserve_stability - (cannot_duplicate_expr ~preserve_stability) ) - mir + let expression_propagation ?(preserve_stability = false) mir = + propagation + (Monotone_framework.expression_propagation_transfer ~preserve_stability + (cannot_duplicate_expr ~preserve_stability) ) + mir -let copy_propagation mir = - let globals = Monotone_framework.globals mir in - propagation (Monotone_framework.copy_propagation_transfer globals) mir + let copy_propagation mir = + let globals = Monotone_framework.globals mir in + propagation (Monotone_framework.copy_propagation_transfer globals) mir -let is_skip_break_continue s = - match s with Stmt.Fixed.Pattern.Skip | Break | Continue -> true | _ -> false + let is_skip_break_continue s = + match s with + | Stmt.Fixed.Pattern.Skip | Break | Continue -> true + | _ -> false -(* TODO: could also implement partial dead code elimination *) -let dead_code_elimination (mir : Program.Typed.t) = - (* TODO: think about whether we should treat function bodies as local scopes in the statement - from the POV of a live variables analysis. - (Obviously, this shouldn't be the case for the purposes of reaching definitions, - constant propagation, expressions analyses. But I do think that's the right way to - go about live variables. *) - let transform s = - let rev_flowgraph, flowgraph_to_mir = - Monotone_framework.inverse_flowgraph_of_stmt s in - let (module Rev_Flowgraph) = rev_flowgraph in - let live_variables = - Monotone_framework.live_variables_mfp mir - (module Rev_Flowgraph) - flowgraph_to_mir in - let dead_code_elim_stmt_base i stmt = - (* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *) - let live_variables_s = - (Map.find_exn live_variables i).Monotone_framework_sigs.entry in - match stmt with - | Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) -> - if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then stmt - else Skip - | Assignment ((x, _, is), rhs) -> - if - Set.Poly.mem live_variables_s x - || cannot_remove_expr rhs - || List.exists ~f:(idx_any cannot_remove_expr) is - then stmt - else Skip - (* NOTE: we never get rid of declarations as we might not be able to - remove an assignment to a variable - due to side effects. *) - (* TODO: maybe we should revisit that. *) - | Decl _ | TargetPE _ - |NRFunApp (_, _) - |Break | Continue | Return _ | Skip -> - stmt - | IfElse (e, b1, b2) -> ( - if - (* TODO: check if e has side effects, like print, reject, then don't optimize? *) - (not (cannot_remove_expr e)) - && b1.Stmt.Fixed.pattern = Skip - && ( Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 - = Some Skip - || Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 - = None ) - then Skip - else - match e.pattern with - | Lit (Int, "0") | Lit (Real, "0.0") -> ( - match b2 with Some x -> x.pattern | None -> Skip ) - | Lit (_, _) -> b1.pattern - | _ -> IfElse (e, b1, b2) ) - | While (e, b) -> ( - if (not (cannot_remove_expr e)) && b.pattern = Break then Skip - else - match e.pattern with - | Lit (Int, "0") | Lit (Real, "0.0") -> Skip - | _ -> While (e, b) ) - | For {loopvar; lower; upper; body} -> - if - (not (cannot_remove_expr lower)) - && (not (cannot_remove_expr upper)) - && is_skip_break_continue body.pattern - then Skip - else For {loopvar; lower; upper; body} - | Profile (name, l) -> - let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in - if List.length l' = 0 then Skip else Profile (name, l') - | Block l -> - let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in - if List.length l' = 0 then Skip else Block l' - | SList l -> - let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in - SList l' in - let dead_code_elim_stmt = - map_rec_stmt_loc_num flowgraph_to_mir dead_code_elim_stmt_base in - dead_code_elim_stmt (Map.find_exn flowgraph_to_mir 1) in - transform_program mir transform + (* TODO: could also implement partial dead code elimination *) + let dead_code_elimination (mir : Program.Typed.t) = + (* TODO: think about whether we should treat function bodies as local scopes in the statement + from the POV of a live variables analysis. + (Obviously, this shouldn't be the case for the purposes of reaching definitions, + constant propagation, expressions analyses. But I do think that's the right way to + go about live variables. *) + let transform s = + let rev_flowgraph, flowgraph_to_mir = + Monotone_framework.inverse_flowgraph_of_stmt s in + let (module Rev_Flowgraph) = rev_flowgraph in + let live_variables = + Monotone_framework.live_variables_mfp mir + (module Rev_Flowgraph) + flowgraph_to_mir in + let dead_code_elim_stmt_base i stmt = + (* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *) + let live_variables_s = + (Map.find_exn live_variables i).Monotone_framework_intf.entry in + match stmt with + | Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) -> + if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then + stmt + else Skip + | Assignment ((x, _, is), rhs) -> + if + Set.Poly.mem live_variables_s x + || cannot_remove_expr rhs + || List.exists ~f:(idx_any cannot_remove_expr) is + then stmt + else Skip + (* NOTE: we never get rid of declarations as we might not be able to + remove an assignment to a variable + due to side effects. *) + (* TODO: maybe we should revisit that. *) + | Decl _ | TargetPE _ + |NRFunApp (_, _) + |Break | Continue | Return _ | Skip -> + stmt + | IfElse (e, b1, b2) -> ( + if + (* TODO: check if e has side effects, like print, reject, then don't optimize? *) + (not (cannot_remove_expr e)) + && b1.Stmt.Fixed.pattern = Skip + && ( Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 + = Some Skip + || Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 + = None ) + then Skip + else + match e.pattern with + | Lit (Int, "0") | Lit (Real, "0.0") -> ( + match b2 with Some x -> x.pattern | None -> Skip ) + | Lit (_, _) -> b1.pattern + | _ -> IfElse (e, b1, b2) ) + | While (e, b) -> ( + if (not (cannot_remove_expr e)) && b.pattern = Break then Skip + else + match e.pattern with + | Lit (Int, "0") | Lit (Real, "0.0") -> Skip + | _ -> While (e, b) ) + | For {loopvar; lower; upper; body} -> + if + (not (cannot_remove_expr lower)) + && (not (cannot_remove_expr upper)) + && is_skip_break_continue body.pattern + then Skip + else For {loopvar; lower; upper; body} + | Profile (name, l) -> + let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in + if List.length l' = 0 then Skip else Profile (name, l') + | Block l -> + let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in + if List.length l' = 0 then Skip else Block l' + | SList l -> + let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in + SList l' in + let dead_code_elim_stmt = + map_rec_stmt_loc_num flowgraph_to_mir dead_code_elim_stmt_base in + dead_code_elim_stmt (Map.find_exn flowgraph_to_mir 1) in + transform_program mir transform -let partial_evaluation = Partial_evaluator.eval_prog + let partial_evaluation = Partial_evaluator.eval_prog -(** + (** * Given a name and Stmt, search the statement for the first assignment * where that name is the assignee. *) -let rec find_assignment_idx (name : string) Stmt.Fixed.{pattern; _} = - let is_index = function Expr.Fixed.Pattern.Indexed _ -> true | _ -> false in - match pattern with - | Stmt.Fixed.Pattern.Assignment - ((assign_name, lhs_ut, idx_lst), (rhs : 'a Expr.Fixed.t)) - when name = assign_name - && (not (Set.Poly.mem (expr_var_names_set rhs) assign_name)) - && not - ( rhs.meta.adlevel = UnsizedType.DataOnly - && UnsizedType.is_array lhs_ut ) -> - Some (idx_lst, is_index rhs.pattern) - | _ -> None + let rec find_assignment_idx (name : string) Stmt.Fixed.{pattern; _} = + let is_index = function + | Expr.Fixed.Pattern.Indexed _ -> true + | _ -> false in + match pattern with + | Stmt.Fixed.Pattern.Assignment + ((assign_name, lhs_ut, idx_lst), (rhs : 'a Expr.Fixed.t)) + when name = assign_name + && (not (Set.Poly.mem (expr_var_names_set rhs) assign_name)) + && not + ( rhs.meta.adlevel = UnsizedType.DataOnly + && UnsizedType.is_array lhs_ut ) -> + Some (idx_lst, is_index rhs.pattern) + | _ -> None -(** + (** * Given a list of Stmts, find Decls whose objects are fully assigned to * in their first assignment and mark them as not needing to be * initialized. *) -and unenforce_initialize (lst : Stmt.Located.t list) = - let rec unenforce_initialize_patt (Stmt.Fixed.{pattern; _} as stmt) sub_lst = - match pattern with - | Stmt.Fixed.Pattern.Decl ({decl_id; decl_type; _} as decl_pat) -> ( - let is_soa = - match decl_type with - | Type.Sized s -> SizedType.get_mem_pattern s = Mem_pattern.SoA - | _ -> false in - match List.hd sub_lst with - | Some next_stmt -> ( - match find_assignment_idx decl_id next_stmt with - | Some - (([] | [Index.All] | [Index.All; Index.All]), is_assigned_to_index) - when not (is_soa && is_assigned_to_index) -> - { stmt with - pattern= - Stmt.Fixed.Pattern.Decl {decl_pat with initialize= false} } - | None | Some _ -> stmt ) - | None -> stmt ) - | Block block_lst -> - {stmt with pattern= Block (unenforce_initialize block_lst)} - | SList s_lst -> {stmt with pattern= SList (unenforce_initialize s_lst)} - (*[] here because we do not want to check out of scope*) - | While (expr, stmt) -> - {stmt with pattern= While (expr, unenforce_initialize_patt stmt [])} - | For ({body; _} as pat) -> - { stmt with - pattern= For {pat with body= unenforce_initialize_patt body []} } - | Profile ((pname : string), stmts) -> - {stmt with pattern= Profile (pname, unenforce_initialize stmts)} - | IfElse ((expr : 'a Expr.Fixed.t), true_stmt, op_false_stmt) -> - let mod_false_stmt = - Option.map ~f:(fun x -> unenforce_initialize_patt x []) op_false_stmt - in - { stmt with - pattern= - IfElse (expr, unenforce_initialize_patt true_stmt [], mod_false_stmt) - } - | _ -> stmt in - match List.hd lst with - | Some stmt -> ( - match List.tl lst with - | Some sub_lst -> - List.cons - (unenforce_initialize_patt stmt sub_lst) - (unenforce_initialize sub_lst) - | None -> lst ) - | None -> lst + and unenforce_initialize (lst : Stmt.Located.t list) = + let rec unenforce_initialize_patt (Stmt.Fixed.{pattern; _} as stmt) sub_lst + = + match pattern with + | Stmt.Fixed.Pattern.Decl ({decl_id; decl_type; _} as decl_pat) -> ( + let is_soa = + match decl_type with + | Type.Sized s -> SizedType.get_mem_pattern s = Mem_pattern.SoA + | _ -> false in + match List.hd sub_lst with + | Some next_stmt -> ( + match find_assignment_idx decl_id next_stmt with + | Some + ( ([] | [Index.All] | [Index.All; Index.All]) + , is_assigned_to_index ) + when not (is_soa && is_assigned_to_index) -> + { stmt with + pattern= + Stmt.Fixed.Pattern.Decl {decl_pat with initialize= false} } + | None | Some _ -> stmt ) + | None -> stmt ) + | Block block_lst -> + {stmt with pattern= Block (unenforce_initialize block_lst)} + | SList s_lst -> {stmt with pattern= SList (unenforce_initialize s_lst)} + (*[] here because we do not want to check out of scope*) + | While (expr, stmt) -> + {stmt with pattern= While (expr, unenforce_initialize_patt stmt [])} + | For ({body; _} as pat) -> + { stmt with + pattern= For {pat with body= unenforce_initialize_patt body []} } + | Profile ((pname : string), stmts) -> + {stmt with pattern= Profile (pname, unenforce_initialize stmts)} + | IfElse ((expr : 'a Expr.Fixed.t), true_stmt, op_false_stmt) -> + let mod_false_stmt = + Option.map + ~f:(fun x -> unenforce_initialize_patt x []) + op_false_stmt in + { stmt with + pattern= + IfElse + (expr, unenforce_initialize_patt true_stmt [], mod_false_stmt) + } + | _ -> stmt in + match List.hd lst with + | Some stmt -> ( + match List.tl lst with + | Some sub_lst -> + List.cons + (unenforce_initialize_patt stmt sub_lst) + (unenforce_initialize sub_lst) + | None -> lst ) + | None -> lst -(** + (** * Take the Mir and perform a transform that requires searching * across the list inside of each piece of the Mir. * @param mir The mir * @param transformer a function that takes in and returns a list of * Stmts. *) -let transform_mir_blocks (mir : Program.Typed.t) - (transformer : Stmt.Located.t list -> Stmt.Located.t list) : Program.Typed.t - = - let transformed_functions = - List.map mir.functions_block ~f:(fun fs -> - let new_body = - match fs.fdbody with - | Some (Stmt.Fixed.{pattern= SList lst; _} as stmt) -> - Some {stmt with pattern= SList (transformer lst)} - | Some (Stmt.Fixed.{pattern= Block lst; _} as stmt) -> - Some {stmt with pattern= Block (transformer lst)} - | alt -> alt in - {fs with fdbody= new_body} ) in - { Program.functions_block= transformed_functions - ; input_vars= mir.input_vars - ; prepare_data= transformer mir.prepare_data - ; log_prob= transformer mir.log_prob - ; generate_quantities= transformer mir.generate_quantities - ; transform_inits= transformer mir.transform_inits - ; unconstrain_array= transformer mir.unconstrain_array - ; output_vars= mir.output_vars - ; prog_name= mir.prog_name - ; prog_path= mir.prog_path } + let transform_mir_blocks (mir : Program.Typed.t) + (transformer : Stmt.Located.t list -> Stmt.Located.t list) : + Program.Typed.t = + let transformed_functions = + List.map mir.functions_block ~f:(fun fs -> + let new_body = + match fs.fdbody with + | Some (Stmt.Fixed.{pattern= SList lst; _} as stmt) -> + Some {stmt with pattern= SList (transformer lst)} + | Some (Stmt.Fixed.{pattern= Block lst; _} as stmt) -> + Some {stmt with pattern= Block (transformer lst)} + | alt -> alt in + {fs with fdbody= new_body} ) in + { Program.functions_block= transformed_functions + ; input_vars= mir.input_vars + ; prepare_data= transformer mir.prepare_data + ; log_prob= transformer mir.log_prob + ; generate_quantities= transformer mir.generate_quantities + ; transform_inits= transformer mir.transform_inits + ; unconstrain_array= transformer mir.unconstrain_array + ; output_vars= mir.output_vars + ; prog_name= mir.prog_name + ; prog_path= mir.prog_path } -let allow_uninitialized_decls mir = - transform_mir_blocks mir unenforce_initialize + let allow_uninitialized_decls mir = + transform_mir_blocks mir unenforce_initialize -let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) = - (* TODO: clean up this code. It is not very pretty. *) - (* TODO: make lazy code motion operate on transformed parameters and models blocks - simultaneously *) - let preprocess_flowgraph = - let preprocess_flowgraph_base - (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) = - match stmt with - | IfElse (e, b1, Some b2) -> - Stmt.Fixed.( - Pattern.IfElse + let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) = + (* TODO: clean up this code. It is not very pretty. *) + (* TODO: make lazy code motion operate on transformed parameters and models blocks + simultaneously *) + let preprocess_flowgraph = + let preprocess_flowgraph_base + (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) = + match stmt with + | IfElse (e, b1, Some b2) -> + Stmt.Fixed.( + Pattern.IfElse + ( e + , { pattern= + Block [b1; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } + , Some + { pattern= + Block [b2; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } )) + | IfElse (e, b, None) -> + IfElse ( e - , { pattern= Block [b1; {pattern= Skip; meta= Location_span.empty}] + , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] ; meta= Location_span.empty } - , Some + , Some {pattern= Skip; meta= Location_span.empty} ) + | While (e, b) -> + While + ( e + , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } ) + | For {loopvar; lower; upper; body= b} -> + For + { loopvar + ; lower + ; upper + ; body= { pattern= - Block [b2; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } )) - | IfElse (e, b, None) -> - IfElse - ( e - , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } - , Some {pattern= Skip; meta= Location_span.empty} ) - | While (e, b) -> - While - ( e - , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } ) - | For {loopvar; lower; upper; body= b} -> - For - { loopvar - ; lower - ; upper - ; body= - { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } } - | _ -> stmt in - map_rec_stmt_loc preprocess_flowgraph_base in - let transform s = - let rev_flowgraph, flowgraph_to_mir = - Monotone_framework.inverse_flowgraph_of_stmt ~blocks_after_body:false s - in - let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in - let latest_expr, used_not_latest_expressions_mfp = - Monotone_framework.lazy_expressions_mfp fwd_flowgraph rev_flowgraph - flowgraph_to_mir in - let expression_map = - let rec collect_expressions accum (e : Expr.Typed.t) = - match e.pattern with - | Lit (_, _) -> accum - | Var _ -> accum - | _ when cannot_duplicate_expr ~preserve_stability e -> - (* Immovable expressions might have movable subexpressions *) - Expr.Fixed.Pattern.fold collect_expressions accum e.pattern - | _ -> Map.set accum ~key:e ~data:(Gensym.generate ~prefix:"lcm_" ()) + Block [b; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } } + | _ -> stmt in + map_rec_stmt_loc preprocess_flowgraph_base in + let transform s = + let rev_flowgraph, flowgraph_to_mir = + Monotone_framework.inverse_flowgraph_of_stmt ~blocks_after_body:false s in - Set.fold - (Monotone_framework.used_expressions_stmt s.pattern) - ~init:Expr.Typed.Map.empty ~f:collect_expressions in - (* TODO: it'd be more efficient to just not accumulate constants in the static analysis *) - let declarations_list = - Map.fold expression_map ~init:[] ~f:(fun ~key ~data accum -> - Stmt.Fixed. - { pattern= - Pattern.Decl - { decl_adtype= Expr.Typed.adlevel_of key - ; decl_id= data - ; decl_type= Type.Unsized (Expr.Typed.type_of key) - ; initialize= true } - ; meta= Location_span.empty } - :: accum ) in - let lazy_code_motion_base i stmt = - let latest_and_used_after_i = - Set.inter - (Map.find_exn latest_expr i) - (Map.find_exn used_not_latest_expressions_mfp i).entry in - let to_assign_in_s = - latest_and_used_after_i - |> Set.filter ~f:(fun x -> Map.mem expression_map x) - |> Set.to_list - |> List.sort ~compare:(fun e e' -> - compare_int (expr_depth e) (expr_depth e') ) in - (* TODO: is this sort doing anything or are they already stored in the right order by - chance? It appears to not do anything. *) - let assignments_to_add_to_s = - List.map - ~f:(fun e -> + let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in + let latest_expr, used_not_latest_expressions_mfp = + Monotone_framework.lazy_expressions_mfp fwd_flowgraph rev_flowgraph + flowgraph_to_mir in + let expression_map = + let rec collect_expressions accum (e : Expr.Typed.t) = + match e.pattern with + | Lit (_, _) -> accum + | Var _ -> accum + | _ when cannot_duplicate_expr ~preserve_stability e -> + (* Immovable expressions might have movable subexpressions *) + Expr.Fixed.Pattern.fold collect_expressions accum e.pattern + | _ -> Map.set accum ~key:e ~data:(Gensym.generate ~prefix:"lcm_" ()) + in + Set.fold + (Monotone_framework.used_expressions_stmt s.pattern) + ~init:Expr.Typed.Map.empty ~f:collect_expressions in + (* TODO: it'd be more efficient to just not accumulate constants in the static analysis *) + let declarations_list = + Map.fold expression_map ~init:[] ~f:(fun ~key ~data accum -> Stmt.Fixed. { pattern= - Assignment - ((Map.find_exn expression_map e, e.meta.type_, []), e) - ; meta= Location_span.empty } ) - to_assign_in_s in - let expr_subst_stmt_except_initial_assign m = - let f stmt = - match stmt with - | Stmt.Fixed.Pattern.Assignment ((x, _, []), e') - when Map.mem m e' - && Expr.Typed.equal {e' with pattern= Var x} - (Map.find_exn m e') -> - expr_subst_stmt_base (Map.remove m e') stmt - | _ -> expr_subst_stmt_base m stmt in - map_rec_stmt_loc f in - let expr_map = - Map.filter_keys - ~f:(fun key -> - Set.mem latest_and_used_after_i key - || Set.mem (Map.find_exn used_not_latest_expressions_mfp i).exit key - ) - (Map.mapi expression_map ~f:(fun ~key ~data -> - {key with pattern= Var data} ) ) in - let f = expr_subst_stmt_except_initial_assign expr_map in - if List.length assignments_to_add_to_s = 0 then - (f Stmt.Fixed.{pattern= stmt; meta= Location_span.empty}).pattern - else - SList - (List.map ~f - ( assignments_to_add_to_s - @ [{pattern= stmt; meta= Location_span.empty}] ) ) in - let lazy_code_motion_stmt = - map_rec_stmt_loc_num flowgraph_to_mir lazy_code_motion_base in - Stmt.Fixed. - { pattern= + Pattern.Decl + { decl_adtype= Expr.Typed.adlevel_of key + ; decl_id= data + ; decl_type= Type.Unsized (Expr.Typed.type_of key) + ; initialize= true } + ; meta= Location_span.empty } + :: accum ) in + let lazy_code_motion_base i stmt = + let latest_and_used_after_i = + Set.inter + (Map.find_exn latest_expr i) + (Map.find_exn used_not_latest_expressions_mfp i).entry in + let to_assign_in_s = + latest_and_used_after_i + |> Set.filter ~f:(fun x -> Map.mem expression_map x) + |> Set.to_list + |> List.sort ~compare:(fun e e' -> + compare_int (expr_depth e) (expr_depth e') ) in + (* TODO: is this sort doing anything or are they already stored in the right order by + chance? It appears to not do anything. *) + let assignments_to_add_to_s = + List.map + ~f:(fun e -> + Stmt.Fixed. + { pattern= + Assignment + ((Map.find_exn expression_map e, e.meta.type_, []), e) + ; meta= Location_span.empty } ) + to_assign_in_s in + let expr_subst_stmt_except_initial_assign m = + let f stmt = + match stmt with + | Stmt.Fixed.Pattern.Assignment ((x, _, []), e') + when Map.mem m e' + && Expr.Typed.equal {e' with pattern= Var x} + (Map.find_exn m e') -> + expr_subst_stmt_base (Map.remove m e') stmt + | _ -> expr_subst_stmt_base m stmt in + map_rec_stmt_loc f in + let expr_map = + Map.filter_keys + ~f:(fun key -> + Set.mem latest_and_used_after_i key + || Set.mem (Map.find_exn used_not_latest_expressions_mfp i).exit + key ) + (Map.mapi expression_map ~f:(fun ~key ~data -> + {key with pattern= Var data} ) ) in + let f = expr_subst_stmt_except_initial_assign expr_map in + if List.length assignments_to_add_to_s = 0 then + (f Stmt.Fixed.{pattern= stmt; meta= Location_span.empty}).pattern + else SList - ( declarations_list - @ [lazy_code_motion_stmt (Map.find_exn flowgraph_to_mir 1)] ) - ; meta= Location_span.empty } in - let cleanup = - let cleanup_base (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) - : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t = - match stmt with - | Stmt.Fixed.( - Pattern.IfElse + (List.map ~f + ( assignments_to_add_to_s + @ [{pattern= stmt; meta= Location_span.empty}] ) ) in + let lazy_code_motion_stmt = + map_rec_stmt_loc_num flowgraph_to_mir lazy_code_motion_base in + Stmt.Fixed. + { pattern= + SList + ( declarations_list + @ [lazy_code_motion_stmt (Map.find_exn flowgraph_to_mir 1)] ) + ; meta= Location_span.empty } in + let cleanup = + let cleanup_base + (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) : + (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t = + match stmt with + | Stmt.Fixed.( + Pattern.IfElse + ( e + , {pattern= Block [b1; {pattern= Skip; _}]; _} + , Some {pattern= Block [b2; {pattern= Skip; _}]; _} )) -> + IfElse (e, b1, Some b2) + | IfElse ( e - , {pattern= Block [b1; {pattern= Skip; _}]; _} - , Some {pattern= Block [b2; {pattern= Skip; _}]; _} )) -> - IfElse (e, b1, Some b2) - | IfElse - ( e - , {pattern= Block [b; {pattern= Skip; _}]; _} - , Some {pattern= Skip; _} ) -> - IfElse (e, b, None) - | While (e, {pattern= Block [b; {pattern= Skip; _}]; _}) -> While (e, b) - | For - { loopvar - ; lower - ; upper - ; body= {pattern= Block [b; {pattern= Skip; _}]; _} } -> - For {loopvar; lower; upper; body= b} - | _ -> stmt in - map_rec_stmt_loc cleanup_base in - transform_program_blockwise mir (fun _ x -> - cleanup (transform (preprocess_flowgraph x)) ) + , {pattern= Block [b; {pattern= Skip; _}]; _} + , Some {pattern= Skip; _} ) -> + IfElse (e, b, None) + | While (e, {pattern= Block [b; {pattern= Skip; _}]; _}) -> While (e, b) + | For + { loopvar + ; lower + ; upper + ; body= {pattern= Block [b; {pattern= Skip; _}]; _} } -> + For {loopvar; lower; upper; body= b} + | _ -> stmt in + map_rec_stmt_loc cleanup_base in + transform_program_blockwise mir (fun _ x -> + cleanup (transform (preprocess_flowgraph x)) ) -let block_fixing mir = - transform_program_blockwise mir (fun _ x -> - (map_rec_stmt_loc (fun stmt -> - match stmt with - | IfElse - ( e - , {pattern= SList l; meta} - , Some {pattern= SList l'; meta= smeta'} ) -> - IfElse + let block_fixing mir = + transform_program_blockwise mir (fun _ x -> + (map_rec_stmt_loc (fun stmt -> + match stmt with + | IfElse ( e - , {pattern= Block l; meta} - , Some {pattern= Block l'; meta= smeta'} ) - | IfElse (e, {pattern= SList l; meta}, b) -> - IfElse (e, {pattern= Block l; meta}, b) - | IfElse (e, b, Some {pattern= SList l'; meta= smeta'}) -> - IfElse (e, b, Some {pattern= Block l'; meta= smeta'}) - | While (e, {pattern= SList l; meta}) -> - While (e, {pattern= Block l; meta}) - | For {loopvar; lower; upper; body= {pattern= SList l; meta}} -> - For {loopvar; lower; upper; body= {pattern= Block l; meta}} - | _ -> stmt ) ) - x ) + , {pattern= SList l; meta} + , Some {pattern= SList l'; meta= smeta'} ) -> + IfElse + ( e + , {pattern= Block l; meta} + , Some {pattern= Block l'; meta= smeta'} ) + | IfElse (e, {pattern= SList l; meta}, b) -> + IfElse (e, {pattern= Block l; meta}, b) + | IfElse (e, b, Some {pattern= SList l'; meta= smeta'}) -> + IfElse (e, b, Some {pattern= Block l'; meta= smeta'}) + | While (e, {pattern= SList l; meta}) -> + While (e, {pattern= Block l; meta}) + | For {loopvar; lower; upper; body= {pattern= SList l; meta}} -> + For {loopvar; lower; upper; body= {pattern= Block l; meta}} + | _ -> stmt ) ) + x ) -(* TODO: implement SlicStan style optimizer for choosing best program block for each statement. *) -(* TODO: add optimization pass to move declarations down as much as possible and introduce as - tight as possible local scopes *) -(* TODO: add tests *) -(* TODO: add pass to get rid of redundant declarations? *) + (* TODO: implement SlicStan style optimizer for choosing best program block for each statement. *) + (* TODO: add optimization pass to move declarations down as much as possible and introduce as + tight as possible local scopes *) + (* TODO: add tests *) + (* TODO: add pass to get rid of redundant declarations? *) -(** + (** * A generic optimization pass for finding a minimal set of variables that * are generated by some circumstance, and then updating the MIR with that set. * @param gen_variables: the variables that must be added to the set at @@ -1079,87 +1158,89 @@ let block_fixing mir = * @param initial_variables: the initial known members of the set of variables * @param stmt the MIR statement to optimize. *) -let optimize_minimal_variables - ~(gen_variables : - (int, Stmt.Located.Non_recursive.t) Map.Poly.t - -> int - -> string Set.Poly.t - -> string Set.Poly.t ) - ~(update_expr : string Set.Poly.t -> Expr.Typed.t -> Expr.Typed.t) - ~(update_stmt : - ( Expr.Typed.t - , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) - Stmt.Fixed.Pattern.t - -> string Core_kernel.Set.Poly.t - -> ( Expr.Typed.t - , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) - Stmt.Fixed.Pattern.t ) - ~(extra_variables : string -> string Set.Poly.t) - ~(initial_variables : string Set.Poly.t) (stmt : Stmt.Located.t) = - let rev_flowgraph, flowgraph_to_mir = - Monotone_framework.inverse_flowgraph_of_stmt stmt in - let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in - let (module Circular_Fwd_Flowgraph) = - Monotone_framework.make_circular_flowgraph fwd_flowgraph rev_flowgraph in - let mfp_variables = - Monotone_framework.minimal_variables_mfp - (module Circular_Fwd_Flowgraph) - flowgraph_to_mir initial_variables gen_variables in - let optimize_min_vars_stmt_base i stmt_pattern = - let variable_set = - let exits = (Map.find_exn mfp_variables i).exit in - Set.Poly.union exits (union_map exits ~f:extra_variables) in - let stmt_val = - Stmt.Fixed.Pattern.map (update_expr variable_set) - (fun x -> x) - stmt_pattern in - update_stmt stmt_val variable_set in - map_rec_stmt_loc_num flowgraph_to_mir optimize_min_vars_stmt_base - (Map.find_exn flowgraph_to_mir 1) + let optimize_minimal_variables + ~(gen_variables : + (int, Stmt.Located.Non_recursive.t) Map.Poly.t + -> int + -> string Set.Poly.t + -> string Set.Poly.t ) + ~(update_expr : string Set.Poly.t -> Expr.Typed.t -> Expr.Typed.t) + ~(update_stmt : + ( Expr.Typed.t + , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) + Stmt.Fixed.Pattern.t + -> string Core_kernel.Set.Poly.t + -> ( Expr.Typed.t + , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) + Stmt.Fixed.Pattern.t ) + ~(extra_variables : string -> string Set.Poly.t) + ~(initial_variables : string Set.Poly.t) (stmt : Stmt.Located.t) = + let rev_flowgraph, flowgraph_to_mir = + Monotone_framework.inverse_flowgraph_of_stmt stmt in + let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in + let (module Circular_Fwd_Flowgraph) = + Monotone_framework.make_circular_flowgraph fwd_flowgraph rev_flowgraph + in + let mfp_variables = + Monotone_framework.minimal_variables_mfp + (module Circular_Fwd_Flowgraph) + flowgraph_to_mir initial_variables gen_variables in + let optimize_min_vars_stmt_base i stmt_pattern = + let variable_set = + let exits = (Map.find_exn mfp_variables i).exit in + Set.Poly.union exits (union_map exits ~f:extra_variables) in + let stmt_val = + Stmt.Fixed.Pattern.map (update_expr variable_set) + (fun x -> x) + stmt_pattern in + update_stmt stmt_val variable_set in + map_rec_stmt_loc_num flowgraph_to_mir optimize_min_vars_stmt_base + (Map.find_exn flowgraph_to_mir 1) -let optimize_ad_levels (mir : Program.Typed.t) = - let gen_ad_variables - (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) - (l : int) (ad_variables : string Set.Poly.t) = - let mir_node = (Map.find_exn flowgraph_to_mir l).pattern in - match mir_node with - | Assignment ((x, _, _), e) - when Expr.Typed.adlevel_of (update_expr_ad_levels ad_variables e) - = AutoDiffable -> - Set.Poly.singleton x - | _ -> Set.Poly.empty in - let global_initial_ad_variables = - Set.Poly.of_list - (List.filter_map - ~f:(fun (v, _, Program.{out_block; _}) -> - match out_block with Parameters -> Some v | _ -> None ) - mir.output_vars ) in - let initial_ad_variables fundef_opt _ = - match (fundef_opt : Stmt.Located.t Program.fun_def option) with - | None -> global_initial_ad_variables - | Some {fdargs; _} -> - Set.Poly.union global_initial_ad_variables - (Set.Poly.of_list - (List.filter_map fdargs ~f:(fun (_, name, ut) -> - if UnsizedType.is_autodiffable ut then Some name else None ) - ) ) in - let extra_variables v = Set.Poly.singleton (v ^ "_in__") in - let update_stmt stmt_pattern variable_set = - match stmt_pattern with - | Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl) - when Set.mem variable_set decl_id -> - Stmt.Fixed.Pattern.Decl {decl with decl_adtype= UnsizedType.AutoDiffable} - | Decl ({decl_id; _} as decl) when not (Set.mem variable_set decl_id) -> - Decl {decl with decl_adtype= DataOnly} - | s -> s in - let transform fundef_opt stmt = - optimize_minimal_variables ~gen_variables:gen_ad_variables - ~update_expr:update_expr_ad_levels ~update_stmt ~extra_variables - ~initial_variables:(initial_ad_variables fundef_opt stmt) - stmt in - transform_program_blockwise mir transform + let optimize_ad_levels (mir : Program.Typed.t) = + let gen_ad_variables + (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) + (l : int) (ad_variables : string Set.Poly.t) = + let mir_node = (Map.find_exn flowgraph_to_mir l).pattern in + match mir_node with + | Assignment ((x, _, _), e) + when Expr.Typed.adlevel_of (update_expr_ad_levels ad_variables e) + = AutoDiffable -> + Set.Poly.singleton x + | _ -> Set.Poly.empty in + let global_initial_ad_variables = + Set.Poly.of_list + (List.filter_map + ~f:(fun (v, _, Program.{out_block; _}) -> + match out_block with Parameters -> Some v | _ -> None ) + mir.output_vars ) in + let initial_ad_variables fundef_opt _ = + match (fundef_opt : Stmt.Located.t Program.fun_def option) with + | None -> global_initial_ad_variables + | Some {fdargs; _} -> + Set.Poly.union global_initial_ad_variables + (Set.Poly.of_list + (List.filter_map fdargs ~f:(fun (_, name, ut) -> + if UnsizedType.is_autodiffable ut then Some name else None ) + ) ) in + let extra_variables v = Set.Poly.singleton (v ^ "_in__") in + let update_stmt stmt_pattern variable_set = + match stmt_pattern with + | Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl) + when Set.mem variable_set decl_id -> + Stmt.Fixed.Pattern.Decl + {decl with decl_adtype= UnsizedType.AutoDiffable} + | Decl ({decl_id; _} as decl) when not (Set.mem variable_set decl_id) -> + Decl {decl with decl_adtype= DataOnly} + | s -> s in + let transform fundef_opt stmt = + optimize_minimal_variables ~gen_variables:gen_ad_variables + ~update_expr:update_expr_ad_levels ~update_stmt ~extra_variables + ~initial_variables:(initial_ad_variables fundef_opt stmt) + stmt in + transform_program_blockwise mir transform -(** + (** * Deduces whether types can be Structures of Arrays (SoA/fast) or * Arrays of Structs (AoS/slow). See the docs in * Mem_pattern.query_demote_stmt/exprs* functions for @@ -1180,149 +1261,81 @@ let optimize_ad_levels (mir : Program.Typed.t) = * * @param mir: The program's whole MIR. *) -let optimize_soa (mir : Program.Typed.t) = - let gen_aos_variables - (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) - (l : int) (aos_variables : string Set.Poly.t) = - let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in - match (mir_node l).pattern with - | stmt -> Memory_patterns.query_demotable_stmt aos_variables stmt in - let initial_variables = - List.fold ~init:Set.Poly.empty - ~f:(Memory_patterns.query_initial_demotable_stmt false) - mir.log_prob in - (* - let print_set s = - Set.Poly.iter ~f:print_endline s in - let () = print_set initial_variables in - *) - let mod_exprs aos_exits mod_expr = - Mir_utils.map_rec_expr - (Memory_patterns.modify_expr_pattern aos_exits) - mod_expr in - let modify_stmt_patt stmt_pattern variable_set = - Memory_patterns.modify_stmt_pattern stmt_pattern variable_set in - let transform stmt = - optimize_minimal_variables ~gen_variables:gen_aos_variables - ~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables - stmt ~extra_variables:(fun _ -> initial_variables) in - let transform' s = - match transform {pattern= SList s; meta= Location_span.empty} with - | {pattern= SList (l : Stmt.Located.t list); _} -> l - | _ -> - Common.FatalError.fatal_error_msg - [%message "Something went wrong with program transformation packing!"] - in - {mir with log_prob= transform' mir.log_prob} - -(* Apparently you need to completely copy/paste type definitions between - ml and mli files?*) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } - -let settings_const b = - { function_inlining= b - ; static_loop_unrolling= b - ; one_step_loop_unrolling= b - ; list_collapsing= b - ; block_fixing= b - ; allow_uninitialized_decls= b - ; constant_propagation= b - ; expression_propagation= b - ; copy_propagation= b - ; dead_code_elimination= b - ; partial_evaluation= b - ; lazy_code_motion= b - ; optimize_ad_levels= b - ; preserve_stability= not b - ; optimize_soa= b } - -let all_optimizations : optimization_settings = settings_const true -let no_optimizations : optimization_settings = settings_const false - -type optimization_level = O0 | O1 | Oexperimental - -let level_optimizations (lvl : optimization_level) : optimization_settings = - match lvl with - | O0 -> no_optimizations - | O1 -> - { function_inlining= true - ; static_loop_unrolling= false - ; one_step_loop_unrolling= false - ; list_collapsing= true - ; block_fixing= true - ; constant_propagation= true - ; expression_propagation= false - ; copy_propagation= true - ; dead_code_elimination= true - ; partial_evaluation= true - ; lazy_code_motion= false - ; allow_uninitialized_decls= true - ; optimize_ad_levels= false - ; preserve_stability= false - ; optimize_soa= true } - | Oexperimental -> all_optimizations + let optimize_soa (mir : Program.Typed.t) = + let gen_aos_variables + (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) + (l : int) (aos_variables : string Set.Poly.t) = + let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in + match (mir_node l).pattern with + | stmt -> Mem.query_demotable_stmt aos_variables stmt in + let initial_variables = + List.fold ~init:Set.Poly.empty + ~f:(Mem.query_initial_demotable_stmt false) + mir.log_prob in + let mod_exprs aos_exits mod_expr = + Mir_utils.map_rec_expr (Mem.modify_expr_pattern aos_exits) mod_expr in + let modify_stmt_patt stmt_pattern variable_set = + Mem.modify_stmt_pattern stmt_pattern variable_set in + let transform stmt = + optimize_minimal_variables ~gen_variables:gen_aos_variables + ~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables + stmt ~extra_variables:(fun _ -> initial_variables) in + let transform' s = + match transform {pattern= SList s; meta= Location_span.empty} with + | {pattern= SList (l : Stmt.Located.t list); _} -> l + | _ -> + raise + (Failure "Something went wrong with program transformation packing!") + in + {mir with log_prob= transform' mir.log_prob} -let optimization_suite ?(settings = all_optimizations) mir = - let preserve_stability = settings.preserve_stability in - let maybe_optimizations = - [ (* Phase order. See phase-ordering-nodes.org for details *) - (* Book section A *) - (* Book section B *) - (* Book: Procedure integration *) - (function_inlining, settings.function_inlining) - (* Book: Sparse conditional constant propagation *) - ; (constant_propagation ~preserve_stability, settings.constant_propagation) - (* Book section C *) - (* Book: Local and global copy propagation *) - ; (copy_propagation, settings.copy_propagation) - (* Book: Sparse conditional constant propagation *) - ; (constant_propagation ~preserve_stability, settings.constant_propagation) - (* Book: Dead-code elimination *) - ; (dead_code_elimination, settings.dead_code_elimination) - (* Matthijs: Before lazy code motion to get loop-invariant code motion *) - ; (one_step_loop_unrolling, settings.one_step_loop_unrolling) - (* Matthjis: expression_propagation < partial_evaluation *) - ; ( expression_propagation ~preserve_stability - , settings.expression_propagation ) - (* Matthjis: partial_evaluation < lazy_code_motion *) - ; (partial_evaluation, settings.partial_evaluation) - (* Book: Loop-invariant code motion *) - ; (lazy_code_motion ~preserve_stability, settings.lazy_code_motion) - (* Matthijs: lazy_code_motion < copy_propagation TODO: Check if this is necessary *) - ; (copy_propagation, settings.copy_propagation) - (* Matthijs: Constant propagation before static loop unrolling *) - ; (constant_propagation ~preserve_stability, settings.constant_propagation) - (* Book: Loop simplification *) - ; (static_loop_unrolling, settings.static_loop_unrolling) - (* Book: Dead-code elimination *) - (* Matthijs: Everything < Dead-code elimination *) - ; (dead_code_elimination, settings.dead_code_elimination) - (* Book: Machine idioms and instruction combining *) - ; (list_collapsing, settings.list_collapsing) - (* Book: Machine idioms and instruction combining *) - ; (optimize_ad_levels, settings.optimize_ad_levels) - ; (optimize_soa, settings.optimize_soa) - (*Remove decls immediately assigned to*) - ; (allow_uninitialized_decls, settings.allow_uninitialized_decls) - (* Book: Machine idioms and instruction combining *) - (* Matthijs: Everything < block_fixing *) - ; (block_fixing, settings.block_fixing) ] in - let optimizations = - List.filter_map maybe_optimizations ~f:(fun (fn, flag) -> - if flag then Some fn else None ) in - List.fold optimizations ~init:mir ~f:(fun mir opt -> opt mir) + let optimization_suite ?(settings = all_optimizations) mir = + let preserve_stability = settings.preserve_stability in + let maybe_optimizations = + [ (* Phase order. See phase-ordering-nodes.org for details *) + (* Book section A *) + (* Book section B *) + (* Book: Procedure integration *) + (function_inlining, settings.function_inlining) + (* Book: Sparse conditional constant propagation *) + ; (constant_propagation ~preserve_stability, settings.constant_propagation) + (* Book section C *) + (* Book: Local and global copy propagation *) + ; (copy_propagation, settings.copy_propagation) + (* Book: Sparse conditional constant propagation *) + ; (constant_propagation ~preserve_stability, settings.constant_propagation) + (* Book: Dead-code elimination *) + ; (dead_code_elimination, settings.dead_code_elimination) + (* Matthijs: Before lazy code motion to get loop-invariant code motion *) + ; (one_step_loop_unrolling, settings.one_step_loop_unrolling) + (* Matthjis: expression_propagation < partial_evaluation *) + ; ( expression_propagation ~preserve_stability + , settings.expression_propagation ) + (* Matthjis: partial_evaluation < lazy_code_motion *) + ; (partial_evaluation, settings.partial_evaluation) + (* Book: Loop-invariant code motion *) + ; (lazy_code_motion ~preserve_stability, settings.lazy_code_motion) + (* Matthijs: lazy_code_motion < copy_propagation TODO: Check if this is necessary *) + ; (copy_propagation, settings.copy_propagation) + (* Matthijs: Constant propagation before static loop unrolling *) + ; (constant_propagation ~preserve_stability, settings.constant_propagation) + (* Book: Loop simplification *) + ; (static_loop_unrolling, settings.static_loop_unrolling) + (* Book: Dead-code elimination *) + (* Matthijs: Everything < Dead-code elimination *) + ; (dead_code_elimination, settings.dead_code_elimination) + (* Book: Machine idioms and instruction combining *) + ; (list_collapsing, settings.list_collapsing) + (* Book: Machine idioms and instruction combining *) + ; (optimize_ad_levels, settings.optimize_ad_levels) + ; (optimize_soa, settings.optimize_soa) + (*Remove decls immediately assigned to*) + ; (allow_uninitialized_decls, settings.allow_uninitialized_decls) + (* Book: Machine idioms and instruction combining *) + (* Matthijs: Everything < block_fixing *) + ; (block_fixing, settings.block_fixing) ] in + let optimizations = + List.filter_map maybe_optimizations ~f:(fun (fn, flag) -> + if flag then Some fn else None ) in + List.fold optimizations ~init:mir ~f:(fun mir opt -> opt mir) +end diff --git a/src/analysis_and_optimization/Optimize.mli b/src/analysis_and_optimization/Optimize.mli index 2885ce9b44..665e5d3914 100644 --- a/src/analysis_and_optimization/Optimize.mli +++ b/src/analysis_and_optimization/Optimize.mli @@ -1,85 +1,5 @@ (* Code for optimization passes on the MIR *) -open Middle - -val function_inlining : Program.Typed.t -> Program.Typed.t -(** Inline all functions except for ones with forward declarations - (e.g. recursive functions, mutually recursive functions, and - functions without a definition *) - -val static_loop_unrolling : Program.Typed.t -> Program.Typed.t -(** Unroll all for-loops with constant bounds, as long as they do - not contain break or continue statements in their body at the - top level *) - -val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t -(** Unroll all loops for one iteration, as long as they do - not contain break or continue statements in their body at the - top level *) - -val list_collapsing : Program.Typed.t -> Program.Typed.t -(** Remove redundant SList constructors from the Mir that might have - been introduced by other optimizations *) - -val block_fixing : Program.Typed.t -> Program.Typed.t -(** Make sure that SList constructors directly under if, for, while or fundef - constructors are replaced with Block constructors. - This should probably be run before we generate code. *) - -val constant_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t -(** Propagate constant values through variable assignments *) - -val expression_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t -(** Propagate arbitrary expressions through variable assignments. - This can be useful for opening up new possibilities for partial evaluation. - It should be followed by some CSE or lazy code motion pass, however. *) - -val copy_propagation : Program.Typed.t -> Program.Typed.t -(** Propagate copies of variables through assignments. *) - -val dead_code_elimination : Program.Typed.t -> Program.Typed.t -(** Eliminate semantically redundant code branches. - This includes removing redundant assignments (because they will be overwritten) - and removing redundant code in program branches that will never be reached. *) - -val partial_evaluation : Program.Typed.t -> Program.Typed.t -(** Partially evaluate expressions in the program. This includes simplification using - algebraic identities of logical and arithmetic operators as well as Stan math functions. *) - -val lazy_code_motion : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t -(** Perform partial redundancy elmination using the lazy code motion algorithm. This - subsumes common subexpression elimination and loop-invariant code motion. *) - -val optimize_ad_levels : Program.Typed.t -> Program.Typed.t -(** Assign the optimal ad-levels to local variables. That means, make sure that - variables only ever get treated as autodiff variables if they have some - dependency on a parameter *) - -val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t -(** Marks Decl types such that, if the first assignment after the decl - assigns to the full object, allow the object to be constructed but - not uninitialized. *) - -(** Interface for turning individual optimizations on/off. Useful for testing - and for top-level interface flags. *) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } +open Optimize_intf val all_optimizations : optimization_settings val no_optimizations : optimization_settings @@ -88,6 +8,8 @@ type optimization_level = O0 | O1 | Oexperimental val level_optimizations : optimization_level -> optimization_settings -val optimization_suite : - ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t -(** Perform all optimizations in this module on the MIR in an appropriate order. *) +(** Produce an optimizer for the MIR which is parameterized by the + given library of functions. These are used in the partial evaluator + and memory optimizations + *) +module Make (StdLibrary : Frontend.Std_library_utils.Library) : OPTIMIZER diff --git a/src/analysis_and_optimization/Optimize_intf.ml b/src/analysis_and_optimization/Optimize_intf.ml new file mode 100644 index 0000000000..d3b4fac4c0 --- /dev/null +++ b/src/analysis_and_optimization/Optimize_intf.ml @@ -0,0 +1,87 @@ +open Middle + +(** Interface for turning individual optimizations on/off. Useful for testing + and for top-level interface flags. *) +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +module type OPTIMIZER = sig + val function_inlining : Program.Typed.t -> Program.Typed.t + (** Inline all functions except for ones with forward declarations + (e.g. recursive functions, mutually recursive functions, and + functions without a definition *) + + val static_loop_unrolling : Program.Typed.t -> Program.Typed.t + (** Unroll all for-loops with constant bounds, as long as they do + not contain break or continue statements in their body at the + top level *) + + val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t + (** Unroll all loops for one iteration, as long as they do + not contain break or continue statements in their body at the + top level *) + + val list_collapsing : Program.Typed.t -> Program.Typed.t + (** Remove redundant SList constructors from the Mir that might have + been introduced by other optimizations *) + + val block_fixing : Program.Typed.t -> Program.Typed.t + (** Make sure that SList constructors directly under if, for, while or fundef + constructors are replaced with Block constructors. + This should probably be run before we generate code. *) + + val constant_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Propagate constant values through variable assignments *) + + val expression_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Propagate arbitrary expressions through variable assignments. + This can be useful for opening up new possibilities for partial evaluation. + It should be followed by some CSE or lazy code motion pass, however. *) + + val copy_propagation : Program.Typed.t -> Program.Typed.t + (** Propagate copies of variables through assignments. *) + + val dead_code_elimination : Program.Typed.t -> Program.Typed.t + (** Eliminate semantically redundant code branches. + This includes removing redundant assignments (because they will be overwritten) + and removing redundant code in program branches that will never be reached. *) + + val partial_evaluation : Program.Typed.t -> Program.Typed.t + (** Partially evaluate expressions in the program. This includes simplification using + algebraic identities of logical and arithmetic operators as well as Stan math functions. *) + + val lazy_code_motion : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Perform partial redundancy elmination using the lazy code motion algorithm. This + subsumes common subexpression elimination and loop-invariant code motion. *) + + val optimize_ad_levels : Program.Typed.t -> Program.Typed.t + (** Assign the optimal ad-levels to local variables. That means, make sure that + variables only ever get treated as autodiff variables if they have some + dependency on a parameter *) + + val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t + (** Marks Decl types such that, if the first assignment after the decl + assigns to the full object, allow the object to be constructed but + not uninitialized. *) + + val optimization_suite : + ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t + (** Perform all optimizations in this module on the MIR in an appropriate order. *) +end diff --git a/src/analysis_and_optimization/Partial_evaluation.ml b/src/analysis_and_optimization/Partial_evaluation.ml new file mode 100644 index 0000000000..af7e50738c --- /dev/null +++ b/src/analysis_and_optimization/Partial_evaluation.ml @@ -0,0 +1,1196 @@ +(* A partial evaluator for use in static analysis and optimization *) + +open Core_kernel +open Core_kernel.Poly +open Middle + +exception Rejected of Location_span.t * string + +let rec is_int query Expr.Fixed.{pattern; _} = + match pattern with + | Lit (Int, i) | Lit (Real, i) -> float_of_string i = float_of_int query + | Promotion (e, _, _) -> is_int query e + | _ -> false + +let apply_prefix_operator_int (op : string) i = + Expr.Fixed.Pattern.Lit + ( Int + , Int.to_string + ( match op with + | "PPlus__" -> i + | "PMinus__" -> -i + | "PNot__" -> if i = 0 then 1 else 0 + | s -> + Common.FatalError.fatal_error_msg + [%message "Not an int prefix operator: " s] ) ) + +let apply_prefix_operator_real (op : string) i = + Expr.Fixed.Pattern.Lit + ( Real + , Float.to_string + ( match op with + | "PPlus__" -> i + | "PMinus__" -> -.i + | s -> + Common.FatalError.fatal_error_msg + [%message "Not a real prefix operator: " s] ) ) + +let apply_operator_int (op : string) i1 i2 = + Expr.Fixed.Pattern.Lit + ( Int + , Int.to_string + ( match op with + | "Plus__" -> i1 + i2 + | "Minus__" -> i1 - i2 + | "Times__" -> i1 * i2 + | "Divide__" | "IntDivide__" -> i1 / i2 + | "Modulo__" -> i1 % i2 + | "Equals__" -> Bool.to_int (i1 = i2) + | "NEquals__" -> Bool.to_int (i1 <> i2) + | "Less__" -> Bool.to_int (i1 < i2) + | "Leq__" -> Bool.to_int (i1 <= i2) + | "Greater__" -> Bool.to_int (i1 > i2) + | "Geq__" -> Bool.to_int (i1 >= i2) + | s -> + Common.FatalError.fatal_error_msg + [%message "Not an int operator: " s] ) ) + +let apply_arithmetic_operator_real (op : string) r1 r2 = + Expr.Fixed.Pattern.Lit + ( Real + , Float.to_string + ( match op with + | "Plus__" -> r1 +. r2 + | "Minus__" -> r1 -. r2 + | "Times__" -> r1 *. r2 + | "Divide__" -> r1 /. r2 + | s -> + Common.FatalError.fatal_error_msg + [%message "Not a real operator: " s] ) ) + +let apply_logical_operator_real (op : string) r1 r2 = + Expr.Fixed.Pattern.Lit + ( Int + , Int.to_string + ( match op with + | "Equals__" -> Bool.to_int (r1 = r2) + | "NEquals__" -> Bool.to_int (r1 <> r2) + | "Less__" -> Bool.to_int (r1 < r2) + | "Leq__" -> Bool.to_int (r1 <= r2) + | "Greater__" -> Bool.to_int (r1 > r2) + | "Geq__" -> Bool.to_int (r1 >= r2) + | s -> + Common.FatalError.fatal_error_msg + [%message "Not a logical operator: " s] ) ) + +let is_multi_index = function + | Index.MultiIndex _ | Upfrom _ | Between _ | All -> true + | Single _ -> false + +module type PARTIAL_EVALUATOR = sig + val try_eval_expr : Expr.Typed.t -> Expr.Typed.t + val eval_prog : Program.Typed.t -> Program.Typed.t +end + +module Make (StdLibrary : Frontend.Std_library_utils.Library) : + PARTIAL_EVALUATOR = struct + module TC = Frontend.Typechecking.Make (StdLibrary) + + let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = + { e with + pattern= + ( match e.pattern with + | Var _ | Lit (_, _) -> e.pattern + | Promotion (expr, ut, ad) -> + Promotion (eval_expr ~preserve_stability expr, ut, ad) + | FunApp (kind, l) -> ( + let l = List.map ~f:(eval_expr ~preserve_stability) l in + match kind with + | UserDefined _ | CompilerInternal _ -> FunApp (kind, l) + | StanLib (f, suffix, mem_type) -> + let get_fun_or_op_rt_opt name l' = + let argument_types = + List.map + ~f:(fun x -> Expr.Typed.(adlevel_of x, type_of x)) + l' in + Operator.of_string_opt name + |> Option.value_map + ~f:(fun op -> + TC.operator_return_type op argument_types + |> Option.map ~f:fst ) + ~default: + (TC.library_function_return_type name argument_types) + in + let try_partially_evaluate_stanlib e = + Expr.Fixed.Pattern.( + match e with + | FunApp (StanLib (f', suffix', mem_type), l') -> ( + match get_fun_or_op_rt_opt f' l' with + | Some _ -> FunApp (StanLib (f', suffix', mem_type), l') + | None -> FunApp (StanLib (f, suffix, mem_type), l) ) + | e -> e) in + let lub_mem_pat lst = + Mem_pattern.lub_mem_pat (List.cons mem_type lst) in + try_partially_evaluate_stanlib + ( match (f, l) with + (* TODO: deal with tilde statements and unnormalized distributions properly here *) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ alpha + ; { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "bernoulli_logit_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_logit_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_logit_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + FunApp + ( StanLib + ( "bernoulli_logit_glm_lpmf" + , suffix + , lub_mem_pat [mem] ) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("bernoulli_logit_lpmf", suffix, lub_mem_pat [mem]) + , [y; alpha] ) + | ( "bernoulli_rng" + , [ { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("bernoulli_logit_rng", suffix, lub_mem_pat [mem]) + , [alpha] ) + | ( "binomial_lpmf" + , [ y; n + ; { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("binomial_logit_lpmf", suffix, lub_mem_pat [mem]) + , [y; n; alpha] ) + | ( "categorical_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("categorical_logit_lpmf", suffix, lub_mem_pat [mem]) + , [y; alpha] ) + | ( "categorical_rng" + , [ { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("categorical_logit_rng", suffix, lub_mem_pat [mem]) + , [alpha] ) + | "columns_dot_product", [x; y] when Expr.Typed.equal x y -> + FunApp + (StanLib ("columns_dot_self", suffix, mem_type), [x]) + | "dot_product", [x; y] when Expr.Typed.equal x y -> + FunApp (StanLib ("dot_self", suffix, mem_type), [x]) + | ( "inv" + , [{pattern= FunApp (StanLib ("sqrt", FnPlain, mem), l); _}] + ) -> + FunApp (StanLib ("inv_sqrt", suffix, mem), l) + | ( "inv" + , [ { pattern= FunApp (StanLib ("square", FnPlain, mem), [x]) + ; _ } ] ) -> + FunApp + (StanLib ("inv_square", suffix, lub_mem_pat [mem]), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem1) + , [ y + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } ] ) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log1m_exp", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem1) + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem2) + , [x] ) + ; _ } ] ) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log1m_inv_logit", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp (StanLib ("Minus__", FnPlain, mem), [y; x]) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + FunApp (StanLib ("log1m", suffix, lub_mem_pat [mem]), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ y + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } ] ) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log1p_exp", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp (StanLib ("Plus__", FnPlain, mem), [y; x]) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + FunApp (StanLib ("log1p", suffix, lub_mem_pat [mem]), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib (("fabs" | "abs"), FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("determinant", FnPlain, mem2) + , [x] ) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log_determinant", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem1) + , [ { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem3), [y]) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp (StanLib ("log_diff_exp", suffix, lub_mem), [x; y]) + (* TODO: log_mix?*) + | ( "log" + , [ { pattern= + FunApp + (StanLib ("falling_factorial", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + ( StanLib + ("log_falling_factorial", suffix, lub_mem_pat [mem]) + , l ) + | ( "log" + , [ { pattern= + FunApp + (StanLib ("rising_factorial", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + ( StanLib + ("log_rising_factorial", suffix, lub_mem_pat [mem]) + , l ) + | ( "log" + , [ { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + (StanLib ("log_inv_logit", suffix, lub_mem_pat [mem]), l) + | ( "log" + , [ { pattern= FunApp (StanLib ("softmax", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + (StanLib ("log_softmax", suffix, lub_mem_pat [mem]), l) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("sum", FnPlain, mem1) + , [ { pattern= + FunApp (StanLib ("exp", FnPlain, mem2), l) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log_sum_exp", suffix, lub_mem), l) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem3), [y]) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp (StanLib ("log_sum_exp", suffix, lub_mem), [x; y]) + | ( "multi_normal_lpdf" + , [ y; mu + ; { pattern= + FunApp (StanLib ("inverse", FnPlain, mem), [tau]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("multi_normal_prec_lpdf", suffix, lub_mem) + , [y; mu; tau] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ alpha + ; { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta; sigma] ) + | ( "neg_binomial_2_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_log_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta; sigma] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ }; phi ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("neg_binomial_2_log_lpmf", suffix, lub_mem) + , [y; eta; phi] ) + | ( "neg_binomial_2_rng" + , [ { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ }; phi ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("neg_binomial_2_log_rng", suffix, lub_mem) + , [eta; phi] ) + | ( "normal_lpdf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "normal_lpdf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "normal_lpdf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta; sigma] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ alpha + ; { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "poisson_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_log_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + (StanLib ("poisson_log_lpmf", suffix, lub_mem), [y; eta]) + | ( "poisson_rng" + , [ { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + (StanLib ("poisson_log_rng", suffix, lub_mem), [eta]) + | "pow", [y; x] when is_int 2 y -> + FunApp (StanLib ("exp2", suffix, mem_type), [x]) + | "rows_dot_product", [x; y] when Expr.Typed.equal x y -> + FunApp (StanLib ("rows_dot_self", suffix, mem_type), [x]) + | "pow", [x; {pattern= Lit (Int, "2"); _}] -> + FunApp (StanLib ("square", suffix, mem_type), [x]) + | "pow", [x; {pattern= Lit (Real, "0.5"); _}] -> + FunApp (StanLib ("sqrt", suffix, mem_type), [x]) + | ( "pow" + , [ x + ; { pattern= + FunApp (StanLib ("Divide__", FnPlain, mem), [y; z]) + ; _ } ] ) + when is_int 1 y && is_int 2 z -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("sqrt", suffix, lub_mem), [x]) + (* This is wrong; if both are type UInt the exponent is rounds down to zero. *) + | ( "square" + , [{pattern= FunApp (StanLib ("sd", FnPlain, mem), [x]); _}] + ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("variance", suffix, lub_mem), [x]) + | "sqrt", [x] when is_int 2 x -> + FunApp (StanLib ("sqrt2", suffix, mem_type), []) + | ( "sum" + , [ { pattern= + FunApp + ( StanLib ("square", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem2) + , [x; y] ) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + (StanLib ("squared_distance", suffix, lub_mem), [x; y]) + | ( "sum" + , [ { pattern= FunApp (StanLib ("diagonal", FnPlain, mem), l) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("trace", suffix, lub_mem), l) + | ( "trace" + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [ d + ; { pattern= + FunApp + ( StanLib + ( "transpose" + , FnPlain + , mem4 ) + , [b] ) + ; _ } ] ) + ; _ }; a ] ) + ; _ }; c ] ) + ; _ } ] ) + when Expr.Typed.equal b c -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3; mem4] in + FunApp + ( StanLib ("trace_gen_quad_form", suffix, lub_mem) + , [d; a; b] ) + | ( "trace" + , [ { pattern= + FunApp (StanLib ("quad_form", FnPlain, mem), [a; b]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + (StanLib ("trace_quad_form", suffix, lub_mem), [a; b]) + | ( ("Plus__" | "add") + , [ ({pattern= Lit (Imaginary, i); _} as im) + ; ({pattern= Lit ((Real | Int), _); _} as r) ] ) + |( ("Plus__" | "add") + , [ ({pattern= Lit ((Real | Int), _); _} as r) + ; ({pattern= Lit (Imaginary, i); _} as im) ] ) + |( ("Plus__" | "add") + , [ ({pattern= Lit (Imaginary, i); _} as im) + ; { pattern= + Promotion + ( ({pattern= Lit ((Real | Int), _); _} as r) + , UComplex + , _ ) + ; _ } ] ) + |( ("Plus__" | "add") + , [ { pattern= + Promotion + ( ({pattern= Lit ((Real | Int), _); _} as r) + , UComplex + , _ ) + ; _ }; ({pattern= Lit (Imaginary, i); _} as im) ] ) -> + let im_part = + Expr.Fixed. + { pattern= Lit (Real, i) + ; meta= {im.meta with type_= UReal} } in + FunApp + (StanLib ("to_complex", suffix, mem_type), [r; im_part]) + | ( "Minus__" + , [ x + ; {pattern= FunApp (StanLib ("erf", FnPlain, mem), l); _} + ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("erfc", suffix, lub_mem), l) + | ( "Minus__" + , [ x + ; {pattern= FunApp (StanLib ("erfc", FnPlain, mem), l); _} + ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("erf", suffix, lub_mem), l) + | ( "Minus__" + , [ {pattern= FunApp (StanLib ("exp", FnPlain, mem), l'); _} + ; x ] ) + when is_int 1 x && not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("expm1", suffix, lub_mem), l') + | ( "Plus__" + , [ { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) + ; _ }; z ] ) + when (not preserve_stability) + && not + ( UnsizedType.is_eigen_type x.meta.type_ + && UnsizedType.is_eigen_type y.meta.type_ ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Plus__" + , [ z + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) + ; _ } ] ) + when (not preserve_stability) + && not + ( UnsizedType.is_eigen_type x.meta.type_ + && UnsizedType.is_eigen_type y.meta.type_ ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Plus__" + , [ { pattern= + FunApp + ( StanLib + (("elt_multiply" | "EltTimes__"), FnPlain, mem) + , [x; y] ) + ; _ }; z ] ) + when not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Plus__" + , [ z + ; { pattern= + FunApp + ( StanLib + (("elt_multiply" | "EltTimes__"), FnPlain, mem) + , [x; y] ) + ; _ } ] ) + when not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Minus__" + , [ x + ; { pattern= FunApp (StanLib ("gamma_p", FnPlain, mem), l) + ; _ } ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("gamma_q", suffix, lub_mem), l) + | ( "Minus__" + , [ x + ; { pattern= FunApp (StanLib ("gamma_q", FnPlain, mem), l) + ; _ } ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("gamma_p", suffix, lub_mem), l) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("matrix_exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [t; a] ) + ; _ } ] ) + ; _ }; b ] ) + when Expr.Typed.type_of t = UInt + || Expr.Typed.type_of t = UReal -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) + , [t; a; b] ) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("matrix_exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [a; t] ) + ; _ } ] ) + ; _ }; b ] ) + when Expr.Typed.type_of t = UInt + || Expr.Typed.type_of t = UReal -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) + , [t; a; b] ) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("matrix_exp", FnPlain, mem), [a]) + ; _ }; b ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("matrix_exp_multiply", suffix, lub_mem) + , [a; b] ) + | ( "Times__" + , [ x + ; {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} + ] ) + |( "Times__" + , [ {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} + ; x ] ) + when not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("lmultiply", suffix, lub_mem), [x; y]) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem1), [v]) + ; _ } + ; { pattern= + FunApp + ( StanLib ("diag_post_multiply", FnPlain, mem2) + , [a; w] ) + ; _ } ] ) + when Expr.Typed.equal v w -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("diag_pre_multiply", FnPlain, mem1) + , [v; a] ) + ; _ } + ; { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem2), [w]) + ; _ } ] ) + when Expr.Typed.equal v w -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("transpose", FnPlain, mem1), [b]) + ; _ } + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem2), [a; c]) + ; _ } ] ) + when Expr.Typed.equal b c -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("transpose", FnPlain, mem2) + , [b] ) + ; _ }; a ] ) + ; _ }; c ] ) + when Expr.Typed.equal b c -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) + | ( "Times__" + , [ e1' + ; { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("diag_post_multiply", suffix, lub_mem) + , [e1'; v] ) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) + ; _ }; e2' ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("diag_pre_multiply", suffix, lub_mem) + , [v; e2'] ) + (* Constant folding for operators *) + | op, [{pattern= Lit (Int, i); _}] -> ( + match op with + | "PPlus__" | "PMinus__" | "PNot__" -> + apply_prefix_operator_int op (Int.of_string i) + | _ -> FunApp (kind, l) ) + | op, [{pattern= Lit (Real, r); _}] -> ( + match op with + | "PPlus__" | "PMinus__" -> + apply_prefix_operator_real op (Float.of_string r) + | _ -> FunApp (kind, l) ) + | ( ("Divide__" | "IntDivide__") + , [{meta= {type_= UInt; _}; _}; {pattern= Lit (Int, i2); _}] + ) + when Int.of_string i2 = 0 -> + raise (Rejected (e.meta.loc, "Integer division by zero")) + | ( op + , [{pattern= Lit (Int, i1); _}; {pattern= Lit (Int, i2); _}] + ) -> ( + match op with + | "Plus__" | "Minus__" | "Times__" | "Divide__" + |"IntDivide__" | "Modulo__" | "Or__" | "And__" + |"Equals__" | "NEquals__" | "Less__" | "Leq__" + |"Greater__" | "Geq__" -> + apply_operator_int op (Int.of_string i1) + (Int.of_string i2) + | _ -> FunApp (kind, l) ) + | ( op + , [ {pattern= Lit (Real, i1); _} + ; {pattern= Lit (Real, i2); _} ] ) + |( op + , [{pattern= Lit (Int, i1); _}; {pattern= Lit (Real, i2); _}] + ) + |( op + , [{pattern= Lit (Real, i1); _}; {pattern= Lit (Int, i2); _}] + ) -> ( + match op with + | "Plus__" | "Minus__" | "Times__" | "Divide__" -> + apply_arithmetic_operator_real op (Float.of_string i1) + (Float.of_string i2) + | "Or__" | "And__" | "Equals__" | "NEquals__" | "Less__" + |"Leq__" | "Greater__" | "Geq__" -> + apply_logical_operator_real op (Float.of_string i1) + (Float.of_string i2) + | _ -> FunApp (kind, l) ) + | _ -> FunApp (kind, l) ) ) + | TernaryIf (e1, e2, e3) -> ( + match + ( eval_expr ~preserve_stability e1 + , eval_expr ~preserve_stability e2 + , eval_expr ~preserve_stability e3 ) + with + | x, _, e3' when is_int 0 x -> e3'.pattern + | {pattern= Lit (Int, _); _}, e2', _ -> e2'.pattern + | e1', e2', e3' -> TernaryIf (e1', e2', e3') ) + | EAnd (e1, e2) -> ( + match + (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) + with + | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> + let i1, i2 = (Int.of_string s1, Int.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 && i2 <> 0))) + | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> + let r1, r2 = (Float.of_string s1, Float.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. && r2 <> 0.))) + | e1', e2' -> EAnd (e1', e2') ) + | EOr (e1, e2) -> ( + match + (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) + with + | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> + let i1, i2 = (Int.of_string s1, Int.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 || i2 <> 0))) + | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> + let r1, r2 = (Float.of_string s1, Float.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. || r2 <> 0.))) + | e1', e2' -> EOr (e1', e2') ) + | Indexed (e, l) -> + (* TODO: do something clever with array and matrix expressions here? + Note that we could also constant fold array sizes if we keep those around on declarations. *) + Indexed (eval_expr e, List.map ~f:(Index.map eval_expr) l) ) } + + let rec simplify_index_expr pattern = + Expr.Fixed.( + match pattern with + | Pattern.Indexed + ( { pattern= + Indexed (obj, inner_indices) + (* , Single ({emeta= {type_= UArray UInt; _} as emeta; _} as multi) + * :: inner_tl ) *) + ; meta } + , ( Single ({meta= Expr.Typed.Meta.{type_= UInt; _}; _} as single_e) + as single ) + :: outer_tl ) + when List.exists ~f:is_multi_index inner_indices -> ( + match List.split_while ~f:(Fn.non is_multi_index) inner_indices with + | inner_singles, MultiIndex first_multi :: inner_tl -> + (* foo [arr1, ..., arrN] [i1, ..., iN] -> + foo [arr1[i1]] [arr[i2]] ... [arrN[iN]] *) + simplify_index_expr + (Indexed + ( { pattern= + Indexed + ( obj + , inner_singles + @ [ Index.Single + { pattern= Indexed (first_multi, [single]) + ; meta= {meta with type_= UInt} } ] + @ inner_tl ) + ; meta } + , outer_tl ) ) + | inner_singles, All :: inner_tl -> + (* v[:x][i] -> v[i] *) + (* v[:][i] -> v[i] *) + (* XXX generate check *) + simplify_index_expr + (Indexed + ( { pattern= Indexed (obj, inner_singles @ [single] @ inner_tl) + ; meta } + , outer_tl ) ) + | inner_singles, Between (bot, _) :: inner_tl + |inner_singles, Upfrom bot :: inner_tl -> + (* v[x:y][z] -> v[x+z-1] *) + (* XXX generate check *) + simplify_index_expr + (Indexed + ( { pattern= + Indexed + ( obj + , inner_singles + @ [ Index.Single + Expr.Helpers.( + binop (binop bot Plus single_e) Minus + loop_bottom) ] + @ inner_tl ) + ; meta } + , outer_tl ) ) + | inner_singles, (([] | Single _ :: _) as multis) -> + Common.FatalError.fatal_error_msg + [%message + " There must be a multi-index." + (inner_singles : Expr.Typed.t Index.t list) + (multis : Expr.Typed.t Index.t list)] ) + | e -> e) + + let remove_trailing_alls_expr = function + | Expr.Fixed.Pattern.Indexed (obj, indices) -> + (* a[2][:] -> a[2] *) + let rec remove_trailing_alls indices = + match List.rev indices with + | Index.All :: tl -> remove_trailing_alls (List.rev tl) + | _ -> indices in + Expr.Fixed.Pattern.Indexed (obj, remove_trailing_alls indices) + | e -> e + + let rec simplify_indices_expr expr = + Expr.Fixed.( + let pattern = + expr.pattern |> remove_trailing_alls_expr |> simplify_index_expr + |> Expr.Fixed.Pattern.map simplify_indices_expr in + {expr with pattern}) + + let try_eval_expr expr = try eval_expr expr with Rejected _ -> expr + + let rec eval_stmt s = + try + Stmt.Fixed. + { s with + pattern= + Pattern.map + (Fn.compose eval_expr simplify_indices_expr) + eval_stmt s.pattern } + with Rejected (loc, m) -> + { Stmt.Fixed.pattern= + NRFunApp (CompilerInternal FnReject, [Expr.Helpers.str m]) + ; meta= loc } + + let eval_prog = Program.map try_eval_expr eval_stmt Fn.id +end diff --git a/src/analysis_and_optimization/Partial_evaluation.mli b/src/analysis_and_optimization/Partial_evaluation.mli new file mode 100644 index 0000000000..942872485d --- /dev/null +++ b/src/analysis_and_optimization/Partial_evaluation.mli @@ -0,0 +1,11 @@ +open Middle + +exception Rejected of Location_span.t * string + +module type PARTIAL_EVALUATOR = sig + val try_eval_expr : Expr.Typed.t -> Expr.Typed.t + val eval_prog : Program.Typed.t -> Program.Typed.t +end + +module Make (StdLibrary : Frontend.Std_library_utils.Library) : + PARTIAL_EVALUATOR diff --git a/src/analysis_and_optimization/Partial_evaluator.ml b/src/analysis_and_optimization/Partial_evaluator.ml deleted file mode 100644 index 5aef0bab75..0000000000 --- a/src/analysis_and_optimization/Partial_evaluator.ml +++ /dev/null @@ -1,1151 +0,0 @@ -(* A partial evaluator for use in static analysis and optimization *) - -open Core_kernel -open Core_kernel.Poly -open Middle - -exception Rejected of Location_span.t * string - -let rec is_int query Expr.Fixed.{pattern; _} = - match pattern with - | Lit (Int, i) | Lit (Real, i) -> float_of_string i = float_of_int query - | Promotion (e, _, _) -> is_int query e - | _ -> false - -let apply_prefix_operator_int (op : string) i = - Expr.Fixed.Pattern.Lit - ( Int - , Int.to_string - ( match op with - | "PPlus__" -> i - | "PMinus__" -> -i - | "PNot__" -> if i = 0 then 1 else 0 - | s -> - Common.FatalError.fatal_error_msg - [%message "Not an int prefix operator: " s] ) ) - -let apply_prefix_operator_real (op : string) i = - Expr.Fixed.Pattern.Lit - ( Real - , Float.to_string - ( match op with - | "PPlus__" -> i - | "PMinus__" -> -.i - | s -> - Common.FatalError.fatal_error_msg - [%message "Not a real prefix operator: " s] ) ) - -let apply_operator_int (op : string) i1 i2 = - Expr.Fixed.Pattern.Lit - ( Int - , Int.to_string - ( match op with - | "Plus__" -> i1 + i2 - | "Minus__" -> i1 - i2 - | "Times__" -> i1 * i2 - | "Divide__" | "IntDivide__" -> i1 / i2 - | "Modulo__" -> i1 % i2 - | "Equals__" -> Bool.to_int (i1 = i2) - | "NEquals__" -> Bool.to_int (i1 <> i2) - | "Less__" -> Bool.to_int (i1 < i2) - | "Leq__" -> Bool.to_int (i1 <= i2) - | "Greater__" -> Bool.to_int (i1 > i2) - | "Geq__" -> Bool.to_int (i1 >= i2) - | s -> - Common.FatalError.fatal_error_msg - [%message "Not an int operator: " s] ) ) - -let apply_arithmetic_operator_real (op : string) r1 r2 = - Expr.Fixed.Pattern.Lit - ( Real - , Float.to_string - ( match op with - | "Plus__" -> r1 +. r2 - | "Minus__" -> r1 -. r2 - | "Times__" -> r1 *. r2 - | "Divide__" -> r1 /. r2 - | s -> - Common.FatalError.fatal_error_msg - [%message "Not a real operator: " s] ) ) - -let apply_logical_operator_real (op : string) r1 r2 = - Expr.Fixed.Pattern.Lit - ( Int - , Int.to_string - ( match op with - | "Equals__" -> Bool.to_int (r1 = r2) - | "NEquals__" -> Bool.to_int (r1 <> r2) - | "Less__" -> Bool.to_int (r1 < r2) - | "Leq__" -> Bool.to_int (r1 <= r2) - | "Greater__" -> Bool.to_int (r1 > r2) - | "Geq__" -> Bool.to_int (r1 >= r2) - | s -> - Common.FatalError.fatal_error_msg - [%message "Not a logical operator: " s] ) ) - -let is_multi_index = function - | Index.MultiIndex _ | Upfrom _ | Between _ | All -> true - | Single _ -> false - -let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = - { e with - pattern= - ( match e.pattern with - | Var _ | Lit (_, _) -> e.pattern - | Promotion (expr, ut, ad) -> - Promotion (eval_expr ~preserve_stability expr, ut, ad) - | FunApp (kind, l) -> ( - let l = List.map ~f:(eval_expr ~preserve_stability) l in - match kind with - | UserDefined _ | CompilerInternal _ -> FunApp (kind, l) - | StanLib (f, suffix, mem_type) -> - let get_fun_or_op_rt_opt name l' = - let argument_types = - List.map ~f:(fun x -> Expr.Typed.(adlevel_of x, type_of x)) l' - in - Operator.of_string_opt name - |> Option.value_map - ~f:(fun op -> - Frontend.Typechecker.operator_stan_math_return_type op - argument_types - |> Option.map ~f:fst ) - ~default: - (Frontend.Typechecker.stan_math_return_type name - argument_types ) in - let try_partially_evaluate_stanlib e = - Expr.Fixed.Pattern.( - match e with - | FunApp (StanLib (f', suffix', mem_type), l') -> ( - match get_fun_or_op_rt_opt f' l' with - | Some _ -> FunApp (StanLib (f', suffix', mem_type), l') - | None -> FunApp (StanLib (f, suffix, mem_type), l) ) - | e -> e) in - let lub_mem_pat lst = - Mem_pattern.lub_mem_pat (List.cons mem_type lst) in - try_partially_evaluate_stanlib - ( match (f, l) with - (* TODO: deal with tilde statements and unnormalized distributions properly here *) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("inv_logit", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ alpha - ; { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("inv_logit", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("inv_logit", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "bernoulli_logit_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_logit_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_logit_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - FunApp - ( StanLib - ("bernoulli_logit_glm_lpmf", suffix, lub_mem_pat [mem]) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("bernoulli_logit_lpmf", suffix, lub_mem_pat [mem]) - , [y; alpha] ) - | ( "bernoulli_rng" - , [ { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("bernoulli_logit_rng", suffix, lub_mem_pat [mem]) - , [alpha] ) - | ( "binomial_lpmf" - , [ y; n - ; { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("binomial_logit_lpmf", suffix, lub_mem_pat [mem]) - , [y; n; alpha] ) - | ( "categorical_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("categorical_logit_lpmf", suffix, lub_mem_pat [mem]) - , [y; alpha] ) - | ( "categorical_rng" - , [ { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("categorical_logit_rng", suffix, lub_mem_pat [mem]) - , [alpha] ) - | "columns_dot_product", [x; y] when Expr.Typed.equal x y -> - FunApp (StanLib ("columns_dot_self", suffix, mem_type), [x]) - | "dot_product", [x; y] when Expr.Typed.equal x y -> - FunApp (StanLib ("dot_self", suffix, mem_type), [x]) - | ( "inv" - , [{pattern= FunApp (StanLib ("sqrt", FnPlain, mem), l); _}] ) - -> - FunApp (StanLib ("inv_sqrt", suffix, mem), l) - | ( "inv" - , [ { pattern= FunApp (StanLib ("square", FnPlain, mem), [x]) - ; _ } ] ) -> - FunApp - (StanLib ("inv_square", suffix, lub_mem_pat [mem]), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem1) - , [ y - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } ] ) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log1m_exp", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem1) - , [ y - ; { pattern= - FunApp - (StanLib ("inv_logit", FnPlain, mem2), [x]) - ; _ } ] ) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log1m_inv_logit", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp (StanLib ("Minus__", FnPlain, mem), [y; x]) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - FunApp (StanLib ("log1m", suffix, lub_mem_pat [mem]), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ y - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } ] ) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log1p_exp", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp (StanLib ("Plus__", FnPlain, mem), [y; x]) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - FunApp (StanLib ("log1p", suffix, lub_mem_pat [mem]), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib (("fabs" | "abs"), FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("determinant", FnPlain, mem2) - , [x] ) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log_determinant", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem1) - , [ { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem3), [y]) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp (StanLib ("log_diff_exp", suffix, lub_mem), [x; y]) - (* TODO: log_mix?*) - | ( "log" - , [ { pattern= - FunApp (StanLib ("falling_factorial", FnPlain, mem), l) - ; _ } ] ) -> - FunApp - ( StanLib - ("log_falling_factorial", suffix, lub_mem_pat [mem]) - , l ) - | ( "log" - , [ { pattern= - FunApp (StanLib ("rising_factorial", FnPlain, mem), l) - ; _ } ] ) -> - FunApp - ( StanLib - ("log_rising_factorial", suffix, lub_mem_pat [mem]) - , l ) - | ( "log" - , [ { pattern= FunApp (StanLib ("inv_logit", FnPlain, mem), l) - ; _ } ] ) -> - FunApp - (StanLib ("log_inv_logit", suffix, lub_mem_pat [mem]), l) - | ( "log" - , [{pattern= FunApp (StanLib ("softmax", FnPlain, mem), l); _}] - ) -> - FunApp - (StanLib ("log_softmax", suffix, lub_mem_pat [mem]), l) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("sum", FnPlain, mem1) - , [ { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), l) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log_sum_exp", suffix, lub_mem), l) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem3), [y]) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp (StanLib ("log_sum_exp", suffix, lub_mem), [x; y]) - | ( "multi_normal_lpdf" - , [ y; mu - ; { pattern= - FunApp (StanLib ("inverse", FnPlain, mem), [tau]) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("multi_normal_prec_lpdf", suffix, lub_mem) - , [y; mu; tau] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ alpha - ; { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta; sigma] ) - | ( "neg_binomial_2_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_log_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta; sigma] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; {pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _} - ; phi ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("neg_binomial_2_log_lpmf", suffix, lub_mem) - , [y; eta; phi] ) - | ( "neg_binomial_2_rng" - , [ {pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _} - ; phi ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("neg_binomial_2_log_rng", suffix, lub_mem) - , [eta; phi] ) - | ( "normal_lpdf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "normal_lpdf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "normal_lpdf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta; sigma] ) - | ( "poisson_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ alpha - ; { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "poisson_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_log_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "poisson_lpmf" - , [ y - ; {pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _} - ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("poisson_log_lpmf", suffix, lub_mem), [y; eta]) - | ( "poisson_rng" - , [{pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _}] - ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("poisson_log_rng", suffix, lub_mem), [eta]) - | "pow", [y; x] when is_int 2 y -> - FunApp (StanLib ("exp2", suffix, mem_type), [x]) - | "rows_dot_product", [x; y] when Expr.Typed.equal x y -> - FunApp (StanLib ("rows_dot_self", suffix, mem_type), [x]) - | "pow", [x; {pattern= Lit (Int, "2"); _}] -> - FunApp (StanLib ("square", suffix, mem_type), [x]) - | "pow", [x; {pattern= Lit (Real, "0.5"); _}] -> - FunApp (StanLib ("sqrt", suffix, mem_type), [x]) - | ( "pow" - , [ x - ; { pattern= - FunApp (StanLib ("Divide__", FnPlain, mem), [y; z]) - ; _ } ] ) - when is_int 1 y && is_int 2 z -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("sqrt", suffix, lub_mem), [x]) - (* This is wrong; if both are type UInt the exponent is rounds down to zero. *) - | ( "square" - , [{pattern= FunApp (StanLib ("sd", FnPlain, mem), [x]); _}] ) - -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("variance", suffix, lub_mem), [x]) - | "sqrt", [x] when is_int 2 x -> - FunApp (StanLib ("sqrt2", suffix, mem_type), []) - | ( "sum" - , [ { pattern= - FunApp - ( StanLib ("square", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem2) - , [x; y] ) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - (StanLib ("squared_distance", suffix, lub_mem), [x; y]) - | ( "sum" - , [ { pattern= FunApp (StanLib ("diagonal", FnPlain, mem), l) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("trace", suffix, lub_mem), l) - | ( "trace" - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [ d - ; { pattern= - FunApp - ( StanLib - ( "transpose" - , FnPlain - , mem4 ) - , [b] ) - ; _ } ] ) - ; _ }; a ] ) - ; _ }; c ] ) - ; _ } ] ) - when Expr.Typed.equal b c -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3; mem4] in - FunApp - ( StanLib ("trace_gen_quad_form", suffix, lub_mem) - , [d; a; b] ) - | ( "trace" - , [ { pattern= - FunApp (StanLib ("quad_form", FnPlain, mem), [a; b]) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("trace_quad_form", suffix, lub_mem), [a; b]) - | ( ("Plus__" | "add") - , [ ({pattern= Lit (Imaginary, i); _} as im) - ; ({pattern= Lit ((Real | Int), _); _} as r) ] ) - |( ("Plus__" | "add") - , [ ({pattern= Lit ((Real | Int), _); _} as r) - ; ({pattern= Lit (Imaginary, i); _} as im) ] ) - |( ("Plus__" | "add") - , [ ({pattern= Lit (Imaginary, i); _} as im) - ; { pattern= - Promotion - ( ({pattern= Lit ((Real | Int), _); _} as r) - , UComplex - , _ ) - ; _ } ] ) - |( ("Plus__" | "add") - , [ { pattern= - Promotion - ( ({pattern= Lit ((Real | Int), _); _} as r) - , UComplex - , _ ) - ; _ }; ({pattern= Lit (Imaginary, i); _} as im) ] ) -> - let im_part = - Expr.Fixed. - { pattern= Lit (Real, i) - ; meta= {im.meta with type_= UReal} } in - FunApp - (StanLib ("to_complex", suffix, mem_type), [r; im_part]) - | ( "Minus__" - , [x; {pattern= FunApp (StanLib ("erf", FnPlain, mem), l); _}] - ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("erfc", suffix, lub_mem), l) - | ( "Minus__" - , [x; {pattern= FunApp (StanLib ("erfc", FnPlain, mem), l); _}] - ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("erf", suffix, lub_mem), l) - | ( "Minus__" - , [{pattern= FunApp (StanLib ("exp", FnPlain, mem), l'); _}; x] - ) - when is_int 1 x && not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("expm1", suffix, lub_mem), l') - | ( "Plus__" - , [ { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) - ; _ }; z ] ) - when (not preserve_stability) - && not - ( UnsizedType.is_eigen_type x.meta.type_ - && UnsizedType.is_eigen_type y.meta.type_ ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Plus__" - , [ z - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) - ; _ } ] ) - when (not preserve_stability) - && not - ( UnsizedType.is_eigen_type x.meta.type_ - && UnsizedType.is_eigen_type y.meta.type_ ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Plus__" - , [ { pattern= - FunApp - ( StanLib - (("elt_multiply" | "EltTimes__"), FnPlain, mem) - , [x; y] ) - ; _ }; z ] ) - when not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Plus__" - , [ z - ; { pattern= - FunApp - ( StanLib - (("elt_multiply" | "EltTimes__"), FnPlain, mem) - , [x; y] ) - ; _ } ] ) - when not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Minus__" - , [ x - ; {pattern= FunApp (StanLib ("gamma_p", FnPlain, mem), l); _} - ] ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("gamma_q", suffix, lub_mem), l) - | ( "Minus__" - , [ x - ; {pattern= FunApp (StanLib ("gamma_q", FnPlain, mem), l); _} - ] ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("gamma_p", suffix, lub_mem), l) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("matrix_exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [t; a] ) - ; _ } ] ) - ; _ }; b ] ) - when Expr.Typed.type_of t = UInt - || Expr.Typed.type_of t = UReal -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) - , [t; a; b] ) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("matrix_exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [a; t] ) - ; _ } ] ) - ; _ }; b ] ) - when Expr.Typed.type_of t = UInt - || Expr.Typed.type_of t = UReal -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) - , [t; a; b] ) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("matrix_exp", FnPlain, mem), [a]) - ; _ }; b ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("matrix_exp_multiply", suffix, lub_mem), [a; b]) - | ( "Times__" - , [ x - ; {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} - ] ) - |( "Times__" - , [ {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} - ; x ] ) - when not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("lmultiply", suffix, lub_mem), [x; y]) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem1), [v]) - ; _ } - ; { pattern= - FunApp - ( StanLib ("diag_post_multiply", FnPlain, mem2) - , [a; w] ) - ; _ } ] ) - when Expr.Typed.equal v w -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("diag_pre_multiply", FnPlain, mem1) - , [v; a] ) - ; _ } - ; { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem2), [w]) - ; _ } ] ) - when Expr.Typed.equal v w -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("transpose", FnPlain, mem1), [b]) - ; _ } - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem2), [a; c]) - ; _ } ] ) - when Expr.Typed.equal b c -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem1) - , [ { pattern= - FunApp - (StanLib ("transpose", FnPlain, mem2), [b]) - ; _ }; a ] ) - ; _ }; c ] ) - when Expr.Typed.equal b c -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) - | ( "Times__" - , [ e1' - ; { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("diag_post_multiply", suffix, lub_mem), [e1'; v]) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) - ; _ }; e2' ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("diag_pre_multiply", suffix, lub_mem), [v; e2']) - (* Constant folding for operators *) - | op, [{pattern= Lit (Int, i); _}] -> ( - match op with - | "PPlus__" | "PMinus__" | "PNot__" -> - apply_prefix_operator_int op (Int.of_string i) - | _ -> FunApp (kind, l) ) - | op, [{pattern= Lit (Real, r); _}] -> ( - match op with - | "PPlus__" | "PMinus__" -> - apply_prefix_operator_real op (Float.of_string r) - | _ -> FunApp (kind, l) ) - | ( ("Divide__" | "IntDivide__") - , [{meta= {type_= UInt; _}; _}; {pattern= Lit (Int, i2); _}] ) - when Int.of_string i2 = 0 -> - raise (Rejected (e.meta.loc, "Integer division by zero")) - | op, [{pattern= Lit (Int, i1); _}; {pattern= Lit (Int, i2); _}] - -> ( - match op with - | "Plus__" | "Minus__" | "Times__" | "Divide__" - |"IntDivide__" | "Modulo__" | "Or__" | "And__" | "Equals__" - |"NEquals__" | "Less__" | "Leq__" | "Greater__" | "Geq__" -> - apply_operator_int op (Int.of_string i1) - (Int.of_string i2) - | _ -> FunApp (kind, l) ) - | ( op - , [{pattern= Lit (Real, i1); _}; {pattern= Lit (Real, i2); _}] - ) - |op, [{pattern= Lit (Int, i1); _}; {pattern= Lit (Real, i2); _}] - |op, [{pattern= Lit (Real, i1); _}; {pattern= Lit (Int, i2); _}] - -> ( - match op with - | "Plus__" | "Minus__" | "Times__" | "Divide__" -> - apply_arithmetic_operator_real op (Float.of_string i1) - (Float.of_string i2) - | "Or__" | "And__" | "Equals__" | "NEquals__" | "Less__" - |"Leq__" | "Greater__" | "Geq__" -> - apply_logical_operator_real op (Float.of_string i1) - (Float.of_string i2) - | _ -> FunApp (kind, l) ) - | _ -> FunApp (kind, l) ) ) - | TernaryIf (e1, e2, e3) -> ( - match - ( eval_expr ~preserve_stability e1 - , eval_expr ~preserve_stability e2 - , eval_expr ~preserve_stability e3 ) - with - | x, _, e3' when is_int 0 x -> e3'.pattern - | {pattern= Lit (Int, _); _}, e2', _ -> e2'.pattern - | e1', e2', e3' -> TernaryIf (e1', e2', e3') ) - | EAnd (e1, e2) -> ( - match - (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) - with - | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> - let i1, i2 = (Int.of_string s1, Int.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 && i2 <> 0))) - | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> - let r1, r2 = (Float.of_string s1, Float.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. && r2 <> 0.))) - | e1', e2' -> EAnd (e1', e2') ) - | EOr (e1, e2) -> ( - match - (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) - with - | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> - let i1, i2 = (Int.of_string s1, Int.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 || i2 <> 0))) - | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> - let r1, r2 = (Float.of_string s1, Float.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. || r2 <> 0.))) - | e1', e2' -> EOr (e1', e2') ) - | Indexed (e, l) -> - (* TODO: do something clever with array and matrix expressions here? - Note that we could also constant fold array sizes if we keep those around on declarations. *) - Indexed (eval_expr e, List.map ~f:(Index.map eval_expr) l) ) } - -let rec simplify_index_expr pattern = - Expr.Fixed.( - match pattern with - | Pattern.Indexed - ( { pattern= - Indexed (obj, inner_indices) - (* , Single ({emeta= {type_= UArray UInt; _} as emeta; _} as multi) - * :: inner_tl ) *) - ; meta } - , ( Single ({meta= Expr.Typed.Meta.{type_= UInt; _}; _} as single_e) as - single ) - :: outer_tl ) - when List.exists ~f:is_multi_index inner_indices -> ( - match List.split_while ~f:(Fn.non is_multi_index) inner_indices with - | inner_singles, MultiIndex first_multi :: inner_tl -> - (* foo [arr1, ..., arrN] [i1, ..., iN] -> - foo [arr1[i1]] [arr[i2]] ... [arrN[iN]] *) - simplify_index_expr - (Indexed - ( { pattern= - Indexed - ( obj - , inner_singles - @ [ Index.Single - { pattern= Indexed (first_multi, [single]) - ; meta= {meta with type_= UInt} } ] - @ inner_tl ) - ; meta } - , outer_tl ) ) - | inner_singles, All :: inner_tl -> - (* v[:x][i] -> v[i] *) - (* v[:][i] -> v[i] *) - (* XXX generate check *) - simplify_index_expr - (Indexed - ( { pattern= Indexed (obj, inner_singles @ [single] @ inner_tl) - ; meta } - , outer_tl ) ) - | inner_singles, Between (bot, _) :: inner_tl - |inner_singles, Upfrom bot :: inner_tl -> - (* v[x:y][z] -> v[x+z-1] *) - (* XXX generate check *) - simplify_index_expr - (Indexed - ( { pattern= - Indexed - ( obj - , inner_singles - @ [ Index.Single - Expr.Helpers.( - binop (binop bot Plus single_e) Minus - loop_bottom) ] - @ inner_tl ) - ; meta } - , outer_tl ) ) - | inner_singles, (([] | Single _ :: _) as multis) -> - Common.FatalError.fatal_error_msg - [%message - " There must be a multi-index." - (inner_singles : Expr.Typed.t Index.t list) - (multis : Expr.Typed.t Index.t list)] ) - | e -> e) - -let remove_trailing_alls_expr = function - | Expr.Fixed.Pattern.Indexed (obj, indices) -> - (* a[2][:] -> a[2] *) - let rec remove_trailing_alls indices = - match List.rev indices with - | Index.All :: tl -> remove_trailing_alls (List.rev tl) - | _ -> indices in - Expr.Fixed.Pattern.Indexed (obj, remove_trailing_alls indices) - | e -> e - -let rec simplify_indices_expr expr = - Expr.Fixed.( - let pattern = - expr.pattern |> remove_trailing_alls_expr |> simplify_index_expr - |> Expr.Fixed.Pattern.map simplify_indices_expr in - {expr with pattern}) - -let try_eval_expr expr = try eval_expr expr with Rejected _ -> expr - -let rec eval_stmt s = - try - Stmt.Fixed. - { s with - pattern= - Pattern.map - (Fn.compose eval_expr simplify_indices_expr) - eval_stmt s.pattern } - with Rejected (loc, m) -> - { Stmt.Fixed.pattern= - NRFunApp (CompilerInternal FnReject, [Expr.Helpers.str m]) - ; meta= loc } - -let eval_prog p : Program.Typed.t = Program.map try_eval_expr eval_stmt Fn.id p diff --git a/src/analysis_and_optimization/Pedantic_analysis.ml b/src/analysis_and_optimization/Pedantic_analysis.ml index 2f629326c1..2c3a0cb3e8 100644 --- a/src/analysis_and_optimization/Pedantic_analysis.ml +++ b/src/analysis_and_optimization/Pedantic_analysis.ml @@ -492,11 +492,15 @@ let settings_constant_prop = ; copy_propagation= true ; partial_evaluation= true } +(** Pedantic mode is only really valid for the Stan Math backend *) +module Optimizer = Optimize.Make (Stan_math_backend.Stan_math_library) + (* Collect all pedantic mode warnings, sorted, to stderr *) let warn_pedantic (mir_unopt : Program.Typed.t) = (* Some warnings will be stronger when constants are propagated *) let mir = - Optimize.optimization_suite ~settings:settings_constant_prop mir_unopt in + Optimizer.optimization_suite ~settings:settings_constant_prop mir_unopt + in (* Try to avoid recomputation by pre-building structures *) let distributions_info = list_distributions mir in let factor_graph = prog_factor_graph mir in diff --git a/src/analysis_and_optimization/dune b/src/analysis_and_optimization/dune index 8153633605..36228b6995 100644 --- a/src/analysis_and_optimization/dune +++ b/src/analysis_and_optimization/dune @@ -1,11 +1,9 @@ (library (name analysis_and_optimization) (public_name stanc.analysis) - (libraries core_kernel str fmt common middle frontend) + (libraries core_kernel str fmt common middle frontend stan_math_backend) (instrumentation (backend bisect_ppx)) (inline_tests) - ;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation - (modules_without_implementation monotone_framework_sigs) (preprocess (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) diff --git a/src/frontend/Ast.ml b/src/frontend/Ast.ml index 253d6e2165..1758c10097 100644 --- a/src/frontend/Ast.ml +++ b/src/frontend/Ast.ml @@ -90,6 +90,15 @@ let expr_loc_lub exprs = let expr_ad_lub exprs = exprs |> List.map ~f:(fun x -> x.emeta.ad_level) |> UnsizedType.lub_ad_type +let mk_fun_app ~is_cond_dist ~loc kind name args ~type_ : typed_expression = + let fn = + if is_cond_dist then CondDistApp (kind, name, args) + else FunApp (kind, name, args) in + mk_typed_expression ~expr:fn ~loc ~type_ + ~ad_level: + ( if UnsizedType.is_int_type type_ then UnsizedType.DataOnly + else expr_ad_lub args ) + (** Assignment operators *) type assignmentoperator = | Assign diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 21eaee43ea..5789d81cbe 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -2,893 +2,1067 @@ open Core_kernel open Core_kernel.Poly open Middle -let trans_fn_kind kind name = - let fname = Utils.stdlib_distribution_name name in - match kind with - | Ast.StanLib suffix -> Fun_kind.StanLib (fname, suffix, AoS) - | UserDefined suffix -> UserDefined (fname, suffix) - -let without_underscores = String.filter ~f:(( <> ) '_') - -let drop_leading_zeros s = - match String.lfindi ~f:(fun _ c -> c <> '0') s with - | Some p when p > 0 -> ( - match s.[p] with - | 'e' | '.' -> String.drop_prefix s (p - 1) - | _ -> String.drop_prefix s p ) - | Some _ -> s - | None -> "0" - -let format_number s = s |> without_underscores |> drop_leading_zeros - -let%expect_test "format_number0" = - format_number "0_000." |> print_endline ; - [%expect "0."] - -let%expect_test "format_number1" = - format_number ".123_456" |> print_endline ; - [%expect ".123456"] - -let rec op_to_funapp op args type_ = - let loc = Ast.expr_loc_lub args in - let adlevel = Ast.expr_ad_lub args in - Expr. - { Fixed.pattern= - FunApp (StanLib (Operator.to_string op, FnPlain, AoS), trans_exprs args) - ; meta= Expr.Typed.Meta.create ~type_ ~adlevel ~loc () } - -and trans_expr {Ast.expr; Ast.emeta} = - let ewrap pattern = +module type AST_MIR_TRANSLATOR = sig + val gather_declarations : + Ast.typed_statement Ast.block option + -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list + + val trans_prog : string -> Ast.typed_program -> Program.Typed.t +end + +module Make (StdLibrary : Std_library_utils.Library) : AST_MIR_TRANSLATOR = +struct + let trans_fn_kind kind name = + let fname = Utils.stdlib_distribution_name name in + match kind with + | Ast.StanLib suffix -> Fun_kind.StanLib (fname, suffix, AoS) + | UserDefined suffix -> UserDefined (fname, suffix) + + let without_underscores = String.filter ~f:(( <> ) '_') + + let drop_leading_zeros s = + match String.lfindi ~f:(fun _ c -> c <> '0') s with + | Some p when p > 0 -> ( + match s.[p] with + | 'e' | '.' -> String.drop_prefix s (p - 1) + | _ -> String.drop_prefix s p ) + | Some _ -> s + | None -> "0" + + let format_number s = s |> without_underscores |> drop_leading_zeros + + let rec op_to_funapp op args type_ = + let loc = Ast.expr_loc_lub args in + let adlevel = Ast.expr_ad_lub args in + Expr. + { Fixed.pattern= + FunApp + (StanLib (Operator.to_string op, FnPlain, AoS), trans_exprs args) + ; meta= Expr.Typed.Meta.create ~type_ ~adlevel ~loc () } + + and trans_expr {Ast.expr; Ast.emeta} = + let ewrap pattern = + Expr. + { Fixed.pattern + ; meta= + Typed.Meta. + {type_= emeta.Ast.type_; adlevel= emeta.ad_level; loc= emeta.loc} + } in + match expr with + | Ast.Paren x -> trans_expr x + | BinOp (lhs, And, rhs) -> EAnd (trans_expr lhs, trans_expr rhs) |> ewrap + | BinOp (lhs, Or, rhs) -> EOr (trans_expr lhs, trans_expr rhs) |> ewrap + | BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs] emeta.type_ + | PrefixOp (op, e) | Ast.PostfixOp (e, op) -> + op_to_funapp op [e] emeta.type_ + | Ast.TernaryIf (cond, ifb, elseb) -> + Expr.Fixed.Pattern.TernaryIf + (trans_expr cond, trans_expr ifb, trans_expr elseb) + |> ewrap + | Variable {name; _} -> Var name |> ewrap + | IntNumeral x -> Lit (Int, format_number x) |> ewrap + | RealNumeral x -> Lit (Real, format_number x) |> ewrap + | ImagNumeral x -> Lit (Imaginary, format_number x) |> ewrap + | FunApp (fn_kind, {name; _}, args) | CondDistApp (fn_kind, {name; _}, args) + -> + FunApp (trans_fn_kind fn_kind name, trans_exprs args) |> ewrap + | GetLP | GetTarget -> + FunApp (StanLib ("target", FnTarget, AoS), []) |> ewrap + | ArrayExpr eles -> + FunApp (CompilerInternal FnMakeArray, trans_exprs eles) |> ewrap + | RowVectorExpr eles -> + FunApp (CompilerInternal FnMakeRowVec, trans_exprs eles) |> ewrap + | Indexed (lhs, indices) -> + Indexed (trans_expr lhs, List.map ~f:trans_idx indices) |> ewrap + | Promotion (e, ty, ad) -> Promotion (trans_expr e, ty, ad) |> ewrap + + and trans_idx = function + | Ast.All -> All + | Ast.Upfrom e -> Upfrom (trans_expr e) + | Ast.Downfrom e -> Between (Expr.Helpers.loop_bottom, trans_expr e) + | Ast.Between (lb, ub) -> Between (trans_expr lb, trans_expr ub) + | Ast.Single e -> ( + match e.emeta.type_ with + | UInt -> Single (trans_expr e) + | UArray _ -> MultiIndex (trans_expr e) + | _ -> + Common.FatalError.fatal_error_msg + [%message "Expecting int or array" (e.emeta.type_ : UnsizedType.t)] + ) + + and trans_exprs exprs = List.map ~f:trans_expr exprs + + let trans_sizedtype = SizedType.map trans_expr + + let neg_inf = Expr. - { Fixed.pattern + { Fixed.pattern= FunApp (CompilerInternal FnNegInf, []) ; meta= - Typed.Meta. - {type_= emeta.Ast.type_; adlevel= emeta.ad_level; loc= emeta.loc} } - in - match expr with - | Ast.Paren x -> trans_expr x - | BinOp (lhs, And, rhs) -> EAnd (trans_expr lhs, trans_expr rhs) |> ewrap - | BinOp (lhs, Or, rhs) -> EOr (trans_expr lhs, trans_expr rhs) |> ewrap - | BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs] emeta.type_ - | PrefixOp (op, e) | Ast.PostfixOp (e, op) -> op_to_funapp op [e] emeta.type_ - | Ast.TernaryIf (cond, ifb, elseb) -> - Expr.Fixed.Pattern.TernaryIf - (trans_expr cond, trans_expr ifb, trans_expr elseb) - |> ewrap - | Variable {name; _} -> Var name |> ewrap - | IntNumeral x -> Lit (Int, format_number x) |> ewrap - | RealNumeral x -> Lit (Real, format_number x) |> ewrap - | ImagNumeral x -> Lit (Imaginary, format_number x) |> ewrap - | FunApp (fn_kind, {name; _}, args) | CondDistApp (fn_kind, {name; _}, args) - -> - FunApp (trans_fn_kind fn_kind name, trans_exprs args) |> ewrap - | GetLP | GetTarget -> FunApp (StanLib ("target", FnTarget, AoS), []) |> ewrap - | ArrayExpr eles -> - FunApp (CompilerInternal FnMakeArray, trans_exprs eles) |> ewrap - | RowVectorExpr eles -> - FunApp (CompilerInternal FnMakeRowVec, trans_exprs eles) |> ewrap - | Indexed (lhs, indices) -> - Indexed (trans_expr lhs, List.map ~f:trans_idx indices) |> ewrap - | Promotion (e, ty, ad) -> Promotion (trans_expr e, ty, ad) |> ewrap - -and trans_idx = function - | Ast.All -> All - | Ast.Upfrom e -> Upfrom (trans_expr e) - | Ast.Downfrom e -> Between (Expr.Helpers.loop_bottom, trans_expr e) - | Ast.Between (lb, ub) -> Between (trans_expr lb, trans_expr ub) - | Ast.Single e -> ( - match e.emeta.type_ with - | UInt -> Single (trans_expr e) - | UArray _ -> MultiIndex (trans_expr e) - | _ -> - Common.FatalError.fatal_error_msg - [%message "Expecting int or array" (e.emeta.type_ : UnsizedType.t)] ) - -and trans_exprs exprs = List.map ~f:trans_expr exprs - -let trans_sizedtype = SizedType.map trans_expr - -let neg_inf = - Expr. - { Fixed.pattern= FunApp (CompilerInternal FnNegInf, []) - ; meta= - Typed.Meta.{type_= UReal; loc= Location_span.empty; adlevel= DataOnly} - } - -let trans_arg (adtype, ut, ident) = (adtype, ident.Ast.name, ut) - -let truncate_dist ud_dists (id : Ast.identifier) - (ast_obs : Ast.typed_expression) ast_args t = - let cdf_suffices = ["_lcdf"; "_cdf_log"] in - let ccdf_suffices = ["_lccdf"; "_ccdf_log"] in - let find_function_info sfx = - let possible_names = List.map ~f:(( ^ ) id.name) sfx in - match - List.find - ~f:(fun (n, _) -> List.mem ~equal:String.equal possible_names n) - ud_dists - with - | Some (name, tp) -> (Ast.UserDefined FnPlain, name, tp) - | None -> - ( Ast.StanLib FnPlain - , List.hd_exn possible_names - , if Stan_math_signatures.is_stan_math_function_name (id.name ^ "_lpmf") - then UnsizedType.UInt - else UnsizedType.UReal (* close enough *) ) in - let targetme loc e = - { Stmt.Fixed.meta= loc - ; pattern= TargetPE (Expr.Helpers.unary_op Operator.PMinus e) } in - let trunc cond_op extrema (x : Expr.Typed.t) y = - let smeta = x.meta.loc in - let ast_obs = - if UnsizedType.is_container ast_obs.Ast.emeta.type_ then - Ast.mk_typed_expression - ~expr: - (FunApp - ( Ast.StanLib FnPlain - , Ast.{name= extrema; id_loc= smeta} - , [ast_obs] ) ) - ~loc:smeta ~type_:UnsizedType.UReal ~ad_level:ast_obs.emeta.ad_level - else ast_obs in - { Stmt.Fixed.meta= smeta - ; pattern= - IfElse - ( Expr.Helpers.binop (trans_expr ast_obs) cond_op x - , {Stmt.Fixed.meta= smeta; pattern= TargetPE neg_inf} - , Some y ) } in - let funapp meta kind name args = - Expr.{Fixed.pattern= FunApp (trans_fn_kind kind name, args); meta} in - let inclusive_bound tp (lb : Expr.Typed.t) = - if UnsizedType.is_int_type tp then - Expr.Helpers.binop lb Minus Expr.Helpers.one - else lb in - let size_adjust e = - if - (not (UnsizedType.is_container ast_obs.Ast.emeta.type_)) - || List.exists - ~f:(fun (i : Ast.typed_expression) -> - UnsizedType.is_container i.emeta.type_ ) - ast_args - then e - else - (* Container y but scalar args - need to multiply by size(y) *) - let trans_ast_obs = trans_expr ast_obs in - let type_ = {trans_ast_obs.meta with type_= UnsizedType.UReal} in - Expr.Helpers.binop e Times - (Expr.Helpers.internal_funapp FnLength [trans_ast_obs] type_) in - match t with - | Ast.NoTruncate -> [] - | TruncateUpFrom lb -> - let fk, fn, tp = find_function_info ccdf_suffices in - let lb = trans_expr lb in - [ trunc Less "min" lb - (targetme lb.meta.loc - (size_adjust - (funapp lb.meta fk fn - (inclusive_bound tp lb :: trans_exprs ast_args) ) ) ) ] - | TruncateDownFrom ub -> - let fk, fn, _ = find_function_info cdf_suffices in - let ub = trans_expr ub in - [ trunc Greater "max" ub - (targetme ub.meta.loc - (size_adjust (funapp ub.meta fk fn (ub :: trans_exprs ast_args))) ) - ] - | TruncateBetween (lb, ub) -> - let fk, fn, tp = find_function_info cdf_suffices in - let lb, ub = (trans_expr lb, trans_expr ub) in - let expr args = - funapp ub.meta (Ast.StanLib FnPlain) "log_diff_exp" - [ funapp ub.meta fk fn (ub :: args) - ; funapp ub.meta fk fn (inclusive_bound tp lb :: args) ] in - let statement = - match - List.findi - ~f:(fun (_ : int) (e : Ast.typed_expression) -> - UnsizedType.is_container e.emeta.type_ ) - ast_args - with - (* If any of the arguments (besides the data) are vectors, need to generate a loop - This can go away if https://github.com/stan-dev/stan/issues/1154 is implemented - *) - | Some (i, _) -> - let ast_args = trans_exprs ast_args in - (* avoid recomputing in each iteration of the loop *) - let temp_decls, ast_args, symbol_reset = - Stmt.Helpers.temp_vars ast_args in - let bound = - let e = List.nth_exn ast_args i in - Expr.Helpers.internal_funapp FnLength [e] - {e.meta with type_= UnsizedType.UInt} in - let bodyfn (idx : Expr.Typed.t) = - let args = - List.map - ~f:(fun (e : Expr.Typed.t) -> - if UnsizedType.is_container e.meta.type_ then - Expr.Helpers.add_int_index e (Index.Single idx) - else e ) - ast_args in - targetme ub.meta.loc (size_adjust (expr args)) in - let loop = Stmt.Helpers.mk_for bound bodyfn ub.meta.loc in - symbol_reset () ; - Stmt.{Fixed.pattern= Block (temp_decls @ [loop]); meta= loop.meta} - | None -> - targetme ub.meta.loc (size_adjust (expr (trans_exprs ast_args))) - in - [trunc Less "min" lb (trunc Greater "max" ub statement)] - -let unquote s = - if s.[0] = '"' && s.[String.length s - 1] = '"' then - String.drop_suffix (String.drop_prefix s 1) 1 - else s - -let trans_printables mloc (ps : Ast.typed_expression Ast.printable list) = - List.map - ~f:(function - | Ast.PString s -> - { (Expr.Helpers.str (unquote s)) with - meta= - Expr.Typed.Meta.create ~type_:UReal ~loc:mloc ~adlevel:DataOnly () - } - | Ast.PExpr e -> trans_expr e ) - ps - -(** These types signal the context for a declaration during statement translation. + Typed.Meta.{type_= UReal; loc= Location_span.empty; adlevel= DataOnly} + } + + let trans_arg (adtype, ut, ident) = (adtype, ident.Ast.name, ut) + + let truncate_dist ud_dists (id : Ast.identifier) + (ast_obs : Ast.typed_expression) ast_args t = + let cdf_suffices = ["_lcdf"; "_cdf_log"] in + let ccdf_suffices = ["_lccdf"; "_ccdf_log"] in + let find_function_info sfx = + let possible_names = List.map ~f:(( ^ ) id.name) sfx in + match + List.find + ~f:(fun (n, _) -> List.mem ~equal:String.equal possible_names n) + ud_dists + with + | Some (name, tp) -> (Ast.UserDefined FnPlain, name, tp) + | None -> + ( Ast.StanLib FnPlain + , List.hd_exn possible_names + , if StdLibrary.is_stdlib_function_name (id.name ^ "_lpmf") then + UnsizedType.UInt + else UnsizedType.UReal (* close enough *) ) in + let targetme loc e = + { Stmt.Fixed.meta= loc + ; pattern= TargetPE (Expr.Helpers.unary_op Operator.PMinus e) } in + let trunc cond_op extrema (x : Expr.Typed.t) y = + let smeta = x.meta.loc in + let ast_obs = + if UnsizedType.is_container ast_obs.Ast.emeta.type_ then + Ast.mk_typed_expression + ~expr: + (FunApp + ( Ast.StanLib FnPlain + , Ast.{name= extrema; id_loc= smeta} + , [ast_obs] ) ) + ~loc:smeta ~type_:UnsizedType.UReal ~ad_level:ast_obs.emeta.ad_level + else ast_obs in + { Stmt.Fixed.meta= smeta + ; pattern= + IfElse + ( Expr.Helpers.binop (trans_expr ast_obs) cond_op x + , {Stmt.Fixed.meta= smeta; pattern= TargetPE neg_inf} + , Some y ) } in + let funapp meta kind name args = + Expr.{Fixed.pattern= FunApp (trans_fn_kind kind name, args); meta} in + let inclusive_bound tp (lb : Expr.Typed.t) = + if UnsizedType.is_int_type tp then + Expr.Helpers.binop lb Minus Expr.Helpers.one + else lb in + let size_adjust e = + if + (not (UnsizedType.is_container ast_obs.Ast.emeta.type_)) + || List.exists + ~f:(fun (i : Ast.typed_expression) -> + UnsizedType.is_container i.emeta.type_ ) + ast_args + then e + else + (* Container y but scalar args - need to multiply by size(y) *) + let trans_ast_obs = trans_expr ast_obs in + let type_ = {trans_ast_obs.meta with type_= UnsizedType.UReal} in + Expr.Helpers.binop e Times + (Expr.Helpers.internal_funapp FnLength [trans_ast_obs] type_) in + match t with + | Ast.NoTruncate -> [] + | TruncateUpFrom lb -> + let fk, fn, tp = find_function_info ccdf_suffices in + let lb = trans_expr lb in + [ trunc Less "min" lb + (targetme lb.meta.loc + (size_adjust + (funapp lb.meta fk fn + (inclusive_bound tp lb :: trans_exprs ast_args) ) ) ) ] + | TruncateDownFrom ub -> + let fk, fn, _ = find_function_info cdf_suffices in + let ub = trans_expr ub in + [ trunc Greater "max" ub + (targetme ub.meta.loc + (size_adjust (funapp ub.meta fk fn (ub :: trans_exprs ast_args))) ) + ] + | TruncateBetween (lb, ub) -> + let fk, fn, tp = find_function_info cdf_suffices in + let lb, ub = (trans_expr lb, trans_expr ub) in + let expr args = + funapp ub.meta (Ast.StanLib FnPlain) "log_diff_exp" + [ funapp ub.meta fk fn (ub :: args) + ; funapp ub.meta fk fn (inclusive_bound tp lb :: args) ] in + let statement = + match + List.findi + ~f:(fun (_ : int) (e : Ast.typed_expression) -> + UnsizedType.is_container e.emeta.type_ ) + ast_args + with + (* If any of the arguments (besides the data) are vectors, need to generate a loop + This can go away if https://github.com/stan-dev/stan/issues/1154 is implemented + *) + | Some (i, _) -> + let ast_args = trans_exprs ast_args in + (* avoid recomputing in each iteration of the loop *) + let temp_decls, ast_args, symbol_reset = + Stmt.Helpers.temp_vars ast_args in + let bound = + let e = List.nth_exn ast_args i in + Expr.Helpers.internal_funapp FnLength [e] + {e.meta with type_= UnsizedType.UInt} in + let bodyfn (idx : Expr.Typed.t) = + let args = + List.map + ~f:(fun (e : Expr.Typed.t) -> + if UnsizedType.is_container e.meta.type_ then + Expr.Helpers.add_int_index e (Index.Single idx) + else e ) + ast_args in + targetme ub.meta.loc (size_adjust (expr args)) in + let loop = Stmt.Helpers.mk_for bound bodyfn ub.meta.loc in + symbol_reset () ; + Stmt.{Fixed.pattern= Block (temp_decls @ [loop]); meta= loop.meta} + | None -> + targetme ub.meta.loc (size_adjust (expr (trans_exprs ast_args))) + in + [trunc Less "min" lb (trunc Greater "max" ub statement)] + + let unquote s = + if s.[0] = '"' && s.[String.length s - 1] = '"' then + String.drop_suffix (String.drop_prefix s 1) 1 + else s + + let trans_printables mloc (ps : Ast.typed_expression Ast.printable list) = + List.map + ~f:(function + | Ast.PString s -> + { (Expr.Helpers.str (unquote s)) with + meta= + Expr.Typed.Meta.create ~type_:UReal ~loc:mloc ~adlevel:DataOnly + () } + | Ast.PExpr e -> trans_expr e ) + ps + + (** These types signal the context for a declaration during statement translation. They are only interpreted by trans_decl.*) -type transform_action = Check | Constrain | Unconstrain | IgnoreTransform -[@@deriving sexp] + type transform_action = Check | Constrain | Unconstrain | IgnoreTransform + [@@deriving sexp] + + type decl_context = + {transform_action: transform_action; dadlevel: UnsizedType.autodifftype} + + let same_shape decl_id decl_var id var meta = + if UnsizedType.is_scalar_type (Expr.Typed.type_of var) then [] + else + [ Stmt. + { Fixed.pattern= + NRFunApp + ( StanLib ("check_matching_dims", FnPlain, AoS) + , Expr.Helpers. + [str "constraint"; str decl_id; decl_var; str id; var] ) + ; meta } ] + + let check_transform_shape decl_id decl_var meta = function + | Transformation.Offset e -> same_shape decl_id decl_var "offset" e meta + | Multiplier e -> same_shape decl_id decl_var "multiplier" e meta + | Lower e -> same_shape decl_id decl_var "lower" e meta + | Upper e -> same_shape decl_id decl_var "upper" e meta + | OffsetMultiplier (e1, e2) -> + same_shape decl_id decl_var "offset" e1 meta + @ same_shape decl_id decl_var "multiplier" e2 meta + | LowerUpper (e1, e2) -> + same_shape decl_id decl_var "lower" e1 meta + @ same_shape decl_id decl_var "upper" e2 meta + | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered + |PositiveOrdered | Simplex | UnitVector | Identity -> + [] + + let copy_indices indexed (var : Expr.Typed.t) = + if UnsizedType.is_scalar_type var.meta.type_ then var + else + match Expr.Helpers.collect_indices indexed with + | [] -> var + | indices -> + Expr.Fixed. + { pattern= Indexed (var, indices) + ; meta= + { var.meta with + type_= + Expr.Helpers.infer_type_of_indexed var.meta.type_ indices } + } + + let extract_transform_args var = function + | Transformation.Lower a | Upper a -> [copy_indices var a] + | Offset a -> + [copy_indices var a; {a with Expr.Fixed.pattern= Lit (Int, "1")}] + | Multiplier a -> [{a with pattern= Lit (Int, "0")}; copy_indices var a] + | LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) -> + [copy_indices var a1; copy_indices var a2] + | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered + |PositiveOrdered | Simplex | UnitVector | Identity -> + [] -type decl_context = - {transform_action: transform_action; dadlevel: UnsizedType.autodifftype} + let param_size transform sizedtype = + let rec shrink_eigen f st = + match st with + | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen f t, d) + | SVector (mem_pattern, d) | SMatrix (mem_pattern, d, _) -> + SVector (mem_pattern, f d) + | SInt | SReal | SComplex | SRowVector _ | SComplexRowVector _ + |SComplexVector _ | SComplexMatrix _ -> + Common.FatalError.fatal_error_msg + [%message + "Expecting SVector or SMatrix, got " + (st : Expr.Typed.t SizedType.t)] in + let rec shrink_eigen_mat f st = + match st with + | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen_mat f t, d) + | SMatrix (mem_pattern, d1, d2) -> SVector (mem_pattern, f d1 d2) + | SInt | SReal | SComplex | SRowVector _ | SVector _ + |SComplexRowVector _ | SComplexVector _ | SComplexMatrix _ -> + Common.FatalError.fatal_error_msg + [%message "Expecting SMatrix, got " (st : Expr.Typed.t SizedType.t)] + in + let k_choose_2 k = + Expr.Helpers.( + binop (binop k Times (binop k Minus (int 1))) Divide (int 2)) in + match transform with + | Transformation.Identity | Lower _ | Upper _ + |LowerUpper (_, _) + |Offset _ | Multiplier _ + |OffsetMultiplier (_, _) + |Ordered | PositiveOrdered | UnitVector -> + sizedtype + | Simplex -> + shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype + | CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype + | CholeskyCov -> + (* (N * (N + 1)) / 2 + (M - N) * N *) + shrink_eigen_mat + (fun m n -> + Expr.Helpers.( + binop + (binop (k_choose_2 n) Plus n) + Plus + (binop (binop m Minus n) Times n)) ) + sizedtype + | Covariance -> + shrink_eigen + (fun k -> Expr.Helpers.(binop k Plus (k_choose_2 k))) + sizedtype + + let rec check_decl var decl_type' decl_id decl_trans smeta adlevel = + match decl_trans with + | Transformation.LowerUpper (lb, ub) -> + check_decl var decl_type' decl_id (Lower lb) smeta adlevel + @ check_decl var decl_type' decl_id (Upper ub) smeta adlevel + | _ when Transformation.has_check decl_trans -> + let check_id id = + let var_name = Fmt.str "%a" Expr.Typed.pp id in + let args = extract_transform_args id decl_trans in + Stmt.Helpers.internal_nrfunapp + (FnCheck {trans= decl_trans; var_name; var= id}) + args smeta in + [check_id var] + | _ -> [] -let same_shape decl_id decl_var id var meta = - if UnsizedType.is_scalar_type (Expr.Typed.type_of var) then [] - else - [ Stmt. + let check_sizedtype name st = + let check x = function + | {Expr.Fixed.pattern= Lit (Int, i); _} when float_of_string i >= 0. -> [] + | n -> + [ Stmt.Helpers.internal_nrfunapp FnValidateSize + Expr.Helpers. + [ str name + ; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n ] + n.meta.loc ] in + let rec sizedtype = function + | SizedType.(SInt | SReal | SComplex) as t -> ([], t) + | SVector (mem_pattern, s) -> + let e = trans_expr s in + (check s e, SizedType.SVector (mem_pattern, e)) + | SRowVector (mem_pattern, s) -> + let e = trans_expr s in + (check s e, SizedType.SRowVector (mem_pattern, e)) + | SMatrix (mem_pattern, r, c) -> + let er = trans_expr r in + let ec = trans_expr c in + (check r er @ check c ec, SizedType.SMatrix (mem_pattern, er, ec)) + | SComplexVector s -> + let e = trans_expr s in + (check s e, SizedType.SComplexVector e) + | SComplexRowVector s -> + let e = trans_expr s in + (check s e, SizedType.SComplexRowVector e) + | SComplexMatrix (r, c) -> + let er = trans_expr r in + let ec = trans_expr c in + (check r er @ check c ec, SizedType.SComplexMatrix (er, ec)) + | SArray (t, s) -> + let e = trans_expr s in + let ll, t = sizedtype t in + (check s e @ ll, SizedType.SArray (t, e)) in + let ll, st = sizedtype st in + (ll, Type.Sized st) + + let trans_decl {transform_action; dadlevel} smeta + (decl_type : Ast.typed_expression SizedType.t) transform identifier + initial_value = + let decl_id = identifier.Ast.name in + let rhs = Option.map ~f:trans_expr initial_value in + let size_checks, dt = check_sizedtype identifier.name decl_type in + let decl_adtype = dadlevel in + let decl_var = + Expr. + { Fixed.pattern= Var decl_id + ; meta= + Typed.Meta.create ~adlevel:dadlevel ~loc:smeta + ~type_:(SizedType.to_unsized decl_type) + () } in + let decl = + Stmt. { Fixed.pattern= - NRFunApp - ( StanLib ("check_matching_dims", FnPlain, AoS) - , Expr.Helpers. - [str "constraint"; str decl_id; decl_var; str id; var] ) - ; meta } ] - -let check_transform_shape decl_id decl_var meta = function - | Transformation.Offset e -> same_shape decl_id decl_var "offset" e meta - | Multiplier e -> same_shape decl_id decl_var "multiplier" e meta - | Lower e -> same_shape decl_id decl_var "lower" e meta - | Upper e -> same_shape decl_id decl_var "upper" e meta - | OffsetMultiplier (e1, e2) -> - same_shape decl_id decl_var "offset" e1 meta - @ same_shape decl_id decl_var "multiplier" e2 meta - | LowerUpper (e1, e2) -> - same_shape decl_id decl_var "lower" e1 meta - @ same_shape decl_id decl_var "upper" e2 meta - | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered - |PositiveOrdered | Simplex | UnitVector | Identity -> - [] - -let copy_indices indexed (var : Expr.Typed.t) = - if UnsizedType.is_scalar_type var.meta.type_ then var - else - match Expr.Helpers.collect_indices indexed with - | [] -> var - | indices -> - Expr.Fixed. - { pattern= Indexed (var, indices) - ; meta= - { var.meta with - type_= Expr.Helpers.infer_type_of_indexed var.meta.type_ indices - } } - -let extract_transform_args var = function - | Transformation.Lower a | Upper a -> [copy_indices var a] - | Offset a -> [copy_indices var a; {a with Expr.Fixed.pattern= Lit (Int, "1")}] - | Multiplier a -> [{a with pattern= Lit (Int, "0")}; copy_indices var a] - | LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) -> - [copy_indices var a1; copy_indices var a2] - | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered - |PositiveOrdered | Simplex | UnitVector | Identity -> - [] - -let param_size transform sizedtype = - let rec shrink_eigen f st = - match st with - | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen f t, d) - | SVector (mem_pattern, d) | SMatrix (mem_pattern, d, _) -> - SVector (mem_pattern, f d) - | SInt | SReal | SComplex | SRowVector _ | SComplexRowVector _ - |SComplexVector _ | SComplexMatrix _ -> + Decl {decl_adtype; decl_id; decl_type= dt; initialize= true} + ; meta= smeta } in + let rhs_assignment = + Option.map + ~f:(fun e -> + Stmt.Fixed. + {pattern= Assignment ((decl_id, e.meta.type_, []), e); meta= smeta} + ) + rhs + |> Option.to_list in + if Utils.is_user_ident decl_id then + let constrain_checks = + match transform_action with + | Constrain | Unconstrain -> + Common.FatalError.fatal_error_msg + [%message "Constraints must use trans_sizedtype_decl instead"] + | Check -> + check_transform_shape decl_id decl_var smeta transform + @ check_decl decl_var dt decl_id transform smeta dadlevel + | IgnoreTransform -> [] in + size_checks @ (decl :: rhs_assignment) @ constrain_checks + else size_checks @ (decl :: rhs_assignment) + + let unwrap_block_or_skip = function + | [({Stmt.Fixed.pattern= Block _; _} as b)] -> Some b + | [{pattern= Skip; _}] -> None + | x -> + Common.FatalError.fatal_error_msg + [%message "Expecting a block or skip, not" (x : Stmt.Located.t list)] + + let stmt_contains_check stmt = + let is_check = function + | Fun_kind.CompilerInternal (Internal_fun.FnCheck _) -> true + | _ -> false in + Stmt.Helpers.contains_fn_kind is_check stmt + + let migrate_checks_to_end_of_block stmts = + let checks, not_checks = List.partition_tf ~f:stmt_contains_check stmts in + not_checks @ checks + + let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) + = + let stmt_typed = ts.stmt and smeta = ts.smeta.loc in + let trans_stmt = + trans_stmt ud_dists {declc with transform_action= IgnoreTransform} in + let trans_single_stmt s = + match trans_stmt s with + | [s] -> s + | s -> Stmt.Fixed.{pattern= SList s; meta= smeta} in + let swrap pattern = [Stmt.Fixed.{meta= smeta; pattern}] in + let mloc = smeta in + match stmt_typed with + | Ast.Assignment {assign_lhs; assign_rhs; assign_op} -> + let rec get_lhs_base = function + | {Ast.lval= Ast.LIndexed (l, _); _} -> get_lhs_base l + | {lval= LVariable s; lmeta} -> (s, lmeta) in + let assign_identifier, lmeta = get_lhs_base assign_lhs in + let id_ad_level = lmeta.Ast.ad_level in + let id_type_ = lmeta.Ast.type_ in + let lhs_type_ = assign_lhs.Ast.lmeta.type_ in + let lhs_ad_level = assign_lhs.Ast.lmeta.ad_level in + let rec get_lhs_indices = function + | {Ast.lval= Ast.LIndexed (l, i); _} -> get_lhs_indices l @ i + | {Ast.lval= Ast.LVariable _; _} -> [] in + let assign_indices = get_lhs_indices assign_lhs in + let assignee = + { Ast.expr= + ( match assign_indices with + | [] -> Ast.Variable assign_identifier + | _ -> + Ast.Indexed + ( { expr= Ast.Variable assign_identifier + ; emeta= + { Ast.loc= Location_span.empty + ; ad_level= id_ad_level + ; type_= id_type_ } } + , assign_indices ) ) + ; emeta= + { Ast.loc= assign_lhs.lmeta.loc + ; ad_level= lhs_ad_level + ; type_= lhs_type_ } } in + let rhs = + match assign_op with + | Ast.Assign | Ast.ArrowAssign -> trans_expr assign_rhs + | Ast.OperatorAssign op -> + op_to_funapp op [assignee; assign_rhs] assignee.emeta.type_ in + Assignment + ( ( assign_identifier.Ast.name + , id_type_ + , List.map ~f:trans_idx assign_indices ) + , rhs ) + |> swrap + | Ast.NRFunApp (fn_kind, {name; _}, args) -> + NRFunApp (trans_fn_kind fn_kind name, trans_exprs args) |> swrap + | Ast.IncrementLogProb e | Ast.TargetPE e -> + TargetPE (trans_expr e) |> swrap + | Ast.Tilde {arg; distribution; args; truncation} -> + let suffix = + Std_library_utils.dist_name_suffix + (module StdLibrary) + ud_dists distribution.name in + let name = distribution.name ^ suffix in + let kind = + let possible_names = + List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices + |> String.Set.of_list in + if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists + then Fun_kind.UserDefined (name, FnLpdf true) + else StanLib (name, FnLpdf true, AoS) in + let add_dist = + Stmt.Fixed.Pattern.TargetPE + Expr. + { Fixed.pattern= FunApp (kind, trans_exprs (arg :: args)) + ; meta= + Typed.Meta.create ~type_:UReal ~loc:mloc + ~adlevel:(Ast.expr_ad_lub (arg :: args)) + () } in + truncate_dist ud_dists distribution arg args truncation @ swrap add_dist + | Ast.Print ps -> + NRFunApp (CompilerInternal FnPrint, trans_printables smeta ps) |> swrap + | Ast.Reject ps -> + NRFunApp (CompilerInternal FnReject, trans_printables smeta ps) |> swrap + | Ast.IfThenElse (cond, ifb, elseb) -> + IfElse + ( trans_expr cond + , trans_single_stmt ifb + , Option.map ~f:trans_single_stmt elseb ) + |> swrap + | Ast.While (cond, body) -> + While (trans_expr cond, trans_single_stmt body) |> swrap + | Ast.For {loop_variable; lower_bound; upper_bound; loop_body} -> + let body = + match trans_single_stmt loop_body with + | {pattern= Block _; _} as b -> b + | x -> {x with pattern= Block [x]} in + For + { loopvar= loop_variable.Ast.name + ; lower= trans_expr lower_bound + ; upper= trans_expr upper_bound + ; body } + |> swrap + | Ast.ForEach (loopvar, iteratee, body) -> + let iteratee' = trans_expr iteratee in + let body_stmts = + match trans_single_stmt body with + | {pattern= Block body_stmts; _} -> body_stmts + | b -> [b] in + let decl_type = + match Expr.Typed.type_of iteratee' with + | UMatrix -> UnsizedType.UReal + | t -> + Expr.Helpers.(infer_type_of_indexed t [Index.Single loop_bottom]) + in + let decl_loopvar = + Stmt.Fixed. + { meta= smeta + ; pattern= + Decl + { decl_adtype= Expr.Typed.adlevel_of iteratee' + ; decl_id= loopvar.name + ; decl_type= Unsized decl_type + ; initialize= true } } in + let assignment var = + Stmt.Fixed. + { pattern= Assignment ((loopvar.name, decl_type, []), var) + ; meta= smeta } in + let bodyfn var = + Stmt.Fixed. + { pattern= Block (decl_loopvar :: assignment var :: body_stmts) + ; meta= smeta } in + Stmt.Helpers.[ensure_var (for_each bodyfn) iteratee' smeta] + | Ast.FunDef _ -> Common.FatalError.fatal_error_msg [%message - "Expecting SVector or SMatrix, got " (st : Expr.Typed.t SizedType.t)] - in - let rec shrink_eigen_mat f st = - match st with - | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen_mat f t, d) - | SMatrix (mem_pattern, d1, d2) -> SVector (mem_pattern, f d1 d2) - | SInt | SReal | SComplex | SRowVector _ | SVector _ | SComplexRowVector _ - |SComplexVector _ | SComplexMatrix _ -> + "Found function definition statement outside of function block"] + | Ast.VarDecl {decl_type; transformation; variables; is_global= _} -> + List.concat_map + ~f:(fun {identifier; initial_value} -> + trans_decl declc smeta decl_type + (Transformation.map trans_expr transformation) + identifier initial_value ) + variables + | Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap + | Ast.Profile (name, stmts) -> + Profile (name, List.concat_map ~f:trans_stmt stmts) |> swrap + | Ast.Return e -> Return (Some (trans_expr e)) |> swrap + | Ast.ReturnVoid -> Return None |> swrap + | Ast.Break -> Break |> swrap + | Ast.Continue -> Continue |> swrap + | Ast.Skip -> Skip |> swrap + + let trans_fun_def ud_dists (ts : Ast.typed_statement) = + match ts.stmt with + | Ast.FunDef {returntype; funname; arguments; body} -> + [ Program. + { fdrt= returntype + ; fdname= funname.name + ; fdsuffix= + Fun_kind.(suffix_from_name funname.name |> without_propto) + ; fdargs= List.map ~f:trans_arg arguments + ; fdbody= + trans_stmt ud_dists + {transform_action= IgnoreTransform; dadlevel= AutoDiffable} + body + |> unwrap_block_or_skip + ; fdloc= ts.smeta.loc } ] + | _ -> Common.FatalError.fatal_error_msg - [%message "Expecting SMatrix, got " (st : Expr.Typed.t SizedType.t)] - in - let k_choose_2 k = - Expr.Helpers.(binop (binop k Times (binop k Minus (int 1))) Divide (int 2)) - in - match transform with - | Transformation.Identity | Lower _ | Upper _ - |LowerUpper (_, _) - |Offset _ | Multiplier _ - |OffsetMultiplier (_, _) - |Ordered | PositiveOrdered | UnitVector -> - sizedtype - | Simplex -> - shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype - | CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype - | CholeskyCov -> - (* (N * (N + 1)) / 2 + (M - N) * N *) - shrink_eigen_mat - (fun m n -> - Expr.Helpers.( - binop - (binop (k_choose_2 n) Plus n) - Plus - (binop (binop m Minus n) Times n)) ) - sizedtype - | Covariance -> - shrink_eigen - (fun k -> Expr.Helpers.(binop k Plus (k_choose_2 k))) - sizedtype + [%message "Found non-function definition statement in function block"] -let rec check_decl var decl_type' decl_id decl_trans smeta adlevel = - match decl_trans with - | Transformation.LowerUpper (lb, ub) -> - check_decl var decl_type' decl_id (Lower lb) smeta adlevel - @ check_decl var decl_type' decl_id (Upper ub) smeta adlevel - | _ when Transformation.has_check decl_trans -> - let check_id id = - let var_name = Fmt.str "%a" Expr.Typed.pp id in - let args = extract_transform_args id decl_trans in - Stmt.Helpers.internal_nrfunapp - (FnCheck {trans= decl_trans; var_name; var= id}) - args smeta in - [check_id var] - | _ -> [] - -let check_sizedtype name st = - let check x = function - | {Expr.Fixed.pattern= Lit (Int, i); _} when float_of_string i >= 0. -> [] - | n -> - [ Stmt.Helpers.internal_nrfunapp FnValidateSize - Expr.Helpers. - [ str name - ; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n ] - n.meta.loc ] in - let rec sizedtype = function - | SizedType.(SInt | SReal | SComplex) as t -> ([], t) - | SVector (mem_pattern, s) -> - let e = trans_expr s in - (check s e, SizedType.SVector (mem_pattern, e)) - | SRowVector (mem_pattern, s) -> - let e = trans_expr s in - (check s e, SizedType.SRowVector (mem_pattern, e)) - | SMatrix (mem_pattern, r, c) -> - let er = trans_expr r in - let ec = trans_expr c in - (check r er @ check c ec, SizedType.SMatrix (mem_pattern, er, ec)) - | SComplexVector s -> - let e = trans_expr s in - (check s e, SizedType.SComplexVector e) - | SComplexRowVector s -> - let e = trans_expr s in - (check s e, SizedType.SComplexRowVector e) - | SComplexMatrix (r, c) -> - let er = trans_expr r in - let ec = trans_expr c in - (check r er @ check c ec, SizedType.SComplexMatrix (er, ec)) - | SArray (t, s) -> - let e = trans_expr s in - let ll, t = sizedtype t in - (check s e @ ll, SizedType.SArray (t, e)) in - let ll, st = sizedtype st in - (ll, Type.Sized st) - -let trans_decl {transform_action; dadlevel} smeta - (decl_type : Ast.typed_expression SizedType.t) transform identifier - initial_value = - let decl_id = identifier.Ast.name in - let rhs = Option.map ~f:trans_expr initial_value in - let size_checks, dt = check_sizedtype identifier.name decl_type in - let decl_adtype = dadlevel in - let decl_var = - Expr. - { Fixed.pattern= Var decl_id - ; meta= - Typed.Meta.create ~adlevel:dadlevel ~loc:smeta - ~type_:(SizedType.to_unsized decl_type) - () } in - let decl = - Stmt. - { Fixed.pattern= - Decl {decl_adtype; decl_id; decl_type= dt; initialize= true} - ; meta= smeta } in - let rhs_assignment = - Option.map - ~f:(fun e -> - Stmt.Fixed. - {pattern= Assignment ((decl_id, e.meta.type_, []), e); meta= smeta} ) - rhs - |> Option.to_list in - if Utils.is_user_ident decl_id then - let constrain_checks = - match transform_action with - | Constrain | Unconstrain -> - Common.FatalError.fatal_error_msg - [%message "Constraints must use trans_sizedtype_decl instead"] - | Check -> - check_transform_shape decl_id decl_var smeta transform - @ check_decl decl_var dt decl_id transform smeta dadlevel - | IgnoreTransform -> [] in - size_checks @ (decl :: rhs_assignment) @ constrain_checks - else size_checks @ (decl :: rhs_assignment) - -let unwrap_block_or_skip = function - | [({Stmt.Fixed.pattern= Block _; _} as b)] -> Some b - | [{pattern= Skip; _}] -> None - | x -> - Common.FatalError.fatal_error_msg - [%message "Expecting a block or skip, not" (x : Stmt.Located.t list)] - -let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) = - let stmt_typed = ts.stmt and smeta = ts.smeta.loc in - let trans_stmt = - trans_stmt ud_dists {declc with transform_action= IgnoreTransform} in - let trans_single_stmt s = - match trans_stmt s with - | [s] -> s - | s -> Stmt.Fixed.{pattern= SList s; meta= smeta} in - let swrap pattern = [Stmt.Fixed.{meta= smeta; pattern}] in - let mloc = smeta in - match stmt_typed with - | Ast.Assignment {assign_lhs; assign_rhs; assign_op} -> - let rec get_lhs_base = function - | {Ast.lval= Ast.LIndexed (l, _); _} -> get_lhs_base l - | {lval= LVariable s; lmeta} -> (s, lmeta) in - let assign_identifier, lmeta = get_lhs_base assign_lhs in - let id_ad_level = lmeta.Ast.ad_level in - let id_type_ = lmeta.Ast.type_ in - let lhs_type_ = assign_lhs.Ast.lmeta.type_ in - let lhs_ad_level = assign_lhs.Ast.lmeta.ad_level in - let rec get_lhs_indices = function - | {Ast.lval= Ast.LIndexed (l, i); _} -> get_lhs_indices l @ i - | {Ast.lval= Ast.LVariable _; _} -> [] in - let assign_indices = get_lhs_indices assign_lhs in - let assignee = - { Ast.expr= - ( match assign_indices with - | [] -> Ast.Variable assign_identifier - | _ -> - Ast.Indexed - ( { expr= Ast.Variable assign_identifier - ; emeta= - { Ast.loc= Location_span.empty - ; ad_level= id_ad_level - ; type_= id_type_ } } - , assign_indices ) ) - ; emeta= - { Ast.loc= assign_lhs.lmeta.loc - ; ad_level= lhs_ad_level - ; type_= lhs_type_ } } in - let rhs = - match assign_op with - | Ast.Assign | Ast.ArrowAssign -> trans_expr assign_rhs - | Ast.OperatorAssign op -> - op_to_funapp op [assignee; assign_rhs] assignee.emeta.type_ in - Assignment - ( ( assign_identifier.Ast.name - , id_type_ - , List.map ~f:trans_idx assign_indices ) - , rhs ) - |> swrap - | Ast.NRFunApp (fn_kind, {name; _}, args) -> - NRFunApp (trans_fn_kind fn_kind name, trans_exprs args) |> swrap - | Ast.IncrementLogProb e | Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap - | Ast.Tilde {arg; distribution; args; truncation} -> - let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name in - let name = distribution.name ^ suffix in - let kind = - let possible_names = - List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices - |> String.Set.of_list in - if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists then - Fun_kind.UserDefined (name, FnLpdf true) - else StanLib (name, FnLpdf true, AoS) in - let add_dist = - Stmt.Fixed.Pattern.TargetPE - Expr. - { Fixed.pattern= FunApp (kind, trans_exprs (arg :: args)) - ; meta= - Typed.Meta.create ~type_:UReal ~loc:mloc - ~adlevel:(Ast.expr_ad_lub (arg :: args)) - () } in - swrap add_dist @ truncate_dist ud_dists distribution arg args truncation - | Ast.Print ps -> - NRFunApp (CompilerInternal FnPrint, trans_printables smeta ps) |> swrap - | Ast.Reject ps -> - NRFunApp (CompilerInternal FnReject, trans_printables smeta ps) |> swrap - | Ast.IfThenElse (cond, ifb, elseb) -> - IfElse - ( trans_expr cond - , trans_single_stmt ifb - , Option.map ~f:trans_single_stmt elseb ) - |> swrap - | Ast.While (cond, body) -> - While (trans_expr cond, trans_single_stmt body) |> swrap - | Ast.For {loop_variable; lower_bound; upper_bound; loop_body} -> - let body = - match trans_single_stmt loop_body with - | {pattern= Block _; _} as b -> b - | x -> {x with pattern= Block [x]} in - For - { loopvar= loop_variable.Ast.name - ; lower= trans_expr lower_bound - ; upper= trans_expr upper_bound - ; body } - |> swrap - | Ast.ForEach (loopvar, iteratee, body) -> - let iteratee' = trans_expr iteratee in - let body_stmts = - match trans_single_stmt body with - | {pattern= Block body_stmts; _} -> body_stmts - | b -> [b] in - let decl_type = - match Expr.Typed.type_of iteratee' with - | UMatrix -> UnsizedType.UReal - | t -> Expr.Helpers.(infer_type_of_indexed t [Index.Single loop_bottom]) - in - let decl_loopvar = - Stmt.Fixed. - { meta= smeta - ; pattern= - Decl - { decl_adtype= Expr.Typed.adlevel_of iteratee' - ; decl_id= loopvar.name - ; decl_type= Unsized decl_type - ; initialize= true } } in - let assignment var = - Stmt.Fixed. - {pattern= Assignment ((loopvar.name, decl_type, []), var); meta= smeta} - in - let bodyfn var = - Stmt.Fixed. - { pattern= Block (decl_loopvar :: assignment var :: body_stmts) - ; meta= smeta } in - Stmt.Helpers.[ensure_var (for_each bodyfn) iteratee' smeta] - | Ast.FunDef _ -> - Common.FatalError.fatal_error_msg - [%message - "Found function definition statement outside of function block"] - | Ast.VarDecl {decl_type; transformation; variables; is_global= _} -> - List.concat_map - ~f:(fun {identifier; initial_value} -> - trans_decl declc smeta decl_type - (Transformation.map trans_expr transformation) - identifier initial_value ) - variables - | Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap - | Ast.Profile (name, stmts) -> - Profile (name, List.concat_map ~f:trans_stmt stmts) |> swrap - | Ast.Return e -> Return (Some (trans_expr e)) |> swrap - | Ast.ReturnVoid -> Return None |> swrap - | Ast.Break -> Break |> swrap - | Ast.Continue -> Continue |> swrap - | Ast.Skip -> Skip |> swrap - -let trans_fun_def ud_dists (ts : Ast.typed_statement) = - match ts.stmt with - | Ast.FunDef {returntype; funname; arguments; body} -> - [ Program. - { fdrt= returntype - ; fdname= funname.name - ; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto) - ; fdargs= List.map ~f:trans_arg arguments - ; fdbody= - trans_stmt ud_dists - {transform_action= IgnoreTransform; dadlevel= AutoDiffable} - body - |> unwrap_block_or_skip - ; fdloc= ts.smeta.loc } ] - | _ -> - Common.FatalError.fatal_error_msg - [%message "Found non-function definition statement in function block"] - -let get_block block prog = - match block with - | Program.Parameters -> prog.Ast.parametersblock - | TransformedParameters -> prog.transformedparametersblock - | GeneratedQuantities -> prog.generatedquantitiesblock - -let trans_sizedtype_decl declc tr name = - let check fn x n = - Stmt.Helpers.internal_nrfunapp fn - Expr.Helpers. - [str name; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n] - n.meta.loc in - let grab_size fn n = function - | Ast.{expr= IntNumeral i; _} as s when float_of_string i >= 2. -> - ([], trans_expr s) - | Ast.({expr= IntNumeral _; _} | {expr= Variable _; _}) as s -> - let e = trans_expr s in - ([check fn s e], e) - | s -> - let e = trans_expr s in - let decl_id = Fmt.str "%s_%ddim__" name n in - let decl = - { Stmt.Fixed.pattern= - Decl - { decl_type= Sized SInt - ; decl_id - ; decl_adtype= DataOnly - ; initialize= true } - ; meta= e.meta.loc } in - let assign = - { Stmt.Fixed.pattern= Assignment ((decl_id, UInt, []), e) - ; meta= e.meta.loc } in - let var = - Expr. - { Fixed.pattern= Var decl_id - ; meta= - Typed.Meta. - { type_= s.Ast.emeta.Ast.type_ - ; adlevel= s.emeta.ad_level - ; loc= s.emeta.loc } } in - ([decl; assign; check fn s var], var) in - let rec go n = function - | SizedType.(SInt | SReal | SComplex) as t -> ([], t) - | SVector (mem_pattern, s) -> - let fn = - match (declc.transform_action, tr) with - | Constrain, Transformation.Simplex -> - Internal_fun.FnValidateSizeSimplex - | Constrain, UnitVector -> FnValidateSizeUnitVector - | _ -> FnValidateSize in - let l, s = grab_size fn n s in - (l, SizedType.SVector (mem_pattern, s)) - | SRowVector (mem_pattern, s) -> - let l, s = grab_size FnValidateSize n s in - (l, SizedType.SRowVector (mem_pattern, s)) - | SComplexRowVector s -> - let l, s = grab_size FnValidateSize n s in - (l, SizedType.SComplexRowVector s) - | SComplexVector s -> - let l, s = grab_size FnValidateSize n s in - (l, SizedType.SComplexVector s) - | SMatrix (mem_pattern, r, c) -> - let l1, r = grab_size FnValidateSize n r in - let l2, c = grab_size FnValidateSize (n + 1) c in - let cf_cov = - match (declc.transform_action, tr) with - | Constrain, CholeskyCov -> - [ { Stmt.Fixed.pattern= - NRFunApp - ( StanLib ("check_greater_or_equal", FnPlain, AoS) - , Expr.Helpers. - [ str ("cholesky_factor_cov " ^ name) - ; str - "num rows (must be greater or equal to num cols)" - ; r; c ] ) - ; meta= r.Expr.Fixed.meta.Expr.Typed.Meta.loc } ] - | _ -> [] in - (l1 @ l2 @ cf_cov, SizedType.SMatrix (mem_pattern, r, c)) - | SComplexMatrix (r, c) -> - let l1, r = grab_size FnValidateSize n r in - let l2, c = grab_size FnValidateSize (n + 1) c in - (l1 @ l2, SizedType.SComplexMatrix (r, c)) - | SArray (t, s) -> - let l, s = grab_size FnValidateSize n s in - let ll, t = go (n + 1) t in - (l @ ll, SizedType.SArray (t, s)) in - go 1 - -let trans_block ud_dists declc block prog = - let f stmt (accum1, accum2, accum3) = - match stmt with - | { Ast.stmt= - VarDecl {decl_type= type_; variables; transformation; is_global= true} - ; smeta } -> - let outvars, sizes, stmts = - List.unzip3 - @@ List.map - ~f:(fun {identifier; initial_value} -> - let decl_id = identifier.Ast.name in - let transform = Transformation.map trans_expr transformation in - let rhs = Option.map ~f:trans_expr initial_value in - let size, type_ = - trans_sizedtype_decl declc transform identifier.name type_ - in - let decl_adtype = declc.dadlevel in - let decl_var = - Expr. - { Fixed.pattern= Var decl_id - ; meta= - Typed.Meta.create ~adlevel:declc.dadlevel - ~loc:smeta.Ast.loc - ~type_:(SizedType.to_unsized type_) - () } in - let decl = - Stmt. - { Fixed.pattern= - Decl - { decl_adtype - ; decl_id - ; decl_type= Sized type_ - ; initialize= true } - ; meta= smeta.loc } in - let rhs_assignment = - Option.map - ~f:(fun e -> - Stmt.Fixed. - { pattern= Assignment ((decl_id, e.meta.type_, []), e) - ; meta= smeta.loc } ) - rhs - |> Option.to_list in - let outvar = - ( identifier.name - , smeta.loc - , Program. - { out_constrained_st= type_ - ; out_unconstrained_st= param_size transform type_ - ; out_block= block - ; out_trans= transform } ) in - let stmts = - if Utils.is_user_ident decl_id then - let constrain_checks = - match declc.transform_action with - | Constrain | Unconstrain -> - check_transform_shape decl_id decl_var smeta.loc - transform - | Check -> - check_transform_shape decl_id decl_var smeta.loc - transform - @ check_decl decl_var (Type.Sized type_) decl_id - transform smeta.loc declc.dadlevel - | IgnoreTransform -> [] in - (decl :: rhs_assignment) @ constrain_checks - else decl :: rhs_assignment in - (outvar, size, stmts) ) - variables in - ( outvars @ accum1 - , List.concat sizes @ accum2 - , List.concat stmts @ accum3 ) - | stmt -> (accum1, accum2, trans_stmt ud_dists declc stmt @ accum3) in - Ast.get_stmts (get_block block prog) |> List.fold_right ~f ~init:([], [], []) - -let stmt_contains_check stmt = - let is_check = function - | Fun_kind.CompilerInternal (Internal_fun.FnCheck _) -> true - | _ -> false in - Stmt.Helpers.contains_fn_kind is_check stmt - -let migrate_checks_to_end_of_block stmts = - let checks, not_checks = List.partition_tf ~f:stmt_contains_check stmts in - not_checks @ checks - -let gather_declarations (b : Ast.typed_statement Ast.block option) = - let data = Ast.get_stmts b in - List.concat_map data ~f:(function - | {stmt= VarDecl {decl_type= sizedtype; transformation; variables; _}; _} -> - List.map - ~f:(fun {identifier; _} -> - ( SizedType.map trans_expr sizedtype - , Transformation.map trans_expr transformation - , identifier.name ) ) - variables - | _ -> [] ) - -let trans_prog filename (p : Ast.typed_program) : Program.Typed.t = - let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = - Deprecation_analysis.remove_unneeded_forward_decls p in - let map f list_op = - Option.value_map ~default:[] - ~f:(fun {Ast.stmts; _} -> List.concat_map ~f stmts) - list_op in - let grab_fundef_names_and_types = function - | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> - [(funname.name, type_)] - | _ -> [] in - let ud_dists = map grab_fundef_names_and_types functionblock in - let trans_stmt = trans_stmt ud_dists in - let get_name_size (s : Ast.typed_statement) = - match s.Ast.stmt with - | Ast.VarDecl {decl_type= st; variables; transformation; _} -> - List.map - ~f:(fun {identifier; _} -> - ( identifier.name - , trans_sizedtype st - , transformation - , s.Ast.smeta.loc ) ) + let get_block block prog = + match block with + | Program.Parameters -> prog.Ast.parametersblock + | TransformedParameters -> prog.transformedparametersblock + | GeneratedQuantities -> prog.generatedquantitiesblock + + let trans_sizedtype_decl declc tr name = + let check fn x n = + Stmt.Helpers.internal_nrfunapp fn + Expr.Helpers. + [str name; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n] + n.meta.loc in + let grab_size fn n = function + | Ast.{expr= IntNumeral i; _} as s when float_of_string i >= 2. -> + ([], trans_expr s) + | Ast.({expr= IntNumeral _; _} | {expr= Variable _; _}) as s -> + let e = trans_expr s in + ([check fn s e], e) + | s -> + let e = trans_expr s in + let decl_id = Fmt.str "%s_%ddim__" name n in + let decl = + { Stmt.Fixed.pattern= + Decl + { decl_type= Sized SInt + ; decl_id + ; decl_adtype= DataOnly + ; initialize= true } + ; meta= e.meta.loc } in + let assign = + { Stmt.Fixed.pattern= Assignment ((decl_id, UInt, []), e) + ; meta= e.meta.loc } in + let var = + Expr. + { Fixed.pattern= Var decl_id + ; meta= + Typed.Meta. + { type_= s.Ast.emeta.Ast.type_ + ; adlevel= s.emeta.ad_level + ; loc= s.emeta.loc } } in + ([decl; assign; check fn s var], var) in + let rec go n = function + | SizedType.(SInt | SReal | SComplex) as t -> ([], t) + | SVector (mem_pattern, s) -> + let fn = + match (declc.transform_action, tr) with + | Constrain, Transformation.Simplex -> + Internal_fun.FnValidateSizeSimplex + | Constrain, UnitVector -> FnValidateSizeUnitVector + | _ -> FnValidateSize in + let l, s = grab_size fn n s in + (l, SizedType.SVector (mem_pattern, s)) + | SRowVector (mem_pattern, s) -> + let l, s = grab_size FnValidateSize n s in + (l, SizedType.SRowVector (mem_pattern, s)) + | SComplexRowVector s -> + let l, s = grab_size FnValidateSize n s in + (l, SizedType.SComplexRowVector s) + | SComplexVector s -> + let l, s = grab_size FnValidateSize n s in + (l, SizedType.SComplexVector s) + | SMatrix (mem_pattern, r, c) -> + let l1, r = grab_size FnValidateSize n r in + let l2, c = grab_size FnValidateSize (n + 1) c in + let cf_cov = + match (declc.transform_action, tr) with + | Constrain, CholeskyCov -> + [ { Stmt.Fixed.pattern= + NRFunApp + ( StanLib ("check_greater_or_equal", FnPlain, AoS) + , Expr.Helpers. + [ str ("cholesky_factor_cov " ^ name) + ; str + "num rows (must be greater or equal to num \ + cols)"; r; c ] ) + ; meta= r.Expr.Fixed.meta.Expr.Typed.Meta.loc } ] + | _ -> [] in + (l1 @ l2 @ cf_cov, SizedType.SMatrix (mem_pattern, r, c)) + | SComplexMatrix (r, c) -> + let l1, r = grab_size FnValidateSize n r in + let l2, c = grab_size FnValidateSize (n + 1) c in + (l1 @ l2, SizedType.SComplexMatrix (r, c)) + | SArray (t, s) -> + let l, s = grab_size FnValidateSize n s in + let ll, t = go (n + 1) t in + (l @ ll, SizedType.SArray (t, s)) in + go 1 + + let trans_block ud_dists declc block prog = + let f stmt (accum1, accum2, accum3) = + match stmt with + | { Ast.stmt= + VarDecl + {decl_type= type_; transformation; variables; is_global= true} + ; smeta } -> + let outvars, sizes, stmts = + List.unzip3 + @@ List.map + ~f:(fun {identifier; initial_value} -> + let decl_id = identifier.Ast.name in + let transform = + Transformation.map trans_expr transformation in + let rhs = Option.map ~f:trans_expr initial_value in + let size, type_ = + trans_sizedtype_decl declc transform identifier.name type_ + in + let decl_adtype = declc.dadlevel in + let decl_var = + Expr. + { Fixed.pattern= Var decl_id + ; meta= + Typed.Meta.create ~adlevel:declc.dadlevel + ~loc:smeta.Ast.loc + ~type_:(SizedType.to_unsized type_) + () } in + let decl = + Stmt. + { Fixed.pattern= + Decl + { decl_adtype + ; decl_id + ; decl_type= Sized type_ + ; initialize= true } + ; meta= smeta.loc } in + let rhs_assignment = + Option.map + ~f:(fun e -> + Stmt.Fixed. + { pattern= Assignment ((decl_id, e.meta.type_, []), e) + ; meta= smeta.loc } ) + rhs + |> Option.to_list in + let outvar = + ( identifier.name + , smeta.loc + , Program. + { out_constrained_st= type_ + ; out_unconstrained_st= param_size transform type_ + ; out_block= block + ; out_trans= transform } ) in + let stmts = + if Utils.is_user_ident decl_id then + let constrain_checks = + match declc.transform_action with + | Constrain | Unconstrain -> + check_transform_shape decl_id decl_var smeta.loc + transform + | Check -> + check_transform_shape decl_id decl_var smeta.loc + transform + @ check_decl decl_var (Type.Sized type_) decl_id + transform smeta.loc declc.dadlevel + | IgnoreTransform -> [] in + (decl :: rhs_assignment) @ constrain_checks + else decl :: rhs_assignment in + (outvar, size, stmts) ) + variables in + ( outvars @ accum1 + , List.concat sizes @ accum2 + , List.concat stmts @ accum3 ) + | stmt -> (accum1, accum2, trans_stmt ud_dists declc stmt @ accum3) in + Ast.get_stmts (get_block block prog) |> List.fold_right ~f ~init:([], [], []) + + let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) + = + let stmt_typed = ts.stmt and smeta = ts.smeta.loc in + let trans_stmt = + trans_stmt ud_dists {declc with transform_action= IgnoreTransform} in + let trans_single_stmt s = + match trans_stmt s with + | [s] -> s + | s -> Stmt.Fixed.{pattern= SList s; meta= smeta} in + let swrap pattern = [Stmt.Fixed.{meta= smeta; pattern}] in + let mloc = smeta in + match stmt_typed with + | Ast.Assignment {assign_lhs; assign_rhs; assign_op} -> + let rec get_lhs_base = function + | {Ast.lval= Ast.LIndexed (l, _); _} -> get_lhs_base l + | {lval= LVariable s; lmeta} -> (s, lmeta) in + let assign_identifier, lmeta = get_lhs_base assign_lhs in + let id_ad_level = lmeta.Ast.ad_level in + let id_type_ = lmeta.Ast.type_ in + let lhs_type_ = assign_lhs.Ast.lmeta.type_ in + let lhs_ad_level = assign_lhs.Ast.lmeta.ad_level in + let rec get_lhs_indices = function + | {Ast.lval= Ast.LIndexed (l, i); _} -> get_lhs_indices l @ i + | {Ast.lval= Ast.LVariable _; _} -> [] in + let assign_indices = get_lhs_indices assign_lhs in + let assignee = + { Ast.expr= + ( match assign_indices with + | [] -> Ast.Variable assign_identifier + | _ -> + Ast.Indexed + ( { expr= Ast.Variable assign_identifier + ; emeta= + { Ast.loc= Location_span.empty + ; ad_level= id_ad_level + ; type_= id_type_ } } + , assign_indices ) ) + ; emeta= + { Ast.loc= assign_lhs.lmeta.loc + ; ad_level= lhs_ad_level + ; type_= lhs_type_ } } in + let rhs = + match assign_op with + | Ast.Assign | Ast.ArrowAssign -> trans_expr assign_rhs + | Ast.OperatorAssign op -> + op_to_funapp op [assignee; assign_rhs] assignee.emeta.type_ in + Assignment + ( ( assign_identifier.Ast.name + , id_type_ + , List.map ~f:trans_idx assign_indices ) + , rhs ) + |> swrap + | Ast.NRFunApp (fn_kind, {name; _}, args) -> + NRFunApp (trans_fn_kind fn_kind name, trans_exprs args) |> swrap + | Ast.IncrementLogProb e | Ast.TargetPE e -> + TargetPE (trans_expr e) |> swrap + | Ast.Tilde {arg; distribution; args; truncation} -> + let suffix = + Std_library_utils.dist_name_suffix + (module StdLibrary) + ud_dists distribution.name in + let name = distribution.name ^ suffix in + let kind = + let possible_names = + List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices + |> String.Set.of_list in + if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists + then Fun_kind.UserDefined (name, FnLpdf true) + else StanLib (name, FnLpdf true, AoS) in + let add_dist = + Stmt.Fixed.Pattern.TargetPE + Expr. + { Fixed.pattern= FunApp (kind, trans_exprs (arg :: args)) + ; meta= + Typed.Meta.create ~type_:UReal ~loc:mloc + ~adlevel:(Ast.expr_ad_lub (arg :: args)) + () } in + swrap add_dist @ truncate_dist ud_dists distribution arg args truncation + | Ast.Print ps -> + NRFunApp (CompilerInternal FnPrint, trans_printables smeta ps) |> swrap + | Ast.Reject ps -> + NRFunApp (CompilerInternal FnReject, trans_printables smeta ps) |> swrap + | Ast.IfThenElse (cond, ifb, elseb) -> + IfElse + ( trans_expr cond + , trans_single_stmt ifb + , Option.map ~f:trans_single_stmt elseb ) + |> swrap + | Ast.While (cond, body) -> + While (trans_expr cond, trans_single_stmt body) |> swrap + | Ast.For {loop_variable; lower_bound; upper_bound; loop_body} -> + let body = + match trans_single_stmt loop_body with + | {pattern= Block _; _} as b -> b + | x -> {x with pattern= Block [x]} in + For + { loopvar= loop_variable.Ast.name + ; lower= trans_expr lower_bound + ; upper= trans_expr upper_bound + ; body } + |> swrap + | Ast.ForEach (loopvar, iteratee, body) -> + let iteratee' = trans_expr iteratee in + let body_stmts = + match trans_single_stmt body with + | {pattern= Block body_stmts; _} -> body_stmts + | b -> [b] in + let decl_type = + match Expr.Typed.type_of iteratee' with + | UMatrix -> UnsizedType.UReal + | t -> + Expr.Helpers.(infer_type_of_indexed t [Index.Single loop_bottom]) + in + let decl_loopvar = + Stmt.Fixed. + { meta= smeta + ; pattern= + Decl + { decl_adtype= Expr.Typed.adlevel_of iteratee' + ; decl_id= loopvar.name + ; decl_type= Unsized decl_type + ; initialize= true } } in + let assignment var = + Stmt.Fixed. + { pattern= Assignment ((loopvar.name, decl_type, []), var) + ; meta= smeta } in + let bodyfn var = + Stmt.Fixed. + { pattern= Block (decl_loopvar :: assignment var :: body_stmts) + ; meta= smeta } in + Stmt.Helpers.[ensure_var (for_each bodyfn) iteratee' smeta] + | Ast.FunDef _ -> + Common.FatalError.fatal_error_msg + [%message + "Found function definition statement outside of function block"] + | Ast.VarDecl {decl_type; transformation; variables; is_global= _} -> + List.concat_map + ~f:(fun {identifier; initial_value} -> + trans_decl declc smeta decl_type + (Transformation.map trans_expr transformation) + identifier initial_value ) variables - | _ -> [] in - let input_vars = - map get_name_size datablock - |> List.map ~f:(fun (n, st, _, loc) -> (n, loc, st)) in - let declc = {transform_action= IgnoreTransform; dadlevel= DataOnly} in - let datab = map (trans_stmt {declc with transform_action= Check}) datablock in - let _, _, param = - trans_block ud_dists - {transform_action= Constrain; dadlevel= AutoDiffable} - Parameters p in - (* Backends will add to transform_inits and unconstrain_array as needed *) - let transform_inits = [] in - let unconstrain_array = [] in - let out_param, paramsizes, param_gq = - trans_block ud_dists {declc with transform_action= Constrain} Parameters p - in - let _, _, txparam = - trans_block ud_dists - {transform_action= Check; dadlevel= AutoDiffable} - TransformedParameters p in - let out_tparam, tparamsizes, txparam_gq = - trans_block ud_dists - {declc with transform_action= Check} - TransformedParameters p in - let out_gq, gq_sizes, gq_stmts = - trans_block ud_dists - {declc with transform_action= Check} - GeneratedQuantities p in - let output_vars = out_param @ out_tparam @ out_gq in - let prepare_data = - datab - @ ( map - (trans_stmt {declc with transform_action= Check}) - transformeddatablock - |> migrate_checks_to_end_of_block ) - @ paramsizes @ tparamsizes @ gq_sizes in - let modelb = map (trans_stmt {declc with dadlevel= AutoDiffable}) modelblock in - let log_prob = - param - @ (txparam |> migrate_checks_to_end_of_block) - @ - match modelb with - | [] -> [] - | hd :: _ -> [{pattern= Block modelb; meta= hd.meta}] in - let txparam_decls, txparam_checks, txparam_stmts = - txparam_gq - |> List.partition3_map ~f:(function - | {pattern= Decl _; _} as d -> `Fst d - | s when stmt_contains_check s -> `Snd s - | s -> `Trd s ) in - let compiler_if_return cond = - Stmt.Fixed. - { pattern= - IfElse (cond, {pattern= Return None; meta= Location_span.empty}, None) - ; meta= Location_span.empty } in - let iexpr pattern = Expr.{pattern; Fixed.meta= Typed.Meta.empty} in - let fnot e = - FunApp (StanLib (Operator.to_string PNot, FnPlain, AoS), [e]) |> iexpr in - let tparam_early_return = - let to_var fv = iexpr (Var (Flag_vars.to_string fv)) in - let v1 = to_var EmitTransformedParameters in - let v2 = to_var EmitGeneratedQuantities in - [compiler_if_return (fnot (EOr (v1, v2) |> iexpr))] in - let gq_early_return = - [ compiler_if_return - (fnot (Var (Flag_vars.to_string EmitGeneratedQuantities) |> iexpr)) ] - in - let generate_quantities = - param_gq @ txparam_decls @ tparam_early_return @ txparam_stmts - @ txparam_checks @ gq_early_return - @ migrate_checks_to_end_of_block gq_stmts in - let normalize_prog_name prog_name = - if String.length prog_name > 0 && not (Char.is_alpha prog_name.[0]) then - "_" ^ prog_name - else prog_name in - { functions_block= map (trans_fun_def ud_dists) functionblock - ; input_vars - ; prepare_data - ; log_prob - ; generate_quantities - ; transform_inits - ; unconstrain_array - ; output_vars - ; prog_name= normalize_prog_name !Typechecker.model_name - ; prog_path= filename } + | Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap + | Ast.Profile (name, stmts) -> + Profile (name, List.concat_map ~f:trans_stmt stmts) |> swrap + | Ast.Return e -> Return (Some (trans_expr e)) |> swrap + | Ast.ReturnVoid -> Return None |> swrap + | Ast.Break -> Break |> swrap + | Ast.Continue -> Continue |> swrap + | Ast.Skip -> Skip |> swrap + + let gather_declarations (b : Ast.typed_statement Ast.block option) = + let data = Ast.get_stmts b in + List.concat_map data ~f:(function + | {stmt= VarDecl {decl_type= sizedtype; transformation; variables; _}; _} + -> + List.map + ~f:(fun {identifier; _} -> + ( SizedType.map trans_expr sizedtype + , Transformation.map trans_expr transformation + , identifier.name ) ) + variables + | _ -> [] ) + + let trans_prog filename (p : Ast.typed_program) : Program.Typed.t = + let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = + Deprecation_analysis.remove_unneeded_forward_decls p in + let map f list_op = + Option.value_map ~default:[] + ~f:(fun {Ast.stmts; _} -> List.concat_map ~f stmts) + list_op in + let grab_fundef_names_and_types = function + | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> + [(funname.name, type_)] + | _ -> [] in + let ud_dists = map grab_fundef_names_and_types functionblock in + let trans_stmt = trans_stmt ud_dists in + let get_name_size (s : Ast.typed_statement) = + match s.Ast.stmt with + | Ast.VarDecl {decl_type= st; variables; transformation; _} -> + List.map + ~f:(fun {identifier; _} -> + ( identifier.name + , trans_sizedtype st + , transformation + , s.Ast.smeta.loc ) ) + variables + | _ -> [] in + let input_vars = + map get_name_size datablock + |> List.map ~f:(fun (n, st, _, loc) -> (n, loc, st)) in + let declc = {transform_action= IgnoreTransform; dadlevel= DataOnly} in + let datab = + map (trans_stmt {declc with transform_action= Check}) datablock in + let _, _, param = + trans_block ud_dists + {transform_action= Constrain; dadlevel= AutoDiffable} + Parameters p in + (* Backends will add to transform_inits and unconstrain_array as needed *) + let transform_inits = [] in + let unconstrain_array = [] in + let out_param, paramsizes, param_gq = + trans_block ud_dists {declc with transform_action= Constrain} Parameters p + in + let _, _, txparam = + trans_block ud_dists + {transform_action= Check; dadlevel= AutoDiffable} + TransformedParameters p in + let out_tparam, tparamsizes, txparam_gq = + trans_block ud_dists + {declc with transform_action= Check} + TransformedParameters p in + let out_gq, gq_sizes, gq_stmts = + trans_block ud_dists + {declc with transform_action= Check} + GeneratedQuantities p in + let output_vars = out_param @ out_tparam @ out_gq in + let prepare_data = + datab + @ ( map + (trans_stmt {declc with transform_action= Check}) + transformeddatablock + |> migrate_checks_to_end_of_block ) + @ paramsizes @ tparamsizes @ gq_sizes in + let modelb = + map (trans_stmt {declc with dadlevel= AutoDiffable}) modelblock in + let log_prob = + param + @ (txparam |> migrate_checks_to_end_of_block) + @ + match modelb with + | [] -> [] + | hd :: _ -> [{pattern= Block modelb; meta= hd.meta}] in + let txparam_decls, txparam_checks, txparam_stmts = + txparam_gq + |> List.partition3_map ~f:(function + | {pattern= Decl _; _} as d -> `Fst d + | s when stmt_contains_check s -> `Snd s + | s -> `Trd s ) in + let compiler_if_return cond = + Stmt.Fixed. + { pattern= + IfElse + (cond, {pattern= Return None; meta= Location_span.empty}, None) + ; meta= Location_span.empty } in + let iexpr pattern = Expr.{pattern; Fixed.meta= Typed.Meta.empty} in + let fnot e = + FunApp (StanLib (Operator.to_string PNot, FnPlain, AoS), [e]) |> iexpr + in + let tparam_early_return = + let to_var fv = iexpr (Var (Flag_vars.to_string fv)) in + let v1 = to_var EmitTransformedParameters in + let v2 = to_var EmitGeneratedQuantities in + [compiler_if_return (fnot (EOr (v1, v2) |> iexpr))] in + let gq_early_return = + [ compiler_if_return + (fnot (Var (Flag_vars.to_string EmitGeneratedQuantities) |> iexpr)) ] + in + let generate_quantities = + param_gq @ txparam_decls @ tparam_early_return @ txparam_stmts + @ txparam_checks @ gq_early_return + @ migrate_checks_to_end_of_block gq_stmts in + let normalize_prog_name prog_name = + if String.length prog_name > 0 && not (Char.is_alpha prog_name.[0]) then + "_" ^ prog_name + else prog_name in + { functions_block= map (trans_fun_def ud_dists) functionblock + ; input_vars + ; prepare_data + ; log_prob + ; generate_quantities + ; transform_inits + ; unconstrain_array + ; output_vars + ; prog_name= normalize_prog_name !Typechecking.model_name + ; prog_path= filename } +end diff --git a/src/frontend/Ast_to_Mir.mli b/src/frontend/Ast_to_Mir.mli index 5d6c7f6f6a..6aef789fb7 100644 --- a/src/frontend/Ast_to_Mir.mli +++ b/src/frontend/Ast_to_Mir.mli @@ -1,8 +1,12 @@ (** Translate from the AST to the MIR *) open Middle -val gather_declarations : - Ast.typed_statement Ast.block option - -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list +module type AST_MIR_TRANSLATOR = sig + val gather_declarations : + Ast.typed_statement Ast.block option + -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list -val trans_prog : string -> Ast.typed_program -> Program.Typed.t + val trans_prog : string -> Ast.typed_program -> Program.Typed.t +end + +module Make (StdLibrary : Std_library_utils.Library) : AST_MIR_TRANSLATOR diff --git a/src/frontend/Canonicalize.ml b/src/frontend/Canonicalize.ml index 73e923cad1..ce1822720f 100644 --- a/src/frontend/Canonicalize.ml +++ b/src/frontend/Canonicalize.ml @@ -1,6 +1,5 @@ open Core_kernel open Ast -open Deprecation_analysis type canonicalizer_settings = { deprecations: bool @@ -23,261 +22,281 @@ let none = ; braces= false ; strip_comments= false } -let rec repair_syntax_stmt user_dists {stmt; smeta} = - match stmt with - | Tilde {arg; distribution= {name; id_loc}; args; truncation} -> - { stmt= - Tilde - { arg - ; distribution= {name= without_suffix user_dists name; id_loc} - ; args - ; truncation } - ; smeta } - | _ -> - { stmt= - map_statement ident (repair_syntax_stmt user_dists) ident ident stmt - ; smeta } +module type CANONICALIZER = sig + val repair_syntax : + untyped_program -> canonicalizer_settings -> untyped_program -let rec replace_deprecated_expr - (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) - {expr; emeta} = - let expr = - match expr with - | GetLP -> GetTarget - | FunApp (StanLib FnPlain, {name= "if_else"; _}, [c; t; e]) -> - Paren - (replace_deprecated_expr deprecated_userdefined - {expr= TernaryIf ({expr= Paren c; emeta= c.emeta}, t, e); emeta} ) - | FunApp (StanLib suffix, {name; id_loc}, e) -> - if is_deprecated_distribution name then - CondDistApp - ( StanLib suffix - , {name= rename_deprecated deprecated_distributions name; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - else if String.is_suffix name ~suffix:"_cdf" then - CondDistApp - ( StanLib suffix - , {name; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - else - FunApp - ( StanLib suffix - , {name= rename_deprecated deprecated_functions name; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - | FunApp (UserDefined suffix, {name; id_loc}, e) -> ( - match String.Map.find deprecated_userdefined name with - | Some type_ -> - CondDistApp - ( UserDefined suffix - , {name= update_suffix name type_; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - | None -> - if String.is_suffix name ~suffix:"_cdf" then + val canonicalize_program : + typed_program -> canonicalizer_settings -> typed_program +end + +module Make (Deprecation : Deprecation_analysis.DEPRECATION_ANALYZER) : + CANONICALIZER = struct + let rec repair_syntax_stmt user_dists {stmt; smeta} = + match stmt with + | Tilde {arg; distribution= {name; id_loc}; args; truncation} -> + { stmt= + Tilde + { arg + ; distribution= + {name= Deprecation.without_suffix user_dists name; id_loc} + ; args + ; truncation } + ; smeta } + | _ -> + { stmt= + map_statement ident (repair_syntax_stmt user_dists) ident ident stmt + ; smeta } + + let rec replace_deprecated_expr + (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) + {expr; emeta} = + let expr = + match expr with + | GetLP -> GetTarget + | FunApp (StanLib FnPlain, {name= "if_else"; _}, [c; t; e]) -> + Paren + (replace_deprecated_expr deprecated_userdefined + {expr= TernaryIf ({expr= Paren c; emeta= c.emeta}, t, e); emeta} ) + | FunApp (StanLib suffix, {name; id_loc}, e) -> + if Deprecation.is_deprecated_distribution name then CondDistApp - ( UserDefined suffix + ( StanLib suffix + , {name= Deprecation.rename_deprecated_distribution name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) + else if String.is_suffix name ~suffix:"_cdf" then + CondDistApp + ( StanLib suffix , {name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) else FunApp + ( StanLib suffix + , {name= Deprecation.rename_deprecated_function name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) + | FunApp (UserDefined suffix, {name; id_loc}, e) -> ( + match String.Map.find deprecated_userdefined name with + | Some type_ -> + CondDistApp ( UserDefined suffix - , {name; id_loc} + , {name= Deprecation.update_suffix name type_; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e - ) ) - | PrefixOp (PNot, e) -> - PrefixOp - (PNot, replace_boolean_real ~parens:true deprecated_userdefined e) - | BinOp (e1, ((And | Or) as op), e2) -> - BinOp - ( replace_boolean_real ~parens:true deprecated_userdefined e1 - , op - , replace_boolean_real ~parens:true deprecated_userdefined e2 ) - | _ -> - map_expression - (replace_deprecated_expr deprecated_userdefined) - ident expr in - {expr; emeta} - -and replace_boolean_real ?(parens = false) deprecated_userdefined e = - match e with - | {emeta= {type_= UReal; _}; _} when parens -> - { emeta= {e.emeta with type_= UInt} - ; expr= - Paren (replace_boolean_real ~parens:false deprecated_userdefined e) } - | {emeta= {type_= UReal; _}; _} -> - { emeta= {e.emeta with type_= UInt} - ; expr= + ) + | None -> + if String.is_suffix name ~suffix:"_cdf" then + CondDistApp + ( UserDefined suffix + , {name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) + else + FunApp + ( UserDefined suffix + , {name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) ) + | PrefixOp (PNot, e) -> + PrefixOp + (PNot, replace_boolean_real ~parens:true deprecated_userdefined e) + | BinOp (e1, ((And | Or) as op), e2) -> BinOp - ( replace_deprecated_expr deprecated_userdefined e - , NEquals - , { expr= RealNumeral "0.0" - ; emeta= - { type_= UInt - ; loc= Middle.Location_span.empty - ; ad_level= DataOnly } } ) } - | _ -> replace_deprecated_expr deprecated_userdefined e + ( replace_boolean_real ~parens:true deprecated_userdefined e1 + , op + , replace_boolean_real ~parens:true deprecated_userdefined e2 ) + | _ -> + map_expression + (replace_deprecated_expr deprecated_userdefined) + ident expr in + {expr; emeta} -let replace_deprecated_lval deprecated_userdefined {lval; lmeta} = - let is_multiindex = function - | Single {emeta= {type_= Middle.UnsizedType.UInt; _}; _} -> false - | _ -> true in - let rec flatten_multi = function - | LVariable id -> (LVariable id, None) - | LIndexed ({lval; lmeta}, idcs) -> ( - let outer = - List.map idcs - ~f:(map_index (replace_deprecated_expr deprecated_userdefined)) - in - let unwrap = Option.value_map ~default:[] ~f:fst in - match flatten_multi lval with - | lval, inner when List.exists ~f:is_multiindex outer -> - (lval, Some (unwrap inner @ outer, lmeta)) - | lval, None -> (LIndexed ({lval; lmeta}, outer), None) - | lval, Some (inner, _) -> (lval, Some (inner @ outer, lmeta)) ) in - let lval = - match flatten_multi lval with - | lval, None -> lval - | lval, Some (idcs, lmeta) -> LIndexed ({lval; lmeta}, idcs) in - {lval; lmeta} + and replace_boolean_real ?(parens = false) deprecated_userdefined e = + match e with + | {emeta= {type_= UReal; _}; _} when parens -> + { emeta= {e.emeta with type_= UInt} + ; expr= + Paren (replace_boolean_real ~parens:false deprecated_userdefined e) + } + | {emeta= {type_= UReal; _}; _} -> + { emeta= {e.emeta with type_= UInt} + ; expr= + BinOp + ( replace_deprecated_expr deprecated_userdefined e + , NEquals + , { expr= RealNumeral "0.0" + ; emeta= + { type_= UInt + ; loc= Middle.Location_span.empty + ; ad_level= DataOnly } } ) } + | _ -> replace_deprecated_expr deprecated_userdefined e -let rec replace_deprecated_stmt - (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) - ({stmt; smeta} : typed_statement) = - let stmt = - match stmt with - | IncrementLogProb e -> - TargetPE (replace_deprecated_expr deprecated_userdefined e) - | Assignment {assign_lhs= l; assign_op= ArrowAssign; assign_rhs= e} -> - Assignment - { assign_lhs= replace_deprecated_lval deprecated_userdefined l - ; assign_op= Assign - ; assign_rhs= (replace_deprecated_expr deprecated_userdefined) e } - | FunDef {returntype; funname= {name; id_loc}; arguments; body} -> - let newname = - match String.Map.find deprecated_userdefined name with - | Some type_ -> update_suffix name type_ - | None -> name in - FunDef - { returntype - ; funname= {name= newname; id_loc} - ; arguments - ; body= replace_deprecated_stmt deprecated_userdefined body } - | IfThenElse (({emeta= {type_= UReal; _}; _} as cond), ifb, elseb) -> - IfThenElse - ( replace_boolean_real deprecated_userdefined cond - , replace_deprecated_stmt deprecated_userdefined ifb - , Option.map ~f:(replace_deprecated_stmt deprecated_userdefined) elseb - ) - | While (({emeta= {type_= UReal; _}; _} as cond), body) -> - While - ( replace_boolean_real deprecated_userdefined cond - , replace_deprecated_stmt deprecated_userdefined body ) - | _ -> - map_statement - (replace_deprecated_expr deprecated_userdefined) - (replace_deprecated_stmt deprecated_userdefined) - (replace_deprecated_lval deprecated_userdefined) - ident stmt in - {stmt; smeta} + let replace_deprecated_lval deprecated_userdefined {lval; lmeta} = + let is_multiindex = function + | Single {emeta= {type_= Middle.UnsizedType.UInt; _}; _} -> false + | _ -> true in + let rec flatten_multi = function + | LVariable id -> (LVariable id, None) + | LIndexed ({lval; lmeta}, idcs) -> ( + let outer = + List.map idcs + ~f:(map_index (replace_deprecated_expr deprecated_userdefined)) + in + let unwrap = Option.value_map ~default:[] ~f:fst in + match flatten_multi lval with + | lval, inner when List.exists ~f:is_multiindex outer -> + (lval, Some (unwrap inner @ outer, lmeta)) + | lval, None -> (LIndexed ({lval; lmeta}, outer), None) + | lval, Some (inner, _) -> (lval, Some (inner @ outer, lmeta)) ) in + let lval = + match flatten_multi lval with + | lval, None -> lval + | lval, Some (idcs, lmeta) -> LIndexed ({lval; lmeta}, idcs) in + {lval; lmeta} -let rec no_parens {expr; emeta} = - match expr with - | Paren e -> no_parens e - | Variable _ | IntNumeral _ | RealNumeral _ | ImagNumeral _ | GetLP - |GetTarget -> - {expr; emeta} - | BinOp (({expr= BinOp (_, op1, _); _} as e1), op2, e2) - when Middle.Operator.(is_cmp op1 && is_cmp op2) -> - { expr= BinOp ({e1 with expr= Paren (no_parens e1)}, op2, keep_parens e2) - ; emeta } - | TernaryIf _ | BinOp _ | PrefixOp _ | PostfixOp _ -> - {expr= map_expression keep_parens ident expr; emeta} - | Indexed (e, l) -> - { expr= - Indexed - ( keep_parens e - , List.map - ~f:(function - | Single e -> Single (no_parens e) - | i -> map_index keep_parens i ) - l ) - ; emeta } - | ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ -> - {expr= map_expression no_parens ident expr; emeta} + let rec replace_deprecated_stmt + (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) + ({stmt; smeta} : typed_statement) = + let stmt = + match stmt with + | IncrementLogProb e -> + TargetPE (replace_deprecated_expr deprecated_userdefined e) + | Assignment {assign_lhs= l; assign_op= ArrowAssign; assign_rhs= e} -> + Assignment + { assign_lhs= replace_deprecated_lval deprecated_userdefined l + ; assign_op= Assign + ; assign_rhs= (replace_deprecated_expr deprecated_userdefined) e } + | FunDef {returntype; funname= {name; id_loc}; arguments; body} -> + let newname = + match String.Map.find deprecated_userdefined name with + | Some type_ -> Deprecation.update_suffix name type_ + | None -> name in + FunDef + { returntype + ; funname= {name= newname; id_loc} + ; arguments + ; body= replace_deprecated_stmt deprecated_userdefined body } + | IfThenElse (({emeta= {type_= UReal; _}; _} as cond), ifb, elseb) -> + IfThenElse + ( replace_boolean_real deprecated_userdefined cond + , replace_deprecated_stmt deprecated_userdefined ifb + , Option.map + ~f:(replace_deprecated_stmt deprecated_userdefined) + elseb ) + | While (({emeta= {type_= UReal; _}; _} as cond), body) -> + While + ( replace_boolean_real deprecated_userdefined cond + , replace_deprecated_stmt deprecated_userdefined body ) + | _ -> + map_statement + (replace_deprecated_expr deprecated_userdefined) + (replace_deprecated_stmt deprecated_userdefined) + (replace_deprecated_lval deprecated_userdefined) + ident stmt in + {stmt; smeta} -and keep_parens {expr; emeta} = - match expr with - | Promotion (e, ut, ad) -> {expr= Promotion (keep_parens e, ut, ad); emeta} - | Paren ({expr= Paren _; _} as e) -> keep_parens e - | Paren ({expr= BinOp _; _} as e) - |Paren ({expr= PrefixOp _; _} as e) - |Paren ({expr= PostfixOp _; _} as e) - |Paren ({expr= TernaryIf _; _} as e) -> - {expr= Paren (no_parens e); emeta} - | _ -> no_parens {expr; emeta} + let rec no_parens {expr; emeta} = + match expr with + | Paren e -> no_parens e + | Variable _ | IntNumeral _ | RealNumeral _ | ImagNumeral _ | GetLP + |GetTarget -> + {expr; emeta} + | BinOp (({expr= BinOp (_, op1, _); _} as e1), op2, e2) + when Middle.Operator.(is_cmp op1 && is_cmp op2) -> + { expr= BinOp ({e1 with expr= Paren (no_parens e1)}, op2, keep_parens e2) + ; emeta } + | TernaryIf _ | BinOp _ | PrefixOp _ | PostfixOp _ -> + {expr= map_expression keep_parens ident expr; emeta} + | Indexed (e, l) -> + { expr= + Indexed + ( keep_parens e + , List.map + ~f:(function + | Single e -> Single (no_parens e) + | i -> map_index keep_parens i ) + l ) + ; emeta } + | ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ -> + {expr= map_expression no_parens ident expr; emeta} -let parens_lval = map_lval_with no_parens ident + and keep_parens {expr; emeta} = + match expr with + | Promotion (e, ut, ad) -> {expr= Promotion (keep_parens e, ut, ad); emeta} + | Paren ({expr= Paren _; _} as e) -> keep_parens e + | Paren ({expr= BinOp _; _} as e) + |Paren ({expr= PrefixOp _; _} as e) + |Paren ({expr= PostfixOp _; _} as e) + |Paren ({expr= TernaryIf _; _} as e) -> + {expr= Paren (no_parens e); emeta} + | _ -> no_parens {expr; emeta} -let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement = - let stmt = - match stmt with - | VarDecl {decl_type= d; transformation= t; variables; is_global} -> - VarDecl - { decl_type= Middle.SizedType.map no_parens d - ; transformation= Middle.Transformation.map keep_parens t - ; variables= List.map ~f:(map_variable no_parens) variables - ; is_global } - | For {loop_variable; lower_bound; upper_bound; loop_body} -> - For - { loop_variable - ; lower_bound= keep_parens lower_bound - ; upper_bound= keep_parens upper_bound - ; loop_body= parens_stmt loop_body } - | _ -> map_statement no_parens parens_stmt parens_lval ident stmt in - {stmt; smeta} + let parens_lval = map_lval_with no_parens ident -let rec blocks_stmt ({stmt; smeta} : typed_statement) : typed_statement = - let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement = - match stmt with - | Block _ -> blocks_stmt {stmt; smeta} - | _ -> - blocks_stmt - @@ mk_typed_statement - ~stmt:(Block [{stmt; smeta}]) - ~return_type:smeta.return_type ~loc:smeta.loc in - let stmt = - match stmt with - | While (e, s) -> While (e, stmt_to_block s) - | IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2)) - |IfThenElse (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _}) - -> - (* Flatten if ... else if ... constructs *) - IfThenElse (e, stmt_to_block s1, Some (blocks_stmt s2)) - | IfThenElse (e, s1, s2) -> - IfThenElse (e, stmt_to_block s1, Option.map ~f:stmt_to_block s2) - | For ({loop_body; _} as f) -> - For {f with loop_body= stmt_to_block loop_body} - | _ -> map_statement ident blocks_stmt ident ident stmt in - {stmt; smeta} + let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement = + let stmt = + match stmt with + | VarDecl {decl_type= d; transformation= t; variables; is_global} -> + VarDecl + { decl_type= Middle.SizedType.map no_parens d + ; transformation= Middle.Transformation.map keep_parens t + ; variables= List.map ~f:(map_variable no_parens) variables + ; is_global } + | For {loop_variable; lower_bound; upper_bound; loop_body} -> + For + { loop_variable + ; lower_bound= keep_parens lower_bound + ; upper_bound= keep_parens upper_bound + ; loop_body= parens_stmt loop_body } + | _ -> map_statement no_parens parens_stmt parens_lval ident stmt in + {stmt; smeta} -let repair_syntax program settings = - if settings.deprecations then - program - |> map_program - (repair_syntax_stmt (userdef_distributions program.functionblock)) - else program + let rec blocks_stmt ({stmt; smeta} : typed_statement) : typed_statement = + let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement = + match stmt with + | Block _ -> blocks_stmt {stmt; smeta} + | _ -> + blocks_stmt + @@ mk_typed_statement + ~stmt:(Block [{stmt; smeta}]) + ~return_type:smeta.return_type ~loc:smeta.loc in + let stmt = + match stmt with + | While (e, s) -> While (e, stmt_to_block s) + | IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2)) + |IfThenElse + (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _}) -> + (* Flatten if ... else if ... constructs *) + IfThenElse (e, stmt_to_block s1, Some (blocks_stmt s2)) + | IfThenElse (e, s1, s2) -> + IfThenElse (e, stmt_to_block s1, Option.map ~f:stmt_to_block s2) + | For ({loop_body; _} as f) -> + For {f with loop_body= stmt_to_block loop_body} + | _ -> map_statement ident blocks_stmt ident ident stmt in + {stmt; smeta} -let canonicalize_program program settings : typed_program = - let program = + let repair_syntax program settings = if settings.deprecations then - remove_unneeded_forward_decls program + program |> map_program - (replace_deprecated_stmt (collect_userdef_distributions program)) - else program in - let program = - if settings.parentheses then program |> map_program parens_stmt else program - in - let program = - if settings.braces then program |> map_program blocks_stmt else program - in - program + (repair_syntax_stmt + (Deprecation.userdef_distributions program.functionblock) ) + else program + + let canonicalize_program program settings : typed_program = + let program = + if settings.deprecations then + Deprecation_analysis.remove_unneeded_forward_decls program + |> map_program + (replace_deprecated_stmt + (Deprecation.collect_userdef_distributions program) ) + else program in + let program = + if settings.parentheses then program |> map_program parens_stmt + else program in + let program = + if settings.braces then program |> map_program blocks_stmt else program + in + program +end diff --git a/src/frontend/Canonicalize.mli b/src/frontend/Canonicalize.mli index 344c1401b5..ca01fe1850 100644 --- a/src/frontend/Canonicalize.mli +++ b/src/frontend/Canonicalize.mli @@ -19,11 +19,17 @@ val legacy : canonicalizer_settings val none : canonicalizer_settings -val repair_syntax : untyped_program -> canonicalizer_settings -> untyped_program -(** When deprecation canonicalization is enabled, this runs before typechecking +module type CANONICALIZER = sig + val repair_syntax : + untyped_program -> canonicalizer_settings -> untyped_program + (** When deprecation canonicalization is enabled, this runs before typechecking and removes suffixes from ~ statements, which are otherwise forbidden by the typechecker *) -val canonicalize_program : - typed_program -> canonicalizer_settings -> typed_program -(** "Canonicalize" the program by removing deprecations, adding or removing parenthesis + val canonicalize_program : + typed_program -> canonicalizer_settings -> typed_program + (** "Canonicalize" the program by removing deprecations, adding or removing parenthesis and braces, etc. *) +end + +module Make (Deprecation : Deprecation_analysis.DEPRECATION_ANALYZER) : + CANONICALIZER diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index 9daae63dfa..f2c5925add 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -2,238 +2,37 @@ open Core_kernel open Ast open Middle -let deprecated_functions = - String.Map.of_alist_exn - [ ("multiply_log", ("lmultiply", "2.33.0")) - ; ("binomial_coefficient_log", ("lchoose", "2.33.0")) - ; ("cov_exp_quad", ("gp_exp_quad_cov", "2.33.0")) - ; ("fabs", ("abs", "2.33.0")) ] - -let deprecated_odes = - String.Map.of_alist_exn - [ ("integrate_ode", ("ode_rk45", "3.0")) - ; ("integrate_ode_rk45", ("ode_rk45", "3.0")) - ; ("integrate_ode_bdf", ("ode_bdf", "3.0")) - ; ("integrate_ode_adams", ("ode_adams", "3.0")) ] - -let deprecated_distributions = - String.Map.of_alist_exn - (List.map - ~f:(fun (x, y) -> (x, (y, "2.33.0"))) - (List.concat_map Middle.Stan_math_signatures.distributions - ~f:(fun (fnkinds, name, _, _) -> - List.filter_map fnkinds ~f:(function - | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") - | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") - | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") - | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") - | Rng | Log | UnaryVectorized _ -> None ) ) ) ) - -let stan_lib_deprecations = - Map.merge_skewed deprecated_distributions deprecated_functions - ~combine:(fun ~key x y -> - Common.FatalError.fatal_error_msg - [%message - "Common key in deprecation map" - (key : string) - (x : string * string) - (y : string * string)] ) - -let is_deprecated_distribution name = - Option.is_some (Map.find deprecated_distributions name) - -let rename_deprecated map name = - Map.find map name |> Option.map ~f:fst |> Option.value ~default:name +module type DEPRECATION_ANALYZER = sig + val find_udf_log_suffix : + typed_statement -> (string * Middle.UnsizedType.t) option + + val update_suffix : string -> Middle.UnsizedType.t -> string + + val collect_userdef_distributions : + typed_program -> Middle.UnsizedType.t String.Map.t + + val without_suffix : string list -> string -> string + val is_deprecated_distribution : string -> bool + val rename_deprecated_distribution : string -> string + val rename_deprecated_function : string -> string + val userdef_distributions : untyped_statement block option -> string list + val collect_warnings : typed_program -> Warnings.t list +end let userdef_functions program = match program.functionblock with - | None -> Hash_set.Poly.create () + | None -> [] | Some {stmts; _} -> List.filter_map stmts ~f:(function | {stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> None | {stmt= FunDef {funname; arguments; _}; _} -> Some (funname.name, Ast.type_of_arguments arguments) | _ -> None ) - |> Hash_set.Poly.of_list let is_redundant_forwarddecl fundefs funname arguments = - Hash_set.mem fundefs (funname.name, Ast.type_of_arguments arguments) - -let userdef_distributions stmts = - let open String in - List.filter_map - ~f:(function - | {stmt= FunDef {funname= {name; _}; _}; _} -> - if - is_suffix ~suffix:"_log_lpdf" name - || is_suffix ~suffix:"_log_lpmf" name - then Some (drop_suffix name 5) - else if is_suffix ~suffix:"_log_log" name then - Some (drop_suffix name 4) - else None - | _ -> None ) - (Ast.get_stmts stmts) - -let without_suffix user_dists name = - let open String in - if is_suffix ~suffix:"_lpdf" name || is_suffix ~suffix:"_lpmf" name then - drop_suffix name 5 - else if - is_suffix ~suffix:"_log" name - && not - ( is_deprecated_distribution (name ^ "_log") - || List.exists ~f:(( = ) name) user_dists ) - then drop_suffix name 4 - else name - -let update_suffix name type_ = - let open String in - if is_suffix ~suffix:"_cdf_log" name then drop_suffix name 8 ^ "_lcdf" - else if is_suffix ~suffix:"_ccdf_log" name then drop_suffix name 9 ^ "_lccdf" - else if Middle.UnsizedType.is_int_type type_ then drop_suffix name 4 ^ "_lpmf" - else drop_suffix name 4 ^ "_lpdf" - -let find_udf_log_suffix = function - | { stmt= - FunDef - { funname= {name; _} - ; arguments= (_, ((UReal | UInt) as type_), _) :: _ - ; _ } - ; smeta= _ } - when String.is_suffix ~suffix:"_log" name -> - Some (name, type_) - | _ -> None - -let rec collect_deprecated_expr (acc : (Location_span.t * string) list) - ({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) : - (Location_span.t * string) list = - match expr with - | FunApp (StanLib FnPlain, {name= "if_else"; _}, l) -> - acc - @ [ ( emeta.loc - , "The function `if_else` is deprecated and will be removed in Stan \ - 2.33.0. Use the conditional operator (x ? y : z) instead; this \ - can be automatically changed using the canonicalize flag for \ - stanc" ) ] - @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) - | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> - let w = - match Map.find stan_lib_deprecations name with - | Some (rename, version) -> - [ ( emeta.loc - , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. This can be automatically changed using the \ - canonicalize flag for stanc" ) ] - | _ when String.is_suffix name ~suffix:"_cdf" -> - [ ( emeta.loc - , "Use of " ^ name - ^ " without a vertical bar (|) between the first two arguments \ - of a CDF is deprecated and will be removed in Stan 2.33.0. \ - This can be automatically changed using the canonicalize \ - flag for stanc" ) ] - | _ -> ( - match Map.find deprecated_odes name with - | Some (rename, version) -> - [ ( emeta.loc - , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. \n\ - The new interface is slightly different, see: \ - https://mc-stan.org/users/documentation/case-studies/convert_odes.html" - ) ] - | _ -> [] ) in - acc @ w @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) - | PrefixOp (PNot, ({emeta= {type_= UReal; loc; _}; _} as e)) -> - let acc = - acc - @ [ ( loc - , "Using a real as a boolean value is deprecated and will be \ - disallowed in Stan 2.34. Use an explicit != 0 comparison \ - instead. This can be automatically changed using the \ - canonicalize flag for stanc" ) ] in - collect_deprecated_expr acc e - | BinOp (({emeta= {type_= UReal; loc; _}; _} as e1), (And | Or), e2) - |BinOp (e1, (And | Or), ({emeta= {type_= UReal; loc; _}; _} as e2)) -> - let acc = - acc - @ [ ( loc - , "Using a real as a boolean value is deprecated and will be \ - disallowed in Stan 2.34. Use an explicit != 0 comparison \ - instead. This can be automatically changed using the \ - canonicalize flag for stanc" ) ] in - let acc = collect_deprecated_expr acc e1 in - let acc = collect_deprecated_expr acc e2 in - acc - | _ -> fold_expression collect_deprecated_expr (fun l _ -> l) acc expr - -let collect_deprecated_lval acc l = - fold_lval_with collect_deprecated_expr (fun x _ -> x) acc l - -let rec collect_deprecated_stmt fundefs (acc : (Location_span.t * string) list) - {stmt; _} : (Location_span.t * string) list = - match stmt with - | FunDef {body= {stmt= Skip; _}; funname; arguments; _} - when is_redundant_forwarddecl fundefs funname arguments -> - acc - @ [ ( funname.id_loc - , "Functions do not need to be declared before definition; all user \ - defined function names are always in scope regardless of \ - defintion order." ) ] - | FunDef - { body - ; funname= {name; id_loc} - ; arguments= (_, ((UReal | UInt) as type_), _) :: _ - ; _ } - when String.is_suffix ~suffix:"_log" name -> - let acc = - acc - @ [ ( id_loc - , "Use of the _log suffix in user defined probability functions is \ - deprecated and will be removed in Stan 2.33.0, use name '" - ^ update_suffix name type_ - ^ "' instead if you intend on using this function in ~ \ - statements or calling unnormalized probability functions \ - inside of it." ) ] in - collect_deprecated_stmt fundefs acc body - | FunDef {body; _} -> collect_deprecated_stmt fundefs acc body - | IfThenElse ({emeta= {type_= UReal; loc; _}; _}, ifb, elseb) -> - let acc = - acc - @ [ ( loc - , "Condition of type real is deprecated and will be disallowed in \ - Stan 2.34. Use an explicit != 0 comparison instead. This can be \ - automatically changed using the canonicalize flag for stanc" ) ] - in - let acc = collect_deprecated_stmt fundefs acc ifb in - Option.value_map ~default:acc - ~f:(collect_deprecated_stmt fundefs acc) - elseb - | While ({emeta= {type_= UReal; loc; _}; _}, body) -> - let acc = - acc - @ [ ( loc - , "Condition of type real is deprecated and will be disallowed in \ - Stan 2.34. Use an explicit != 0 comparison instead. This can be \ - automatically changed using the canonicalize flag for stanc" ) ] - in - collect_deprecated_stmt fundefs acc body - | _ -> - fold_statement collect_deprecated_expr - (collect_deprecated_stmt fundefs) - collect_deprecated_lval - (fun l _ -> l) - acc stmt - -let collect_userdef_distributions program = - program.functionblock |> Ast.get_stmts - |> List.filter_map ~f:find_udf_log_suffix - |> List.dedup_and_sort ~compare:(fun (x, _) (y, _) -> String.compare x y) - |> String.Map.of_alist_exn - -let collect_warnings (program : typed_program) = - let fundefs = userdef_functions program in - fold_program (collect_deprecated_stmt fundefs) [] program + let equal (id1, a1) (id2, a2) = + String.equal id1 id2 && UnsizedType.equal_argumentlist a1 a2 in + List.mem ~equal fundefs (funname.name, Ast.type_of_arguments arguments) let remove_unneeded_forward_decls program = let fundefs = userdef_functions program in @@ -246,3 +45,191 @@ let remove_unneeded_forward_decls program = functionblock= Option.map program.functionblock ~f:(fun x -> {x with stmts= List.filter ~f:drop_forwarddecl x.stmts} ) } + +module Make (StdLibrary : Std_library_utils.Library) : DEPRECATION_ANALYZER = +struct + let stan_lib_deprecations = + Map.merge_skewed StdLibrary.deprecated_distributions + StdLibrary.deprecated_functions ~combine:(fun ~key x y -> + Common.FatalError.fatal_error_msg + [%message + "Common key in deprecation map" + (key : string) + (x : Std_library_utils.deprecation_info) + (y : Std_library_utils.deprecation_info)] ) + + let is_deprecated_distribution name = + Map.mem StdLibrary.deprecated_distributions name + + let rename_deprecated map name = + Map.find map name + |> Option.map + ~f:(fun Std_library_utils.{replacement; canonicalize_away; _} -> + if canonicalize_away then replacement else name ) + |> Option.value ~default:name + + let rename_deprecated_distribution = + rename_deprecated StdLibrary.deprecated_distributions + + let rename_deprecated_function = + rename_deprecated StdLibrary.deprecated_functions + + let userdef_distributions stmts = + let open String in + List.filter_map + ~f:(function + | {stmt= FunDef {funname= {name; _}; _}; _} -> + if + is_suffix ~suffix:"_log_lpdf" name + || is_suffix ~suffix:"_log_lpmf" name + then Some (drop_suffix name 5) + else if is_suffix ~suffix:"_log_log" name then + Some (drop_suffix name 4) + else None + | _ -> None ) + (Ast.get_stmts stmts) + + let without_suffix user_dists name = + let open String in + if is_suffix ~suffix:"_lpdf" name || is_suffix ~suffix:"_lpmf" name then + drop_suffix name 5 + else if + is_suffix ~suffix:"_log" name + && not + ( is_deprecated_distribution (name ^ "_log") + || List.exists ~f:(( = ) name) user_dists ) + then drop_suffix name 4 + else name + + let update_suffix name type_ = + let open String in + if is_suffix ~suffix:"_cdf_log" name then drop_suffix name 8 ^ "_lcdf" + else if is_suffix ~suffix:"_ccdf_log" name then + drop_suffix name 9 ^ "_lccdf" + else if Middle.UnsizedType.is_int_type type_ then + drop_suffix name 4 ^ "_lpmf" + else drop_suffix name 4 ^ "_lpdf" + + let find_udf_log_suffix = function + | { stmt= + FunDef + { funname= {name; _} + ; arguments= (_, ((UReal | UInt) as type_), _) :: _ + ; _ } + ; smeta= _ } + when String.is_suffix ~suffix:"_log" name -> + Some (name, type_) + | _ -> None + + let rec collect_deprecated_expr (acc : (Location_span.t * string) list) + ({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) : + (Location_span.t * string) list = + match expr with + | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> + let w = + match Map.find stan_lib_deprecations name with + | Some {replacement; version; extra_message; _} -> + [ ( emeta.loc + , name ^ " is deprecated and will be removed in Stan " ^ version + ^ ". Use " ^ replacement ^ " instead. " ^ extra_message ) ] + | _ when String.is_suffix name ~suffix:"_cdf" -> + [ ( emeta.loc + , "Use of " ^ name + ^ " without a vertical bar (|) between the first two \ + arguments of a CDF is deprecated and will be removed in \ + Stan 2.33.0. This can be automatically changed using the \ + canonicalize flag for stanc" ) ] + | _ -> [] in + acc @ w @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) + | PrefixOp (PNot, ({emeta= {type_= UReal; loc; _}; _} as e)) -> + let acc = + acc + @ [ ( loc + , "Using a real as a boolean value is deprecated and will be \ + disallowed in Stan 2.34. Use an explicit != 0 comparison \ + instead. This can be automatically changed using the \ + canonicalize flag for stanc" ) ] in + collect_deprecated_expr acc e + | BinOp (({emeta= {type_= UReal; loc; _}; _} as e1), (And | Or), e2) + |BinOp (e1, (And | Or), ({emeta= {type_= UReal; loc; _}; _} as e2)) -> + let acc = + acc + @ [ ( loc + , "Using a real as a boolean value is deprecated and will be \ + disallowed in Stan 2.34. Use an explicit != 0 comparison \ + instead. This can be automatically changed using the \ + canonicalize flag for stanc" ) ] in + let acc = collect_deprecated_expr acc e1 in + let acc = collect_deprecated_expr acc e2 in + acc + | _ -> fold_expression collect_deprecated_expr (fun l _ -> l) acc expr + + let collect_deprecated_lval acc l = + fold_lval_with collect_deprecated_expr (fun x _ -> x) acc l + + let rec collect_deprecated_stmt fundefs + (acc : (Location_span.t * string) list) {stmt; _} : + (Location_span.t * string) list = + match stmt with + | FunDef {body= {stmt= Skip; _}; funname; arguments; _} + when is_redundant_forwarddecl fundefs funname arguments -> + acc + @ [ ( funname.id_loc + , "Functions do not need to be declared before definition; all \ + user defined function names are always in scope regardless of \ + defintion order." ) ] + | FunDef + { body + ; funname= {name; id_loc} + ; arguments= (_, ((UReal | UInt) as type_), _) :: _ + ; _ } + when String.is_suffix ~suffix:"_log" name -> + let acc = + acc + @ [ ( id_loc + , "Use of the _log suffix in user defined probability functions \ + is deprecated and will be removed in Stan 2.33.0, use name '" + ^ update_suffix name type_ + ^ "' instead if you intend on using this function in ~ \ + statements or calling unnormalized probability functions \ + inside of it." ) ] in + collect_deprecated_stmt fundefs acc body + | FunDef {body; _} -> collect_deprecated_stmt fundefs acc body + | IfThenElse ({emeta= {type_= UReal; loc; _}; _}, ifb, elseb) -> + let acc = + acc + @ [ ( loc + , "Condition of type real is deprecated and will be disallowed \ + in Stan 2.34. Use an explicit != 0 comparison instead. This \ + can be automatically changed using the canonicalize flag for \ + stanc" ) ] in + let acc = collect_deprecated_stmt fundefs acc ifb in + Option.value_map ~default:acc + ~f:(collect_deprecated_stmt fundefs acc) + elseb + | While ({emeta= {type_= UReal; loc; _}; _}, body) -> + let acc = + acc + @ [ ( loc + , "Condition of type real is deprecated and will be disallowed \ + in Stan 2.34. Use an explicit != 0 comparison instead. This \ + can be automatically changed using the canonicalize flag for \ + stanc" ) ] in + collect_deprecated_stmt fundefs acc body + | _ -> + fold_statement collect_deprecated_expr + (collect_deprecated_stmt fundefs) + collect_deprecated_lval + (fun l _ -> l) + acc stmt + + let collect_userdef_distributions program = + program.functionblock |> Ast.get_stmts + |> List.filter_map ~f:find_udf_log_suffix + |> List.dedup_and_sort ~compare:(fun (x, _) (y, _) -> String.compare x y) + |> String.Map.of_alist_exn + + let collect_warnings (program : typed_program) = + let fundefs = userdef_functions program in + fold_program (collect_deprecated_stmt fundefs) [] program +end diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index eeabc62b7c..93221564a5 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -5,19 +5,23 @@ open Core_kernel open Ast -val find_udf_log_suffix : - typed_statement -> (string * Middle.UnsizedType.t) option +module type DEPRECATION_ANALYZER = sig + val find_udf_log_suffix : + typed_statement -> (string * Middle.UnsizedType.t) option -val update_suffix : string -> Middle.UnsizedType.t -> string + val update_suffix : string -> Middle.UnsizedType.t -> string -val collect_userdef_distributions : - typed_program -> Middle.UnsizedType.t String.Map.t + val collect_userdef_distributions : + typed_program -> Middle.UnsizedType.t String.Map.t + + val without_suffix : string list -> string -> string + val is_deprecated_distribution : string -> bool + val rename_deprecated_distribution : string -> string + val rename_deprecated_function : string -> string + val userdef_distributions : untyped_statement block option -> string list + val collect_warnings : typed_program -> Warnings.t list +end -val without_suffix : string list -> string -> string -val is_deprecated_distribution : string -> bool -val deprecated_distributions : (string * string) String.Map.t -val deprecated_functions : (string * string) String.Map.t -val rename_deprecated : (string * string) String.Map.t -> string -> string -val userdef_distributions : untyped_statement block option -> string list -val collect_warnings : typed_program -> Warnings.t list val remove_unneeded_forward_decls : typed_program -> typed_program + +module Make (StdLibrary : Std_library_utils.Library) : DEPRECATION_ANALYZER diff --git a/src/frontend/Environment.ml b/src/frontend/Environment.ml index 5267d6f277..d1674cc433 100644 --- a/src/frontend/Environment.ml +++ b/src/frontend/Environment.ml @@ -27,9 +27,9 @@ type info = type t = info list String.Map.t -let stan_math_environment = +let make_from_library signatures : t = let functions = - Hashtbl.to_alist Stan_math_signatures.stan_math_signatures + Hashtbl.to_alist signatures |> List.map ~f:(fun (key, values) -> ( key , List.map values ~f:(fun (rt, args, mem) -> diff --git a/src/frontend/Environment.mli b/src/frontend/Environment.mli index 71e07c4a0c..752331d796 100644 --- a/src/frontend/Environment.mli +++ b/src/frontend/Environment.mli @@ -26,8 +26,16 @@ type info = type t -val stan_math_environment : t -(** A type environment which contains the Stan math library functions +val make_from_library : + ( string + , ( UnsizedType.returntype + * (UnsizedType.autodifftype * UnsizedType.t) list + * Mem_pattern.t ) + list ) + Core_kernel.Hashtbl.t + -> t +(** Make a type environment from a hashtable of functions like those from + [Std_library_utils] *) val find : t -> string -> info list diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index f5bb82fedc..3f67761eb4 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -49,50 +49,6 @@ let rec get_function_calls_expr (funs, distrs) expr = | _ -> (funs, distrs) in fold_expression get_function_calls_expr (fun acc _ -> acc) acc expr.expr -let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = - let acc = - match stmt.stmt with - | NRFunApp (StanLib _, f, _) -> (Set.add funs f.name, distrs) - | Print _ -> (Set.add funs "print", distrs) - | Reject _ -> (Set.add funs "reject", distrs) - | Tilde {distribution; _} -> - let possible_names = - List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices - |> String.Set.of_list in - if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists then - (funs, distrs) - else - let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name - in - let name = distribution.name ^ Utils.unnormalized_suffix suffix in - (funs, Set.add distrs name) - | _ -> (funs, distrs) in - fold_statement get_function_calls_expr - (get_function_calls_stmt ud_dists) - (fun acc _ -> acc) - (fun acc _ -> acc) - acc stmt.stmt - -let function_calls_json p = - let map f list_op = - Option.value_map ~default:[] - ~f:(fun {stmts; _} -> List.concat_map ~f stmts) - list_op in - let grab_fundef_names_and_types = function - | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> - [(funname.name, type_)] - | _ -> [] in - let ud_dists = map grab_fundef_names_and_types p.functionblock in - let funs, distrs = - fold_program - (get_function_calls_stmt ud_dists) - (String.Set.empty, String.Set.empty) - p in - let set_to_List s = - `List (Set.to_list s |> List.map ~f:(fun str -> `String str)) in - `Assoc [("functions", set_to_List funs); ("distributions", set_to_List distrs)] - let includes_json () = `Assoc [ ( "included_files" @@ -100,12 +56,64 @@ let includes_json () = ( List.rev !Preprocessor.included_files |> List.map ~f:(fun str -> `String str) ) ) ] -let info_json ast = - List.fold ~f:Util.combine ~init:(`Assoc []) - [ block_info_json "inputs" ast.datablock - ; block_info_json "parameters" ast.parametersblock - ; block_info_json "transformed parameters" ast.transformedparametersblock - ; block_info_json "generated quantities" ast.generatedquantitiesblock - ; function_calls_json ast; includes_json () ] +module type INFO = sig + val info : Ast.typed_program -> string +end + +module Make (StdLibrary : Std_library_utils.Library) : INFO = struct + let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = + let acc = + match stmt.stmt with + | NRFunApp (StanLib _, f, _) -> (Set.add funs f.name, distrs) + | Print _ -> (Set.add funs "print", distrs) + | Reject _ -> (Set.add funs "reject", distrs) + | Tilde {distribution; _} -> + let possible_names = + List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices + |> String.Set.of_list in + if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists + then (funs, distrs) + else + let suffix = + Std_library_utils.dist_name_suffix + (module StdLibrary) + ud_dists distribution.name in + let name = distribution.name ^ Utils.unnormalized_suffix suffix in + (funs, Set.add distrs name) + | _ -> (funs, distrs) in + fold_statement get_function_calls_expr + (get_function_calls_stmt ud_dists) + (fun acc _ -> acc) + (fun acc _ -> acc) + acc stmt.stmt + + let function_calls_json p = + let map f list_op = + Option.value_map ~default:[] + ~f:(fun {stmts; _} -> List.concat_map ~f stmts) + list_op in + let grab_fundef_names_and_types = function + | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> + [(funname.name, type_)] + | _ -> [] in + let ud_dists = map grab_fundef_names_and_types p.functionblock in + let funs, distrs = + fold_program + (get_function_calls_stmt ud_dists) + (String.Set.empty, String.Set.empty) + p in + let set_to_List s = + `List (Set.to_list s |> List.map ~f:(fun str -> `String str)) in + `Assoc + [("functions", set_to_List funs); ("distributions", set_to_List distrs)] + + let info_json ast = + List.fold ~f:Util.combine ~init:(`Assoc []) + [ block_info_json "inputs" ast.datablock + ; block_info_json "parameters" ast.parametersblock + ; block_info_json "transformed parameters" ast.transformedparametersblock + ; block_info_json "generated quantities" ast.generatedquantitiesblock + ; function_calls_json ast; includes_json () ] -let info ast = pretty_to_string (info_json ast) + let info ast = pretty_to_string (info_json ast) +end diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index a21aa8f717..1d07aab8fd 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -10,10 +10,14 @@ - [type]: the base type of the variable (["int"] or ["real"]). - [dimensions]: the number of dimensions ([0] for a scalar, [1] for a vector or row vector, etc.). - + The JSON object also have the fields [stanlib_calls] and [distributions] containing the name of the standard library functions called and distributions used. *) -val info : Ast.typed_program -> string +module type INFO = sig + val info : Ast.typed_program -> string +end + +module Make (StdLibrary : Std_library_utils.Library) : INFO diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index 6cc1df3b73..4a4ed12c9c 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -14,14 +14,13 @@ module TypeError = struct | IntIntArrayOrRangeExpected of UnsizedType.t | IntOrRealContainerExpected of UnsizedType.t | ArrayVectorRowVectorMatrixExpected of UnsizedType.t - | IllTypedAssignment of Operator.t * UnsizedType.t * UnsizedType.t + | IllTypedAssignment of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list | IllTypedTernaryIf of UnsizedType.t * UnsizedType.t * UnsizedType.t - | IllTypedReduceSum of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedVariadic of + | IllTypedVariadicFn of string * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list @@ -45,9 +44,15 @@ module TypeError = struct string * UnsizedType.t list * (SignatureMismatch.signature_error list * bool) - | IllTypedBinaryOperator of Operator.t * UnsizedType.t * UnsizedType.t - | IllTypedPrefixOperator of Operator.t * UnsizedType.t - | IllTypedPostfixOperator of Operator.t * UnsizedType.t + | IllTypedBinaryOperator of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list + | IllTypedPrefixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list + | IllTypedPostfixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list | NotIndexable of UnsizedType.t * int let pp ppf = function @@ -96,18 +101,18 @@ module TypeError = struct "Foreach-loop must be over array, vector, row_vector or matrix. \ Instead found expression of type %a." UnsizedType.pp ut - | IllTypedAssignment (Operator.Equals, lt, rt) -> + | IllTypedAssignment (Operator.Equals, lt, rt, _) -> Fmt.pf ppf "Ill-typed arguments supplied to assignment operator =: lhs has type \ %a and rhs has type %a" UnsizedType.pp lt UnsizedType.pp rt - | IllTypedAssignment (op, lt, rt) -> + | IllTypedAssignment (op, lt, rt, sigs) -> Fmt.pf ppf "@[Ill-typed arguments supplied to assignment operator %a=: lhs \ has type %a and rhs has type %a.@ Available signatures for given \ lhs:@]@ %a" Operator.pp op UnsizedType.pp lt UnsizedType.pp rt - SignatureMismatch.pp_math_lib_assignmentoperator_sigs (lt, op) + SignatureMismatch.pp_assignmentoperator_sigs (lt, sigs) | IllTypedTernaryIf (UInt, ut, _) when UnsizedType.is_fun_type ut -> Fmt.pf ppf "Ternary expression cannot have a function type: %a" UnsizedType.pp ut @@ -120,10 +125,7 @@ module TypeError = struct Fmt.pf ppf "Condition in ternary expression must be primitive int; found type=%a" UnsizedType.pp ut1 - | IllTypedReduceSum (name, arg_tys, expected_args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedVariadic (name, arg_tys, args, error, return_type) -> + | IllTypedVariadicFn (name, arg_tys, args, error, return_type) -> SignatureMismatch.pp_signature_mismatch ppf ( name , arg_tys @@ -216,32 +218,25 @@ module TypeError = struct prefix suffix prefix prefix newsuffix | IllTypedFunctionApp (name, arg_tys, errors) -> SignatureMismatch.pp_signature_mismatch ppf (name, arg_tys, errors) - | IllTypedBinaryOperator (op, lt, rt) -> + | IllTypedBinaryOperator (op, lt, rt, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to infix operator %a. Available \ - signatures: %s@[Instead supplied arguments of incompatible type: \ - %a, %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp lt UnsizedType.pp rt - | IllTypedPrefixOperator (op, ut) -> + signatures: @[%a@.@]@[Instead supplied arguments of \ + incompatible type: %a, %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp lt + UnsizedType.pp rt + | IllTypedPrefixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to prefix operator %a. Available \ - signatures: %s@[Instead supplied argument of incompatible type: \ - %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut - | IllTypedPostfixOperator (op, ut) -> + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut + | IllTypedPostfixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to postfix operator %a. Available \ - signatures: %s\n\ - Instead supplied argument of incompatible type: %a." Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut end module IdentifierError = struct @@ -519,8 +514,8 @@ let int_or_real_container_expected loc ut = let array_vector_rowvector_matrix_expected loc ut = TypeError (loc, TypeError.ArrayVectorRowVectorMatrixExpected ut) -let illtyped_assignment loc assignop lt rt = - TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt)) +let illtyped_assignment loc assignop lt rt sigs = + TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt, sigs)) let illtyped_ternary_if loc predt lt rt = TypeError (loc, TypeError.IllTypedTernaryIf (predt, lt, rt)) @@ -528,11 +523,9 @@ let illtyped_ternary_if loc predt lt rt = let returning_fn_expected_nonreturning_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonReturningFound name) -let illtyped_reduce_sum loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedReduceSum (name, arg_tys, args, error)) - -let illtyped_variadic loc name arg_tys args fn_rt error = - TypeError (loc, TypeError.IllTypedVariadic (name, arg_tys, args, error, fn_rt)) +let illtyped_variadic_fn loc name arg_tys args error return_type = + TypeError + (loc, TypeError.IllTypedVariadicFn (name, arg_tys, args, error, return_type)) let ambiguous_function_promotion loc name arg_tys signatures = TypeError @@ -566,14 +559,14 @@ let nonreturning_fn_expected_undeclaredident_found loc name sug = let illtyped_fn_app loc name errors arg_tys = TypeError (loc, TypeError.IllTypedFunctionApp (name, arg_tys, errors)) -let illtyped_binary_op loc op lt rt = - TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt)) +let illtyped_binary_op loc op lt rt sigs = + TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt, sigs)) -let illtyped_prefix_op loc op ut = - TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut)) +let illtyped_prefix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut, sigs)) -let illtyped_postfix_op loc op ut = - TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut)) +let illtyped_postfix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut, sigs)) let not_indexable loc ut nidcs = TypeError (loc, TypeError.NotIndexable (ut, nidcs)) diff --git a/src/frontend/Semantic_error.mli b/src/frontend/Semantic_error.mli index 2ca63b0814..5121aebef6 100644 --- a/src/frontend/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -25,7 +25,12 @@ val array_vector_rowvector_matrix_expected : Location_span.t -> UnsizedType.t -> t val illtyped_assignment : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t val illtyped_ternary_if : Location_span.t -> UnsizedType.t -> UnsizedType.t -> UnsizedType.t -> t @@ -42,14 +47,6 @@ val returning_fn_expected_undeclared_dist_suffix_found : val returning_fn_expected_wrong_dist_suffix_found : Location_span.t -> string * string -> t -val illtyped_reduce_sum : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - val ambiguous_function_promotion : Location_span.t -> string @@ -58,13 +55,13 @@ val ambiguous_function_promotion : list -> t -val illtyped_variadic : +val illtyped_variadic_fn : Location_span.t -> string -> UnsizedType.t list -> (UnsizedType.autodifftype * UnsizedType.t) list - -> UnsizedType.t -> SignatureMismatch.function_mismatch + -> UnsizedType.t -> t val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t @@ -81,10 +78,27 @@ val illtyped_fn_app : -> t val illtyped_binary_op : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_prefix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_postfix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t -val illtyped_prefix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t -val illtyped_postfix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t val not_indexable : Location_span.t -> UnsizedType.t -> int -> t val ident_is_keyword : Location_span.t -> string -> t val ident_is_model_name : Location_span.t -> string -> t diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index f351e50d5e..3628f7923f 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -243,6 +243,20 @@ let find_compatible_rt function_types args = let errors, omitted = List.split_n errors max_n_errors in SignatureErrors (errors, not (List.is_empty omitted)) +let find_matching_first_order_fn tenv matches (fname : Ast.identifier) = + let candidates = + Utils.stdlib_distribution_name fname.name + |> Environment.find tenv |> List.map ~f:matches in + let ok, errs = List.partition_map candidates ~f:Result.to_either in + match unique_minimum_promotion ok with + | Ok a -> UniqueMatch a + | Error (Some promotions) -> + List.filter_map promotions ~f:(function + | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) + | _ -> None ) + |> AmbiguousMatch + | Error None -> SignatureErrors (List.hd_exn errs) + let matching_function env name args = let name = Utils.stdlib_distribution_name name in let function_types = @@ -252,9 +266,6 @@ let matching_function env name args = UnsizedType.compare_returntype ret1 ret2 ) in find_compatible_rt function_types args -let matching_stanlib_function = - matching_function Environment.stan_math_environment - let check_variadic_args ~allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys fun_return args = let minimal_func_type = @@ -390,10 +401,8 @@ let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) = (list ~sep:cut pp_signature) sigs pp_omitted () -let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = +let pp_assignmentoperator_sigs ppf (lt, errors) = let signatures = - let errors = - Stan_math_signatures.make_assignmentoperator_stan_math_signatures op in let errors = List.filter ~f:(fun (_, args, _) -> @@ -405,7 +414,7 @@ let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = | errors, _ -> Some (errors, true) in let pp_sigs ppf (signatures, omitted) = Fmt.pf ppf "@[%a%a@]" - (Fmt.list ~sep:Fmt.cut Stan_math_signatures.pp_math_sig) + (Fmt.list ~sep:Fmt.cut Std_library_utils.pp_math_sig) signatures (if omitted then Fmt.pf else Fmt.nop) "@ (Additional signatures omitted)" in diff --git a/src/frontend/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli index 312e0a3818..d8617c89b1 100644 --- a/src/frontend/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -54,12 +54,6 @@ val matching_function : Requires a unique minimum option under type promotion *) -val matching_stanlib_function : - string -> (UnsizedType.autodifftype * UnsizedType.t) list -> match_result -(** Same as [matching_function] but requires specifically that the function - be from StanMath (uses [Environment.stan_math_environment]) -*) - val check_variadic_args : allow_lpdf:bool -> (UnsizedType.autodifftype * UnsizedType.t) list @@ -74,6 +68,15 @@ val check_variadic_args : If none is found, returns [Error] of the list of args and a function_mismatch. *) +val find_matching_first_order_fn : + Environment.t + -> (Environment.info -> (UnsizedType.t * Promotion.t list, 'a) result) + -> Ast.identifier + -> (UnsizedType.t * Promotion.t list, 'a) generic_match_result +(** Given a constraint function [matches], find any signature which exists + Returns the first [Ok] if any exist, or else [Error] +*) + val pp_signature_mismatch : Format.formatter -> string @@ -85,8 +88,8 @@ val pp_signature_mismatch : * bool ) -> unit -val pp_math_lib_assignmentoperator_sigs : - Format.formatter -> UnsizedType.t * Operator.t -> unit +val pp_assignmentoperator_sigs : + Format.formatter -> UnsizedType.t * Std_library_utils.signature list -> unit val compare_errors : function_mismatch -> function_mismatch -> int val compare_match_results : match_result -> match_result -> int diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml new file mode 100644 index 0000000000..dd6d7ea0a1 --- /dev/null +++ b/src/frontend/Std_library_utils.ml @@ -0,0 +1,124 @@ +(** General functions and signatures for a Standard Library *) + +open Middle +open Core_kernel + +(* Types for the module representing the standard library *) +type fun_arg = UnsizedType.autodifftype * UnsizedType.t +type signature = UnsizedType.returntype * fun_arg list * Mem_pattern.t + +type variadic_signature = + { return_type: UnsizedType.t + ; control_args: fun_arg list + ; required_fn_rt: UnsizedType.t + ; required_fn_args: fun_arg list } +[@@deriving create] + +type deprecation_info = + { replacement: string + ; version: string + ; extra_message: string + ; canonicalize_away: bool } +[@@deriving sexp] + +(* We could consider breaking up this module more, so we would have + more type-level guarantees about what each Functor is able to do + with the library. Most of them only need is_stdlib_function_name, + maybe get_signatures. + + The Stan_math_library could still satisfy all of them by + using [include] +*) + +module type Library = sig + (** This module is used as a parameter for many functors which + rely on information about a backend-specific Stan library. *) + + val function_signatures : (string, signature list) Hashtbl.t + (** Mapping from names to signature(s) of functions *) + + val variadic_signatures : (string, variadic_signature) Hashtbl.t + (** Mapping from names to description of a variadic function. + Note that these function names cannot be overloaded, and usually require + customized code-gen in the backend. +*) + + val distribution_families : string list + + val is_stdlib_function_name : string -> bool + (** Equivalent to [Hashtbl.mem function_signatures s]*) + + val get_signatures : string -> signature list + (** Equivalent to [Hashtbl.find_multi function_signatures s]*) + + val get_operator_signatures : Operator.t -> signature list + val get_assignment_operator_signatures : Operator.t -> signature list + val is_not_overloadable : string -> bool + val is_variadic_function_name : string -> bool + val is_special_function_name : string -> bool + val special_function_returntype : string -> UnsizedType.returntype option + + val check_special_fn : + is_cond_dist:bool + -> Location_span.t + -> Environment.originblock + -> Environment.t + -> Ast.identifier + -> Ast.typed_expression list + -> Ast.typed_expression + (** This function is responsible for typechecking varadic function + calls. It needs to live in the Library since this is usually + bespoke per-function. *) + + val operator_to_function_names : Operator.t -> string list + val string_operator_to_function_name : string -> string + val deprecated_distributions : deprecation_info String.Map.t + val deprecated_functions : deprecation_info String.Map.t +end + +(** A "standard library" for Stan which contains no functions. + Useful only for testing + *) +module NullLibrary : Library = struct + let function_signatures : (string, signature list) Hashtbl.t = + String.Table.create () + + let variadic_signatures : (string, variadic_signature) Hashtbl.t = + String.Table.create () + + let distribution_families : string list = [] + let is_stdlib_function_name _ = false + let get_signatures _ = [] + let get_assignment_operator_signatures _ = [] + let get_operator_signatures _ = [] + let is_not_overloadable _ = false + let is_variadic_function_name _ = false + let is_special_function_name _ = false + let special_function_returntype _ = None + + let check_special_fn ~is_cond_dist _ _ _ _ _ : Ast.typed_expression = + ignore (is_cond_dist : bool) ; + Common.FatalError.fatal_error_msg [%message "Impossible"] + + let operator_to_function_names _ = [] + let string_operator_to_function_name s = s + let deprecated_distributions = String.Map.empty + let deprecated_functions = String.Map.empty +end + +let pp_math_sig ppf ((rt, args, mem_pattern) : signature) = + UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) + +let pp_math_sigs ppf (sigs : signature list) = + (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs + +let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs + +let dist_name_suffix (module StdLibrary : Library) udf_names name = + let is_udf_name s = + List.exists ~f:(fun (n, _) -> String.equal s n) udf_names in + Utils.distribution_suffices + |> List.filter ~f:(fun sfx -> + StdLibrary.is_stdlib_function_name (name ^ sfx) + || is_udf_name (name ^ sfx) ) + |> List.hd_exn diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml deleted file mode 100644 index 7530428eb2..0000000000 --- a/src/frontend/Typechecker.ml +++ /dev/null @@ -1,1754 +0,0 @@ -(** a type/semantic checker for Stan ASTs - - Functions which begin with "check_" return a typed version of their input - Functions which begin with "verify_" return unit if a check succeeds, or else - throw an Errors.SemanticError exception. - Other functions which begin with "infer"/"calculate" vary. Usually they return - a value, but a few do have error conditions. - - All Error.SemanticError excpetions are caught by check_program - which turns the ast or exception into a Result.t for external usage - - A type environment (Env.t) is used to hold variables and functions, including - stan math functions. This is a functional map, meaning it is handled immutably. -*) - -open Core_kernel -open Core_kernel.Poly -open Middle -open Ast -module Env = Environment - -(* we only allow errors raised by this function *) -let error e = raise (Errors.SemanticError e) - -(* warnings are built up in a list *) -let warnings : Warnings.t list ref = ref [] - -let add_warning (span : Location_span.t) (message : string) = - warnings := (span, message) :: !warnings - -let attach_warnings x = (x, List.rev !warnings) - -(* model name - don't love this here *) -let model_name = ref "" -let check_that_all_functions_have_definition = ref true - -(* Record structure holding flags and other markers about context to be - used for error reporting. *) -type context_flags_record = - { current_block: Env.originblock - ; in_toplevel_decl: bool - ; in_fun_def: bool - ; in_returning_fun_def: bool - ; in_rng_fun_def: bool - ; in_lp_fun_def: bool - ; in_udf_dist_def: bool - ; loop_depth: int } - -let context block = - { current_block= block - ; in_toplevel_decl= false - ; in_fun_def= false - ; in_returning_fun_def= false - ; in_rng_fun_def= false - ; in_lp_fun_def= false - ; in_udf_dist_def= false - ; loop_depth= 0 } - -let calculate_autodifftype cf origin ut = - match origin with - | Env.(Param | TParam | Model | Functions) - when not (UnsizedType.is_int_type ut || cf.current_block = GQuant) -> - UnsizedType.AutoDiffable - | _ -> DataOnly - -let arg_type x = (x.emeta.ad_level, x.emeta.type_) -let get_arg_types = List.map ~f:arg_type -let type_of_expr_typed ue = ue.emeta.type_ -let has_int_type ue = ue.emeta.type_ = UInt -let has_int_array_type ue = ue.emeta.type_ = UArray UInt - -let has_int_or_real_type ue = - match ue.emeta.type_ with UInt | UReal -> true | _ -> false - -(* -- General checks ---------------------------------------------- *) -let reserved_keywords = - [ "for"; "in"; "while"; "repeat"; "until"; "if"; "then"; "else"; "true" - ; "false"; "target"; "int"; "real"; "complex"; "void"; "vector"; "simplex" - ; "unit_vector"; "ordered"; "positive_ordered"; "row_vector"; "matrix" - ; "cholesky_factor_corr"; "cholesky_factor_cov"; "corr_matrix"; "cov_matrix" - ; "functions"; "model"; "data"; "parameters"; "quantities"; "transformed" - ; "generated"; "profile"; "return"; "break"; "continue"; "increment_log_prob" - ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" - ; "static"; "auto" ] - -let verify_identifier id : unit = - if id.name = !model_name then - Semantic_error.ident_is_model_name id.id_loc id.name |> error - else if - String.is_suffix id.name ~suffix:"__" - || List.mem reserved_keywords id.name ~equal:String.equal - then Semantic_error.ident_is_keyword id.id_loc id.name |> error - -let distribution_name_variants name = - if name = "multiply_log" || name = "binomial_coefficient_log" then [name] - else - (* this will have some duplicates, but preserves order better *) - match Utils.split_distribution_suffix name with - | Some (stem, "lpmf") | Some (stem, "lpdf") | Some (stem, "log") -> - [name; stem ^ "_lpmf"; stem ^ "_lpdf"; stem ^ "_log"] - | Some (stem, "lcdf") | Some (stem, "cdf_log") -> - [name; stem ^ "_lcdf"; stem ^ "_cdf_log"] - | Some (stem, "lccdf") | Some (stem, "ccdf_log") -> - [name; stem ^ "_lccdf"; stem ^ "_ccdf_log"] - | _ -> [name] - -(** verify that the variable being declared is previous unused. - allowed to shadow StanLib *) -let verify_name_fresh_var loc tenv name = - if Utils.is_unnormalized_distribution name then - Semantic_error.ident_has_unnormalized_suffix loc name |> error - else if - List.exists (Env.find tenv name) ~f:(function - | {kind= `Variable _; _} -> true - | _ -> false (* user variables can shadow function names *) ) - then Semantic_error.ident_in_use loc name |> error - -(** verify that the variable being declared is previous unused. *) -let verify_name_fresh_udf loc tenv name = - if - (* variadic functions are currently not in math sigs and aren't - overloadable due to their separate typechecking *) - Stan_math_signatures.is_reduce_sum_fn name - || Stan_math_signatures.is_stan_math_variadic_function_name name - then Semantic_error.ident_is_stanmath_name loc name |> error - else if Utils.is_unnormalized_distribution name then - Semantic_error.udf_is_unnormalized_fn loc name |> error - else if - (* if a variable is already defined with this name - - not really possible as all functions are defined before data, - but future-proofing is good *) - List.exists - ~f:(function {kind= `Variable _; _} -> true | _ -> false) - (Env.find tenv name) - then Semantic_error.ident_in_use loc name |> error - -(** Checks that a variable/function name: - - a function/identifier does not have the _lupdf/_lupmf suffix - - is not already in use (for now) -*) -let verify_name_fresh tenv id ~is_udf = - let f = - if is_udf then verify_name_fresh_udf id.id_loc tenv - else verify_name_fresh_var id.id_loc tenv in - List.iter ~f (distribution_name_variants id.name) - -let is_of_compatible_return_type rt1 srt2 = - UnsizedType.( - match (rt1, srt2) with - | Void, NoReturnType - |Void, Incomplete Void - |Void, Complete Void - |Void, AnyReturnType -> - true - | ReturnType UReal, Complete (ReturnType UInt) -> true - | ReturnType UComplex, Complete (ReturnType UReal) -> true - | ReturnType UComplex, Complete (ReturnType UInt) -> true - | ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2 - | ReturnType _, AnyReturnType -> true - | _ -> false) - -(* -- Expressions ------------------------------------------------- *) -let check_ternary_if loc pe te fe = - let promote expr type_ ad_level = - if - (not (UnsizedType.equal expr.emeta.type_ type_)) - || UnsizedType.compare_autodifftype expr.emeta.ad_level ad_level <> 0 - then - { expr= Promotion (expr, UnsizedType.internal_scalar type_, ad_level) - ; emeta= {expr.emeta with type_; ad_level} } - else expr in - match - ( pe.emeta.type_ - , UnsizedType.common_type (te.emeta.type_, fe.emeta.type_) - , expr_ad_lub [pe; te; fe] ) - with - | UInt, Some type_, ad_level when not (UnsizedType.is_fun_type type_) -> - mk_typed_expression - ~expr: - (TernaryIf (pe, promote te type_ ad_level, promote fe type_ ad_level)) - ~ad_level ~type_ ~loc - | _, _, _ -> - Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_ - fe.emeta.type_ - |> error - -let match_to_rt_option = function - | SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt - | _ -> None - -let stan_math_return_type name arg_tys = - match - Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures name - with - | Some {return_type; _} -> Some (UnsizedType.ReturnType return_type) - | None when Stan_math_signatures.is_reduce_sum_fn name -> - Some (UnsizedType.ReturnType UReal) - | None -> - SignatureMismatch.matching_stanlib_function name arg_tys - |> match_to_rt_option - -let operator_stan_math_return_type op arg_tys = - match (op, arg_tys) with - | Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] -> - Some (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion]) - | IntDivide, _ -> None - | _ -> - Stan_math_signatures.operator_to_stan_math_fns op - |> List.filter_map ~f:(fun name -> - SignatureMismatch.matching_stanlib_function name arg_tys - |> function - | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) - | _ -> None ) - |> List.hd - -let assignmentoperator_stan_math_return_type assop arg_tys = - ( match assop with - | Operator.Divide -> - SignatureMismatch.matching_stanlib_function "divide" arg_tys - |> match_to_rt_option - | Plus | Minus | Times | EltTimes | EltDivide -> - operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst - | _ -> None ) - |> Option.bind ~f:(function - | ReturnType rtype - when rtype = snd (List.hd_exn arg_tys) - && not - ( (assop = Operator.EltTimes || assop = Operator.EltDivide) - && UnsizedType.is_scalar_type rtype ) -> - Some UnsizedType.Void - | _ -> None ) - -let check_binop loc op le re = - let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in - match rt with - | Some (ReturnType type_, [p1; p2]) -> - mk_typed_expression - ~expr:(BinOp (Promotion.promote le p1, op, Promotion.promote re p2)) - ~ad_level:(expr_ad_lub [le; re]) - ~type_ ~loc - | _ -> - Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ - |> error - -let check_prefixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in - match rt with - | Some (ReturnType type_, _) -> - mk_typed_expression - ~expr:(PrefixOp (op, te)) - ~ad_level:(expr_ad_lub [te]) - ~type_ ~loc - | _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error - -let check_postfixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in - match rt with - | Some (ReturnType type_, _) -> - mk_typed_expression - ~expr:(PostfixOp (te, op)) - ~ad_level:(expr_ad_lub [te]) - ~type_ ~loc - | _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error - -let check_id cf loc tenv id = - match Env.find tenv (Utils.stdlib_distribution_name id.name) with - | [] -> - Semantic_error.ident_not_in_scope loc id.name - (Env.nearest_ident tenv id.name) - |> error - | {kind= `StanMath; _} :: _ -> - ( calculate_autodifftype cf MathLibrary UMathLibraryFunction - , UnsizedType.UMathLibraryFunction ) - | {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _ - when cf.in_toplevel_decl -> - Semantic_error.non_data_variable_size_decl loc |> error - | _ :: _ - when Utils.is_unnormalized_distribution id.name - && not - ( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def)) - || cf.current_block = Model ) -> - Semantic_error.invalid_unnormalized_fn loc |> error - | {kind= `Variable {origin; _}; type_} :: _ -> - (calculate_autodifftype cf origin type_, type_) - | { kind= `UserDefined | `UserDeclared _ - ; type_= UFun (args, rt, FnLpdf _, mem_pattern) } - :: _ -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - (calculate_autodifftype cf Functions type_, type_) - | {kind= `UserDefined | `UserDeclared _; type_} :: _ -> - (calculate_autodifftype cf Functions type_, type_) - -let check_variable cf loc tenv id = - let ad_level, type_ = check_id cf loc tenv id in - mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc - -let get_consistent_types type_ es = - let ad = - UnsizedType.lub_ad_type (List.map ~f:(fun e -> e.emeta.ad_level) es) in - let f state e = - match state with - | Error e -> Error e - | Ok ty -> ( - match UnsizedType.common_type (ty, e.emeta.type_) with - | Some ty -> Ok ty - | None -> Error (ty, e.emeta) ) in - List.fold ~init:(Ok type_) ~f es - |> Result.map ~f:(fun ty -> - let promotions = - List.map (get_arg_types es) - ~f:(Promotion.get_type_promotion_exn (ad, ty)) in - (ad, ty, promotions) ) - -let check_array_expr loc es = - match es with - | [] -> - (* NB: This is actually disallowed by parser *) - Semantic_error.empty_array loc |> error - | {emeta= {type_; _}; _} :: _ -> ( - match get_consistent_types type_ es with - | Error (ty, meta) -> - Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error - | Ok (ad_level, type_, promotions) -> - let type_ = UnsizedType.UArray type_ in - mk_typed_expression - ~expr:(ArrayExpr (Promotion.promote_list es promotions)) - ~ad_level ~type_ ~loc ) - -let check_rowvector loc es = - match es with - | {emeta= {type_= UnsizedType.URowVector; _}; _} :: _ -> ( - match get_consistent_types URowVector es with - | Ok (ad_level, typ, promotions) -> - mk_typed_expression - ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) - ~ad_level - ~type_:(if typ = UComplexRowVector then UComplexMatrix else UMatrix) - ~loc - | Error (_, meta) -> - Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) - | {emeta= {type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> ( - match get_consistent_types UComplexRowVector es with - | Ok (ad_level, _, promotions) -> - mk_typed_expression - ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) - ~ad_level ~type_:UComplexMatrix ~loc - | Error (_, meta) -> - Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) - | _ -> ( - match get_consistent_types UReal es with - | Ok (ad_level, typ, promotions) -> - mk_typed_expression - ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) - ~ad_level - ~type_:(if typ = UComplex then UComplexRowVector else URowVector) - ~loc - | Error (_, meta) -> - Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error ) - -(* index checking *) - -let indexing_type idx = - match idx with - | Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single - | _ -> `Multi - -let is_multiindex i = - match indexing_type i with `Single -> false | `Multi -> true - -let inferred_unsizedtype_of_indexed ~loc ut indices = - let rec aux type_ idcs = - let vec, rowvec, scalar = - if UnsizedType.is_complex_type type_ then - UnsizedType.(UComplexVector, UComplexRowVector, UComplex) - else (UVector, URowVector, UReal) in - match (type_, idcs) with - | _, [] -> type_ - | UnsizedType.UArray type_, `Single :: tl -> aux type_ tl - | UArray type_, `Multi :: tl -> aux type_ tl |> UnsizedType.UArray - | (UVector | URowVector | UComplexRowVector | UComplexVector), [`Single] - |(UMatrix | UComplexMatrix), [`Single; `Single] -> - scalar - | ( ( UVector | URowVector | UMatrix | UComplexVector | UComplexMatrix - | UComplexRowVector ) - , [`Multi] ) - |(UMatrix | UComplexMatrix), [`Multi; `Multi] -> - type_ - | (UMatrix | UComplexMatrix), ([`Single] | [`Single; `Multi]) -> rowvec - | (UMatrix | UComplexMatrix), [`Multi; `Single] -> vec - | (UMatrix | UComplexMatrix), _ :: _ :: _ :: _ - |(UVector | URowVector | UComplexRowVector | UComplexVector), _ :: _ :: _ - |(UInt | UReal | UComplex | UFun _ | UMathLibraryFunction), _ :: _ -> - Semantic_error.not_indexable loc ut (List.length indices) |> error in - aux ut (List.map ~f:indexing_type indices) - -let inferred_ad_type_of_indexed at uindices = - UnsizedType.lub_ad_type - ( at - :: List.map - ~f:(function - | All -> UnsizedType.DataOnly - | Single ue1 | Upfrom ue1 | Downfrom ue1 -> - UnsizedType.lub_ad_type [at; ue1.emeta.ad_level] - | Between (ue1, ue2) -> - UnsizedType.lub_ad_type - [at; ue1.emeta.ad_level; ue2.emeta.ad_level] ) - uindices ) - -(* function checking *) - -let verify_conddist_name loc id = - if - List.exists - ~f:(fun x -> String.is_suffix id.name ~suffix:x) - Utils.conditioning_suffices - then () - else Semantic_error.conditional_notation_not_allowed loc |> error - -let verify_fn_conditioning loc id = - if - List.exists - ~f:(fun suffix -> String.is_suffix id.name ~suffix) - Utils.conditioning_suffices - && not (String.is_suffix id.name ~suffix:"_cdf") - then Semantic_error.conditioning_required loc |> error - -(** `Target+=` can only be used in model and functions - with right suffix (same for tilde etc) -*) -let verify_fn_target_plus_equals cf loc id = - if - String.is_suffix id.name ~suffix:"_lp" - && not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then Semantic_error.target_plusequals_outside_model_or_logprob loc |> error - -(** Rng functions cannot be used in Tp or Model and only - in function defs with the right suffix -*) -let verify_fn_rng cf loc id = - if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then - Semantic_error.invalid_decl_rng_fn loc |> error - else if - String.is_suffix id.name ~suffix:"_rng" - && ( (cf.in_fun_def && not cf.in_rng_fun_def) - || cf.current_block = TParam || cf.current_block = Model ) - then Semantic_error.invalid_rng_fn loc |> error - -(** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs - or the model block -*) -let verify_unnormalized cf loc id = - if - Utils.is_unnormalized_distribution id.name - && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) - then Semantic_error.invalid_unnormalized_fn loc |> error - -let mk_fun_app ~is_cond_dist ~loc kind name args ~type_ : Ast.typed_expression = - let fn = - if is_cond_dist then CondDistApp (kind, name, args) - else FunApp (kind, name, args) in - mk_typed_expression ~expr:fn ~loc ~type_ - ~ad_level: - ( if UnsizedType.is_int_type type_ then UnsizedType.DataOnly - else expr_ad_lub args ) - -let check_normal_fn ~is_cond_dist loc tenv id es = - match Env.find tenv (Utils.normalized_name id.name) with - | {kind= `Variable _; _} :: _ - (* variables can sometimes shadow stanlib functions, so we have to check this *) - when not - (Stan_math_signatures.is_stan_math_function_name - (Utils.normalized_name id.name) ) -> - Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error - | [] -> - ( match Utils.split_distribution_suffix id.name with - | Some (prefix, suffix) -> ( - let known_families = - List.map - ~f:(fun (_, y, _, _) -> y) - Stan_math_signatures.distributions in - let is_known_family s = - List.mem known_families s ~equal:String.equal in - match suffix with - | ("lpmf" | "lumpf") when Env.mem tenv (prefix ^ "_lpdf") -> - Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc - (prefix, suffix) - | ("lpdf" | "lumdf") when Env.mem tenv (prefix ^ "_lpmf") -> - Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc - (prefix, suffix) - | _ -> - if - is_known_family prefix - && List.mem ~equal:String.equal - Utils.cumulative_distribution_suffices_w_rng suffix - then - Semantic_error - .returning_fn_expected_undeclared_dist_suffix_found loc - (prefix, suffix) - else - Semantic_error.returning_fn_expected_undeclaredident_found loc - id.name - (Env.nearest_ident tenv id.name) ) - | None -> - Semantic_error.returning_fn_expected_undeclaredident_found loc id.name - (Env.nearest_ident tenv id.name) ) - |> error - | _ (* a function *) -> ( - (* NB: At present, [SignatureMismatch.matching_function] cannot handle overloaded function types. - This is not needed until UDFs can be higher-order, as it is special cased for - variadic functions - *) - match - SignatureMismatch.matching_function tenv id.name (get_arg_types es) - with - | UniqueMatch (Void, _, _) -> - Semantic_error.returning_fn_expected_nonreturning_found loc id.name - |> error - | UniqueMatch (ReturnType ut, fnk, promotions) -> - mk_fun_app ~is_cond_dist ~loc - (fnk (Fun_kind.suffix_from_name id.name)) - id - (Promotion.promote_list es promotions) - ~type_:ut - | AmbiguousMatch sigs -> - Semantic_error.ambiguous_function_promotion loc id.name - (Some (List.map ~f:type_of_expr_typed es)) - sigs - |> error - | SignatureErrors (l, b) -> - es - |> List.map ~f:(fun e -> e.emeta.type_) - |> Semantic_error.illtyped_fn_app loc id.name (l, b) - |> error ) - -(** Given a constraint function [matches], find any signature which exists - Returns the first [Ok] if any exist, or else [Error] -*) -let find_matching_first_order_fn tenv matches fname = - let candidates = - Utils.stdlib_distribution_name fname.name - |> Env.find tenv |> List.map ~f:matches in - let ok, errs = List.partition_map candidates ~f:Result.to_either in - match SignatureMismatch.unique_minimum_promotion ok with - | Ok a -> SignatureMismatch.UniqueMatch a - | Error (Some promotions) -> - List.filter_map promotions ~f:(function - | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) - | _ -> None ) - |> AmbiguousMatch - | Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs) - -let make_function_variable cf loc id = function - | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | UnsizedType.UFun _ as type_ -> - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | type_ -> - Common.FatalError.fatal_error_msg - [%message - "Attempting to create function variable out of " - (type_ : UnsizedType.t)] - -let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list) - = - if Stan_math_signatures.is_stan_math_variadic_function_name id.name then - check_variadic ~is_cond_dist loc cf tenv id tes - else if Stan_math_signatures.is_reduce_sum_fn id.name then - check_reduce_sum ~is_cond_dist loc cf tenv id tes - else check_normal_fn ~is_cond_dist loc tenv id tes - -(** Reduce sum is a special case, even compared to the other - variadic functions, because it is polymorphic in the type of the - first argument. The first, fourth, and fifth arguments must agree, - which is too complicated to be captured declaratively. *) -and check_reduce_sum ~is_cond_dist loc cf tenv id tes = - let basic_mismatch () = - let mandatory_args = - UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in - let mandatory_fun_args = - UnsizedType. - [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in - SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args - mandatory_fun_args UReal (get_arg_types tes) in - let matching remaining_es fn = - match fn with - | Env. - { type_= - UnsizedType.UFun - (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as - ftype - ; _ } - when List.mem Stan_math_signatures.reduce_sum_slice_types - sliced_arg_fun_type ~equal:( = ) -> - let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in - let mandatory_fun_args = - [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args - mandatory_fun_args UReal arg_types - | _ -> basic_mismatch () in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - (* a valid signature exists *) - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_fun_app ~is_cond_dist ~loc (StanLib FnPlain) id - (Promotion.promote_list tes promotions) - ~type_:UnsizedType.UReal - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_reduce_sum loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> - let expected_args, err = - basic_mismatch () |> Result.error |> Option.value_exn in - Semantic_error.illtyped_reduce_sum loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error - -and check_variadic ~is_cond_dist loc cf tenv id tes = - let Stan_math_signatures. - {control_args; required_fn_args; required_fn_rt; return_type} = - Hashtbl.find_exn Stan_math_signatures.stan_math_variadic_signatures id.name - in - let matching remaining_es Env.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args - required_fn_args required_fn_rt arg_types in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_fun_app ~is_cond_dist ~loc (StanLib FnPlain) id - (Promotion.promote_list tes promotions) - ~type_:return_type - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args required_fn_rt err - |> error ) - | _ -> - let expected_args, err = - SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args - required_fn_args required_fn_rt (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args required_fn_rt err - |> error - -and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) = - let name_check = - if is_cond_dist then verify_conddist_name else verify_fn_conditioning in - let res = check_fn ~is_cond_dist loc cf tenv id es in - verify_identifier id ; - name_check loc id ; - verify_fn_target_plus_equals cf loc id ; - verify_fn_rng cf loc id ; - verify_unnormalized cf loc id ; - res - -and check_indexed loc cf tenv e indices = - let tindices = List.map ~f:(check_index cf tenv) indices in - let te = check_expression cf tenv e in - let ad_level = inferred_ad_type_of_indexed te.emeta.ad_level tindices in - let type_ = inferred_unsizedtype_of_indexed ~loc te.emeta.type_ tindices in - mk_typed_expression ~expr:(Indexed (te, tindices)) ~ad_level ~type_ ~loc - -and check_index cf tenv = function - | All -> All - (* Check that indexes have int (container) type *) - | Single e -> - let te = check_expression cf tenv e in - if has_int_type te || has_int_array_type te then Single te - else - Semantic_error.int_intarray_or_range_expected te.emeta.loc - te.emeta.type_ - |> error - | Upfrom e -> check_expression_of_int_type cf tenv e "Range bound" |> Upfrom - | Downfrom e -> - check_expression_of_int_type cf tenv e "Range bound" |> Downfrom - | Between (e1, e2) -> - let le = check_expression_of_int_type cf tenv e1 "Range bound" in - let ue = check_expression_of_int_type cf tenv e2 "Range bound" in - Between (le, ue) - -and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : - Ast.typed_expression = - let loc = emeta.loc in - let ce = check_expression cf tenv in - match expr with - | TernaryIf (e1, e2, e3) -> - let pe = ce e1 in - let te = ce e2 in - let fe = ce e3 in - check_ternary_if loc pe te fe - | BinOp (e1, op, e2) -> - let le = ce e1 in - let re = ce e2 in - let binop_type_warnings x y = - match (x.emeta.type_, y.emeta.type_, op) with - | UInt, UInt, Divide -> - let hint ppf () = - match (x.expr, y.expr) with - | IntNumeral x, _ -> - Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_typed_expression y - | _, Ast.IntNumeral y -> - Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_typed_expression x y - | _ -> - Fmt.pf ppf "%a * 1.0 / %a" Pretty_printing.pp_typed_expression - x Pretty_printing.pp_typed_expression y in - let s = - Fmt.str - "@[@[Found int division:@]@ @[%a@]@,\ - @[%a@]@ @[%a@]@,\ - @[%a@]@]" - Pretty_printing.pp_expression {expr; emeta} Fmt.text - "Values will be rounded towards zero. If rounding is not \ - desired you can write the division as" - hint () Fmt.text - "If rounding is intended please use the integer division \ - operator %/%." in - add_warning x.emeta.loc s - | (UArray UMatrix | UMatrix), (UInt | UReal), Pow -> - let s = - Fmt.str - "@[@[Found matrix^scalar:@]@ @[%a@]@,\ - @[%a@]@ @[%a@]@]" Pretty_printing.pp_expression - {expr; emeta} Fmt.text - "matrix ^ number is interpreted as element-wise \ - exponentiation. If this is intended, you can silence this \ - warning by using elementwise operator .^" - Fmt.text - "If you intended matrix exponentiation, use the function \ - matrix_power(matrix,int) instead." in - add_warning x.emeta.loc s - | _ when Operator.is_cmp op -> ( - match le.expr with - | BinOp (e1, op2, e2) when Operator.is_cmp op2 -> - let pp_e = Pretty_printing.pp_typed_expression in - let pp = Operator.pp in - add_warning loc - (Fmt.str - "Found %a. This is interpreted as %a. Consider if the \ - intended meaning was %a instead.@ You can silence this \ - warning by adding explicit parenthesis. This can be \ - automatically changed using the canonicalize flag for \ - stanc" - (fun ppf () -> - Fmt.pf ppf "@[%a %a %a@]" pp_e le pp op2 pp_e re ) - () - (fun ppf () -> - Fmt.pf ppf "@[(%a) %a %a@]" pp_e le pp op2 pp_e re ) - () - (fun ppf () -> - Fmt.pf ppf "@[%a %a %a && %a %a %a@]" pp_e e1 pp op - pp_e e2 pp_e e2 pp op2 pp_e re ) - () ) - | _ -> () ) - | _ -> () in - binop_type_warnings le re ; check_binop loc op le re - | PrefixOp (op, e) -> ce e |> check_prefixop loc op - | PostfixOp (e, op) -> ce e |> check_postfixop loc op - | Variable id -> - verify_identifier id ; - check_variable cf loc tenv id - | IntNumeral s -> ( - match float_of_string_opt s with - | Some i when i < 2_147_483_648.0 -> - mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly ~type_:UInt - ~loc - | _ -> Semantic_error.bad_int_literal loc |> error ) - | RealNumeral s -> - mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal - ~loc - | ImagNumeral s -> - mk_typed_expression ~expr:(ImagNumeral s) ~ad_level:DataOnly - ~type_:UComplex ~loc - | GetLP -> - (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) - if - not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then - Semantic_error.target_plusequals_outside_model_or_logprob loc |> error - else - mk_typed_expression ~expr:GetLP - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) - ~type_:UReal ~loc - | GetTarget -> - (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) - if - not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then - Semantic_error.target_plusequals_outside_model_or_logprob loc |> error - else - mk_typed_expression ~expr:GetTarget - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) - ~type_:UReal ~loc - | ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc - | RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc - | Paren e -> - let te = ce e in - mk_typed_expression ~expr:(Paren te) ~ad_level:te.emeta.ad_level - ~type_:te.emeta.type_ ~loc - | Indexed (e, indices) -> check_indexed loc cf tenv e indices - | FunApp ((), id, es) -> - es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id - | CondDistApp ((), id, es) -> - es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id - | Promotion (e, _, _) -> - (* Should never happen: promotions are produced during typechecking *) - Common.FatalError.fatal_error_msg - [%message "Promotion in untyped AST" (e : Ast.untyped_expression)] - -and check_expression_of_int_type cf tenv e name = - let te = check_expression cf tenv e in - if has_int_type te then te - else Semantic_error.int_expected te.emeta.loc name te.emeta.type_ |> error - -let check_expression_of_int_or_real_type cf tenv e name = - let te = check_expression cf tenv e in - if has_int_or_real_type te then te - else - Semantic_error.int_or_real_expected te.emeta.loc name te.emeta.type_ - |> error - -let check_expression_of_scalar_or_type cf tenv t e name = - let te = check_expression cf tenv e in - if UnsizedType.is_scalar_type te.emeta.type_ || te.emeta.type_ = t then te - else - Semantic_error.scalar_or_type_expected te.emeta.loc name t te.emeta.type_ - |> error - -(* -- Statements ------------------------------------------------- *) -(* non returning functions *) -let verify_nrfn_target loc cf id = - if - String.is_suffix id.name ~suffix:"_lp" - && not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then Semantic_error.target_plusequals_outside_model_or_logprob loc |> error - -let check_nrfn loc tenv id es = - match Env.find tenv id.name with - | {kind= `Variable _; _} :: _ - (* variables can shadow stanlib functions, so we have to check this *) - when not (Stan_math_signatures.is_stan_math_function_name id.name) -> - Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error - | [] -> - Semantic_error.nonreturning_fn_expected_undeclaredident_found loc id.name - (Env.nearest_ident tenv id.name) - |> error - | _ (* a function *) -> ( - match - SignatureMismatch.matching_function tenv id.name (get_arg_types es) - with - | UniqueMatch (Void, fnk, promotions) -> - mk_typed_statement - ~stmt: - (NRFunApp - ( fnk (Fun_kind.suffix_from_name id.name) - , id - , Promotion.promote_list es promotions ) ) - ~return_type:NoReturnType ~loc - | UniqueMatch (ReturnType _, _, _) -> - Semantic_error.nonreturning_fn_expected_returning_found loc id.name - |> error - | AmbiguousMatch sigs -> - Semantic_error.ambiguous_function_promotion loc id.name - (Some (List.map ~f:type_of_expr_typed es)) - sigs - |> error - | SignatureErrors (l, b) -> - es - |> List.map ~f:type_of_expr_typed - |> Semantic_error.illtyped_fn_app loc id.name (l, b) - |> error ) - -let check_nr_fn_app loc cf tenv id es = - let tes = List.map ~f:(check_expression cf tenv) es in - verify_identifier id ; - verify_nrfn_target loc cf id ; - check_nrfn loc tenv id tes - -(* assignments *) -let verify_assignment_read_only loc is_readonly id = - if is_readonly then - Semantic_error.cannot_assign_to_read_only loc id.name |> error - -(* Variables from previous blocks are read-only. - In particular, data and parameters never assigned to -*) -let verify_assignment_global loc cf block is_global id = - if (not is_global) || block = cf.current_block then () - else Semantic_error.cannot_assign_to_global loc id.name |> error - -(* Until function types are added to the user language, we - disallow assignments to function values -*) -let verify_assignment_non_function loc ut id = - match ut with - | UnsizedType.UFun _ | UMathLibraryFunction -> - Semantic_error.cannot_assign_function loc ut id.name |> error - | _ -> () - -let check_assignment_operator loc assop lhs rhs = - let err op = - Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_ - in - match assop with - | Assign | ArrowAssign -> ( - match - SignatureMismatch.check_of_same_type_mod_conv lhs.lmeta.type_ - rhs.emeta.type_ - with - | Ok p -> Promotion.promote rhs p - | Error _ -> err Operator.Equals |> error ) - | OperatorAssign op -> ( - let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in - let return_type = assignmentoperator_stan_math_return_type op args in - match return_type with Some Void -> rhs | _ -> err op |> error ) - -let check_lvalue cf tenv = function - | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> - verify_identifier id ; - let ad_level, type_ = check_id cf loc tenv id in - {lval= LVariable id; lmeta= {ad_level; type_; loc}} - | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> - let rec check_inner = function - | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> - verify_identifier id ; - let ad_level, type_ = check_id cf loc tenv id in - let var = {lval= LVariable id; lmeta= {ad_level; type_; loc}} in - (var, var, []) - | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> - let lval, var, flat = check_inner lval in - let idcs = List.map ~f:(check_index cf tenv) idcs in - let ad_level = - inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in - let type_ = - inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in - ( {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} - , var - , flat @ idcs ) in - let lval, var, flat = check_inner lval in - let idcs = List.map ~f:(check_index cf tenv) idcs in - let ad_level = inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in - let type_ = inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in - if List.exists ~f:is_multiindex flat then ( - add_warning loc - "Nested multi-indexing on the left hand side of assignment does not \ - behave the same as nested indexing in expressions. This is \ - considered a bug and will be disallowed in Stan 2.33.0. The \ - indexing can be automatically fixed using the canonicalize flag for \ - stanc." ; - let lvalue_rvalue_types_differ = - try - let flat_type = - inferred_unsizedtype_of_indexed ~loc var.lmeta.type_ (flat @ idcs) - in - let rec can_assign = function - | UnsizedType.(UArray t1, UArray t2) -> can_assign (t1, t2) - | UVector, URowVector | URowVector, UVector -> false - | t1, t2 -> UnsizedType.compare t1 t2 <> 0 in - can_assign (flat_type, type_) - with Errors.SemanticError _ -> true in - if lvalue_rvalue_types_differ then - Semantic_error.cannot_assign_to_multiindex loc |> error ) ; - {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} - -let check_assignment loc cf tenv assign_lhs assign_op assign_rhs = - let assign_id = Ast.id_of_lvalue assign_lhs in - let lhs = check_lvalue cf tenv assign_lhs in - let rhs = check_expression cf tenv assign_rhs in - let block, global, readonly = - let var = Env.find tenv assign_id.name in - match var with - | {kind= `Variable {origin; global; readonly}; _} :: _ -> - (origin, global, readonly) - | {kind= `StanMath; _} :: _ -> (MathLibrary, true, false) - | {kind= `UserDefined | `UserDeclared _; _} :: _ -> (Functions, true, false) - | _ -> - Semantic_error.ident_not_in_scope loc assign_id.name - (Env.nearest_ident tenv assign_id.name) - |> error in - verify_assignment_global loc cf block global assign_id ; - verify_assignment_read_only loc readonly assign_id ; - verify_assignment_non_function loc rhs.emeta.type_ assign_id ; - let rhs' = check_assignment_operator loc assign_op lhs rhs in - mk_typed_statement ~return_type:NoReturnType ~loc - ~stmt:(Assignment {assign_lhs= lhs; assign_op; assign_rhs= rhs'}) - -(* target plus-equals / increment log-prob *) - -let verify_target_pe_expr_type loc e = - if UnsizedType.is_fun_type e.emeta.type_ then - Semantic_error.int_or_real_container_expected loc e.emeta.type_ |> error - -let verify_target_pe_usage loc cf = - if cf.in_lp_fun_def || cf.current_block = Model then () - else Semantic_error.target_plusequals_outside_model_or_logprob loc |> error - -let check_target_pe loc cf tenv e = - let te = check_expression cf tenv e in - verify_target_pe_usage loc cf ; - verify_target_pe_expr_type loc te ; - mk_typed_statement ~stmt:(TargetPE te) ~return_type:NoReturnType ~loc - -let check_incr_logprob loc cf tenv e = - let te = check_expression cf tenv e in - verify_target_pe_usage loc cf ; - verify_target_pe_expr_type loc te ; - mk_typed_statement ~stmt:(IncrementLogProb te) ~return_type:NoReturnType ~loc - -(* tilde/sampling notation*) -let verify_sampling_pdf_pmf id = - if - String.( - is_suffix id.name ~suffix:"_lpdf" - || is_suffix id.name ~suffix:"_lpmf" - || is_suffix id.name ~suffix:"_lupdf" - || is_suffix id.name ~suffix:"_lupmf") - then Semantic_error.invalid_sampling_pdf_or_pmf id.id_loc |> error - -let verify_sampling_cdf_ccdf loc id = - if - String.( - is_suffix id.name ~suffix:"_cdf" || is_suffix id.name ~suffix:"_ccdf") - then Semantic_error.invalid_sampling_cdf_or_ccdf loc id.name |> error - -(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) -let verify_valid_sampling_pos loc cf = - if cf.in_lp_fun_def || cf.current_block = Model then () - else Semantic_error.target_plusequals_outside_model_or_logprob loc |> error - -let verify_sampling_distribution loc tenv id arguments = - let name = id.name in - let argumenttypes = List.map ~f:arg_type arguments in - let name_w_suffix_sampling_dist suffix = - SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes in - let sampling_dists = - List.map ~f:name_w_suffix_sampling_dist Utils.distribution_suffices in - let is_sampling_dist_defined = - List.exists - ~f:(function UniqueMatch (ReturnType UReal, _, _) -> true | _ -> false) - sampling_dists - && name <> "binomial_coefficient" - && name <> "multiply" in - if is_sampling_dist_defined then () - else - match - List.max_elt sampling_dists - ~compare:SignatureMismatch.compare_match_results - with - | None | Some (UniqueMatch _) | Some (SignatureErrors ([], _)) -> - (* Either non-existant or a very odd case, - output the old non-informative error *) - Semantic_error.invalid_sampling_no_such_dist loc name |> error - | Some (AmbiguousMatch sigs) -> - Semantic_error.ambiguous_function_promotion loc id.name - (Some (List.map ~f:type_of_expr_typed arguments)) - sigs - |> error - | Some (SignatureErrors (l, b)) -> - arguments - |> List.map ~f:(fun e -> e.emeta.type_) - |> Semantic_error.illtyped_fn_app loc id.name (l, b) - |> error - -let is_cumulative_density_defined tenv id arguments = - let name = id.name in - let argumenttypes = List.map ~f:arg_type arguments in - let valid_arg_types_for_suffix suffix = - match - SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes - with - | UniqueMatch (ReturnType UReal, _, _) -> true - | _ -> false in - (valid_arg_types_for_suffix "_lcdf" || valid_arg_types_for_suffix "_cdf_log") - && ( valid_arg_types_for_suffix "_lccdf" - || valid_arg_types_for_suffix "_ccdf_log" ) - -let verify_sampling_cdf_defined loc tenv id truncation args = - let check e = - if not (is_cumulative_density_defined tenv id (e :: args)) then - Semantic_error.invalid_truncation_cdf_or_ccdf loc - (get_arg_types (e :: args)) - |> error in - match truncation with - | NoTruncate -> () - | TruncateUpFrom e | TruncateDownFrom e -> check e - | TruncateBetween (e1, e2) -> check e1 ; check e2 - -let check_truncation cf tenv truncation = - let check e = - check_expression_of_int_or_real_type cf tenv e "Truncation bound" in - match truncation with - | NoTruncate -> NoTruncate - | TruncateUpFrom e -> check e |> TruncateUpFrom - | TruncateDownFrom e -> check e |> TruncateDownFrom - | TruncateBetween (e1, e2) -> (check e1, check e2) |> TruncateBetween - -let check_tilde loc cf tenv distribution truncation arg args = - let te = check_expression cf tenv arg in - let tes = List.map ~f:(check_expression cf tenv) args in - let ttrunc = check_truncation cf tenv truncation in - verify_identifier distribution ; - verify_sampling_pdf_pmf distribution ; - verify_valid_sampling_pos loc cf ; - verify_sampling_cdf_ccdf loc distribution ; - verify_sampling_distribution loc tenv distribution (te :: tes) ; - verify_sampling_cdf_defined loc tenv distribution ttrunc tes ; - let stmt = Tilde {arg= te; distribution; args= tes; truncation= ttrunc} in - mk_typed_statement ~stmt ~loc ~return_type:NoReturnType - -(* Break and continue only occur in loops. *) -let check_break loc cf = - if cf.loop_depth = 0 then Semantic_error.break_outside_loop loc |> error - else mk_typed_statement ~stmt:Break ~return_type:NoReturnType ~loc - -let check_continue loc cf = - if cf.loop_depth = 0 then Semantic_error.continue_outside_loop loc |> error - else mk_typed_statement ~stmt:Continue ~return_type:NoReturnType ~loc - -let check_return loc cf tenv e = - if not cf.in_returning_fun_def then - Semantic_error.expression_return_outside_returning_fn loc |> error - else - let te = check_expression cf tenv e in - mk_typed_statement ~stmt:(Return te) - ~return_type:(Complete (ReturnType te.emeta.type_)) ~loc - -let check_returnvoid loc cf = - if (not cf.in_fun_def) || cf.in_returning_fun_def then - Semantic_error.void_outside_nonreturning_fn loc |> error - else mk_typed_statement ~stmt:ReturnVoid ~return_type:(Complete Void) ~loc - -let check_printable cf tenv = function - | PString s -> PString s - (* Print/reject expressions cannot be of function type. *) - | PExpr e -> ( - let te = check_expression cf tenv e in - match te.emeta.type_ with - | UFun _ | UMathLibraryFunction -> - Semantic_error.not_printable te.emeta.loc |> error - | _ -> PExpr te ) - -let check_print loc cf tenv ps = - let tps = List.map ~f:(check_printable cf tenv) ps in - mk_typed_statement ~stmt:(Print tps) ~return_type:NoReturnType ~loc - -let check_reject loc cf tenv ps = - let tps = List.map ~f:(check_printable cf tenv) ps in - mk_typed_statement ~stmt:(Reject tps) ~return_type:AnyReturnType ~loc - -let check_skip loc = - mk_typed_statement ~stmt:Skip ~return_type:NoReturnType ~loc - -let rec stmt_is_escape {stmt; _} = - match stmt with - | Break | Continue | Reject _ | Return _ | ReturnVoid -> true - | _ -> false - -and list_until_escape xs = - let rec aux accu = function - | [next; next'] when stmt_is_escape next' -> List.rev (next' :: next :: accu) - | next :: next' :: unreachable :: _ when stmt_is_escape next' -> - add_warning unreachable.smeta.loc - "Unreachable statement (following a reject, break, continue, or \ - return) found, is this intended?" ; - List.rev (next' :: next :: accu) - | next :: rest -> aux (next :: accu) rest - | [] -> List.rev accu in - aux [] xs - -let returntype_leastupperbound loc rt1 rt2 = - match (rt1, rt2) with - | UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt - |ReturnType UInt, ReturnType UReal -> - UnsizedType.ReturnType UReal - | _, _ when rt1 = rt2 -> rt2 - | _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> error - -let try_compute_block_statement_returntype loc srt1 srt2 = - match (srt1, srt2) with - | Complete rt1, Complete rt2 | Incomplete rt1, Complete rt2 -> - Complete (returntype_leastupperbound loc rt1 rt2) - | Incomplete rt1, Incomplete rt2 | Complete rt1, Incomplete rt2 -> - Incomplete (returntype_leastupperbound loc rt1 rt2) - | NoReturnType, NoReturnType -> NoReturnType - | AnyReturnType, Incomplete rt - |Complete rt, NoReturnType - |NoReturnType, Incomplete rt - |Incomplete rt, NoReturnType -> - Incomplete rt - | NoReturnType, Complete rt - |Complete rt, AnyReturnType - |Incomplete rt, AnyReturnType - |AnyReturnType, Complete rt -> - Complete rt - | AnyReturnType, NoReturnType - |NoReturnType, AnyReturnType - |AnyReturnType, AnyReturnType -> - AnyReturnType - -let try_compute_ifthenelse_statement_returntype loc srt1 srt2 = - match (srt1, srt2) with - | Complete rt1, Complete rt2 -> - returntype_leastupperbound loc rt1 rt2 |> Complete - | Incomplete rt1, Incomplete rt2 - |Complete rt1, Incomplete rt2 - |Incomplete rt1, Complete rt2 -> - returntype_leastupperbound loc rt1 rt2 |> Incomplete - | AnyReturnType, NoReturnType - |NoReturnType, AnyReturnType - |NoReturnType, NoReturnType -> - NoReturnType - | AnyReturnType, Incomplete rt - |Incomplete rt, AnyReturnType - |Complete rt, NoReturnType - |NoReturnType, Complete rt - |NoReturnType, Incomplete rt - |Incomplete rt, NoReturnType -> - Incomplete rt - | Complete rt, AnyReturnType | AnyReturnType, Complete rt -> Complete rt - | AnyReturnType, AnyReturnType -> AnyReturnType - -(* statements which contain statements, and therefore need to be mutually recursive - with check_statement -*) -let rec check_if_then_else loc cf tenv pred_e s_true s_false_opt = - (* we don't need these nested type environments *) - let _, ts_true = check_statement cf tenv s_true in - let ts_false_opt = - s_false_opt |> Option.map ~f:(check_statement cf tenv) |> Option.map ~f:snd - in - let te = - check_expression_of_int_or_real_type cf tenv pred_e - "Condition in conditional" in - let stmt = IfThenElse (te, ts_true, ts_false_opt) in - let srt1 = ts_true.smeta.return_type in - let srt2 = - ts_false_opt - |> Option.map ~f:(fun s -> s.smeta.return_type) - |> Option.value ~default:NoReturnType in - let return_type = try_compute_ifthenelse_statement_returntype loc srt1 srt2 in - mk_typed_statement ~stmt ~return_type ~loc - -and check_while loc cf tenv cond_e loop_body = - let _, ts = - check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body - and te = - check_expression_of_int_or_real_type cf tenv cond_e - "Condition in while-loop" in - mk_typed_statement - ~stmt:(While (te, ts)) - ~return_type:ts.smeta.return_type ~loc - -and check_for loc cf tenv loop_var lower_bound_e upper_bound_e loop_body = - let te1 = - check_expression_of_int_type cf tenv lower_bound_e "Lower bound of for-loop" - and te2 = - check_expression_of_int_type cf tenv upper_bound_e "Upper bound of for-loop" - in - verify_identifier loop_var ; - let ts = check_loop_body cf tenv loop_var UnsizedType.UInt loop_body in - mk_typed_statement - ~stmt: - (For - { loop_variable= loop_var - ; lower_bound= te1 - ; upper_bound= te2 - ; loop_body= ts } ) - ~return_type:ts.smeta.return_type ~loc - -and check_foreach_loop_identifier_type loc ty = - match ty with - | UnsizedType.UArray ut -> ut - | UVector | URowVector | UMatrix -> UnsizedType.UReal - | _ -> Semantic_error.array_vector_rowvector_matrix_expected loc ty |> error - -and check_foreach loc cf tenv loop_var foreach_e loop_body = - let te = check_expression cf tenv foreach_e in - verify_identifier loop_var ; - let loop_var_ty = - check_foreach_loop_identifier_type te.emeta.loc te.emeta.type_ in - let ts = check_loop_body cf tenv loop_var loop_var_ty loop_body in - mk_typed_statement - ~stmt:(ForEach (loop_var, te, ts)) - ~return_type:ts.smeta.return_type ~loc - -and check_loop_body cf tenv loop_var loop_var_ty loop_body = - verify_name_fresh tenv loop_var ~is_udf:false ; - (* Add to type environment as readonly. - Check that function args and loop identifiers are not modified in - function. (passed by const ref) - *) - let tenv = - Env.add tenv loop_var.name loop_var_ty - (`Variable {origin= cf.current_block; global= false; readonly= true}) - in - snd (check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body) - -and check_block loc cf tenv stmts = - let _, checked_stmts = - List.fold_map stmts ~init:tenv ~f:(check_statement cf) in - let return_type = - checked_stmts |> list_until_escape - |> List.map ~f:(fun s -> s.smeta.return_type) - |> List.fold ~init:NoReturnType - ~f:(try_compute_block_statement_returntype loc) in - mk_typed_statement ~stmt:(Block checked_stmts) ~return_type ~loc - -and check_profile loc cf tenv name stmts = - let _, checked_stmts = - List.fold_map stmts ~init:tenv ~f:(check_statement cf) in - let return_type = - checked_stmts |> list_until_escape - |> List.map ~f:(fun s -> s.smeta.return_type) - |> List.fold ~init:NoReturnType - ~f:(try_compute_block_statement_returntype loc) in - mk_typed_statement ~stmt:(Profile (name, checked_stmts)) ~return_type ~loc - -(* variable declarations *) -and verify_valid_transformation_for_type loc is_global sized_ty trans = - let is_real {emeta; _} = emeta.type_ = UReal in - let is_real_transformation = - match trans with - | Transformation.Lower e -> is_real e - | Upper e -> is_real e - | LowerUpper (e1, e2) -> is_real e1 || is_real e2 - | _ -> false in - if is_global && sized_ty = SizedType.SInt && is_real_transformation then - Semantic_error.non_int_bounds loc |> error ; - let is_transformation = - match trans with Transformation.Identity -> false | _ -> true in - if is_global && SizedType.(contains_complex sized_ty) && is_transformation - then Semantic_error.complex_transform loc |> error - -and verify_transformed_param_ty loc cf is_global unsized_ty = - if - is_global - && (cf.current_block = Param || cf.current_block = TParam) - && UnsizedType.is_int_type unsized_ty - then Semantic_error.transformed_params_int loc |> error - -and check_sizedtype cf tenv sizedty = - let check e msg = check_expression_of_int_type cf tenv e msg in - match sizedty with - | SizedType.SInt -> SizedType.SInt - | SReal -> SReal - | SComplex -> SComplex - | SVector (mem_pattern, e) -> - let te = check e "Vector sizes" in - SVector (mem_pattern, te) - | SRowVector (mem_pattern, e) -> - let te = check e "Row vector sizes" in - SRowVector (mem_pattern, te) - | SMatrix (mem_pattern, e1, e2) -> - let te1 = check e1 "Matrix row size" in - let te2 = check e2 "Matrix column size" in - SMatrix (mem_pattern, te1, te2) - | SComplexVector e -> - let te = check e "complex vector sizes" in - SComplexVector te - | SComplexRowVector e -> - let te = check e "complex row vector sizes" in - SComplexRowVector te - | SComplexMatrix (e1, e2) -> - let te1 = check e1 "Complex matrix row size" in - let te2 = check e2 "Complex matrix column size" in - SComplexMatrix (te1, te2) - | SArray (st, e) -> - let tst = check_sizedtype cf tenv st in - let te = check e "Array sizes" in - SArray (tst, te) - -and check_var_decl_initial_value loc cf tenv {identifier; initial_value} = - match initial_value with - | Some e -> ( - let lhs = - check_lvalue cf tenv {lval= LVariable identifier; lmeta= {loc}} in - let rhs = check_expression cf tenv e in - match - SignatureMismatch.check_of_same_type_mod_conv lhs.lmeta.type_ - rhs.emeta.type_ - with - | Ok p -> Ast.{identifier; initial_value= Some (Promotion.promote rhs p)} - | Error _ -> - Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_ - rhs.emeta.type_ - |> error ) - | None -> Ast.{identifier; initial_value= None} - -and check_transformation cf tenv ut trans = - let check e msg = check_expression_of_scalar_or_type cf tenv ut e msg in - match trans with - | Transformation.Identity -> Transformation.Identity - | Lower e -> check e "Lower bound" |> Lower - | Upper e -> check e "Upper bound" |> Upper - | LowerUpper (e1, e2) -> - (check e1 "Lower bound", check e2 "Upper bound") |> LowerUpper - | Offset e -> check e "Offset" |> Offset - | Multiplier e -> check e "Multiplier" |> Multiplier - | OffsetMultiplier (e1, e2) -> - (check e1 "Offset", check e2 "Multiplier") |> OffsetMultiplier - | Ordered -> Ordered - | PositiveOrdered -> PositiveOrdered - | Simplex -> Simplex - | UnitVector -> UnitVector - | CholeskyCorr -> CholeskyCorr - | CholeskyCov -> CholeskyCov - | Correlation -> Correlation - | Covariance -> Covariance - -and check_var_decl loc cf tenv sized_ty trans - (variables : untyped_expression Ast.variable list) is_global = - let checked_type = - check_sizedtype {cf with in_toplevel_decl= is_global} tenv sized_ty in - let unsized_type = SizedType.to_unsized checked_type in - let checked_trans = check_transformation cf tenv unsized_type trans in - let tenv, tvariables = - List.fold_map ~init:tenv - ~f:(fun tenv' ({identifier; _} as var) -> - verify_identifier identifier ; - verify_name_fresh tenv' identifier ~is_udf:false ; - let tenv'' = - Env.add tenv' identifier.name unsized_type - (`Variable - {origin= cf.current_block; global= is_global; readonly= false} ) - in - (tenv'', check_var_decl_initial_value loc cf tenv'' var) ) - variables in - verify_valid_transformation_for_type loc is_global checked_type checked_trans ; - verify_transformed_param_ty loc cf is_global unsized_type ; - let stmt = - VarDecl - { decl_type= checked_type - ; transformation= checked_trans - ; variables= tvariables - ; is_global } in - (tenv, mk_typed_statement ~stmt ~loc ~return_type:NoReturnType) - -(* function definitions *) -and exists_matching_fn_declared tenv id arg_tys rt = - let options = - List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) - in - let f = function - | Env.{kind= `UserDeclared _; type_= UFun (listedtypes, rt', _, _)} - when arg_tys = listedtypes && rt = rt' -> - true - | _ -> false in - List.exists ~f options - -and verify_unique_signature tenv loc id arg_tys rt = - let existing = - List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) - in - let same_args = function - | Env.{type_= UFun (listedtypes, _, _, _); _} - when List.map ~f:snd arg_tys = List.map ~f:snd listedtypes -> - true - | _ -> false in - match List.filter existing ~f:same_args with - | [] -> () - | {type_= UFun (_, rt', _, _); _} :: _ when rt <> rt' -> - Semantic_error.fn_overload_rt_only loc id.name rt rt' |> error - | {kind; _} :: _ -> - Semantic_error.fn_decl_redefined loc id.name - ~stan_math:(kind = `StanMath) - (UnsizedType.UFun (arg_tys, rt, Fun_kind.suffix_from_name id.name, AoS)) - |> error - -and verify_fundef_overloaded loc tenv id arg_tys rt = - if exists_matching_fn_declared tenv id arg_tys rt then - (* this is the definition to an existing forward declaration *) - () - else - (* this should be an overload with a unique signature *) - verify_unique_signature tenv loc id arg_tys rt ; - verify_name_fresh tenv id ~is_udf:true - -and get_fn_decl_or_defn loc tenv id arg_tys rt body = - match body with - | {stmt= Skip; _} -> - if exists_matching_fn_declared tenv id arg_tys rt then - Semantic_error.fn_decl_exists loc id.name |> error - else `UserDeclared id.id_loc - | _ -> `UserDefined - -and verify_fundef_dist_rt loc id return_ty = - let is_dist = - List.exists - ~f:(fun x -> String.is_suffix id.name ~suffix:x) - Utils.conditioning_suffices_w_log in - if is_dist then - match return_ty with - | UnsizedType.ReturnType UReal -> () - | _ -> Semantic_error.non_real_prob_fn_def loc |> error - -and verify_pdf_fundef_first_arg_ty loc id arg_tys = - if String.is_suffix id.name ~suffix:"_lpdf" then - let rt = List.hd arg_tys |> Option.map ~f:snd in - match rt with - | Some rt when not (UnsizedType.is_int_type rt) -> () - | _ -> Semantic_error.prob_density_non_real_variate loc rt |> error - -and verify_pmf_fundef_first_arg_ty loc id arg_tys = - if String.is_suffix id.name ~suffix:"_lpmf" then - let rt = List.hd arg_tys |> Option.map ~f:snd in - match rt with - | Some rt when UnsizedType.is_int_type rt -> () - | _ -> Semantic_error.prob_mass_non_int_variate loc rt |> error - -and verify_fundef_distinct_arg_ids loc arg_names = - let dup_exists l = - List.find_a_dup ~compare:String.compare l |> Option.is_some in - if dup_exists arg_names then Semantic_error.duplicate_arg_names loc |> error - -and verify_fundef_return_tys loc return_type body = - if - body.stmt = Skip - || is_of_compatible_return_type return_type body.smeta.return_type - then () - else Semantic_error.incompatible_return_types loc |> error - -and add_function tenv name type_ defined = - (* if we're providing a definition, we remove prior declarations - to simplify the environment *) - if defined = `UserDefined then - let existing_defns = Env.find tenv name in - let defns = - List.filter - ~f:(function - | Env.{kind= `UserDeclared _; type_= type'} when type' = type_ -> - false - | _ -> true ) - existing_defns in - let new_fn = Env.{kind= `UserDefined; type_} in - Env.set_raw tenv name (new_fn :: defns) - else Env.add tenv name type_ defined - -and check_fundef loc cf tenv return_ty id args body = - List.iter args ~f:(fun (_, _, id) -> verify_identifier id) ; - verify_identifier id ; - let arg_types = List.map ~f:(fun (w, y, _) -> (w, y)) args in - let arg_identifiers = List.map ~f:(fun (_, _, z) -> z) args in - let arg_names = List.map ~f:(fun x -> x.name) arg_identifiers in - verify_fundef_dist_rt loc id return_ty ; - verify_pdf_fundef_first_arg_ty loc id arg_types ; - verify_pmf_fundef_first_arg_ty loc id arg_types ; - List.iter - ~f:(fun id -> verify_name_fresh tenv id ~is_udf:false) - arg_identifiers ; - verify_fundef_distinct_arg_ids loc arg_names ; - (* We treat DataOnly arguments as if they are data and AutoDiffable arguments - as if they are parameters, for the purposes of type checking. - *) - let arg_types_internal = - List.map - ~f:(function - | UnsizedType.DataOnly, ut -> (Env.Data, ut) - | AutoDiffable, ut -> (Param, ut) ) - arg_types in - let tenv_body = - List.fold2_exn arg_names arg_types_internal ~init:tenv - ~f:(fun env name (origin, typ) -> - Env.add env name typ - (* readonly so that function args and loop identifiers - are not modified in function. (passed by const ref) *) - (`Variable {origin; readonly= true; global= false}) ) in - let context = - let is_udf_dist name = - List.exists - ~f:(fun suffix -> String.is_suffix name ~suffix) - Utils.distribution_suffices in - { cf with - in_fun_def= true - ; in_rng_fun_def= String.is_suffix id.name ~suffix:"_rng" - ; in_lp_fun_def= String.is_suffix id.name ~suffix:"_lp" - ; in_udf_dist_def= is_udf_dist id.name - ; in_returning_fun_def= return_ty <> Void } in - let _, checked_body = check_statement context tenv_body body in - verify_fundef_return_tys loc return_ty checked_body ; - let stmt = - FunDef - {returntype= return_ty; funname= id; arguments= args; body= checked_body} - in - (* NB: **not** tenv_body, so args don't leak out *) - (tenv, mk_typed_statement ~return_type:NoReturnType ~loc ~stmt) - -and check_statement (cf : context_flags_record) (tenv : Env.t) - (s : Ast.untyped_statement) : Env.t * typed_statement = - let loc = s.smeta.loc in - match s.stmt with - | NRFunApp (_, id, es) -> (tenv, check_nr_fn_app loc cf tenv id es) - | Assignment {assign_lhs; assign_op; assign_rhs} -> - (tenv, check_assignment loc cf tenv assign_lhs assign_op assign_rhs) - | TargetPE e -> (tenv, check_target_pe loc cf tenv e) - | IncrementLogProb e -> (tenv, check_incr_logprob loc cf tenv e) - | Tilde {arg; distribution; args; truncation} -> - (tenv, check_tilde loc cf tenv distribution truncation arg args) - | Break -> (tenv, check_break loc cf) - | Continue -> (tenv, check_continue loc cf) - | Return e -> (tenv, check_return loc cf tenv e) - | ReturnVoid -> (tenv, check_returnvoid loc cf) - | Print ps -> (tenv, check_print loc cf tenv ps) - | Reject ps -> (tenv, check_reject loc cf tenv ps) - | Skip -> (tenv, check_skip loc) - (* the following can contain further statements *) - | IfThenElse (e, s1, os2) -> (tenv, check_if_then_else loc cf tenv e s1 os2) - | While (e, s) -> (tenv, check_while loc cf tenv e s) - | For {loop_variable; lower_bound; upper_bound; loop_body} -> - ( tenv - , check_for loc cf tenv loop_variable lower_bound upper_bound loop_body ) - | ForEach (id, e, s) -> (tenv, check_foreach loc cf tenv id e s) - | Block stmts -> (tenv, check_block loc cf tenv stmts) - | Profile (name, vdsl) -> (tenv, check_profile loc cf tenv name vdsl) - (* these two are special in that they're allowed to change the type environment *) - | VarDecl {decl_type; transformation; variables; is_global} -> - check_var_decl loc cf tenv decl_type transformation variables is_global - | FunDef {returntype; funname; arguments; body} -> - check_fundef loc cf tenv returntype funname arguments body - -let verify_fun_def_body_in_block = function - | {stmt= FunDef {body= {stmt= Block _; _}; _}; _} - |{stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> - () - | {stmt= FunDef {body= {stmt= _; smeta}; _}; _} -> - Semantic_error.fn_decl_needs_block smeta.loc |> error - | _ -> () - -let verify_functions_have_defn tenv function_block_stmts_opt = - let error_on_undefined name funs = - List.iter (List.rev funs) ~f:(fun f -> - match f with - | Env.{kind= `UserDeclared loc; _} -> - Semantic_error.fn_decl_without_def loc name |> error - | _ -> () ) in - if !check_that_all_functions_have_definition then - Env.iteri tenv error_on_undefined ; - match function_block_stmts_opt with - | Some {stmts= []; _} | None -> () - | Some {stmts= ls; _} -> List.iter ~f:verify_fun_def_body_in_block ls - -let add_userdefined_functions tenv stmts_opt = - match stmts_opt with - | None -> tenv - | Some {stmts; _} -> - let f tenv (s : Ast.untyped_statement) = - match s with - | {stmt= FunDef {returntype; funname; arguments; body}; smeta= {loc}} -> - let arg_types = Ast.type_of_arguments arguments in - verify_fundef_overloaded loc tenv funname arg_types returntype ; - let defined = - get_fn_decl_or_defn loc tenv funname arg_types returntype body - in - add_function tenv funname.name - (UFun - ( arg_types - , returntype - , Fun_kind.suffix_from_name funname.name - , AoS ) ) - defined - | _ -> tenv in - List.fold ~init:tenv ~f stmts - -let check_toplevel_block block tenv stmts_opt = - let cf = context block in - match stmts_opt with - | Some {stmts; xloc} -> - let tenv', stmts = - List.fold_map stmts ~init:tenv ~f:(check_statement cf) in - (tenv', Some {stmts; xloc}) - | None -> (tenv, None) - -let verify_correctness_invariant (ast : untyped_program) - (decorated_ast : typed_program) = - let detyped = untyped_program_of_typed_program decorated_ast in - if compare_untyped_program ast detyped = 0 then () - else - Common.FatalError.fatal_error_msg - [%message - "Type checked AST does not match original AST. " - (detyped : untyped_program) - (ast : untyped_program)] - -let check_program_exn - ( { functionblock= fb - ; datablock= db - ; transformeddatablock= tdb - ; parametersblock= pb - ; transformedparametersblock= tpb - ; modelblock= mb - ; generatedquantitiesblock= gqb - ; comments } as ast ) = - warnings := [] ; - (* create a new type environment which has only stan-math functions *) - let tenv = Env.stan_math_environment in - let tenv = add_userdefined_functions tenv fb in - let tenv, typed_fb = check_toplevel_block Functions tenv fb in - verify_functions_have_defn tenv typed_fb ; - let tenv, typed_db = check_toplevel_block Data tenv db in - let tenv, typed_tdb = check_toplevel_block TData tenv tdb in - let tenv, typed_pb = check_toplevel_block Param tenv pb in - let tenv, typed_tpb = check_toplevel_block TParam tenv tpb in - let _, typed_mb = check_toplevel_block Model tenv mb in - let _, typed_gqb = check_toplevel_block GQuant tenv gqb in - let prog = - { functionblock= typed_fb - ; datablock= typed_db - ; transformeddatablock= typed_tdb - ; parametersblock= typed_pb - ; transformedparametersblock= typed_tpb - ; modelblock= typed_mb - ; generatedquantitiesblock= typed_gqb - ; comments } in - verify_correctness_invariant ast prog ; - attach_warnings prog - -let check_program ast = - try Result.Ok (check_program_exn ast) - with Errors.SemanticError err -> Result.Error err diff --git a/src/frontend/Typechecking.ml b/src/frontend/Typechecking.ml new file mode 100644 index 0000000000..5598c8a9d6 --- /dev/null +++ b/src/frontend/Typechecking.ml @@ -0,0 +1,1707 @@ +(** a type/semantic checker for Stan ASTs + + Functions which begin with "check_" return a typed version of their input + Functions which begin with "verify_" return unit if a check succeeds, or else + throw an Errors.SemanticError exception. + Other functions which begin with "infer"/"calculate" vary. Usually they return + a value, but a few do have error conditions. + + All Error.SemanticError excpetions are caught by check_program + which turns the ast or exception into a Result.t for external usage + + A type environment (Env.t) is used to hold variables and functions, including + stan math functions. This is a functional map, meaning it is handled immutably. +*) + +open Core_kernel +open Core_kernel.Poly +open Middle +open Ast +open Typechecking_intf +module Env = Environment + +(* we only allow errors raised by this function *) +let error e = raise (Errors.SemanticError e) + +(* warnings are built up in a list *) +let warnings : Warnings.t list ref = ref [] + +let add_warning (span : Location_span.t) (message : string) = + warnings := (span, message) :: !warnings + +let attach_warnings x = (x, List.rev !warnings) + +(* model name - don't love this here *) +let model_name = ref "" +let check_that_all_functions_have_definition = ref true + +(* Record structure holding flags and other markers about context to be + used for error reporting. *) +type context_flags_record = + { current_block: Env.originblock + ; in_toplevel_decl: bool + ; in_fun_def: bool + ; in_returning_fun_def: bool + ; in_rng_fun_def: bool + ; in_lp_fun_def: bool + ; in_udf_dist_def: bool + ; loop_depth: int } + +let context block = + { current_block= block + ; in_toplevel_decl= false + ; in_fun_def= false + ; in_returning_fun_def= false + ; in_rng_fun_def= false + ; in_lp_fun_def= false + ; in_udf_dist_def= false + ; loop_depth= 0 } + +let calculate_autodifftype current_block origin ut = + match origin with + | Env.(Param | TParam | Model | Functions) + when not (UnsizedType.is_int_type ut || current_block = Env.GQuant) -> + UnsizedType.AutoDiffable + | _ -> DataOnly + +let arg_type x = (x.emeta.ad_level, x.emeta.type_) +let get_arg_types = List.map ~f:arg_type +let type_of_expr_typed ue = ue.emeta.type_ +let has_int_type ue = ue.emeta.type_ = UInt +let has_int_array_type ue = ue.emeta.type_ = UArray UInt + +let has_int_or_real_type ue = + match ue.emeta.type_ with UInt | UReal -> true | _ -> false + +let make_function_variable current_block loc id = function + | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | UnsizedType.UFun _ as type_ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | type_ -> + Common.FatalError.fatal_error_msg + [%message + "Attempting to create function variable out of " + (type_ : UnsizedType.t)] + +(* -- General checks ---------------------------------------------- *) +let reserved_keywords = + [ "for"; "in"; "while"; "repeat"; "until"; "if"; "then"; "else"; "true" + ; "false"; "target"; "int"; "real"; "complex"; "void"; "vector"; "simplex" + ; "unit_vector"; "ordered"; "positive_ordered"; "row_vector"; "matrix" + ; "cholesky_factor_corr"; "cholesky_factor_cov"; "corr_matrix"; "cov_matrix" + ; "functions"; "model"; "data"; "parameters"; "quantities"; "transformed" + ; "generated"; "profile"; "return"; "break"; "continue"; "increment_log_prob" + ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" + ; "static"; "auto" ] + +module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER = struct + let std_library_tenv : Env.t = + Env.make_from_library StdLibrary.function_signatures + + let matching_library_function = + SignatureMismatch.matching_function std_library_tenv + + let verify_identifier id : unit = + if id.name = !model_name then + Semantic_error.ident_is_model_name id.id_loc id.name |> error + else if + String.is_suffix id.name ~suffix:"__" + || List.mem reserved_keywords id.name ~equal:String.equal + then Semantic_error.ident_is_keyword id.id_loc id.name |> error + + let distribution_name_variants name = + if name = "multiply_log" || name = "binomial_coefficient_log" then [name] + else + (* this will have some duplicates, but preserves order better *) + match Utils.split_distribution_suffix name with + | Some (stem, "lpmf") | Some (stem, "lpdf") | Some (stem, "log") -> + [name; stem ^ "_lpmf"; stem ^ "_lpdf"; stem ^ "_log"] + | Some (stem, "lcdf") | Some (stem, "cdf_log") -> + [name; stem ^ "_lcdf"; stem ^ "_cdf_log"] + | Some (stem, "lccdf") | Some (stem, "ccdf_log") -> + [name; stem ^ "_lccdf"; stem ^ "_ccdf_log"] + | _ -> [name] + + (** verify that the variable being declared is previous unused. + allowed to shadow StanLib *) + let verify_name_fresh_var loc tenv name = + if Utils.is_unnormalized_distribution name then + Semantic_error.ident_has_unnormalized_suffix loc name |> error + else if + List.exists (Env.find tenv name) ~f:(function + | {kind= `Variable _; _} -> true + | _ -> false (* user variables can shadow function names *) ) + then Semantic_error.ident_in_use loc name |> error + + (** verify that the variable being declared is previous unused. *) + let verify_name_fresh_udf loc tenv name = + if + (* variadic functions are currently not in math sigs and aren't + overloadable due to their separate typechecking *) + StdLibrary.is_not_overloadable name + then Semantic_error.ident_is_stanmath_name loc name |> error + else if Utils.is_unnormalized_distribution name then + Semantic_error.udf_is_unnormalized_fn loc name |> error + else if + (* if a variable is already defined with this name + - not really possible as all functions are defined before data, + but future-proofing is good *) + List.exists + ~f:(function {kind= `Variable _; _} -> true | _ -> false) + (Env.find tenv name) + then Semantic_error.ident_in_use loc name |> error + + (** Checks that a variable/function name: + - a function/identifier does not have the _lupdf/_lupmf suffix + - is not already in use (for now) +*) + let verify_name_fresh tenv id ~is_udf = + let f = + if is_udf then verify_name_fresh_udf id.id_loc tenv + else verify_name_fresh_var id.id_loc tenv in + List.iter ~f (distribution_name_variants id.name) + + let is_of_compatible_return_type rt1 srt2 = + UnsizedType.( + match (rt1, srt2) with + | Void, NoReturnType + |Void, Incomplete Void + |Void, Complete Void + |Void, AnyReturnType -> + true + | ReturnType UReal, Complete (ReturnType UInt) -> true + | ReturnType UComplex, Complete (ReturnType UReal) -> true + | ReturnType UComplex, Complete (ReturnType UInt) -> true + | ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2 + | ReturnType _, AnyReturnType -> true + | _ -> false) + + (* -- Expressions ------------------------------------------------- *) + let check_ternary_if loc pe te fe = + let promote expr type_ ad_level = + if + (not (UnsizedType.equal expr.emeta.type_ type_)) + || UnsizedType.compare_autodifftype expr.emeta.ad_level ad_level <> 0 + then + { expr= Promotion (expr, UnsizedType.internal_scalar type_, ad_level) + ; emeta= {expr.emeta with type_; ad_level} } + else expr in + match + ( pe.emeta.type_ + , UnsizedType.common_type (te.emeta.type_, fe.emeta.type_) + , expr_ad_lub [pe; te; fe] ) + with + | UInt, Some type_, ad_level when not (UnsizedType.is_fun_type type_) -> + mk_typed_expression + ~expr: + (TernaryIf (pe, promote te type_ ad_level, promote fe type_ ad_level) + ) + ~ad_level ~type_ ~loc + | _, _, _ -> + Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_ + fe.emeta.type_ + |> error + + let match_to_rt_option = function + | SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt + | _ -> None + + let library_function_return_type name arg_tys = + match Hashtbl.find StdLibrary.variadic_signatures name with + | Some {return_type; _} -> Some (UnsizedType.ReturnType return_type) + | None when StdLibrary.is_special_function_name name -> + StdLibrary.special_function_returntype name + | None -> matching_library_function name arg_tys |> match_to_rt_option + + let operator_return_type op arg_tys = + match (op, arg_tys) with + | Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] -> + Some + (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion]) + | IntDivide, _ -> None + | _ -> + StdLibrary.operator_to_function_names op + |> List.filter_map ~f:(fun name -> + matching_library_function name arg_tys + |> function + | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) + | _ -> None ) + |> List.hd + + let assignmentoperator_return_type assop arg_tys = + ( match assop with + | Operator.Divide -> + matching_library_function "divide" arg_tys |> match_to_rt_option + | Plus | Minus | Times | EltTimes | EltDivide -> + operator_return_type assop arg_tys |> Option.map ~f:fst + | _ -> None ) + |> Option.bind ~f:(function + | ReturnType rtype + when rtype = snd (List.hd_exn arg_tys) + && not + ( (assop = Operator.EltTimes || assop = Operator.EltDivide) + && UnsizedType.is_scalar_type rtype ) -> + Some UnsizedType.Void + | _ -> None ) + + let check_binop loc op le re = + let rt = [le; re] |> get_arg_types |> operator_return_type op in + match rt with + | Some (ReturnType type_, [p1; p2]) -> + mk_typed_expression + ~expr:(BinOp (Promotion.promote le p1, op, Promotion.promote re p2)) + ~ad_level:(expr_ad_lub [le; re]) + ~type_ ~loc + | _ -> + Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ + (StdLibrary.get_operator_signatures op) + |> error + + let check_prefixop loc op te = + let rt = operator_return_type op [arg_type te] in + match rt with + | Some (ReturnType type_, _) -> + mk_typed_expression + ~expr:(PrefixOp (op, te)) + ~ad_level:(expr_ad_lub [te]) + ~type_ ~loc + | _ -> + Semantic_error.illtyped_prefix_op loc op te.emeta.type_ + (StdLibrary.get_operator_signatures op) + |> error + + let check_postfixop loc op te = + let rt = operator_return_type op [arg_type te] in + match rt with + | Some (ReturnType type_, _) -> + mk_typed_expression + ~expr:(PostfixOp (te, op)) + ~ad_level:(expr_ad_lub [te]) + ~type_ ~loc + | _ -> + Semantic_error.illtyped_postfix_op loc op te.emeta.type_ + (StdLibrary.get_operator_signatures op) + |> error + + let check_id cf loc tenv id = + match Env.find tenv (Utils.stdlib_distribution_name id.name) with + | [] -> + Semantic_error.ident_not_in_scope loc id.name + (Env.nearest_ident tenv id.name) + |> error + | {kind= `StanMath; _} :: _ -> + ( calculate_autodifftype cf.current_block MathLibrary + UMathLibraryFunction + , UnsizedType.UMathLibraryFunction ) + | {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _ + when cf.in_toplevel_decl -> + Semantic_error.non_data_variable_size_decl loc |> error + | _ :: _ + when Utils.is_unnormalized_distribution id.name + && not + ( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def)) + || cf.current_block = Model ) -> + Semantic_error.invalid_unnormalized_fn loc |> error + | {kind= `Variable {origin; _}; type_} :: _ -> + (calculate_autodifftype cf.current_block origin type_, type_) + | { kind= `UserDefined | `UserDeclared _ + ; type_= UFun (args, rt, FnLpdf _, mem_pattern) } + :: _ -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + (calculate_autodifftype cf.current_block Functions type_, type_) + | {kind= `UserDefined | `UserDeclared _; type_} :: _ -> + (calculate_autodifftype cf.current_block Functions type_, type_) + + let check_variable cf loc tenv id = + let ad_level, type_ = check_id cf loc tenv id in + mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc + + let get_consistent_types type_ es = + let ad = + UnsizedType.lub_ad_type (List.map ~f:(fun e -> e.emeta.ad_level) es) in + let f state e = + match state with + | Error e -> Error e + | Ok ty -> ( + match UnsizedType.common_type (ty, e.emeta.type_) with + | Some ty -> Ok ty + | None -> Error (ty, e.emeta) ) in + List.fold ~init:(Ok type_) ~f es + |> Result.map ~f:(fun ty -> + let promotions = + List.map (get_arg_types es) + ~f:(Promotion.get_type_promotion_exn (ad, ty)) in + (ad, ty, promotions) ) + + let check_array_expr loc es = + match es with + | [] -> + (* NB: This is actually disallowed by parser *) + Semantic_error.empty_array loc |> error + | {emeta= {type_; _}; _} :: _ -> ( + match get_consistent_types type_ es with + | Error (ty, meta) -> + Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error + | Ok (ad_level, type_, promotions) -> + let type_ = UnsizedType.UArray type_ in + mk_typed_expression + ~expr:(ArrayExpr (Promotion.promote_list es promotions)) + ~ad_level ~type_ ~loc ) + + let check_rowvector loc es = + match es with + | {emeta= {type_= UnsizedType.URowVector; _}; _} :: _ -> ( + match get_consistent_types URowVector es with + | Ok (ad_level, typ, promotions) -> + mk_typed_expression + ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) + ~ad_level + ~type_:(if typ = UComplex then UComplexMatrix else UMatrix) + ~loc + | Error (_, meta) -> + Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) + | {emeta= {type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> ( + match get_consistent_types UComplexRowVector es with + | Ok (ad_level, _, promotions) -> + mk_typed_expression + ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) + ~ad_level ~type_:UComplexMatrix ~loc + | Error (_, meta) -> + Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) + | _ -> ( + match get_consistent_types UReal es with + | Ok (ad_level, typ, promotions) -> + mk_typed_expression + ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) + ~ad_level + ~type_:(if typ = UComplex then UComplexRowVector else URowVector) + ~loc + | Error (_, meta) -> + Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error ) + + (* index checking *) + + let indexing_type idx = + match idx with + | Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single + | _ -> `Multi + + let is_multiindex i = + match indexing_type i with `Single -> false | `Multi -> true + + let inferred_unsizedtype_of_indexed ~loc ut indices = + let rec aux type_ idcs = + let vec, rowvec, scalar = + if UnsizedType.is_complex_type type_ then + UnsizedType.(UComplexVector, UComplexRowVector, UComplex) + else (UVector, URowVector, UReal) in + match (type_, idcs) with + | _, [] -> type_ + | UnsizedType.UArray type_, `Single :: tl -> aux type_ tl + | UArray type_, `Multi :: tl -> aux type_ tl |> UnsizedType.UArray + | (UVector | URowVector | UComplexRowVector | UComplexVector), [`Single] + |(UMatrix | UComplexMatrix), [`Single; `Single] -> + scalar + | ( ( UVector | URowVector | UMatrix | UComplexVector | UComplexMatrix + | UComplexRowVector ) + , [`Multi] ) + |(UMatrix | UComplexMatrix), [`Multi; `Multi] -> + type_ + | (UMatrix | UComplexMatrix), ([`Single] | [`Single; `Multi]) -> rowvec + | (UMatrix | UComplexMatrix), [`Multi; `Single] -> vec + | (UMatrix | UComplexMatrix), _ :: _ :: _ :: _ + |(UVector | URowVector | UComplexRowVector | UComplexVector), _ :: _ :: _ + |(UInt | UReal | UComplex | UFun _ | UMathLibraryFunction), _ :: _ -> + Semantic_error.not_indexable loc ut (List.length indices) |> error + in + aux ut (List.map ~f:indexing_type indices) + + let inferred_ad_type_of_indexed at uindices = + UnsizedType.lub_ad_type + ( at + :: List.map + ~f:(function + | All -> UnsizedType.DataOnly + | Single ue1 | Upfrom ue1 | Downfrom ue1 -> + UnsizedType.lub_ad_type [at; ue1.emeta.ad_level] + | Between (ue1, ue2) -> + UnsizedType.lub_ad_type + [at; ue1.emeta.ad_level; ue2.emeta.ad_level] ) + uindices ) + + (* function checking *) + let verify_conddist_name loc id = + if + List.exists + ~f:(fun x -> String.is_suffix id.name ~suffix:x) + Utils.conditioning_suffices + then () + else Semantic_error.conditional_notation_not_allowed loc |> error + + let verify_fn_conditioning loc id = + if + List.exists + ~f:(fun suffix -> String.is_suffix id.name ~suffix) + Utils.conditioning_suffices + && not (String.is_suffix id.name ~suffix:"_cdf") + then Semantic_error.conditioning_required loc |> error + + (** `Target+=` can only be used in model and functions + with right suffix (same for tilde etc) +*) + let verify_fn_target_plus_equals cf loc id = + if + String.is_suffix id.name ~suffix:"_lp" + && not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then Semantic_error.target_plusequals_outside_model_or_logprob loc |> error + + (** Rng functions cannot be used in Tp or Model and only + in function defs with the right suffix +*) + let verify_fn_rng cf loc id = + if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then + Semantic_error.invalid_decl_rng_fn loc |> error + else if + String.is_suffix id.name ~suffix:"_rng" + && ( (cf.in_fun_def && not cf.in_rng_fun_def) + || cf.current_block = TParam || cf.current_block = Model ) + then Semantic_error.invalid_rng_fn loc |> error + + (** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs + or the model block +*) + let verify_unnormalized cf loc id = + if + Utils.is_unnormalized_distribution id.name + && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) + then Semantic_error.invalid_unnormalized_fn loc |> error + + let check_normal_fn ~is_cond_dist loc tenv id es = + match Env.find tenv (Utils.normalized_name id.name) with + | {kind= `Variable _; _} :: _ + (* variables can sometimes shadow stanlib functions, so we have to check this *) + when not + (StdLibrary.is_stdlib_function_name + (Utils.normalized_name id.name) ) -> + Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error + | [] -> + ( match Utils.split_distribution_suffix id.name with + | Some (prefix, suffix) -> ( + let known_families = StdLibrary.distribution_families in + let is_known_family s = + List.mem known_families s ~equal:String.equal in + match suffix with + | ("lpmf" | "lumpf") when Env.mem tenv (prefix ^ "_lpdf") -> + Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc + (prefix, suffix) + | ("lpdf" | "lumdf") when Env.mem tenv (prefix ^ "_lpmf") -> + Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc + (prefix, suffix) + | _ -> + if + is_known_family prefix + && List.mem ~equal:String.equal + Utils.cumulative_distribution_suffices_w_rng suffix + then + Semantic_error + .returning_fn_expected_undeclared_dist_suffix_found loc + (prefix, suffix) + else + Semantic_error.returning_fn_expected_undeclaredident_found loc + id.name + (Env.nearest_ident tenv id.name) ) + | None -> + Semantic_error.returning_fn_expected_undeclaredident_found loc + id.name + (Env.nearest_ident tenv id.name) ) + |> error + | _ (* a function *) -> ( + (* NB: At present, [SignatureMismatch.matching_function] cannot handle overloaded function types. + This is not needed until UDFs can be higher-order, as it is special cased for + variadic functions + *) + match + SignatureMismatch.matching_function tenv id.name (get_arg_types es) + with + | UniqueMatch (Void, _, _) -> + Semantic_error.returning_fn_expected_nonreturning_found loc id.name + |> error + | UniqueMatch (ReturnType ut, fnk, promotions) -> + mk_fun_app ~is_cond_dist + (fnk (Fun_kind.suffix_from_name id.name)) + id + (Promotion.promote_list es promotions) + ~type_:ut ~loc + | AmbiguousMatch sigs -> + Semantic_error.ambiguous_function_promotion loc id.name + (Some (List.map ~f:type_of_expr_typed es)) + sigs + |> error + | SignatureErrors (l, b) -> + es + |> List.map ~f:(fun e -> e.emeta.type_) + |> Semantic_error.illtyped_fn_app loc id.name (l, b) + |> error ) + + let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list) + = + if StdLibrary.is_special_function_name id.name then + StdLibrary.check_special_fn ~is_cond_dist loc cf.current_block tenv id tes + else if StdLibrary.is_variadic_function_name id.name then + check_variadic ~is_cond_dist loc cf.current_block tenv id tes + else check_normal_fn ~is_cond_dist loc tenv id tes + + and check_variadic ~is_cond_dist loc cf tenv id tes = + let Std_library_utils. + {control_args; required_fn_args; required_fn_rt; return_type} = + Hashtbl.find_exn StdLibrary.variadic_signatures id.name in + let matching remaining_es Env.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype cf Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args + required_fn_args required_fn_rt arg_types in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = make_function_variable cf loc fname ftype :: remaining_es in + mk_fun_app ~is_cond_dist (StanLib FnPlain) id + (Promotion.promote_list tes promotions) + ~type_:return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err required_fn_rt + |> error ) + | _ -> + let expected_args, err = + SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args + required_fn_args required_fn_rt (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err required_fn_rt + |> error + + and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) + = + let name_check = + if is_cond_dist then verify_conddist_name else verify_fn_conditioning + in + let res = check_fn ~is_cond_dist loc cf tenv id es in + verify_identifier id ; + name_check loc id ; + verify_fn_target_plus_equals cf loc id ; + verify_fn_rng cf loc id ; + verify_unnormalized cf loc id ; + res + + and check_indexed loc cf tenv e indices = + let tindices = List.map ~f:(check_index cf tenv) indices in + let te = check_expression cf tenv e in + let ad_level = inferred_ad_type_of_indexed te.emeta.ad_level tindices in + let type_ = inferred_unsizedtype_of_indexed ~loc te.emeta.type_ tindices in + mk_typed_expression ~expr:(Indexed (te, tindices)) ~ad_level ~type_ ~loc + + and check_index cf tenv = function + | All -> All + (* Check that indexes have int (container) type *) + | Single e -> + let te = check_expression cf tenv e in + if has_int_type te || has_int_array_type te then Single te + else + Semantic_error.int_intarray_or_range_expected te.emeta.loc + te.emeta.type_ + |> error + | Upfrom e -> check_expression_of_int_type cf tenv e "Range bound" |> Upfrom + | Downfrom e -> + check_expression_of_int_type cf tenv e "Range bound" |> Downfrom + | Between (e1, e2) -> + let le = check_expression_of_int_type cf tenv e1 "Range bound" in + let ue = check_expression_of_int_type cf tenv e2 "Range bound" in + Between (le, ue) + + and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : + Ast.typed_expression = + let loc = emeta.loc in + let ce = check_expression cf tenv in + match expr with + | TernaryIf (e1, e2, e3) -> + let pe = ce e1 in + let te = ce e2 in + let fe = ce e3 in + check_ternary_if loc pe te fe + | BinOp (e1, op, e2) -> + let le = ce e1 in + let re = ce e2 in + let binop_type_warnings x y = + match (x.emeta.type_, y.emeta.type_, op) with + | UInt, UInt, Divide -> + let hint ppf () = + match (x.expr, y.expr) with + | IntNumeral x, _ -> + Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_typed_expression + y + | _, Ast.IntNumeral y -> + Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_typed_expression x + y + | _ -> + Fmt.pf ppf "%a * 1.0 / %a" + Pretty_printing.pp_typed_expression x + Pretty_printing.pp_typed_expression y in + let s = + Fmt.str + "@[@[Found int division:@]@ @[%a@]@,\ + @[%a@]@ @[%a@]@,\ + @[%a@]@]" + Pretty_printing.pp_expression {expr; emeta} Fmt.text + "Values will be rounded towards zero. If rounding is not \ + desired you can write the division as" + hint () Fmt.text + "If rounding is intended please use the integer division \ + operator %/%." in + add_warning x.emeta.loc s + | (UArray UMatrix | UMatrix), (UInt | UReal), Pow -> + let s = + Fmt.str + "@[@[Found matrix^scalar:@]@ @[%a@]@,\ + @[%a@]@ @[%a@]@]" Pretty_printing.pp_expression + {expr; emeta} Fmt.text + "matrix ^ number is interpreted as element-wise \ + exponentiation. If this is intended, you can silence this \ + warning by using elementwise operator .^" + Fmt.text + "If you intended matrix exponentiation, use the function \ + matrix_power(matrix,int) instead." in + add_warning x.emeta.loc s + | _ when Operator.is_cmp op -> ( + match le.expr with + | BinOp (e1, op2, e2) when Operator.is_cmp op2 -> + let pp_e = Pretty_printing.pp_typed_expression in + let pp = Operator.pp in + add_warning loc + (Fmt.str + "Found %a. This is interpreted as %a. Consider if the \ + intended meaning was %a instead.@ You can silence this \ + warning by adding explicit parenthesis. This can be \ + automatically changed using the canonicalize flag for \ + stanc" + (fun ppf () -> + Fmt.pf ppf "@[%a %a %a@]" pp_e le pp op2 pp_e re ) + () + (fun ppf () -> + Fmt.pf ppf "@[(%a) %a %a@]" pp_e le pp op2 pp_e re + ) + () + (fun ppf () -> + Fmt.pf ppf "@[%a %a %a && %a %a %a@]" pp_e e1 pp op + pp_e e2 pp_e e2 pp op2 pp_e re ) + () ) + | _ -> () ) + | _ -> () in + binop_type_warnings le re ; check_binop loc op le re + | PrefixOp (op, e) -> ce e |> check_prefixop loc op + | PostfixOp (e, op) -> ce e |> check_postfixop loc op + | Variable id -> + verify_identifier id ; + check_variable cf loc tenv id + | IntNumeral s -> ( + match float_of_string_opt s with + | Some i when i < 2_147_483_648.0 -> + mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly + ~type_:UInt ~loc + | _ -> Semantic_error.bad_int_literal loc |> error ) + | RealNumeral s -> + mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly + ~type_:UReal ~loc + | ImagNumeral s -> + mk_typed_expression ~expr:(ImagNumeral s) ~ad_level:DataOnly + ~type_:UComplex ~loc + | GetLP -> + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + if + not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then + Semantic_error.target_plusequals_outside_model_or_logprob loc |> error + else + mk_typed_expression ~expr:GetLP + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) + ~type_:UReal ~loc + | GetTarget -> + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + if + not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then + Semantic_error.target_plusequals_outside_model_or_logprob loc |> error + else + mk_typed_expression ~expr:GetTarget + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) + ~type_:UReal ~loc + | ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc + | RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc + | Paren e -> + let te = ce e in + mk_typed_expression ~expr:(Paren te) ~ad_level:te.emeta.ad_level + ~type_:te.emeta.type_ ~loc + | Indexed (e, indices) -> check_indexed loc cf tenv e indices + | FunApp ((), id, es) -> + es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id + | CondDistApp ((), id, es) -> + es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id + | Promotion (e, _, _) -> + (* Should never happen: promotions are produced during typechecking *) + Common.FatalError.fatal_error_msg + [%message "Promotion in untyped AST" (e : Ast.untyped_expression)] + + and check_expression_of_int_type cf tenv e name = + let te = check_expression cf tenv e in + if has_int_type te then te + else Semantic_error.int_expected te.emeta.loc name te.emeta.type_ |> error + + let check_expression_of_int_or_real_type cf tenv e name = + let te = check_expression cf tenv e in + if has_int_or_real_type te then te + else + Semantic_error.int_or_real_expected te.emeta.loc name te.emeta.type_ + |> error + + let check_expression_of_scalar_or_type cf tenv t e name = + let te = check_expression cf tenv e in + if UnsizedType.is_scalar_type te.emeta.type_ || te.emeta.type_ = t then te + else + Semantic_error.scalar_or_type_expected te.emeta.loc name t te.emeta.type_ + |> error + + (* -- Statements ------------------------------------------------- *) + (* non returning functions *) + let verify_nrfn_target loc cf id = + if + String.is_suffix id.name ~suffix:"_lp" + && not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then Semantic_error.target_plusequals_outside_model_or_logprob loc |> error + + let check_nrfn loc tenv id es = + match Env.find tenv id.name with + | {kind= `Variable _; _} :: _ + (* variables can shadow stanlib functions, so we have to check this *) + when not (StdLibrary.is_stdlib_function_name id.name) -> + Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error + | [] -> + Semantic_error.nonreturning_fn_expected_undeclaredident_found loc + id.name + (Env.nearest_ident tenv id.name) + |> error + | _ (* a function *) -> ( + match + SignatureMismatch.matching_function tenv id.name (get_arg_types es) + with + | UniqueMatch (Void, fnk, promotions) -> + mk_typed_statement + ~stmt: + (NRFunApp + ( fnk (Fun_kind.suffix_from_name id.name) + , id + , Promotion.promote_list es promotions ) ) + ~return_type:NoReturnType ~loc + | UniqueMatch (ReturnType _, _, _) -> + Semantic_error.nonreturning_fn_expected_returning_found loc id.name + |> error + | AmbiguousMatch sigs -> + Semantic_error.ambiguous_function_promotion loc id.name + (Some (List.map ~f:type_of_expr_typed es)) + sigs + |> error + | SignatureErrors (l, b) -> + es + |> List.map ~f:type_of_expr_typed + |> Semantic_error.illtyped_fn_app loc id.name (l, b) + |> error ) + + let check_nr_fn_app loc cf tenv id es = + let tes = List.map ~f:(check_expression cf tenv) es in + verify_identifier id ; + verify_nrfn_target loc cf id ; + check_nrfn loc tenv id tes + + (* assignments *) + let verify_assignment_read_only loc is_readonly id = + if is_readonly then + Semantic_error.cannot_assign_to_read_only loc id.name |> error + + (* Variables from previous blocks are read-only. + In particular, data and parameters never assigned to + *) + let verify_assignment_global loc cf block is_global id = + if (not is_global) || block = cf.current_block then () + else Semantic_error.cannot_assign_to_global loc id.name |> error + + (* Until function types are added to the user language, we + disallow assignments to function values + *) + let verify_assignment_non_function loc ut id = + match ut with + | UnsizedType.UFun _ | UMathLibraryFunction -> + Semantic_error.cannot_assign_function loc ut id.name |> error + | _ -> () + + let check_assignment_operator loc assop lhs rhs = + let err op sigs = + Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_ + sigs in + match assop with + | Assign | ArrowAssign -> ( + match + SignatureMismatch.check_of_same_type_mod_conv lhs.lmeta.type_ + rhs.emeta.type_ + with + | Ok p -> Promotion.promote rhs p + | Error _ -> err Operator.Equals [] |> error ) + | OperatorAssign op -> ( + let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in + let return_type = assignmentoperator_return_type op args in + match return_type with + | Some Void -> rhs + | _ -> + err op (StdLibrary.get_assignment_operator_signatures op) |> error ) + + let check_lvalue cf tenv = function + | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> + verify_identifier id ; + let ad_level, type_ = check_id cf loc tenv id in + {lval= LVariable id; lmeta= {ad_level; type_; loc}} + | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> + let rec check_inner = function + | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> + verify_identifier id ; + let ad_level, type_ = check_id cf loc tenv id in + let var = {lval= LVariable id; lmeta= {ad_level; type_; loc}} in + (var, var, []) + | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> + let lval, var, flat = check_inner lval in + let idcs = List.map ~f:(check_index cf tenv) idcs in + let ad_level = + inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in + let type_ = + inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in + ( {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} + , var + , flat @ idcs ) in + let lval, var, flat = check_inner lval in + let idcs = List.map ~f:(check_index cf tenv) idcs in + let ad_level = inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in + let type_ = inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in + if List.exists ~f:is_multiindex flat then ( + add_warning loc + "Nested multi-indexing on the left hand side of assignment does \ + not behave the same as nested indexing in expressions. This is \ + considered a bug and will be disallowed in Stan 2.33.0. The \ + indexing can be automatically fixed using the canonicalize flag \ + for stanc." ; + let lvalue_rvalue_types_differ = + try + let flat_type = + inferred_unsizedtype_of_indexed ~loc var.lmeta.type_ + (flat @ idcs) in + let rec can_assign = function + | UnsizedType.(UArray t1, UArray t2) -> can_assign (t1, t2) + | UVector, URowVector | URowVector, UVector -> false + | t1, t2 -> UnsizedType.compare t1 t2 <> 0 in + can_assign (flat_type, type_) + with Errors.SemanticError _ -> true in + if lvalue_rvalue_types_differ then + Semantic_error.cannot_assign_to_multiindex loc |> error ) ; + {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} + + let check_assignment loc cf tenv assign_lhs assign_op assign_rhs = + let assign_id = Ast.id_of_lvalue assign_lhs in + let lhs = check_lvalue cf tenv assign_lhs in + let rhs = check_expression cf tenv assign_rhs in + let block, global, readonly = + let var = Env.find tenv assign_id.name in + match var with + | {kind= `Variable {origin; global; readonly}; _} :: _ -> + (origin, global, readonly) + | {kind= `StanMath; _} :: _ -> (MathLibrary, true, false) + | {kind= `UserDefined | `UserDeclared _; _} :: _ -> + (Functions, true, false) + | _ -> + Semantic_error.ident_not_in_scope loc assign_id.name + (Env.nearest_ident tenv assign_id.name) + |> error in + verify_assignment_global loc cf block global assign_id ; + verify_assignment_read_only loc readonly assign_id ; + verify_assignment_non_function loc rhs.emeta.type_ assign_id ; + let rhs' = check_assignment_operator loc assign_op lhs rhs in + mk_typed_statement ~return_type:NoReturnType ~loc + ~stmt:(Assignment {assign_lhs= lhs; assign_op; assign_rhs= rhs'}) + + (* target plus-equals / increment log-prob *) + + let verify_target_pe_expr_type loc e = + if UnsizedType.is_fun_type e.emeta.type_ then + Semantic_error.int_or_real_container_expected loc e.emeta.type_ |> error + + let verify_target_pe_usage loc cf = + if cf.in_lp_fun_def || cf.current_block = Model then () + else Semantic_error.target_plusequals_outside_model_or_logprob loc |> error + + let check_target_pe loc cf tenv e = + let te = check_expression cf tenv e in + verify_target_pe_usage loc cf ; + verify_target_pe_expr_type loc te ; + mk_typed_statement ~stmt:(TargetPE te) ~return_type:NoReturnType ~loc + + let check_incr_logprob loc cf tenv e = + let te = check_expression cf tenv e in + verify_target_pe_usage loc cf ; + verify_target_pe_expr_type loc te ; + mk_typed_statement ~stmt:(IncrementLogProb te) ~return_type:NoReturnType + ~loc + + (* tilde/sampling notation*) + let verify_sampling_pdf_pmf id = + if + String.( + is_suffix id.name ~suffix:"_lpdf" + || is_suffix id.name ~suffix:"_lpmf" + || is_suffix id.name ~suffix:"_lupdf" + || is_suffix id.name ~suffix:"_lupmf") + then Semantic_error.invalid_sampling_pdf_or_pmf id.id_loc |> error + + let verify_sampling_cdf_ccdf loc id = + if + String.( + is_suffix id.name ~suffix:"_cdf" || is_suffix id.name ~suffix:"_ccdf") + then Semantic_error.invalid_sampling_cdf_or_ccdf loc id.name |> error + + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + let verify_valid_sampling_pos loc cf = + if cf.in_lp_fun_def || cf.current_block = Model then () + else Semantic_error.target_plusequals_outside_model_or_logprob loc |> error + + let verify_sampling_distribution loc tenv id arguments = + let name = id.name in + let argumenttypes = List.map ~f:arg_type arguments in + let name_w_suffix_sampling_dist suffix = + SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes + in + let sampling_dists = + List.map ~f:name_w_suffix_sampling_dist Utils.distribution_suffices in + let is_sampling_dist_defined = + List.exists + ~f:(function UniqueMatch (ReturnType UReal, _, _) -> true | _ -> false) + sampling_dists + && name <> "binomial_coefficient" + && name <> "multiply" in + if is_sampling_dist_defined then () + else + match + List.max_elt sampling_dists + ~compare:SignatureMismatch.compare_match_results + with + | None | Some (UniqueMatch _) | Some (SignatureErrors ([], _)) -> + (* Either non-existant or a very odd case, + output the old non-informative error *) + Semantic_error.invalid_sampling_no_such_dist loc name |> error + | Some (AmbiguousMatch sigs) -> + Semantic_error.ambiguous_function_promotion loc id.name + (Some (List.map ~f:type_of_expr_typed arguments)) + sigs + |> error + | Some (SignatureErrors (l, b)) -> + arguments + |> List.map ~f:(fun e -> e.emeta.type_) + |> Semantic_error.illtyped_fn_app loc id.name (l, b) + |> error + + let is_cumulative_density_defined tenv id arguments = + let name = id.name in + let argumenttypes = List.map ~f:arg_type arguments in + let valid_arg_types_for_suffix suffix = + match + SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes + with + | UniqueMatch (ReturnType UReal, _, _) -> true + | _ -> false in + (valid_arg_types_for_suffix "_lcdf" || valid_arg_types_for_suffix "_cdf_log") + && ( valid_arg_types_for_suffix "_lccdf" + || valid_arg_types_for_suffix "_ccdf_log" ) + + let verify_sampling_cdf_defined loc tenv id truncation args = + let check e = + if not (is_cumulative_density_defined tenv id (e :: args)) then + Semantic_error.invalid_truncation_cdf_or_ccdf loc + (get_arg_types (e :: args)) + |> error in + match truncation with + | NoTruncate -> () + | TruncateUpFrom e | TruncateDownFrom e -> check e + | TruncateBetween (e1, e2) -> check e1 ; check e2 + + let check_truncation cf tenv truncation = + let check e = + check_expression_of_int_or_real_type cf tenv e "Truncation bound" in + match truncation with + | NoTruncate -> NoTruncate + | TruncateUpFrom e -> check e |> TruncateUpFrom + | TruncateDownFrom e -> check e |> TruncateDownFrom + | TruncateBetween (e1, e2) -> (check e1, check e2) |> TruncateBetween + + let check_tilde loc cf tenv distribution truncation arg args = + let te = check_expression cf tenv arg in + let tes = List.map ~f:(check_expression cf tenv) args in + let ttrunc = check_truncation cf tenv truncation in + verify_identifier distribution ; + verify_sampling_pdf_pmf distribution ; + verify_valid_sampling_pos loc cf ; + verify_sampling_cdf_ccdf loc distribution ; + verify_sampling_distribution loc tenv distribution (te :: tes) ; + verify_sampling_cdf_defined loc tenv distribution ttrunc tes ; + let stmt = Tilde {arg= te; distribution; args= tes; truncation= ttrunc} in + mk_typed_statement ~stmt ~loc ~return_type:NoReturnType + + (* Break and continue only occur in loops. *) + let check_break loc cf = + if cf.loop_depth = 0 then Semantic_error.break_outside_loop loc |> error + else mk_typed_statement ~stmt:Break ~return_type:NoReturnType ~loc + + let check_continue loc cf = + if cf.loop_depth = 0 then Semantic_error.continue_outside_loop loc |> error + else mk_typed_statement ~stmt:Continue ~return_type:NoReturnType ~loc + + let check_return loc cf tenv e = + if not cf.in_returning_fun_def then + Semantic_error.expression_return_outside_returning_fn loc |> error + else + let te = check_expression cf tenv e in + mk_typed_statement ~stmt:(Return te) + ~return_type:(Complete (ReturnType te.emeta.type_)) ~loc + + let check_returnvoid loc cf = + if (not cf.in_fun_def) || cf.in_returning_fun_def then + Semantic_error.void_outside_nonreturning_fn loc |> error + else mk_typed_statement ~stmt:ReturnVoid ~return_type:(Complete Void) ~loc + + let check_printable cf tenv = function + | PString s -> PString s + (* Print/reject expressions cannot be of function type. *) + | PExpr e -> ( + let te = check_expression cf tenv e in + match te.emeta.type_ with + | UFun _ | UMathLibraryFunction -> + Semantic_error.not_printable te.emeta.loc |> error + | _ -> PExpr te ) + + let check_print loc cf tenv ps = + let tps = List.map ~f:(check_printable cf tenv) ps in + mk_typed_statement ~stmt:(Print tps) ~return_type:NoReturnType ~loc + + let check_reject loc cf tenv ps = + let tps = List.map ~f:(check_printable cf tenv) ps in + mk_typed_statement ~stmt:(Reject tps) ~return_type:AnyReturnType ~loc + + let check_skip loc = + mk_typed_statement ~stmt:Skip ~return_type:NoReturnType ~loc + + let rec stmt_is_escape {stmt; _} = + match stmt with + | Break | Continue | Reject _ | Return _ | ReturnVoid -> true + | _ -> false + + and list_until_escape xs = + let rec aux accu = function + | [next; next'] when stmt_is_escape next' -> + List.rev (next' :: next :: accu) + | next :: next' :: unreachable :: _ when stmt_is_escape next' -> + add_warning unreachable.smeta.loc + "Unreachable statement (following a reject, break, continue, or \ + return) found, is this intended?" ; + List.rev (next' :: next :: accu) + | next :: rest -> aux (next :: accu) rest + | [] -> List.rev accu in + aux [] xs + + let returntype_leastupperbound loc rt1 rt2 = + match (rt1, rt2) with + | UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt + |ReturnType UInt, ReturnType UReal -> + UnsizedType.ReturnType UReal + | _, _ when rt1 = rt2 -> rt2 + | _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> error + + let try_compute_block_statement_returntype loc srt1 srt2 = + match (srt1, srt2) with + | Complete rt1, Complete rt2 | Incomplete rt1, Complete rt2 -> + Complete (returntype_leastupperbound loc rt1 rt2) + | Incomplete rt1, Incomplete rt2 | Complete rt1, Incomplete rt2 -> + Incomplete (returntype_leastupperbound loc rt1 rt2) + | NoReturnType, NoReturnType -> NoReturnType + | AnyReturnType, Incomplete rt + |Complete rt, NoReturnType + |NoReturnType, Incomplete rt + |Incomplete rt, NoReturnType -> + Incomplete rt + | NoReturnType, Complete rt + |Complete rt, AnyReturnType + |Incomplete rt, AnyReturnType + |AnyReturnType, Complete rt -> + Complete rt + | AnyReturnType, NoReturnType + |NoReturnType, AnyReturnType + |AnyReturnType, AnyReturnType -> + AnyReturnType + + let try_compute_ifthenelse_statement_returntype loc srt1 srt2 = + match (srt1, srt2) with + | Complete rt1, Complete rt2 -> + returntype_leastupperbound loc rt1 rt2 |> Complete + | Incomplete rt1, Incomplete rt2 + |Complete rt1, Incomplete rt2 + |Incomplete rt1, Complete rt2 -> + returntype_leastupperbound loc rt1 rt2 |> Incomplete + | AnyReturnType, NoReturnType + |NoReturnType, AnyReturnType + |NoReturnType, NoReturnType -> + NoReturnType + | AnyReturnType, Incomplete rt + |Incomplete rt, AnyReturnType + |Complete rt, NoReturnType + |NoReturnType, Complete rt + |NoReturnType, Incomplete rt + |Incomplete rt, NoReturnType -> + Incomplete rt + | Complete rt, AnyReturnType | AnyReturnType, Complete rt -> Complete rt + | AnyReturnType, AnyReturnType -> AnyReturnType + + (* statements which contain statements, and therefore need to be mutually recursive + with check_statement + *) + let rec check_if_then_else loc cf tenv pred_e s_true s_false_opt = + (* we don't need these nested type environments *) + let _, ts_true = check_statement cf tenv s_true in + let ts_false_opt = + s_false_opt + |> Option.map ~f:(check_statement cf tenv) + |> Option.map ~f:snd in + let te = + check_expression_of_int_or_real_type cf tenv pred_e + "Condition in conditional" in + let stmt = IfThenElse (te, ts_true, ts_false_opt) in + let srt1 = ts_true.smeta.return_type in + let srt2 = + ts_false_opt + |> Option.map ~f:(fun s -> s.smeta.return_type) + |> Option.value ~default:NoReturnType in + let return_type = + try_compute_ifthenelse_statement_returntype loc srt1 srt2 in + mk_typed_statement ~stmt ~return_type ~loc + + and check_while loc cf tenv cond_e loop_body = + let _, ts = + check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body + and te = + check_expression_of_int_or_real_type cf tenv cond_e + "Condition in while-loop" in + mk_typed_statement + ~stmt:(While (te, ts)) + ~return_type:ts.smeta.return_type ~loc + + and check_for loc cf tenv loop_var lower_bound_e upper_bound_e loop_body = + let te1 = + check_expression_of_int_type cf tenv lower_bound_e + "Lower bound of for-loop" + and te2 = + check_expression_of_int_type cf tenv upper_bound_e + "Upper bound of for-loop" in + verify_identifier loop_var ; + let ts = check_loop_body cf tenv loop_var UnsizedType.UInt loop_body in + mk_typed_statement + ~stmt: + (For + { loop_variable= loop_var + ; lower_bound= te1 + ; upper_bound= te2 + ; loop_body= ts } ) + ~return_type:ts.smeta.return_type ~loc + + and check_foreach_loop_identifier_type loc ty = + match ty with + | UnsizedType.UArray ut -> ut + | UVector | URowVector | UMatrix -> UnsizedType.UReal + | _ -> Semantic_error.array_vector_rowvector_matrix_expected loc ty |> error + + and check_foreach loc cf tenv loop_var foreach_e loop_body = + let te = check_expression cf tenv foreach_e in + verify_identifier loop_var ; + let loop_var_ty = + check_foreach_loop_identifier_type te.emeta.loc te.emeta.type_ in + let ts = check_loop_body cf tenv loop_var loop_var_ty loop_body in + mk_typed_statement + ~stmt:(ForEach (loop_var, te, ts)) + ~return_type:ts.smeta.return_type ~loc + + and check_loop_body cf tenv loop_var loop_var_ty loop_body = + verify_name_fresh tenv loop_var ~is_udf:false ; + (* Add to type environment as readonly. + Check that function args and loop identifiers are not modified in + function. (passed by const ref) + *) + let tenv = + Env.add tenv loop_var.name loop_var_ty + (`Variable {origin= cf.current_block; global= false; readonly= true}) + in + snd (check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body) + + and check_block loc cf tenv stmts = + let _, checked_stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) in + let return_type = + checked_stmts |> list_until_escape + |> List.map ~f:(fun s -> s.smeta.return_type) + |> List.fold ~init:NoReturnType + ~f:(try_compute_block_statement_returntype loc) in + mk_typed_statement ~stmt:(Block checked_stmts) ~return_type ~loc + + and check_profile loc cf tenv name stmts = + let _, checked_stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) in + let return_type = + checked_stmts |> list_until_escape + |> List.map ~f:(fun s -> s.smeta.return_type) + |> List.fold ~init:NoReturnType + ~f:(try_compute_block_statement_returntype loc) in + mk_typed_statement ~stmt:(Profile (name, checked_stmts)) ~return_type ~loc + + (* variable declarations *) + and verify_valid_transformation_for_type loc is_global sized_ty trans = + let is_real {emeta; _} = emeta.type_ = UReal in + let is_real_transformation = + match trans with + | Transformation.Lower e -> is_real e + | Upper e -> is_real e + | LowerUpper (e1, e2) -> is_real e1 || is_real e2 + | _ -> false in + if is_global && sized_ty = SizedType.SInt && is_real_transformation then + Semantic_error.non_int_bounds loc |> error ; + let is_transformation = + match trans with Transformation.Identity -> false | _ -> true in + if is_global && SizedType.(contains_complex sized_ty) && is_transformation + then Semantic_error.complex_transform loc |> error + + and verify_transformed_param_ty loc cf is_global unsized_ty = + if + is_global + && (cf.current_block = Param || cf.current_block = TParam) + && UnsizedType.is_int_type unsized_ty + then Semantic_error.transformed_params_int loc |> error + + and check_sizedtype cf tenv sizedty = + let check e msg = check_expression_of_int_type cf tenv e msg in + match sizedty with + | SizedType.SInt -> SizedType.SInt + | SReal -> SReal + | SComplex -> SComplex + | SVector (mem_pattern, e) -> + let te = check e "Vector sizes" in + SVector (mem_pattern, te) + | SRowVector (mem_pattern, e) -> + let te = check e "Row vector sizes" in + SRowVector (mem_pattern, te) + | SMatrix (mem_pattern, e1, e2) -> + let te1 = check e1 "Matrix row size" in + let te2 = check e2 "Matrix column size" in + SMatrix (mem_pattern, te1, te2) + | SComplexVector e -> + let te = check e "complex vector sizes" in + SComplexVector te + | SComplexRowVector e -> + let te = check e "complex row vector sizes" in + SComplexRowVector te + | SComplexMatrix (e1, e2) -> + let te1 = check e1 "Complex matrix row size" in + let te2 = check e2 "Complex matrix column size" in + SComplexMatrix (te1, te2) + | SArray (st, e) -> + let tst = check_sizedtype cf tenv st in + let te = check e "Array sizes" in + SArray (tst, te) + + and check_var_decl_initial_value loc cf tenv {identifier; initial_value} = + match initial_value with + | Some e -> ( + let lhs = + check_lvalue cf tenv {lval= LVariable identifier; lmeta= {loc}} in + let rhs = check_expression cf tenv e in + match + SignatureMismatch.check_of_same_type_mod_conv lhs.lmeta.type_ + rhs.emeta.type_ + with + | Ok p -> Ast.{identifier; initial_value= Some (Promotion.promote rhs p)} + | Error _ -> + Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_ + rhs.emeta.type_ [] + |> error ) + | None -> Ast.{identifier; initial_value= None} + + and check_transformation cf tenv ut trans = + let check e msg = check_expression_of_scalar_or_type cf tenv ut e msg in + match trans with + | Transformation.Identity -> Transformation.Identity + | Lower e -> check e "Lower bound" |> Lower + | Upper e -> check e "Upper bound" |> Upper + | LowerUpper (e1, e2) -> + (check e1 "Lower bound", check e2 "Upper bound") |> LowerUpper + | Offset e -> check e "Offset" |> Offset + | Multiplier e -> check e "Multiplier" |> Multiplier + | OffsetMultiplier (e1, e2) -> + (check e1 "Offset", check e2 "Multiplier") |> OffsetMultiplier + | Ordered -> Ordered + | PositiveOrdered -> PositiveOrdered + | Simplex -> Simplex + | UnitVector -> UnitVector + | CholeskyCorr -> CholeskyCorr + | CholeskyCov -> CholeskyCov + | Correlation -> Correlation + | Covariance -> Covariance + + and check_var_decl loc cf tenv sized_ty trans + (variables : untyped_expression Ast.variable list) is_global = + let checked_type = + check_sizedtype {cf with in_toplevel_decl= is_global} tenv sized_ty in + let unsized_type = SizedType.to_unsized checked_type in + let checked_trans = check_transformation cf tenv unsized_type trans in + let tenv, tvariables = + List.fold_map ~init:tenv + ~f:(fun tenv' ({identifier; _} as var) -> + verify_identifier identifier ; + verify_name_fresh tenv' identifier ~is_udf:false ; + let tenv'' = + Env.add tenv' identifier.name unsized_type + (`Variable + {origin= cf.current_block; global= is_global; readonly= false} + ) in + (tenv'', check_var_decl_initial_value loc cf tenv'' var) ) + variables in + verify_valid_transformation_for_type loc is_global checked_type + checked_trans ; + verify_transformed_param_ty loc cf is_global unsized_type ; + let stmt = + VarDecl + { decl_type= checked_type + ; transformation= checked_trans + ; variables= tvariables + ; is_global } in + (tenv, mk_typed_statement ~stmt ~loc ~return_type:NoReturnType) + + (* function definitions *) + and exists_matching_fn_declared tenv id arg_tys rt = + let options = + List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) + in + let f = function + | Env.{kind= `UserDeclared _; type_= UFun (listedtypes, rt', _, _)} + when arg_tys = listedtypes && rt = rt' -> + true + | _ -> false in + List.exists ~f options + + and verify_unique_signature tenv loc id arg_tys rt = + let existing = + List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) + in + let same_args = function + | Env.{type_= UFun (listedtypes, _, _, _); _} + when List.map ~f:snd arg_tys = List.map ~f:snd listedtypes -> + true + | _ -> false in + match List.filter existing ~f:same_args with + | [] -> () + | {type_= UFun (_, rt', _, _); _} :: _ when rt <> rt' -> + Semantic_error.fn_overload_rt_only loc id.name rt rt' |> error + | {kind; _} :: _ -> + Semantic_error.fn_decl_redefined loc id.name + ~stan_math:(kind = `StanMath) + (UnsizedType.UFun (arg_tys, rt, Fun_kind.suffix_from_name id.name, AoS) + ) + |> error + + and verify_fundef_overloaded loc tenv id arg_tys rt = + if exists_matching_fn_declared tenv id arg_tys rt then + (* this is the definition to an existing forward declaration *) + () + else + (* this should be an overload with a unique signature *) + verify_unique_signature tenv loc id arg_tys rt ; + verify_name_fresh tenv id ~is_udf:true + + and get_fn_decl_or_defn loc tenv id arg_tys rt body = + match body with + | {stmt= Skip; _} -> + if exists_matching_fn_declared tenv id arg_tys rt then + Semantic_error.fn_decl_exists loc id.name |> error + else `UserDeclared id.id_loc + | _ -> `UserDefined + + and verify_fundef_dist_rt loc id return_ty = + let is_dist = + List.exists + ~f:(fun x -> String.is_suffix id.name ~suffix:x) + Utils.conditioning_suffices_w_log in + if is_dist then + match return_ty with + | UnsizedType.ReturnType UReal -> () + | _ -> Semantic_error.non_real_prob_fn_def loc |> error + + and verify_pdf_fundef_first_arg_ty loc id arg_tys = + if String.is_suffix id.name ~suffix:"_lpdf" then + let rt = List.hd arg_tys |> Option.map ~f:snd in + match rt with + | Some rt when not (UnsizedType.is_int_type rt) -> () + | _ -> Semantic_error.prob_density_non_real_variate loc rt |> error + + and verify_pmf_fundef_first_arg_ty loc id arg_tys = + if String.is_suffix id.name ~suffix:"_lpmf" then + let rt = List.hd arg_tys |> Option.map ~f:snd in + match rt with + | Some rt when UnsizedType.is_int_type rt -> () + | _ -> Semantic_error.prob_mass_non_int_variate loc rt |> error + + and verify_fundef_distinct_arg_ids loc arg_names = + let dup_exists l = + List.find_a_dup ~compare:String.compare l |> Option.is_some in + if dup_exists arg_names then Semantic_error.duplicate_arg_names loc |> error + + and verify_fundef_return_tys loc return_type body = + if + body.stmt = Skip + || is_of_compatible_return_type return_type body.smeta.return_type + then () + else Semantic_error.incompatible_return_types loc |> error + + and add_function tenv name type_ defined = + (* if we're providing a definition, we remove prior declarations + to simplify the environment *) + if defined = `UserDefined then + let existing_defns = Env.find tenv name in + let defns = + List.filter + ~f:(function + | Env.{kind= `UserDeclared _; type_= type'} when type' = type_ -> + false + | _ -> true ) + existing_defns in + let new_fn = Env.{kind= `UserDefined; type_} in + Env.set_raw tenv name (new_fn :: defns) + else Env.add tenv name type_ defined + + and check_fundef loc cf tenv return_ty id args body = + List.iter args ~f:(fun (_, _, id) -> verify_identifier id) ; + verify_identifier id ; + let arg_types = List.map ~f:(fun (w, y, _) -> (w, y)) args in + let arg_identifiers = List.map ~f:(fun (_, _, z) -> z) args in + let arg_names = List.map ~f:(fun x -> x.name) arg_identifiers in + verify_fundef_dist_rt loc id return_ty ; + verify_pdf_fundef_first_arg_ty loc id arg_types ; + verify_pmf_fundef_first_arg_ty loc id arg_types ; + List.iter + ~f:(fun id -> verify_name_fresh tenv id ~is_udf:false) + arg_identifiers ; + verify_fundef_distinct_arg_ids loc arg_names ; + (* We treat DataOnly arguments as if they are data and AutoDiffable arguments + as if they are parameters, for the purposes of type checking. + *) + let arg_types_internal = + List.map + ~f:(function + | UnsizedType.DataOnly, ut -> (Env.Data, ut) + | AutoDiffable, ut -> (Param, ut) ) + arg_types in + let tenv_body = + List.fold2_exn arg_names arg_types_internal ~init:tenv + ~f:(fun env name (origin, typ) -> + Env.add env name typ + (* readonly so that function args and loop identifiers + are not modified in function. (passed by const ref) *) + (`Variable {origin; readonly= true; global= false}) ) in + let context = + let is_udf_dist name = + List.exists + ~f:(fun suffix -> String.is_suffix name ~suffix) + Utils.distribution_suffices in + { cf with + in_fun_def= true + ; in_rng_fun_def= String.is_suffix id.name ~suffix:"_rng" + ; in_lp_fun_def= String.is_suffix id.name ~suffix:"_lp" + ; in_udf_dist_def= is_udf_dist id.name + ; in_returning_fun_def= return_ty <> Void } in + let _, checked_body = check_statement context tenv_body body in + verify_fundef_return_tys loc return_ty checked_body ; + let stmt = + FunDef + {returntype= return_ty; funname= id; arguments= args; body= checked_body} + in + (* NB: **not** tenv_body, so args don't leak out *) + (tenv, mk_typed_statement ~return_type:NoReturnType ~loc ~stmt) + + and check_statement (cf : context_flags_record) (tenv : Env.t) + (s : Ast.untyped_statement) : Env.t * typed_statement = + let loc = s.smeta.loc in + match s.stmt with + | NRFunApp (_, id, es) -> (tenv, check_nr_fn_app loc cf tenv id es) + | Assignment {assign_lhs; assign_op; assign_rhs} -> + (tenv, check_assignment loc cf tenv assign_lhs assign_op assign_rhs) + | TargetPE e -> (tenv, check_target_pe loc cf tenv e) + | IncrementLogProb e -> (tenv, check_incr_logprob loc cf tenv e) + | Tilde {arg; distribution; args; truncation} -> + (tenv, check_tilde loc cf tenv distribution truncation arg args) + | Break -> (tenv, check_break loc cf) + | Continue -> (tenv, check_continue loc cf) + | Return e -> (tenv, check_return loc cf tenv e) + | ReturnVoid -> (tenv, check_returnvoid loc cf) + | Print ps -> (tenv, check_print loc cf tenv ps) + | Reject ps -> (tenv, check_reject loc cf tenv ps) + | Skip -> (tenv, check_skip loc) + (* the following can contain further statements *) + | IfThenElse (e, s1, os2) -> (tenv, check_if_then_else loc cf tenv e s1 os2) + | While (e, s) -> (tenv, check_while loc cf tenv e s) + | For {loop_variable; lower_bound; upper_bound; loop_body} -> + ( tenv + , check_for loc cf tenv loop_variable lower_bound upper_bound loop_body + ) + | ForEach (id, e, s) -> (tenv, check_foreach loc cf tenv id e s) + | Block stmts -> (tenv, check_block loc cf tenv stmts) + | Profile (name, vdsl) -> (tenv, check_profile loc cf tenv name vdsl) + (* these two are special in that they're allowed to change the type environment *) + | VarDecl {decl_type; transformation; variables; is_global} -> + check_var_decl loc cf tenv decl_type transformation variables is_global + | FunDef {returntype; funname; arguments; body} -> + check_fundef loc cf tenv returntype funname arguments body + + let verify_fun_def_body_in_block = function + | {stmt= FunDef {body= {stmt= Block _; _}; _}; _} + |{stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> + () + | {stmt= FunDef {body= {stmt= _; smeta}; _}; _} -> + Semantic_error.fn_decl_needs_block smeta.loc |> error + | _ -> () + + let verify_functions_have_defn tenv function_block_stmts_opt = + let error_on_undefined name funs = + List.iter (List.rev funs) ~f:(fun f -> + match f with + | Env.{kind= `UserDeclared loc; _} -> + Semantic_error.fn_decl_without_def loc name |> error + | _ -> () ) in + if !check_that_all_functions_have_definition then + Env.iteri tenv error_on_undefined ; + match function_block_stmts_opt with + | Some {stmts= []; _} | None -> () + | Some {stmts= ls; _} -> List.iter ~f:verify_fun_def_body_in_block ls + + let add_userdefined_functions tenv stmts_opt = + match stmts_opt with + | None -> tenv + | Some {stmts; _} -> + let f tenv (s : Ast.untyped_statement) = + match s with + | {stmt= FunDef {returntype; funname; arguments; body}; smeta= {loc}} + -> + let arg_types = Ast.type_of_arguments arguments in + verify_fundef_overloaded loc tenv funname arg_types returntype ; + let defined = + get_fn_decl_or_defn loc tenv funname arg_types returntype body + in + add_function tenv funname.name + (UFun + ( arg_types + , returntype + , Fun_kind.suffix_from_name funname.name + , AoS ) ) + defined + | _ -> tenv in + List.fold ~init:tenv ~f stmts + + let check_toplevel_block block tenv stmts_opt = + let cf = context block in + match stmts_opt with + | Some {stmts; xloc} -> + let tenv', stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) in + (tenv', Some {stmts; xloc}) + | None -> (tenv, None) + + let verify_correctness_invariant (ast : untyped_program) + (decorated_ast : typed_program) = + let detyped = untyped_program_of_typed_program decorated_ast in + if compare_untyped_program ast detyped = 0 then () + else + Common.FatalError.fatal_error_msg + [%message + "Type checked AST does not match original AST. " + (detyped : untyped_program) + (ast : untyped_program)] + + let check_program_exn + ( { functionblock= fb + ; datablock= db + ; transformeddatablock= tdb + ; parametersblock= pb + ; transformedparametersblock= tpb + ; modelblock= mb + ; generatedquantitiesblock= gqb + ; comments } as ast ) = + warnings := [] ; + (* create a new type environment which has only stan-math functions *) + let tenv = std_library_tenv in + let tenv = add_userdefined_functions tenv fb in + let tenv, typed_fb = check_toplevel_block Functions tenv fb in + verify_functions_have_defn tenv typed_fb ; + let tenv, typed_db = check_toplevel_block Data tenv db in + let tenv, typed_tdb = check_toplevel_block TData tenv tdb in + let tenv, typed_pb = check_toplevel_block Param tenv pb in + let tenv, typed_tpb = check_toplevel_block TParam tenv tpb in + let _, typed_mb = check_toplevel_block Model tenv mb in + let _, typed_gqb = check_toplevel_block GQuant tenv gqb in + let prog = + { functionblock= typed_fb + ; datablock= typed_db + ; transformeddatablock= typed_tdb + ; parametersblock= typed_pb + ; transformedparametersblock= typed_tpb + ; modelblock= typed_mb + ; generatedquantitiesblock= typed_gqb + ; comments } in + verify_correctness_invariant ast prog ; + attach_warnings prog + + let check_program ast = + try Result.Ok (check_program_exn ast) + with Errors.SemanticError err -> Result.Error err +end diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecking.mli similarity index 54% rename from src/frontend/Typechecker.mli rename to src/frontend/Typechecking.mli index 07662dfe11..87958d2a23 100644 --- a/src/frontend/Typechecker.mli +++ b/src/frontend/Typechecking.mli @@ -11,37 +11,36 @@ A type environment {!val:Environment.t} is used to hold variables and functions, including Stan math functions. This is a functional map, meaning it is handled immutably. -*) - -open Ast -val check_program_exn : untyped_program -> typed_program * Warnings.t list -(** - Type check a full Stan program. - Can raise [Errors.SemanticError] + This module is parameterized over a Standard Library of function signatures, See + [Std_library_utils.Library]. For the main compiler, this is + [Stan_math_backend.Stan_math_library] *) -val check_program : - untyped_program -> (typed_program * Warnings.t list, Semantic_error.t) result -(** - The safe version of [check_program_exn]. This catches - all [Errors.SemanticError] exceptions and converts them - into a [Result.t] -*) - -val operator_stan_math_return_type : - Middle.Operator.t - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> (Middle.UnsizedType.returntype * Promotion.t list) option - -val stan_math_return_type : - string - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> Middle.UnsizedType.returntype option +open Ast +open Typechecking_intf val model_name : string ref (** A reference to hold the model name. Relevant for checking variable - clashes and used in code generation. *) + clashes and used in code generation. *) val check_that_all_functions_have_definition : bool ref (** A switch to determine whether we check that all functions have a definition *) + +val get_arg_types : typed_expression list -> Std_library_utils.fun_arg list +val type_of_expr_typed : typed_expression -> Middle.UnsizedType.t + +val calculate_autodifftype : + Environment.originblock + -> Environment.originblock + -> Middle.UnsizedType.t + -> Middle.UnsizedType.autodifftype + +val make_function_variable : + Environment.originblock + -> Middle.Location_span.t + -> identifier + -> Middle.UnsizedType.t + -> Ast.typed_expression + +module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER diff --git a/src/frontend/Typechecking_intf.ml b/src/frontend/Typechecking_intf.ml new file mode 100644 index 0000000000..ffdd57ba84 --- /dev/null +++ b/src/frontend/Typechecking_intf.ml @@ -0,0 +1,29 @@ +open Ast + +(** Signature for a Stan typechecker *) +module type TYPECHECKER = sig + val check_program_exn : untyped_program -> typed_program * Warnings.t list + (** + Type check a full Stan program. + Can raise [Errors.SemanticError] + *) + + val check_program : + untyped_program + -> (typed_program * Warnings.t list, Semantic_error.t) result + (** + The safe version of [check_program_exn]. This catches + all [Errors.SemanticError] exceptions and converts them + into a [Result.t] + *) + + val operator_return_type : + Middle.Operator.t + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> (Middle.UnsizedType.returntype * Promotion.t list) option + + val library_function_return_type : + string + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> Middle.UnsizedType.returntype option +end diff --git a/src/frontend/dune b/src/frontend/dune index 25da52c2e9..ff607ad7ec 100644 --- a/src/frontend/dune +++ b/src/frontend/dune @@ -6,7 +6,7 @@ (backend bisect_ppx)) (inline_tests) (preprocess - (pps ppx_jane ppx_deriving.fold ppx_deriving.map))) + (pps ppx_jane ppx_deriving.fold ppx_deriving.map ppx_deriving.create))) (ocamllex lexer) diff --git a/src/middle/Stan_math_signatures.mli b/src/middle/Stan_math_signatures.mli deleted file mode 100644 index f177959196..0000000000 --- a/src/middle/Stan_math_signatures.mli +++ /dev/null @@ -1,75 +0,0 @@ -(** This module stores a table of all signatures from the Stan - math C++ library which are exposed to Stan, and some helper - functions for dealing with those signatures. -*) - -open Core_kernel - -(** Function arguments are represented by their type an autodiff - type. This is [AutoDiffable] for everything except arguments - marked with the data keyword *) -type fun_arg = UnsizedType.autodifftype * UnsizedType.t - -(** Signatures consist of a return type, a list of arguments, and a flag - for whether or not those arguments can be Struct of Arrays objects *) -type signature = UnsizedType.returntype * fun_arg list * Mem_pattern.t - -val stan_math_signatures : (string, signature list) Hashtbl.t -(** Mapping from names to signature(s) of functions *) - -val is_stan_math_function_name : string -> bool -(** Equivalent to [Hashtbl.mem stan_math_signatures s]*) - -type variadic_signature = - { return_type: UnsizedType.t - ; control_args: fun_arg list - ; required_fn_rt: UnsizedType.t - ; required_fn_args: fun_arg list } - -val stan_math_variadic_signatures : (string, variadic_signature) Hashtbl.t -(** Mapping from names to description of a variadic function. - - Note that these function names cannot be overloaded, and usually require - customized code-gen in the backend. -*) - -val is_stan_math_variadic_function_name : string -> bool -(** Equivalent to [Hashtbl.mem stan_math_variadic_signatures s]*) - -(** Pretty printers *) - -val pp_math_sig : signature Fmt.t -val pretty_print_all_math_sigs : unit Fmt.t -val pretty_print_all_math_distributions : unit Fmt.t - -type dimensionality -type return_behavior - -type fkind = private - | Lpmf - | Lpdf - | Log - | Rng - | Cdf - | Ccdf - | UnaryVectorized of return_behavior -[@@deriving show {with_path= false}] - -val distributions : - (fkind list * string * dimensionality list * Mem_pattern.t) list -(** The distribution {e families} exposed by the math library *) - -val dist_name_suffix : (string * 'a) list -> string -> string - -(** Helpers for dealing with operators as signatures *) - -val operator_to_stan_math_fns : Operator.t -> string list -val string_operator_to_stan_math_fns : string -> string -val pretty_print_math_lib_operator_sigs : Operator.t -> string list -val make_assignmentoperator_stan_math_signatures : Operator.t -> signature list - -(** Special functions for the variadic signatures exposed *) - -(* reduce_sum helpers *) -val is_reduce_sum_fn : string -> bool -val reduce_sum_slice_types : UnsizedType.t list diff --git a/src/middle/dune b/src/middle/dune index 2884272674..2c6ede8b5a 100644 --- a/src/middle/dune +++ b/src/middle/dune @@ -6,9 +6,4 @@ (backend bisect_ppx)) (inline_tests) (preprocess - (pps - ppx_jane - ppx_deriving.map - ppx_deriving.fold - ppx_deriving.create - ppx_deriving.show))) + (pps ppx_jane ppx_deriving.map ppx_deriving.fold ppx_deriving.create))) diff --git a/src/stan_math_backend/Lower_expr.ml b/src/stan_math_backend/Lower_expr.ml index a1185325be..509aaeaedb 100644 --- a/src/stan_math_backend/Lower_expr.ml +++ b/src/stan_math_backend/Lower_expr.ml @@ -45,9 +45,9 @@ type variadic = FixedArgs | ReduceSum | VariadicHOF of int [@@deriving compare, hash] let functor_type hof = - match Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures hof with + match Hashtbl.find Stan_math_library.variadic_signatures hof with | Some {required_fn_args; _} -> VariadicHOF (List.length required_fn_args) - | None when Stan_math_signatures.is_reduce_sum_fn hof -> ReduceSum + | None when Stan_math_library.is_reduce_sum_fn hof -> ReduceSum | None -> FixedArgs let functor_suffix_select = function @@ -329,7 +329,7 @@ and lower_functionals fname suffix es mem_pattern = | ( x , {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _} :: grainsize :: container :: tl ) - when Stan_math_signatures.is_reduce_sum_fn x -> + when Stan_math_library.is_reduce_sum_fn x -> let chop_functor_suffix = String.chop_suffix_exn ~suffix:reduce_sum_functor_suffix in let propto_template = @@ -343,11 +343,9 @@ and lower_functionals fname suffix es mem_pattern = ^ reduce_sum_functor_suffix in ( Fmt.str "%s<%s%s>" fname normalized_dist_functor propto_template , grainsize :: container :: msgs :: tl ) - | _, _ - when Stan_math_signatures.is_stan_math_variadic_function_name fname -> - let Stan_math_signatures.{control_args; _} = - Hashtbl.find_exn - Stan_math_signatures.stan_math_variadic_signatures fname in + | _, _ when Stan_math_library.is_variadic_function_name fname -> + let Frontend.Std_library_utils.{control_args; _} = + Hashtbl.find_exn Stan_math_library.variadic_signatures fname in let hd, tl = List.split_n converted_es (List.length control_args + 1) in (fname, hd @ (msgs :: tl)) diff --git a/src/middle/Stan_math_signatures.ml b/src/stan_math_backend/Stan_math_library.ml similarity index 94% rename from src/middle/Stan_math_signatures.ml rename to src/stan_math_backend/Stan_math_library.ml index e6c37bfc1d..df91128a90 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/stan_math_backend/Stan_math_library.ml @@ -1,7 +1,9 @@ (** The signatures of the Stan Math library, which are used for type checking *) -open Core_kernel +open Core_kernel open Core_kernel.Poly +open Middle +open Frontend.Std_library_utils (** The "dimensionality" (bad name?) is supposed to help us represent the vectorized nature of many Stan functions. It allows us to represent when @@ -87,23 +89,13 @@ type fkind = | UnaryVectorized of return_behavior [@@deriving show {with_path= false}] -type fun_arg = UnsizedType.autodifftype * UnsizedType.t -type signature = UnsizedType.returntype * fun_arg list * Mem_pattern.t - -type variadic_signature = - { return_type: UnsizedType.t - ; control_args: fun_arg list - ; required_fn_rt: UnsizedType.t - ; required_fn_args: fun_arg list } -[@@deriving create] - let is_primitive = function | UnsizedType.UReal -> true | UInt -> true | _ -> false (** The signatures hash table *) -let (stan_math_signatures : (string, signature list) Hashtbl.t) = +let (function_signatures : (string, signature list) Hashtbl.t) = String.Table.create () (** All of the signatures that are added by hand, rather than the ones @@ -115,7 +107,7 @@ let (manual_stan_math_signatures : (string, signature list) Hashtbl.t) = These functions cannot be overloaded. *) -let (stan_math_variadic_signatures : (string, variadic_signature) Hashtbl.t) = +let (variadic_signatures : (string, variadic_signature) Hashtbl.t) = String.Table.create () (* XXX The correct word here isn't combination - what is it? *) @@ -298,6 +290,9 @@ let distributions = ; ([Lpdf], "wishart_cholesky", [DMatrix; DReal; DMatrix], SoA) ; ([Lpdf; Log], "wishart", [DMatrix; DReal; DMatrix], SoA) ] +let distribution_families = + List.map ~f:(fun (_, name, _, _) -> name) distributions + let basic_vectorized = UnaryVectorized IntsToReals let math_sigs = @@ -369,27 +364,11 @@ let all_declarative_sigs = distributions @ math_sigs let declarative_fnsigs = List.concat_map ~f:mk_declarative_sig all_declarative_sigs -let is_stan_math_function_name name = +let is_stdlib_function_name name = let name = Utils.stdlib_distribution_name name in - Hashtbl.mem stan_math_signatures name - -let is_stan_math_variadic_function_name name = - Hashtbl.mem stan_math_variadic_signatures name - -let dist_name_suffix udf_names name = - let is_udf_name s = List.exists ~f:(fun (n, _) -> n = s) udf_names in - match - Utils.distribution_suffices - |> List.filter ~f:(fun sfx -> - is_stan_math_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) - |> List.hd - with - | Some hd -> hd - | None -> - Common.FatalError.fatal_error_msg - [%message "Couldn't find distribution " name] - -let operator_to_stan_math_fns op = + Hashtbl.mem function_signatures name + +let operator_to_function_names op = match op with | Operator.Plus -> ["add"] | PPlus -> ["plus"] @@ -415,21 +394,15 @@ let operator_to_stan_math_fns op = | PNot -> ["logical_negation"] | Transpose -> ["transpose"] -let int_divide_type = - UnsizedType. - ( ReturnType UInt - , [(AutoDiffable, UInt); (AutoDiffable, UInt)] - , Mem_pattern.AoS ) - -let get_sigs name = +let get_signatures name = let name = Utils.stdlib_distribution_name name in - Hashtbl.find_multi stan_math_signatures name |> List.sort ~compare + Hashtbl.find_multi function_signatures name |> List.sort ~compare -let make_assignmentoperator_stan_math_signatures assop = +let get_assignment_operator_signatures assop = ( match assop with | Operator.Divide -> ["divide"] - | assop -> operator_to_stan_math_fns assop ) - |> List.concat_map ~f:get_sigs + | assop -> operator_to_function_names assop ) + |> List.concat_map ~f:get_signatures |> List.concat_map ~f:(function | ReturnType rtype, [(ad1, lhs); (ad2, rhs)], _ when rtype = lhs @@ -442,15 +415,7 @@ let make_assignmentoperator_stan_math_signatures assop = else [(Void, [(ad1, lhs); (ad2, rhs)], SoA)] | _ -> [] ) -let pp_math_sig ppf (rt, args, mem_pattern) = - UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) - -let pp_math_sigs ppf name = - (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf (get_sigs name) - -let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs - -let string_operator_to_stan_math_fns str = +let string_operator_to_function_name str = match str with | "Plus__" -> "add" | "PPlus__" -> "plus" @@ -485,10 +450,10 @@ let pretty_print_all_math_sigs ppf () = (List.map ~f:snd args) UnsizedType.pp_returntype rt in let pp_sigs_for_name ppf name = (list ~sep:cut pp_sig) ppf - (List.map ~f:(fun t -> (name, t)) (get_sigs name)) in + (List.map ~f:(fun t -> (name, t)) (get_signatures name)) in pf ppf "@[%a@]" (list ~sep:cut pp_sigs_for_name) - (List.sort ~compare (Hashtbl.keys stan_math_signatures)) + (List.sort ~compare (Hashtbl.keys function_signatures)) let pretty_print_all_math_distributions ppf () = let open Fmt in @@ -498,16 +463,61 @@ let pretty_print_all_math_distributions ppf () = (List.map ~f:(Fn.compose String.lowercase show_fkind) kinds) in pf ppf "@[%a@]" (list ~sep:cut pp_dist) distributions -let pretty_print_math_lib_operator_sigs op = - if op = Operator.IntDivide then - [Fmt.str "@[@,%a@]" pp_math_sig int_divide_type] - else operator_to_stan_math_fns op |> List.map ~f:pretty_print_math_sigs +let int_divide_type = + UnsizedType. + ( ReturnType UInt + , [(AutoDiffable, UInt); (AutoDiffable, UInt)] + , Mem_pattern.AoS ) + +let get_operator_signatures op = + if op = Operator.IntDivide then [int_divide_type] + else operator_to_function_names op |> List.concat_map ~f:get_signatures + +let deprecated_distributions = + List.concat_map distributions ~f:(fun (fnkinds, name, _, _) -> + List.filter_map fnkinds ~f:(function + | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") + | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") + | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") + | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") + | Rng | Log | UnaryVectorized _ -> None ) ) + |> List.map ~f:(fun (x, y) -> + ( x + , { replacement= y + ; version= "2.33.0" + ; extra_message= + "This can be automatically changed using the canonicalize flag \ + for stanc" + ; canonicalize_away= true } ) ) + |> String.Map.of_alist_exn + +let deprecated_functions = + let make extra_message version canonicalize_away replacement = + {extra_message; replacement; version; canonicalize_away} in + let ode = + make + "\n\ + The new interface is slightly different, see: \ + https://mc-stan.org/users/documentation/case-studies/convert_odes.html" + "3.0" false in + let std = + make + "This can be automatically changed using the canonicalize flag for stanc" + "2.33.0" true in + String.Map.of_alist_exn + [ ("multiply_log", std "lmultiply") + ; ("binomial_coefficient_log", std "lchoose") + ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) + ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") + ; ("integrate_ode_bdf", ode "ode_bdf") + ; ("integrate_ode_adams", ode "ode_adams") + ; ("if_else", std "the conditional operator (x ? y : z)") + ; ("fabs", std "abs") ] (* -- Some helper definitions to populate stan_math_signatures -- *) let add_qualified (name, rt, argts, supports_soa) = - Hashtbl.add_multi stan_math_signatures ~key:name - ~data:(rt, argts, supports_soa) + Hashtbl.add_multi function_signatures ~key:name ~data:(rt, argts, supports_soa) let add_nullary name = add_unqualified (name, UnsizedType.ReturnType UReal, [], AoS) @@ -751,7 +761,7 @@ let for_vector_types s = List.iter ~f:s vector_types (* -- Start populating stan_math_signaturess -- *) let () = List.iter declarative_fnsigs ~f:(fun (key, rt, args, mem_pattern) -> - Hashtbl.add_multi stan_math_signatures ~key ~data:(rt, args, mem_pattern) ) ; + Hashtbl.add_multi function_signatures ~key ~data:(rt, args, mem_pattern) ) ; add_unqualified ("acos", ReturnType UComplex, [UComplex], AoS) ; add_unqualified ("acosh", ReturnType UComplex, [UComplex], AoS) ; List.iter @@ -2560,7 +2570,7 @@ let () = for type-checking *) Hashtbl.iteri manual_stan_math_signatures ~f:(fun ~key ~data -> List.iter data ~f:(fun data -> - Hashtbl.add_multi stan_math_signatures ~key ~data ) ) + Hashtbl.add_multi function_signatures ~key ~data ) ) (* variadics *) @@ -2629,12 +2639,21 @@ let variadic_ode_nonadjoint_fns = let ode_tolerances_suffix = "_tol" let is_reduce_sum_fn f = Set.mem reduce_sum_functions f +let is_special_function_name = is_reduce_sum_fn +let is_variadic_function_name name = Hashtbl.mem variadic_signatures name + +let special_function_returntype name = + if is_reduce_sum_fn name then Some (UnsizedType.ReturnType UReal) else None + +let is_not_overloadable name = + is_variadic_function_name name || is_special_function_name name + let variadic_dae_fun_return_type = UnsizedType.UVector let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector let add_variadic_fn name ~return_type ?control_args ~required_fn_rt ?required_fn_args () = - Hashtbl.add_exn stan_math_variadic_signatures ~key:name + Hashtbl.add_exn variadic_signatures ~key:name ~data: (create_variadic_signature ~return_type ?control_args ?required_fn_args ~required_fn_rt () ) @@ -2693,9 +2712,76 @@ let () = ~required_fn_args:[UnsizedType.(AutoDiffable, UVector)] () -let%expect_test "dist name suffix" = - dist_name_suffix [] "normal" |> print_endline ; - [%expect {| _lpdf |}] +module Special_typechecking = struct + (** This module serves as the backend-specific portion + of the typechecker. *) + + open Frontend + open Typechecking + open Ast + + let error e = raise (Errors.SemanticError e) + + let check_reduce_sum ~is_cond_dist loc current_block tenv id tes = + let basic_mismatch () = + let mandatory_args = + UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in + let mandatory_fun_args = + UnsizedType. + [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] + in + SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args + mandatory_fun_args UReal (get_arg_types tes) in + let matching remaining_es fn = + match fn with + | Environment. + { type_= + UnsizedType.UFun + (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as + ftype + ; _ } + when List.mem reduce_sum_slice_types sliced_arg_fun_type ~equal:( = ) -> + let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in + let mandatory_fun_args = + [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args + mandatory_fun_args UReal arg_types + | _ -> basic_mismatch () in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + (* a valid signature exists *) + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_fun_app ~is_cond_dist (StanLib FnPlain) id + (Promotion.promote_list tes promotions) + ~type_:UnsizedType.UReal ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error ) + | _ -> + let expected_args, err = + basic_mismatch () |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error +end + +let check_special_fn = Special_typechecking.check_reduce_sum let%expect_test "declarative distributions" = let special_suffixes = @@ -2706,7 +2792,7 @@ let%expect_test "declarative distributions" = distributions |> List.map ~f:(function _, n, _, _ -> n) |> String.Set.of_list in - Hashtbl.keys stan_math_signatures + Hashtbl.keys function_signatures |> List.filter ~f:(fun name -> match Utils.split_distribution_suffix name with | Some (name, suffix) diff --git a/src/stan_math_backend/Stan_math_library.mli b/src/stan_math_backend/Stan_math_library.mli new file mode 100644 index 0000000000..31f0ee15cd --- /dev/null +++ b/src/stan_math_backend/Stan_math_library.mli @@ -0,0 +1,22 @@ +(** This module stores a table of all signatures from the Stan + math C++ library which are exposed to Stan, and some helper + functions for dealing with those signatures. +*) + +include Frontend.Std_library_utils.Library + +(** These functions are used by the drivers to display + all available functions and distributions. They are + not part of the Library interface since different drivers + for different backends would likely want different behavior + here *) + +val pretty_print_all_math_sigs : unit Fmt.t +val pretty_print_all_math_distributions : unit Fmt.t + +(** These functions related to variadic functions + are specific to this backend and used + during code generation *) + +(* reduce_sum helpers *) +val is_reduce_sum_fn : string -> bool diff --git a/src/stan_math_backend/dune b/src/stan_math_backend/dune index 05ec7c18b1..84badc1996 100644 --- a/src/stan_math_backend/dune +++ b/src/stan_math_backend/dune @@ -1,7 +1,7 @@ (library (name stan_math_backend) (public_name stanc.stan_math_backend) - (libraries core_kernel re fmt middle yojson) + (libraries core_kernel re fmt frontend middle yojson) (instrumentation (backend bisect_ppx)) (private_modules @@ -14,4 +14,10 @@ numbering) (inline_tests) (preprocess - (pps ppx_jane ppx_deriving.map ppx_deriving.fold ppx_deriving.make))) + (pps + ppx_jane + ppx_deriving.map + ppx_deriving.fold + ppx_deriving.show + ppx_deriving.create + ppx_deriving.make))) diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index fbc9c33e87..88ce19e16b 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -3,9 +3,17 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Stan_math_backend -open Analysis_and_optimization open Middle +open Analysis_and_optimization +open Stan_math_backend + +(* Initialize functor modules with the Stan Math Library *) +module Typechecker = Typechecking.Make (Stan_math_library) +module Deprecations = Deprecation_analysis.Make (Stan_math_library) +module Canonicalizer = Canonicalize.Make (Deprecations) +module ModelInfo = Info.Make (Stan_math_library) +module Ast2Mir = Ast_to_Mir.Make (Stan_math_library) +module Optimizer = Optimize.Make (Stan_math_library) (** The main program. *) let version = "%%NAME%%3 %%VERSION%%" @@ -165,7 +173,7 @@ let options = exit 0 ) , " Display stanc version number" ) ; ( "--name" - , Arg.Set_string Typechecker.model_name + , Arg.Set_string Typechecking.model_name , " Take a string to set the model name (default = \ \"$model_filename_model\")" ) ; ( "--O0" @@ -199,12 +207,11 @@ let options = , Arg.Set print_model_cpp , " If set, output the generated C++ Stan model class to stdout." ) ; ( "--allow-undefined" - , Arg.Clear Typechecker.check_that_all_functions_have_definition + , Arg.Clear Typechecking.check_that_all_functions_have_definition , " Do not fail if a function is declared but not defined" ) ; ( "--include-paths" , Arg.String - (fun str -> - Preprocessor.include_paths := String.split_on_chars ~on:[','] str ) + (fun str -> Preprocessor.include_paths := String.split ~on:',' str) , " Takes a comma-separated list of directories that may contain a file \ in an #include directive (default = \"\")" ) ; ( "--use-opencl" @@ -295,11 +302,11 @@ let use_file filename = ~print_warnings:(not !canonicalize_settings.deprecations) ~bare_functions:!bare_functions in (* must be before typecheck to fix up deprecated syntax which gets rejected *) - let ast = Canonicalize.repair_syntax ast !canonicalize_settings in + let ast = Canonicalizer.repair_syntax ast !canonicalize_settings in Debugging.ast_logger ast ; let typed_ast = type_ast_or_exit ?printed_filename ast in let canonical_ast = - Canonicalize.canonicalize_program typed_ast !canonicalize_settings in + Canonicalizer.canonicalize_program typed_ast !canonicalize_settings in if !pretty_print_program then print_or_write (Pretty_printing.pretty_print_typed_program @@ -307,13 +314,13 @@ let use_file filename = ~inline_includes:!canonicalize_settings.inline_includes canonical_ast ~strip_comments:!canonicalize_settings.strip_comments ) ; if !print_info_json then ( - print_endline (Info.info canonical_ast) ; + print_endline (ModelInfo.info canonical_ast) ; exit 0 ) ; if not !canonicalize_settings.deprecations then Warnings.pp_warnings Fmt.stderr ?printed_filename - (Deprecation_analysis.collect_warnings typed_ast) ; + (Deprecations.collect_warnings typed_ast) ; if !generate_data then ( - let decls = Ast_to_Mir.gather_declarations typed_ast.datablock in + let decls = Ast2Mir.gather_declarations typed_ast.datablock in let context = match !data_file with | None -> Map.Poly.empty @@ -331,11 +338,11 @@ let use_file filename = | None -> Map.Poly.empty | Some file -> Debug_data_generation.json_to_mir - (Ast_to_Mir.gather_declarations typed_ast.datablock) + (Ast2Mir.gather_declarations typed_ast.datablock) (Yojson.Basic.from_file file) in match Debug_data_generation.gen_values_json ~new_only:true ~context - (Ast_to_Mir.gather_declarations typed_ast.parametersblock) + (Ast2Mir.gather_declarations typed_ast.parametersblock) with | Ok s -> print_or_write s ; exit 0 | Error e -> @@ -347,7 +354,7 @@ let use_file filename = Fmt.pf Fmt.stderr "Warning: ignoring --debug-data-file" ; Debugging.typed_ast_logger typed_ast ; if not !pretty_print_program then ( - let mir = Ast_to_Mir.trans_prog filename typed_ast in + let mir = Ast2Mir.trans_prog filename typed_ast in if !dump_mir then Sexp.pp_hum Format.std_formatter [%sexp (mir : Middle.Program.Typed.t)] ; if !dump_mir_pretty then Program.Typed.pp Format.std_formatter mir ; @@ -367,7 +374,7 @@ let use_file filename = if !no_soa_opt then {base_optims with optimize_soa= false} else if !soa_opt then {base_optims with optimize_soa= true} else base_optims in - Optimize.optimization_suite ~settings:set_optims tx_mir in + Optimizer.optimization_suite ~settings:set_optims tx_mir in if !dump_mem_pattern then Memory_patterns.pp_mem_patterns Format.std_formatter opt_mir ; if !dump_opt_mir then @@ -393,11 +400,11 @@ let main () = Arg.parse options add_file usage ; (* Deal with multiple modalities *) if !dump_stan_math_sigs then ( - Stan_math_signatures.pretty_print_all_math_sigs Format.std_formatter () ; + Stan_math_library.pretty_print_all_math_sigs Format.std_formatter () ; exit 0 ) ; if !dump_stan_math_distributions then ( - Stan_math_signatures.pretty_print_all_math_distributions - Format.std_formatter () ; + Stan_math_library.pretty_print_all_math_distributions Format.std_formatter + () ; exit 0 ) ; if !model_file = "" then model_file_err () ; let stanc_args_to_print = @@ -417,12 +424,12 @@ let main () = Lower_program.standalone_functions := true ; bare_functions := true ) ; (* Just translate a stan program *) - if !Typechecker.model_name = "" then - Typechecker.model_name := + if !Typechecking.model_name = "" then + Typechecking.model_name := mangle (remove_dotstan List.(hd_exn (rev (String.split !model_file ~on:'/')))) ^ "_model" - else Typechecker.model_name := mangle !Typechecker.model_name ; + else Typechecking.model_name := mangle !Typechecking.model_name ; use_file !model_file let () = main () diff --git a/src/stancjs/stancjs.ml b/src/stancjs/stancjs.ml index 3e9f5ab155..a81d051fc7 100644 --- a/src/stancjs/stancjs.ml +++ b/src/stancjs/stancjs.ml @@ -1,16 +1,24 @@ open Core_kernel open Frontend -open Stan_math_backend -open Analysis_and_optimization open Middle +open Analysis_and_optimization +open Stan_math_backend open Js_of_ocaml +(* Initialize functors with Stan Math C++ signatures *) +module Typechecker = Typechecking.Make (Stan_math_library) +module Deprecations = Deprecation_analysis.Make (Stan_math_library) +module Canonicalizer = Canonicalize.Make (Deprecations) +module ModelInfo = Info.Make (Stan_math_library) +module Ast2Mir = Ast_to_Mir.Make (Stan_math_library) +module Optimizer = Optimize.Make (Stan_math_library) + let version = "%%NAME%% %%VERSION%%" let stan2cpp model_name model_string is_flag_set flag_val = Common.Gensym.reset_danger_use_cautiously () ; - Typechecker.model_name := model_name ; - Typechecker.check_that_all_functions_have_definition := + Typechecking.model_name := model_name ; + Typechecking.check_that_all_functions_have_definition := not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ; Transform_Mir.use_opencl := is_flag_set "use-opencl" ; Lower_program.standalone_functions := @@ -33,7 +41,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = >>| fun (typed_ast, type_warnings) -> let warnings = parser_warnings @ type_warnings in if is_flag_set "info" then - r.return (Result.Ok (Info.info typed_ast), warnings, []) ; + r.return (Result.Ok (ModelInfo.info typed_ast), warnings, []) ; let canonicalizer_settings = if is_flag_set "print-canonical" then Canonicalize.legacy else @@ -56,11 +64,11 @@ let stan2cpp model_name model_string is_flag_set flag_val = flag_val "max-line-length" |> Option.map ~f:int_of_string |> Option.value ~default:78 in + let mir = Ast2Mir.trans_prog model_name typed_ast in let deprecation_warnings = if canonicalizer_settings.deprecations then [] - else Deprecation_analysis.collect_warnings typed_ast in + else Deprecations.collect_warnings typed_ast in let warnings = warnings @ deprecation_warnings in - let mir = Ast_to_Mir.trans_prog model_name typed_ast in let tx_mir = Transform_Mir.trans_prog mir in if is_flag_set "auto-format" || is_flag_set "print-canonical" then r.return @@ -70,7 +78,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = ~line_length ~inline_includes:canonicalizer_settings.inline_includes ~strip_comments:canonicalizer_settings.strip_comments - (Canonicalize.canonicalize_program typed_ast + (Canonicalizer.canonicalize_program typed_ast canonicalizer_settings ) ) , warnings , [] ) ; @@ -87,7 +95,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = ( Result.map_error ~f:(fun e -> Errors.DebugDataError e) (Debug_data_generation.gen_values_json - (Ast_to_Mir.gather_declarations typed_ast.datablock) ) + (Ast2Mir.gather_declarations typed_ast.datablock) ) , warnings , [] ) ; if is_flag_set "debug-generate-inits" then @@ -95,7 +103,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = ( Result.map_error ~f:(fun e -> Errors.DebugDataError e) (Debug_data_generation.gen_values_json - (Ast_to_Mir.gather_declarations typed_ast.parametersblock) ) + (Ast2Mir.gather_declarations typed_ast.parametersblock) ) , warnings , [] ) ; let opt_mir = @@ -105,7 +113,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = else if is_flag_set "Oexperimental" || is_flag_set "O" then Optimize.Oexperimental else Optimize.O0 in - Optimize.optimization_suite + Optimizer.optimization_suite ~settings:(Optimize.level_optimizations opt_lvl) tx_mir in if is_flag_set "debug-optimized-mir" then @@ -210,11 +218,11 @@ let stan2cpp_wrapped name code (flags : Js.string_array Js.t Js.opt) = wrap_result ?printed_filename ~code result ~warnings let dump_stan_math_signatures () = - Js.string @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_sigs () + Js.string @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_sigs () let dump_stan_math_distributions () = Js.string - @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_distributions () + @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_distributions () let () = Js.export "dump_stan_math_signatures" dump_stan_math_signatures ; diff --git a/test/integration/bad/stanc.expected b/test/integration/bad/stanc.expected index 360b0a0bf9..b5b2653f96 100644 --- a/test/integration/bad/stanc.expected +++ b/test/integration/bad/stanc.expected @@ -1637,7 +1637,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (real, vector) => vector @@ -1669,7 +1668,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (real, vector) => vector diff --git a/test/integration/good/warning/pretty.expected b/test/integration/good/warning/pretty.expected index 3da261fe53..730658df14 100644 --- a/test/integration/good/warning/pretty.expected +++ b/test/integration/good/warning/pretty.expected @@ -1801,58 +1801,58 @@ model { y_p ~ normal(0, 1); } -Warning in 'if_else.stan', line 9, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 10, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 11, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 12, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 21, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 22, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 23, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 24, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 26, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 27, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 28, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 29, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 30, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.33.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc +Warning in 'if_else.stan', line 9, column 26: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 10, column 26: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 11, column 26: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 12, column 26: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 21, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 22, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 23, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 24, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 26, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 27, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 28, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 29, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 30, column 28: if_else is deprecated and will + be removed in Stan 2.33.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc $ ../../../../../install/default/bin/stanc --auto-format increment_log_prob.stan transformed data { int n; diff --git a/test/unit/Debug_data_generation_tests.ml b/test/unit/Debug_data_generation_tests.ml index ab9837a48f..f3d97521e7 100644 --- a/test/unit/Debug_data_generation_tests.ml +++ b/test/unit/Debug_data_generation_tests.ml @@ -1,11 +1,4 @@ -open Analysis_and_optimization open Core_kernel -open Frontend -open Debug_data_generation - -let print_data_prog ast = - gen_values_json (Ast_to_Mir.gather_declarations ast.Ast.datablock) - |> Result.ok |> Option.value_exn let%expect_test "whole program data generation check" = let ast = @@ -19,7 +12,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -50,7 +43,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -99,7 +92,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -136,7 +129,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -258,7 +251,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -517,7 +510,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -649,7 +642,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -668,7 +661,7 @@ let%expect_test "Complex numbers program" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| diff --git a/test/unit/Desugar_test.ml b/test/unit/Desugar_test.ml index 525fbfe24b..1bf13ae28b 100644 --- a/test/unit/Desugar_test.ml +++ b/test/unit/Desugar_test.ml @@ -1,6 +1,9 @@ open Core_kernel open Analysis_and_optimization +module Partial_evaluator = + Partial_evaluation.Make (Stan_math_backend.Stan_math_library) + let print_tdata Middle.Program.{prepare_data; _} = Fmt.(str "@[%a@]@," (list ~sep:cut Middle.Stmt.Located.pp) prepare_data) |> print_endline diff --git a/test/unit/Optimize.ml b/test/unit/Optimize.ml index c581730023..fc9f05ceec 100644 --- a/test/unit/Optimize.ml +++ b/test/unit/Optimize.ml @@ -1,9 +1,13 @@ open Core_kernel -open Analysis_and_optimization.Optimize open Middle open Common open Analysis_and_optimization.Mir_utils +module Optimizer = + Analysis_and_optimization.Optimize.Make (Stan_math_backend.Stan_math_library) + +open Optimizer + let reset_and_mir_of_string s = Gensym.reset_danger_use_cautiously () ; Test_utils.mir_of_string s @@ -448,8 +452,8 @@ let%expect_test "recursive functions" = } |}] let%expect_test "do not try to inline extern functions" = - let before = !Frontend.Typechecker.check_that_all_functions_have_definition in - Frontend.Typechecker.check_that_all_functions_have_definition := false ; + let before = !Frontend.Typechecking.check_that_all_functions_have_definition in + Frontend.Typechecking.check_that_all_functions_have_definition := false ; let mir = reset_and_mir_of_string {| @@ -461,7 +465,7 @@ let%expect_test "do not try to inline extern functions" = } |} in - Frontend.Typechecker.check_that_all_functions_have_definition := before ; + Frontend.Typechecking.check_that_all_functions_have_definition := before ; let mir = function_inlining mir in Fmt.str "@[%a@]" Program.Typed.pp mir |> print_endline ; [%expect diff --git a/test/unit/Test_utils.ml b/test/unit/Test_utils.ml index 5522236e61..146451c9ac 100644 --- a/test/unit/Test_utils.ml +++ b/test/unit/Test_utils.ml @@ -1,6 +1,12 @@ open Frontend open Core_kernel +module CppLibrary : Std_library_utils.Library = + Stan_math_backend.Stan_math_library + +module Typechecker = Typechecking.Make (CppLibrary) +module Ast2Mir = Ast_to_Mir.Make (CppLibrary) + let untyped_ast_of_string s = let res, warnings = Parse.parse_string Parser.Incremental.program s in Fmt.epr "%a" (Fmt.list ~sep:Fmt.nop Warnings.pp) warnings ; @@ -15,4 +21,9 @@ let typed_ast_of_string_exn s = |> Result.map_error ~f:Errors.to_string |> Result.ok_or_failwith |> fst -let mir_of_string s = typed_ast_of_string_exn s |> Ast_to_Mir.trans_prog "" +let mir_of_string s = typed_ast_of_string_exn s |> Ast2Mir.trans_prog "" + +let print_data_prog ast = + Analysis_and_optimization.Debug_data_generation.gen_values_json + (Ast2Mir.gather_declarations ast.Ast.datablock) + |> Result.ok |> Option.value_exn