From 5f9a0003add60ae0c0b2f4ecd302552472456c58 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 27 Apr 2022 10:25:24 -0400 Subject: [PATCH 01/19] Start refactor Try to use virtual libraries instead, get cyclic behavior --- docs/core_ideas.mld | 25 ++ docs/exposing_new_functions.mld | 7 +- .../Debug_data_generation.ml | 2 +- .../Dependence_analysis.ml | 7 +- src/analysis_and_optimization/Mem_pattern.ml | 11 +- .../Monotone_framework.ml | 20 +- ...rk_sigs.mli => Monotone_framework_intf.ml} | 0 src/analysis_and_optimization/Optimize.ml | 131 ++++--- src/analysis_and_optimization/Optimize.mli | 55 +-- .../Partial_evaluator.ml | 4 +- .../Partial_evaluator.mli | 4 + .../Pedantic_analysis.ml | 5 +- src/analysis_and_optimization/dune | 4 +- src/frontend/Ast.ml | 4 + src/frontend/Ast_to_Mir.ml | 14 +- src/frontend/Canonicalize.ml | 21 +- src/frontend/Deprecation_analysis.ml | 62 +--- src/frontend/Deprecation_analysis.mli | 5 +- src/frontend/Environment.ml | 4 +- src/frontend/Environment.mli | 12 +- src/frontend/Info.ml | 18 +- src/frontend/Info.mli | 2 +- src/frontend/Library.mli | 40 ++ src/frontend/Semantic_error.ml | 122 +++--- src/frontend/Semantic_error.mli | 57 ++- src/frontend/SignatureMismatch.ml | 23 +- src/frontend/SignatureMismatch.mli | 19 +- src/frontend/Std_library_utils.ml | 31 ++ src/frontend/Typechecker.ml | 290 +++------------ src/frontend/Typechecker.mli | 45 ++- src/frontend/dune | 1 + src/middle/Stan_math_signatures.mli | 78 ---- src/middle/dune | 7 +- src/stan_math_backend/Expression_gen.ml | 6 +- .../Stan_math_signatures.ml | 347 +++++++++++++++--- .../Stan_math_signatures.mli | 69 ++++ src/stan_math_backend/dune | 4 +- src/stan_math_library/Library.ml | 1 + src/stan_math_library/dune | 8 + src/stanc/dune | 7 +- src/stanc/stanc.ml | 4 +- src/stancjs/stancjs.ml | 30 +- test/integration/bad/stanc.expected | 2 - test/unit/Debug_data_generation_tests.ml | 21 +- test/unit/Desugar_test.ml | 3 + test/unit/Optimize.ml | 6 +- test/unit/Test_utils.ml | 12 +- 47 files changed, 878 insertions(+), 772 deletions(-) rename src/analysis_and_optimization/{Monotone_framework_sigs.mli => Monotone_framework_intf.ml} (100%) create mode 100644 src/analysis_and_optimization/Partial_evaluator.mli create mode 100644 src/frontend/Library.mli create mode 100644 src/frontend/Std_library_utils.ml delete mode 100644 src/middle/Stan_math_signatures.mli rename src/{middle => stan_math_backend}/Stan_math_signatures.ml (89%) create mode 100644 src/stan_math_backend/Stan_math_signatures.mli create mode 100644 src/stan_math_library/Library.ml create mode 100644 src/stan_math_library/dune diff --git a/docs/core_ideas.mld b/docs/core_ideas.mld index 111cd474c0..b05bfec243 100644 --- a/docs/core_ideas.mld +++ b/docs/core_ideas.mld @@ -65,6 +65,31 @@ This takes some getting used to, and also can lead to some unhelpful type signat VSCode, because abbreviations are not always used in hover-over text. For example, [Expr.Typed.t], the MIR's typed expression type, actually has a signature of [Expr.Typed.Meta.t Expr.Fixed.t]. +{1 The [Library] interface and functors} + +Many modules of stanc are modeled as OCaml {{:https://ocaml.org/learn/tutorials/functors.html}functors}, +which take in another module as input and produce a module as output. For the most part, +these functors expect an instance of the [Library] interface defined in +[src/frontend/Std_library_utils.ml]. + +This module primarily contains signatures for the Stan standard library. For most users, +you can assume this will be filled in with [src/stan_math_backend/Stan_math_library.ml], +the object representing the {{:https://github.com/stan-dev/math}stan-dev/math} C++ library. + +Usages of these functors are rather simple, e.g. in the core stanc driver the line + +{[ +module Typechecker = Typechecking.Make (Stan_math_library) +]} + +defines a module [Typechecker] by supplying the functor [Typechecking.Make] with +the Stan C++ library module. After this, [Typechecker.check_program] will typecheck +an AST against those specific functions. + +As noted in the above tutorial link, the syntax of functors is often the hardest part +of using and understanding them. The functors which accept [Library] are all relatively +simple, and should serve as good examples to beginners with the concept. + {1 The [Fmt] library and pretty-printing} We extensively use the {{:https://erratique.ch/software/fmt}Fmt} library for our pretty-printing and code diff --git a/docs/exposing_new_functions.mld b/docs/exposing_new_functions.mld index 2d1ff34b32..087d5bffa6 100644 --- a/docs/exposing_new_functions.mld +++ b/docs/exposing_new_functions.mld @@ -7,7 +7,7 @@ For a function to be built into Stan, it has to be included in the Stan Math library and its signature has to be exposed to the compiler. -To do the latter, we have to add a corresponding line in [src/middle/Stan_math_signatures.ml]. +To do the latter, we have to add a corresponding line in [src/stan_math_backend/Stan_math_library.ml]. The compiler uses the signatures defined there to do type checking. @@ -130,8 +130,9 @@ For example, the following line defines the signature [add(real, matrix) => matr Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length list of arguments, are {b NOT} added to this list. -These are instead treated as special cases in the [Typechecker] module in the frontend folder. It -it best to consult an existing example of how these are done before proceeding. +These are instead handled by special functions like [is_variadic_function_name]. They +must also be given custom typechecking rules in the private sub-module [Variadic_typechecking]. +It is best to consult an existing example of how these are done before proceeding. {1 Testing} diff --git a/src/analysis_and_optimization/Debug_data_generation.ml b/src/analysis_and_optimization/Debug_data_generation.ml index e3166aa0e4..a6858b4837 100644 --- a/src/analysis_and_optimization/Debug_data_generation.ml +++ b/src/analysis_and_optimization/Debug_data_generation.ml @@ -30,7 +30,7 @@ let rec vect_to_mat l m = let eval_expr m e = let e = Mir_utils.subst_expr m e in - let e = Partial_evaluator.eval_expr e in + let e = Partial_evaluator.try_eval_expr e in let rec strip_promotions (e : Middle.Expr.Typed.t) = match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e in diff --git a/src/analysis_and_optimization/Dependence_analysis.ml b/src/analysis_and_optimization/Dependence_analysis.ml index 0bc86aeeb8..7b4a28951b 100644 --- a/src/analysis_and_optimization/Dependence_analysis.ml +++ b/src/analysis_and_optimization/Dependence_analysis.ml @@ -4,7 +4,7 @@ open Middle open Dataflow_types open Mir_utils open Dataflow_utils -open Monotone_framework_sigs +open Monotone_framework_intf open Monotone_framework (***********************************) @@ -119,9 +119,8 @@ let mir_reaching_definitions (mir : Program.Typed.t) (stmt : Stmt.Located.t) : Map.Poly.map rd_map ~f:(fun {entry; exit} -> {entry= to_rd_set entry; exit= to_rd_set exit} ) -let all_labels - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - : int Set.Poly.t = +let all_labels (module Flowgraph : FLOWGRAPH with type labels = int) : + int Set.Poly.t = let step set = Set.Poly.union set (union_map set ~f:(fun l -> Map.Poly.find_exn Flowgraph.successors l)) diff --git a/src/analysis_and_optimization/Mem_pattern.ml b/src/analysis_and_optimization/Mem_pattern.ml index 7613ea45d0..d71aa3cc8a 100644 --- a/src/analysis_and_optimization/Mem_pattern.ml +++ b/src/analysis_and_optimization/Mem_pattern.ml @@ -111,16 +111,13 @@ let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t) let query_stan_math_mem_pattern_support (name : string) (args : (UnsizedType.autodifftype * UnsizedType.t) list) = - let open Stan_math_signatures in match name with - | x when is_reduce_sum_fn x -> false - | x when is_variadic_ode_fn x -> false - | x when is_variadic_dae_fn x -> false + | x when Frontend.Library.is_variadic_function_name x -> false | _ -> let name = - string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name) - in - let namematches = Hashtbl.find_multi stan_math_signatures name in + Frontend.Library.string_operator_to_function_name + (Utils.stdlib_distribution_name name) in + let namematches = Frontend.Library.get_signatures name in let filteredmatches = List.filter ~f:(fun x -> diff --git a/src/analysis_and_optimization/Monotone_framework.ml b/src/analysis_and_optimization/Monotone_framework.ml index 442781b210..83c231c694 100644 --- a/src/analysis_and_optimization/Monotone_framework.ml +++ b/src/analysis_and_optimization/Monotone_framework.ml @@ -2,7 +2,7 @@ open Core_kernel open Core_kernel.Poly -open Monotone_framework_sigs +open Monotone_framework_intf open Mir_utils open Middle @@ -864,7 +864,7 @@ let rec declared_variables_stmt (List.map ~f:(fun x -> declared_variables_stmt x.pattern) l) let propagation_mfp (prog : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (propagation_transfer : (int, Stmt.Located.Non_recursive.t) Map.Poly.t @@ -897,7 +897,7 @@ let propagation_mfp (prog : Program.Typed.t) Mf.mfp () let reaching_definitions_mfp (mir : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let variables = ( module struct @@ -918,7 +918,7 @@ let reaching_definitions_mfp (mir : Program.Typed.t) Mf.mfp () let initialized_vars_mfp (total : string Set.Poly.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let (module Lattice) = dual_powerset_lattice_empty_initial @@ -949,8 +949,7 @@ let globals (prog : Program.Typed.t) = (** Monotone framework instance for live_variables analysis. Expects reverse flowgraph. *) let live_variables_mfp (prog : Program.Typed.t) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let never_kill = globals prog in let variables = @@ -970,10 +969,8 @@ let live_variables_mfp (prog : Program.Typed.t) (** Instantiate all four instances of the monotone framework for lazy code motion, reusing code between them *) -let lazy_expressions_mfp - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) +let lazy_expressions_mfp (module Flowgraph : FLOWGRAPH with type labels = int) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let all_expressions = used_subexpressions_stmt @@ -1031,8 +1028,7 @@ let lazy_expressions_mfp * *) let minimal_variables_mfp - (module Circular_Fwd_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Circular_Fwd_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (initial_variables : string Set.Poly.t) (gen_variable : diff --git a/src/analysis_and_optimization/Monotone_framework_sigs.mli b/src/analysis_and_optimization/Monotone_framework_intf.ml similarity index 100% rename from src/analysis_and_optimization/Monotone_framework_sigs.mli rename to src/analysis_and_optimization/Monotone_framework_intf.ml diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index eaa793e617..1f3de62d1e 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -6,6 +6,66 @@ open Common open Middle open Mir_utils +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +let settings_const b = + { function_inlining= b + ; static_loop_unrolling= b + ; one_step_loop_unrolling= b + ; list_collapsing= b + ; block_fixing= b + ; allow_uninitialized_decls= b + ; constant_propagation= b + ; expression_propagation= b + ; copy_propagation= b + ; dead_code_elimination= b + ; partial_evaluation= b + ; lazy_code_motion= b + ; optimize_ad_levels= b + ; preserve_stability= not b + ; optimize_soa= b } + +let all_optimizations : optimization_settings = settings_const true +let no_optimizations : optimization_settings = settings_const false + +type optimization_level = O0 | O1 | Oexperimental + +let level_optimizations (lvl : optimization_level) : optimization_settings = + match lvl with + | O0 -> no_optimizations + | O1 -> + { function_inlining= true + ; static_loop_unrolling= false + ; one_step_loop_unrolling= false + ; list_collapsing= true + ; block_fixing= true + ; constant_propagation= true + ; expression_propagation= false + ; copy_propagation= true + ; dead_code_elimination= true + ; partial_evaluation= true + ; lazy_code_motion= false + ; allow_uninitialized_decls= true + ; optimize_ad_levels= false + ; preserve_stability= false + ; optimize_soa= true } + | Oexperimental -> all_optimizations + (** Apply the transformation to each function body and to the rest of the program as one block. @@ -629,7 +689,7 @@ let list_collapsing (mir : Program.Typed.t) = let propagation (propagation_transfer : (int, Stmt.Located.Non_recursive.t) Map.Poly.t - -> (module Monotone_framework_sigs.TRANSFER_FUNCTION + -> (module Monotone_framework_intf.TRANSFER_FUNCTION with type labels = int and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t option ) ) (mir : Program.Typed.t) = @@ -712,7 +772,7 @@ let dead_code_elimination (mir : Program.Typed.t) = let dead_code_elim_stmt_base i stmt = (* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *) let live_variables_s = - (Map.find_exn live_variables i).Monotone_framework_sigs.entry in + (Map.find_exn live_variables i).Monotone_framework_intf.entry in match stmt with | Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) -> if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then stmt @@ -1184,11 +1244,6 @@ let optimize_soa (mir : Program.Typed.t) = List.fold ~init:Set.Poly.empty ~f:(Mem_pattern.query_initial_demotable_stmt false) mir.log_prob in - (* - let print_set s = - Set.Poly.iter ~f:print_endline s in - let () = print_set initial_variables in - *) let mod_exprs aos_exits mod_expr = Mir_utils.map_rec_expr (Mem_pattern.modify_expr_pattern aos_exits) mod_expr in @@ -1207,68 +1262,6 @@ let optimize_soa (mir : Program.Typed.t) = in {mir with log_prob= transform' mir.log_prob} -(* Apparently you need to completely copy/paste type definitions between - ml and mli files?*) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } - -let settings_const b = - { function_inlining= b - ; static_loop_unrolling= b - ; one_step_loop_unrolling= b - ; list_collapsing= b - ; block_fixing= b - ; allow_uninitialized_decls= b - ; constant_propagation= b - ; expression_propagation= b - ; copy_propagation= b - ; dead_code_elimination= b - ; partial_evaluation= b - ; lazy_code_motion= b - ; optimize_ad_levels= b - ; preserve_stability= not b - ; optimize_soa= b } - -let all_optimizations : optimization_settings = settings_const true -let no_optimizations : optimization_settings = settings_const false - -type optimization_level = O0 | O1 | Oexperimental - -let level_optimizations (lvl : optimization_level) : optimization_settings = - match lvl with - | O0 -> no_optimizations - | O1 -> - { function_inlining= true - ; static_loop_unrolling= false - ; one_step_loop_unrolling= false - ; list_collapsing= true - ; block_fixing= true - ; constant_propagation= true - ; expression_propagation= false - ; copy_propagation= true - ; dead_code_elimination= true - ; partial_evaluation= true - ; lazy_code_motion= false - ; allow_uninitialized_decls= true - ; optimize_ad_levels= false - ; preserve_stability= false - ; optimize_soa= true } - | Oexperimental -> all_optimizations - let optimization_suite ?(settings = all_optimizations) mir = let preserve_stability = settings.preserve_stability in let maybe_optimizations = diff --git a/src/analysis_and_optimization/Optimize.mli b/src/analysis_and_optimization/Optimize.mli index 2885ce9b44..d5ca472a6d 100644 --- a/src/analysis_and_optimization/Optimize.mli +++ b/src/analysis_and_optimization/Optimize.mli @@ -1,6 +1,33 @@ (* Code for optimization passes on the MIR *) + open Middle +(** Interface for turning individual optimizations on/off. Useful for testing + and for top-level interface flags. *) +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +val all_optimizations : optimization_settings +val no_optimizations : optimization_settings + +type optimization_level = O0 | O1 | Oexperimental + +val level_optimizations : optimization_level -> optimization_settings + val function_inlining : Program.Typed.t -> Program.Typed.t (** Inline all functions except for ones with forward declarations (e.g. recursive functions, mutually recursive functions, and @@ -59,35 +86,9 @@ val optimize_ad_levels : Program.Typed.t -> Program.Typed.t val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t (** Marks Decl types such that, if the first assignment after the decl - assigns to the full object, allow the object to be constructed but + assigns to the full object, allow the object to be constructed but not uninitialized. *) -(** Interface for turning individual optimizations on/off. Useful for testing - and for top-level interface flags. *) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } - -val all_optimizations : optimization_settings -val no_optimizations : optimization_settings - -type optimization_level = O0 | O1 | Oexperimental - -val level_optimizations : optimization_level -> optimization_settings - val optimization_suite : ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t (** Perform all optimizations in this module on the MIR in an appropriate order. *) diff --git a/src/analysis_and_optimization/Partial_evaluator.ml b/src/analysis_and_optimization/Partial_evaluator.ml index 0b898eaa03..cd9794f62b 100644 --- a/src/analysis_and_optimization/Partial_evaluator.ml +++ b/src/analysis_and_optimization/Partial_evaluator.ml @@ -106,11 +106,11 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = Operator.of_string_opt name |> Option.value_map ~f:(fun op -> - Frontend.Typechecker.operator_stan_math_return_type op + Frontend.Typechecker.operator_return_type op argument_types |> Option.map ~f:fst ) ~default: - (Frontend.Typechecker.stan_math_return_type name + (Frontend.Typechecker.library_function_return_type name argument_types ) in let try_partially_evaluate_stanlib e = Expr.Fixed.Pattern.( diff --git a/src/analysis_and_optimization/Partial_evaluator.mli b/src/analysis_and_optimization/Partial_evaluator.mli new file mode 100644 index 0000000000..ce3b24fa58 --- /dev/null +++ b/src/analysis_and_optimization/Partial_evaluator.mli @@ -0,0 +1,4 @@ +open Middle + +val try_eval_expr : Expr.Typed.t -> Expr.Typed.t +val eval_prog : Program.Typed.t -> Program.Typed.t diff --git a/src/analysis_and_optimization/Pedantic_analysis.ml b/src/analysis_and_optimization/Pedantic_analysis.ml index 19b7777f40..f35b171fe3 100644 --- a/src/analysis_and_optimization/Pedantic_analysis.ml +++ b/src/analysis_and_optimization/Pedantic_analysis.ml @@ -487,11 +487,14 @@ let settings_constant_prop = ; copy_propagation= true ; partial_evaluation= true } +(** Pedantic mode is only really valid for the Stan Math backend *) + (* Collect all pedantic mode warnings, sorted, to stderr *) let warn_pedantic (mir_unopt : Program.Typed.t) = (* Some warnings will be stronger when constants are propagated *) let mir = - Optimize.optimization_suite ~settings:settings_constant_prop mir_unopt in + Optimize.optimization_suite ~settings:settings_constant_prop mir_unopt + in (* Try to avoid recomputation by pre-building structures *) let distributions_info = list_distributions mir in let factor_graph = prog_factor_graph mir in diff --git a/src/analysis_and_optimization/dune b/src/analysis_and_optimization/dune index b656409cb0..c7f4a7dd05 100644 --- a/src/analysis_and_optimization/dune +++ b/src/analysis_and_optimization/dune @@ -1,9 +1,7 @@ (library (name analysis_and_optimization) (public_name stanc.analysis) - (libraries core_kernel str fmt common middle frontend) + (libraries core_kernel str fmt common middle frontend stan_math_backend) (inline_tests) - ;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation - (modules_without_implementation monotone_framework_sigs) (preprocess (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) diff --git a/src/frontend/Ast.ml b/src/frontend/Ast.ml index 58502a85a5..c47e9a8daa 100644 --- a/src/frontend/Ast.ml +++ b/src/frontend/Ast.ml @@ -80,6 +80,10 @@ let mk_untyped_expression ~expr ~loc = {expr; emeta= {loc}} let mk_typed_expression ~expr ~loc ~type_ ~ad_level = {expr; emeta= {loc; type_; ad_level}} +let mk_fun_app ~is_cond_dist (kind, id, arguments) = + if is_cond_dist then CondDistApp (kind, id, arguments) + else FunApp (kind, id, arguments) + let expr_loc_lub exprs = match List.map ~f:(fun e -> e.emeta.loc) exprs with | [] -> Location_span.empty diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index d8db6c1e3d..a68157bdae 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -21,14 +21,6 @@ let drop_leading_zeros s = let format_number s = s |> without_underscores |> drop_leading_zeros -let%expect_test "format_number0" = - format_number "0_000." |> print_endline ; - [%expect "0."] - -let%expect_test "format_number1" = - format_number ".123_456" |> print_endline ; - [%expect ".123456"] - let rec op_to_funapp op args type_ = let loc = Ast.expr_loc_lub args in let adlevel = Ast.expr_ad_lub args in @@ -107,8 +99,8 @@ let truncate_dist ud_dists (id : Ast.identifier) ast_obs ast_args t = | None -> ( Ast.StanLib FnPlain , Set.to_list possible_names |> List.hd_exn - , if Stan_math_signatures.is_stan_math_function_name (id.name ^ "_lpmf") - then UnsizedType.UInt + , if Library.is_stdlib_function_name (id.name ^ "_lpmf") then + UnsizedType.UInt else UnsizedType.UReal (* close enough *) ) in let trunc cond_op (x : Ast.typed_expression) y = let smeta = x.Ast.emeta.loc in @@ -431,7 +423,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) = | Ast.IncrementLogProb e | Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap | Ast.Tilde {arg; distribution; args; truncation} -> let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name in + Std_library_utils.dist_name_suffix Library.is_stdlib_function_name ud_dists distribution.name in let name = distribution.name ^ suffix in let kind = let possible_names = diff --git a/src/frontend/Canonicalize.ml b/src/frontend/Canonicalize.ml index d4bec27795..f079db7aea 100644 --- a/src/frontend/Canonicalize.ml +++ b/src/frontend/Canonicalize.ml @@ -1,6 +1,6 @@ open Core_kernel open Ast -open Deprecation_analysis +module Deprecation = Deprecation_analysis type canonicalizer_settings = {deprecations: bool; parentheses: bool; braces: bool; inline_includes: bool} @@ -20,7 +20,8 @@ let rec repair_syntax_stmt user_dists {stmt; smeta} = { stmt= Tilde { arg - ; distribution= {name= without_suffix user_dists name; id_loc} + ; distribution= + {name= Deprecation.without_suffix user_dists name; id_loc} ; args ; truncation } ; smeta } @@ -46,10 +47,10 @@ let rec replace_deprecated_expr (replace_deprecated_expr deprecated_userdefined {expr= TernaryIf ({expr= Paren c; emeta= c.emeta}, t, e); emeta} ) | FunApp (StanLib suffix, {name; id_loc}, e) -> - if is_deprecated_distribution name then + if Deprecation.is_deprecated_distribution name then CondDistApp ( StanLib suffix - , {name= rename_deprecated deprecated_distributions name; id_loc} + , {name= Deprecation.rename_deprecated_distribution name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) else if String.is_suffix name ~suffix:"_cdf" then CondDistApp @@ -59,14 +60,14 @@ let rec replace_deprecated_expr else FunApp ( StanLib suffix - , {name= rename_deprecated deprecated_functions name; id_loc} + , {name= Deprecation.rename_deprecated_function name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) | FunApp (UserDefined suffix, {name; id_loc}, e) -> ( match String.Map.find deprecated_userdefined name with | Some type_ -> CondDistApp ( UserDefined suffix - , {name= update_suffix name type_; id_loc} + , {name= Deprecation.update_suffix name type_; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) | None -> if String.is_suffix name ~suffix:"_cdf" then @@ -125,7 +126,7 @@ let rec replace_deprecated_stmt | FunDef {returntype; funname= {name; id_loc}; arguments; body} -> let newname = match String.Map.find deprecated_userdefined name with - | Some type_ -> update_suffix name type_ + | Some type_ -> Deprecation.update_suffix name type_ | None -> name in FunDef { returntype @@ -226,7 +227,8 @@ let repair_syntax program settings = if settings.deprecations then program |> map_program - (repair_syntax_stmt (userdef_distributions program.functionblock)) + (repair_syntax_stmt + (Deprecation.userdef_distributions program.functionblock) ) else program let canonicalize_program program settings : typed_program = @@ -234,7 +236,8 @@ let canonicalize_program program settings : typed_program = if settings.deprecations then program |> map_program - (replace_deprecated_stmt (collect_userdef_distributions program)) + (replace_deprecated_stmt + (Deprecation.collect_userdef_distributions program) ) else program in let program = if settings.parentheses then program |> map_program parens_stmt else program diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index a1f5db637f..e4a23df9e0 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -2,47 +2,29 @@ open Core_kernel open Ast open Middle -let deprecated_functions = - String.Map.of_alist_exn - [ ("multiply_log", ("lmultiply", "2.32.0")) - ; ("binomial_coefficient_log", ("lchoose", "2.32.0")) - ; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) ] - -let deprecated_odes = - String.Map.of_alist_exn - [ ("integrate_ode", ("ode_rk45", "3.0")) - ; ("integrate_ode_rk45", ("ode_rk45", "3.0")) - ; ("integrate_ode_bdf", ("ode_bdf", "3.0")) - ; ("integrate_ode_adams", ("ode_adams", "3.0")) ] - -let deprecated_distributions = - String.Map.of_alist_exn - (List.map - ~f:(fun (x, y) -> (x, (y, "2.32.0"))) - (List.concat_map Middle.Stan_math_signatures.distributions - ~f:(fun (fnkinds, name, _, _) -> - List.filter_map fnkinds ~f:(function - | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") - | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") - | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") - | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") - | Rng | UnaryVectorized -> None ) ) ) ) - let stan_lib_deprecations = - Map.merge_skewed deprecated_distributions deprecated_functions + Map.merge_skewed Library.deprecated_distributions Library.deprecated_functions ~combine:(fun ~key x y -> Common.FatalError.fatal_error_msg [%message "Common key in deprecation map" (key : string) - (x : string * string) - (y : string * string)] ) + (x : Std_library_utils.deprecation_info) + (y : Std_library_utils.deprecation_info)] ) let is_deprecated_distribution name = - Option.is_some (Map.find deprecated_distributions name) + Map.mem Library.deprecated_distributions name let rename_deprecated map name = - Map.find map name |> Option.map ~f:fst |> Option.value ~default:name + Map.find map name + |> Option.map ~f:(fun Std_library_utils.{replacement; canonicalize_away; _} -> + if canonicalize_away then replacement else name ) + |> Option.value ~default:name + +let rename_deprecated_distribution = + rename_deprecated Library.deprecated_distributions + +let rename_deprecated_function = rename_deprecated Library.deprecated_functions let distribution_suffix name = let open String in @@ -119,12 +101,10 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list) | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> let w = match Map.find stan_lib_deprecations name with - | Some (rename, version) -> + | Some {replacement; version; extra_message; _} -> [ ( emeta.loc , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. This can be automatically changed using the \ - canonicalize flag for stanc" ) ] + ^ ". Use " ^ replacement ^ " instead. " ^ extra_message ) ] | _ when String.is_suffix name ~suffix:"_cdf" -> [ ( emeta.loc , "Use of " ^ name @@ -132,17 +112,7 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list) of a CDF is deprecated and will be removed in Stan 2.32.0. \ This can be automatically changed using the canonicalize \ flag for stanc" ) ] - | _ -> ( - match Map.find deprecated_odes name with - | Some (rename, version) -> - [ ( emeta.loc - , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. \n\ - The new interface is slightly different, see: \ - https://mc-stan.org/users/documentation/case-studies/convert_odes.html" - ) ] - | _ -> [] ) in + | _ -> [] in acc @ w @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) | _ -> fold_expression collect_deprecated_expr (fun l _ -> l) acc expr diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index 205e400afd..4a4daf5cf0 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -16,8 +16,7 @@ val collect_userdef_distributions : val distribution_suffix : string -> bool val without_suffix : string list -> string -> string val is_deprecated_distribution : string -> bool -val deprecated_distributions : (string * string) String.Map.t -val deprecated_functions : (string * string) String.Map.t -val rename_deprecated : (string * string) String.Map.t -> string -> string +val rename_deprecated_distribution : string -> string +val rename_deprecated_function : string -> string val userdef_distributions : untyped_statement block option -> string list val collect_warnings : typed_program -> Warnings.t list diff --git a/src/frontend/Environment.ml b/src/frontend/Environment.ml index edcdc89c83..0d41e0d460 100644 --- a/src/frontend/Environment.ml +++ b/src/frontend/Environment.ml @@ -27,9 +27,9 @@ type info = type t = info list String.Map.t -let stan_math_environment = +let make_from_library signatures : t = let functions = - Hashtbl.to_alist Stan_math_signatures.stan_math_signatures + Hashtbl.to_alist signatures |> List.map ~f:(fun (key, values) -> ( key , List.map values ~f:(fun (rt, args, mem) -> diff --git a/src/frontend/Environment.mli b/src/frontend/Environment.mli index a0436c3610..0464b2c1de 100644 --- a/src/frontend/Environment.mli +++ b/src/frontend/Environment.mli @@ -29,8 +29,16 @@ type info = type t -val stan_math_environment : t -(** A type environment which contains the Stan math library functions +val make_from_library : + ( string + , ( UnsizedType.returntype + * (UnsizedType.autodifftype * UnsizedType.t) list + * Common.Helpers.mem_pattern ) + list ) + Core_kernel.Hashtbl.t + -> t +(** Make a type environment from a hashtable of functions like those from + [Std_library_utils] *) val find : t -> string -> info list diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index 5ded69f550..78802bbf84 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -45,6 +45,14 @@ let rec get_function_calls_expr (funs, distrs) expr = | _ -> (funs, distrs) in fold_expression get_function_calls_expr (fun acc _ -> acc) acc expr.expr +let includes_json () = + `Assoc + [ ( "included_files" + , `List + ( List.rev !Preprocessor.included_files + |> List.map ~f:(fun str -> `String str) ) ) ] + + let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = let acc = match stmt.stmt with @@ -57,8 +65,7 @@ let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = (funs, distrs) else let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name - in + Std_library_utils.dist_name_suffix Library.is_stdlib_function_name ud_dists distribution.name in let name = distribution.name ^ Utils.unnormalized_suffix suffix in (funs, Set.add distrs name) | _ -> (funs, distrs) in @@ -87,13 +94,6 @@ let function_calls_json p = `List (Set.to_list s |> List.map ~f:(fun str -> `String str)) in `Assoc [("functions", set_to_List funs); ("distributions", set_to_List distrs)] -let includes_json () = - `Assoc - [ ( "included_files" - , `List - ( List.rev !Preprocessor.included_files - |> List.map ~f:(fun str -> `String str) ) ) ] - let info_json ast = List.fold ~f:Util.combine ~init:(`Assoc []) [ block_info_json "inputs" ast.datablock diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index a21aa8f717..ca0447be43 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -10,7 +10,7 @@ - [type]: the base type of the variable (["int"] or ["real"]). - [dimensions]: the number of dimensions ([0] for a scalar, [1] for a vector or row vector, etc.). - + The JSON object also have the fields [stanlib_calls] and [distributions] containing the name of the standard library functions called and distributions used. diff --git a/src/frontend/Library.mli b/src/frontend/Library.mli new file mode 100644 index 0000000000..5f1c399ea3 --- /dev/null +++ b/src/frontend/Library.mli @@ -0,0 +1,40 @@ +(** This module is used as a parameter for many functors which + rely on information about a backend-specific Stan library. *) + +open Middle +open Core_kernel +open Std_library_utils + +val function_signatures : (string, signature list) Hashtbl.t +(** Mapping from names to signature(s) of functions *) + +val distribution_families : string list + +val is_stdlib_function_name : string -> bool +(** Equivalent to [Hashtbl.mem function_signatures s]*) + +val get_signatures : string -> signature list +(** Equivalent to [Hashtbl.find_multi function_signatures s]*) + +val get_operator_signatures : Operator.t -> signature list +val get_assignment_operator_signatures : Operator.t -> signature list +val is_not_overloadable : string -> bool +val is_variadic_function_name : string -> bool +val variadic_function_returntype : string -> UnsizedType.returntype option + +val check_variadic_fn : + Ast.identifier + -> is_cond_dist:bool + -> Location_span.t + -> Environment.originblock + -> Environment.t + -> Ast.typed_expression list + -> Ast.typed_expression +(** This function is responsible for typechecking varadic function + calls. It needs to live in the Library since this is usually + bespoke per-function. *) + +val operator_to_function_names : Operator.t -> string list +val string_operator_to_function_name : string -> string +val deprecated_distributions : deprecation_info String.Map.t +val deprecated_functions : deprecation_info String.Map.t diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index 79594d24fb..e56f5fdb19 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -14,19 +14,13 @@ module TypeError = struct | IntIntArrayOrRangeExpected of UnsizedType.t | IntOrRealContainerExpected of UnsizedType.t | ArrayVectorRowVectorMatrixExpected of UnsizedType.t - | IllTypedAssignment of Operator.t * UnsizedType.t * UnsizedType.t + | IllTypedAssignment of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list | IllTypedTernaryIf of UnsizedType.t * UnsizedType.t * UnsizedType.t - | IllTypedReduceSum of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedReduceSumGeneric of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedVariadicDE of + | IllTypedVariadicFn of string * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list @@ -50,9 +44,15 @@ module TypeError = struct string * UnsizedType.t list * (SignatureMismatch.signature_error list * bool) - | IllTypedBinaryOperator of Operator.t * UnsizedType.t * UnsizedType.t - | IllTypedPrefixOperator of Operator.t * UnsizedType.t - | IllTypedPostfixOperator of Operator.t * UnsizedType.t + | IllTypedBinaryOperator of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list + | IllTypedPrefixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list + | IllTypedPostfixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list | NotIndexable of UnsizedType.t * int let pp ppf = function @@ -101,18 +101,18 @@ module TypeError = struct "Foreach-loop must be over array, vector, row_vector or matrix. \ Instead found expression of type %a." UnsizedType.pp ut - | IllTypedAssignment (Operator.Equals, lt, rt) -> + | IllTypedAssignment (Operator.Equals, lt, rt, _) -> Fmt.pf ppf "Ill-typed arguments supplied to assignment operator =: lhs has type \ %a and rhs has type %a" UnsizedType.pp lt UnsizedType.pp rt - | IllTypedAssignment (op, lt, rt) -> + | IllTypedAssignment (op, lt, rt, sigs) -> Fmt.pf ppf "@[Ill-typed arguments supplied to assignment operator %a=: lhs \ has type %a and rhs has type %a.@ Available signatures for given \ lhs:@]@ %a" Operator.pp op UnsizedType.pp lt UnsizedType.pp rt - SignatureMismatch.pp_math_lib_assignmentoperator_sigs (lt, op) + SignatureMismatch.pp_assignmentoperator_sigs (lt, sigs) | IllTypedTernaryIf (UInt, ut, _) when UnsizedType.is_fun_type ut -> Fmt.pf ppf "Ternary expression cannot have a function type: %a" UnsizedType.pp ut @@ -125,13 +125,7 @@ module TypeError = struct Fmt.pf ppf "Condition in ternary expression must be primitive int; found type=%a" UnsizedType.pp ut1 - | IllTypedReduceSum (name, arg_tys, expected_args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedVariadicDE (name, arg_tys, args, error, return_type) -> + | IllTypedVariadicFn (name, arg_tys, args, error, return_type) -> SignatureMismatch.pp_signature_mismatch ppf ( name , arg_tys @@ -224,32 +218,25 @@ module TypeError = struct prefix suffix prefix prefix newsuffix | IllTypedFunctionApp (name, arg_tys, errors) -> SignatureMismatch.pp_signature_mismatch ppf (name, arg_tys, errors) - | IllTypedBinaryOperator (op, lt, rt) -> + | IllTypedBinaryOperator (op, lt, rt, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to infix operator %a. Available \ - signatures: %s@[Instead supplied arguments of incompatible type: \ - %a, %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp lt UnsizedType.pp rt - | IllTypedPrefixOperator (op, ut) -> + signatures: @[%a@.@]@[Instead supplied arguments of \ + incompatible type: %a, %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp lt + UnsizedType.pp rt + | IllTypedPrefixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to prefix operator %a. Available \ - signatures: %s@[Instead supplied argument of incompatible type: \ - %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut - | IllTypedPostfixOperator (op, ut) -> + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut + | IllTypedPostfixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to postfix operator %a. Available \ - signatures: %s\n\ - Instead supplied argument of incompatible type: %a." Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut end module IdentifierError = struct @@ -531,8 +518,8 @@ let int_or_real_container_expected loc ut = let array_vector_rowvector_matrix_expected loc ut = TypeError (loc, TypeError.ArrayVectorRowVectorMatrixExpected ut) -let illtyped_assignment loc assignop lt rt = - TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt)) +let illtyped_assignment loc assignop lt rt sigs = + TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt, sigs)) let illtyped_ternary_if loc predt lt rt = TypeError (loc, TypeError.IllTypedTernaryIf (predt, lt, rt)) @@ -540,34 +527,9 @@ let illtyped_ternary_if loc predt lt rt = let returning_fn_expected_nonreturning_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonReturningFound name) -let illtyped_reduce_sum loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedReduceSum (name, arg_tys, args, error)) - -let illtyped_reduce_sum_generic loc name arg_tys expected_args error = +let illtyped_variadic_fn loc name arg_tys args error return_type = TypeError - ( loc - , TypeError.IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) - ) - -let illtyped_variadic_ode loc name arg_tys args error = - TypeError - ( loc - , TypeError.IllTypedVariadicDE - ( name - , arg_tys - , args - , error - , Stan_math_signatures.variadic_ode_fun_return_type ) ) - -let illtyped_variadic_dae loc name arg_tys args error = - TypeError - ( loc - , TypeError.IllTypedVariadicDE - ( name - , arg_tys - , args - , error - , Stan_math_signatures.variadic_dae_fun_return_type ) ) + (loc, TypeError.IllTypedVariadicFn (name, arg_tys, args, error, return_type)) let ambiguous_function_promotion loc name arg_tys signatures = TypeError @@ -601,14 +563,14 @@ let nonreturning_fn_expected_undeclaredident_found loc name sug = let illtyped_fn_app loc name errors arg_tys = TypeError (loc, TypeError.IllTypedFunctionApp (name, arg_tys, errors)) -let illtyped_binary_op loc op lt rt = - TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt)) +let illtyped_binary_op loc op lt rt sigs = + TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt, sigs)) -let illtyped_prefix_op loc op ut = - TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut)) +let illtyped_prefix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut, sigs)) -let illtyped_postfix_op loc op ut = - TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut)) +let illtyped_postfix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut, sigs)) let not_indexable loc ut nidcs = TypeError (loc, TypeError.NotIndexable (ut, nidcs)) diff --git a/src/frontend/Semantic_error.mli b/src/frontend/Semantic_error.mli index 525a320694..18787d79ee 100644 --- a/src/frontend/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -25,7 +25,12 @@ val array_vector_rowvector_matrix_expected : Location_span.t -> UnsizedType.t -> t val illtyped_assignment : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t val illtyped_ternary_if : Location_span.t -> UnsizedType.t -> UnsizedType.t -> UnsizedType.t -> t @@ -42,28 +47,13 @@ val returning_fn_expected_undeclared_dist_suffix_found : val returning_fn_expected_wrong_dist_suffix_found : Location_span.t -> string * string -> t -val illtyped_reduce_sum : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - -val illtyped_reduce_sum_generic : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - -val illtyped_variadic_ode : +val illtyped_variadic_fn : Location_span.t -> string -> UnsizedType.t list -> (UnsizedType.autodifftype * UnsizedType.t) list -> SignatureMismatch.function_mismatch + -> UnsizedType.t -> t val ambiguous_function_promotion : @@ -74,14 +64,6 @@ val ambiguous_function_promotion : list -> t -val illtyped_variadic_dae : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t @@ -96,10 +78,27 @@ val illtyped_fn_app : -> t val illtyped_binary_op : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_prefix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_postfix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t -val illtyped_prefix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t -val illtyped_postfix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t val not_indexable : Location_span.t -> UnsizedType.t -> int -> t val ident_is_keyword : Location_span.t -> string -> t val ident_is_model_name : Location_span.t -> string -> t diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index 7454d627fe..5d4d4e08bf 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -245,9 +245,6 @@ let matching_function env name args = UnsizedType.compare_returntype ret1 ret2 ) in find_compatible_rt function_types args -let matching_stanlib_function = - matching_function Environment.stan_math_environment - let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys fun_return args = let minimal_func_type = @@ -286,6 +283,20 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys | (_, x) :: _ -> TypeMismatch (minimal_func_type, x, None) |> wrap_err | [] -> Error ([], ArgNumMismatch (List.length mandatory_arg_tys, 0)) +let find_matching_first_order_fn tenv matches (fname : Ast.identifier) = + let candidates = + Utils.stdlib_distribution_name fname.name + |> Environment.find tenv |> List.map ~f:matches in + let ok, errs = List.partition_map candidates ~f:Result.to_either in + match unique_minimum_promotion ok with + | Ok a -> UniqueMatch a + | Error (Some promotions) -> + List.filter_map promotions ~f:(function + | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) + | _ -> None ) + |> AmbiguousMatch + | Error None -> SignatureErrors (List.hd_exn errs) + let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) = let open Fmt in let ctx = ref TypeMap.empty in @@ -383,10 +394,8 @@ let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) = (list ~sep:cut pp_signature) sigs pp_omitted () -let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = +let pp_assignmentoperator_sigs ppf (lt, errors) = let signatures = - let errors = - Stan_math_signatures.make_assignmentoperator_stan_math_signatures op in let errors = List.filter ~f:(fun (_, args, _) -> @@ -398,7 +407,7 @@ let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = | errors, _ -> Some (errors, true) in let pp_sigs ppf (signatures, omitted) = Fmt.pf ppf "@[%a%a@]" - (Fmt.list ~sep:Fmt.cut Stan_math_signatures.pp_math_sig) + (Fmt.list ~sep:Fmt.cut Std_library_utils.pp_math_sig) signatures (if omitted then Fmt.pf else Fmt.nop) "@ (Additional signatures omitted)" in diff --git a/src/frontend/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli index 66a951903c..6927e21716 100644 --- a/src/frontend/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -55,12 +55,6 @@ val matching_function : Requires a unique minimum option under type promotion *) -val matching_stanlib_function : - string -> (UnsizedType.autodifftype * UnsizedType.t) list -> match_result -(** Same as [matching_function] but requires specifically that the function - be from StanMath (uses [Environment.stan_math_environment]) -*) - val check_variadic_args : bool -> (UnsizedType.autodifftype * UnsizedType.t) list @@ -75,6 +69,15 @@ val check_variadic_args : If none is found, returns [Error] of the list of args and a function_mismatch. *) +val find_matching_first_order_fn : + Environment.t + -> (Environment.info -> (UnsizedType.t * Promotion.t list, 'a) result) + -> Ast.identifier + -> (UnsizedType.t * Promotion.t list, 'a) generic_match_result +(** Given a constraint function [matches], find any signature which exists + Returns the first [Ok] if any exist, or else [Error] +*) + val pp_signature_mismatch : Format.formatter -> string @@ -86,8 +89,8 @@ val pp_signature_mismatch : * bool ) -> unit -val pp_math_lib_assignmentoperator_sigs : - Format.formatter -> UnsizedType.t * Operator.t -> unit +val pp_assignmentoperator_sigs : + Format.formatter -> UnsizedType.t * Std_library_utils.signature list -> unit val compare_errors : function_mismatch -> function_mismatch -> int val compare_match_results : match_result -> match_result -> int diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml new file mode 100644 index 0000000000..f4ea514d3c --- /dev/null +++ b/src/frontend/Std_library_utils.ml @@ -0,0 +1,31 @@ +open Middle +open Core_kernel + +(* Types for the module representing the standard library *) +type fun_arg = UnsizedType.autodifftype * UnsizedType.t + +type signature = + UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern + +type deprecation_info = + { replacement: string + ; version: string + ; extra_message: string + ; canonicalize_away: bool } +[@@deriving sexp] + +let pp_math_sig ppf ((rt, args, mem_pattern) : signature) = + UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) + +let pp_math_sigs ppf (sigs : signature list) = + (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs + +let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs + +let dist_name_suffix (check : string -> bool) udf_names name = + let is_udf_name s = + List.exists ~f:(fun (n, _) -> String.equal s n) udf_names in + Utils.distribution_suffices + |> List.filter ~f:(fun sfx -> + check (name ^ sfx) || is_udf_name (name ^ sfx) ) + |> List.hd_exn diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index ab40113c23..efb692047f 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -56,10 +56,10 @@ let context block = ; in_udf_dist_def= false ; loop_depth= 0 } -let calculate_autodifftype cf origin ut = +let calculate_autodifftype current_block origin ut = match origin with | Env.(Param | TParam | Model | Functions) - when not (UnsizedType.is_int_type ut || cf.current_block = GQuant) -> + when not (UnsizedType.is_int_type ut || current_block = Env.GQuant) -> UnsizedType.AutoDiffable | _ -> DataOnly @@ -83,6 +83,8 @@ let reserved_keywords = ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" ; "static"; "auto" ] +let std_library_tenv : Env.t = Env.make_from_library Library.function_signatures + let verify_identifier id : unit = if id.name = !model_name then Semantic_error.ident_is_model_name id.id_loc id.name |> error @@ -121,9 +123,7 @@ let verify_name_fresh_udf loc tenv name = if (* variadic functions are currently not in math sigs and aren't overloadable due to their separate typechecking *) - Stan_math_signatures.is_reduce_sum_fn name - || Stan_math_signatures.is_variadic_ode_fn name - || Stan_math_signatures.is_variadic_dae_fn name + Library.is_not_overloadable name then Semantic_error.ident_is_stanmath_name loc name |> error else if Utils.is_unnormalized_distribution name then Semantic_error.udf_is_unnormalized_fn loc name |> error @@ -190,39 +190,35 @@ let match_to_rt_option = function | SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt | _ -> None -let stan_math_return_type name arg_tys = +let library_function_return_type name arg_tys = match name with - | x when Stan_math_signatures.is_reduce_sum_fn x -> - Some (UnsizedType.ReturnType UReal) - | x when Stan_math_signatures.is_variadic_ode_fn x -> - Some (UnsizedType.ReturnType (UArray UVector)) - | x when Stan_math_signatures.is_variadic_dae_fn x -> - Some (UnsizedType.ReturnType (UArray UVector)) + | x when Library.is_variadic_function_name x -> + Library.variadic_function_returntype x | _ -> - SignatureMismatch.matching_stanlib_function name arg_tys + SignatureMismatch.matching_function std_library_tenv name arg_tys |> match_to_rt_option -let operator_stan_math_return_type op arg_tys = +let operator_return_type op arg_tys = match (op, arg_tys) with | Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] -> Some (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion]) | IntDivide, _ -> None | _ -> - Stan_math_signatures.operator_to_stan_math_fns op + Library.operator_to_function_names op |> List.filter_map ~f:(fun name -> - SignatureMismatch.matching_stanlib_function name arg_tys + SignatureMismatch.matching_function std_library_tenv name arg_tys |> function | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) | _ -> None ) |> List.hd -let assignmentoperator_stan_math_return_type assop arg_tys = +let assignmentoperator_return_type assop arg_tys = ( match assop with | Operator.Divide -> - SignatureMismatch.matching_stanlib_function "divide" arg_tys + SignatureMismatch.matching_function std_library_tenv "divide" arg_tys |> match_to_rt_option | Plus | Minus | Times | EltTimes | EltDivide -> - operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst + operator_return_type assop arg_tys |> Option.map ~f:fst | _ -> None ) |> Option.bind ~f:(function | ReturnType rtype @@ -234,7 +230,7 @@ let assignmentoperator_stan_math_return_type assop arg_tys = | _ -> None ) let check_binop loc op le re = - let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in + let rt = [le; re] |> get_arg_types |> operator_return_type op in match rt with | Some (ReturnType type_, [p1; p2]) -> mk_typed_expression @@ -243,27 +239,34 @@ let check_binop loc op le re = ~type_ ~loc | _ -> Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ + (Library.get_operator_signatures op) |> error let check_prefixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in + let rt = operator_return_type op [arg_type te] in match rt with | Some (ReturnType type_, _) -> mk_typed_expression ~expr:(PrefixOp (op, te)) ~ad_level:(expr_ad_lub [te]) ~type_ ~loc - | _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error + | _ -> + Semantic_error.illtyped_prefix_op loc op te.emeta.type_ + (Library.get_operator_signatures op) + |> error let check_postfixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in + let rt = operator_return_type op [arg_type te] in match rt with | Some (ReturnType type_, _) -> mk_typed_expression ~expr:(PostfixOp (te, op)) ~ad_level:(expr_ad_lub [te]) ~type_ ~loc - | _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error + | _ -> + Semantic_error.illtyped_postfix_op loc op te.emeta.type_ + (Library.get_operator_signatures op) + |> error let check_id cf loc tenv id = match Env.find tenv (Utils.stdlib_distribution_name id.name) with @@ -272,7 +275,7 @@ let check_id cf loc tenv id = (Env.nearest_ident tenv id.name) |> error | {kind= `StanMath; _} :: _ -> - ( calculate_autodifftype cf MathLibrary UMathLibraryFunction + ( calculate_autodifftype cf.current_block MathLibrary UMathLibraryFunction , UnsizedType.UMathLibraryFunction ) | {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _ when cf.in_toplevel_decl -> @@ -284,16 +287,16 @@ let check_id cf loc tenv id = || cf.current_block = Model ) -> Semantic_error.invalid_unnormalized_fn loc |> error | {kind= `Variable {origin; _}; type_} :: _ -> - (calculate_autodifftype cf origin type_, type_) + (calculate_autodifftype cf.current_block origin type_, type_) | { kind= `UserDefined | `UserDeclared _ ; type_= UFun (args, rt, FnLpdf _, mem_pattern) } :: _ -> let type_ = UnsizedType.UFun (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - (calculate_autodifftype cf Functions type_, type_) + (calculate_autodifftype cf.current_block Functions type_, type_) | {kind= `UserDefined | `UserDeclared _; type_} :: _ -> - (calculate_autodifftype cf Functions type_, type_) + (calculate_autodifftype cf.current_block Functions type_, type_) let check_variable cf loc tenv id = let ad_level, type_ = check_id cf loc tenv id in @@ -411,7 +414,6 @@ let inferred_ad_type_of_indexed at uindices = uindices ) (* function checking *) - let verify_conddist_name loc id = if List.exists @@ -460,24 +462,17 @@ let verify_unnormalized cf loc id = && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) then Semantic_error.invalid_unnormalized_fn loc |> error -let mk_fun_app ~is_cond_dist (x, y, z) = - if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z) - let check_normal_fn ~is_cond_dist loc tenv id es = match Env.find tenv (Utils.normalized_name id.name) with | {kind= `Variable _; _} :: _ (* variables can sometimes shadow stanlib functions, so we have to check this *) - when not - (Stan_math_signatures.is_stan_math_function_name - (Utils.normalized_name id.name) ) -> + when not (Library.is_stdlib_function_name (Utils.normalized_name id.name)) + -> Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error | [] -> ( match Utils.split_distribution_suffix id.name with | Some (prefix, suffix) -> ( - let known_families = - List.map - ~f:(fun (_, y, _, _) -> y) - Stan_math_signatures.distributions in + let known_families = Library.distribution_families in let is_known_family s = List.mem known_families s ~equal:String.equal in match suffix with @@ -534,201 +529,12 @@ let check_normal_fn ~is_cond_dist loc tenv id es = |> Semantic_error.illtyped_fn_app loc id.name (l, b) |> error ) -(** Given a constraint function [matches], find any signature which exists - Returns the first [Ok] if any exist, or else [Error] -*) -let find_matching_first_order_fn tenv matches fname = - let candidates = - Utils.stdlib_distribution_name fname.name - |> Env.find tenv |> List.map ~f:matches in - let ok, errs = List.partition_map candidates ~f:Result.to_either in - match SignatureMismatch.unique_minimum_promotion ok with - | Ok a -> SignatureMismatch.UniqueMatch a - | Error (Some promotions) -> - List.filter_map promotions ~f:(function - | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) - | _ -> None ) - |> AmbiguousMatch - | Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs) - -let make_function_variable cf loc id = function - | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | UnsizedType.UFun _ as type_ -> - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | type_ -> - Common.FatalError.fatal_error_msg - [%message - "Attempting to create function variable out of " - (type_ : UnsizedType.t)] - let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list) = - if Stan_math_signatures.is_reduce_sum_fn id.name then - check_reduce_sum ~is_cond_dist loc cf tenv id tes - else if Stan_math_signatures.is_variadic_ode_fn id.name then - check_variadic_ode ~is_cond_dist loc cf tenv id tes - else if Stan_math_signatures.is_variadic_dae_fn id.name then - check_variadic_dae ~is_cond_dist loc cf tenv id tes + if Library.is_variadic_function_name id.name then + Library.check_variadic_fn id ~is_cond_dist loc cf.current_block tenv tes else check_normal_fn ~is_cond_dist loc tenv id tes -and check_reduce_sum ~is_cond_dist loc cf tenv id tes = - let basic_mismatch () = - let mandatory_args = - UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in - let mandatory_fun_args = - UnsizedType. - [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in - SignatureMismatch.check_variadic_args true mandatory_args mandatory_fun_args - UReal (get_arg_types tes) in - let fail () = - let expected_args, err = - basic_mismatch () |> Result.error |> Option.value_exn in - Semantic_error.illtyped_reduce_sum_generic loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es fn = - match fn with - | Env. - { type_= - UnsizedType.UFun - (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as - ftype - ; _ } - when List.mem Stan_math_signatures.reduce_sum_slice_types - sliced_arg_fun_type ~equal:( = ) -> - let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in - let mandatory_fun_args = - [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal arg_types - | _ -> basic_mismatch () in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - (* a valid signature exists *) - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_reduce_sum loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - -and check_variadic_ode ~is_cond_dist loc cf tenv id tes = - let optional_tol_mandatory_args = - if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then - Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types - else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then - Stan_math_signatures.variadic_ode_tol_arg_types - else [] in - let mandatory_arg_types = - Stan_math_signatures.variadic_ode_mandatory_arg_types - @ optional_tol_mandatory_args in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_ode_mandatory_fun_args - Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_ode loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es Env.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_ode_mandatory_fun_args - Stan_math_signatures.variadic_ode_fun_return_type arg_types in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) - ~type_:Stan_math_signatures.variadic_ode_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_ode loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - -and check_variadic_dae ~is_cond_dist loc cf tenv id tes = - let optional_tol_mandatory_args = - if Stan_math_signatures.is_variadic_dae_tol_fn id.name then - Stan_math_signatures.variadic_dae_tol_arg_types - else [] in - let mandatory_arg_types = - Stan_math_signatures.variadic_dae_mandatory_arg_types - @ optional_tol_mandatory_args in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_dae_mandatory_fun_args - Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_dae loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es Env.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_dae_mandatory_fun_args - Stan_math_signatures.variadic_dae_fun_return_type arg_types in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) - ~type_:Stan_math_signatures.variadic_dae_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_dae loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) = let name_check = if is_cond_dist then verify_conddist_name else verify_fn_conditioning in @@ -844,7 +650,8 @@ and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error else mk_typed_expression ~expr:GetLP - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) ~type_:UReal ~loc | GetTarget -> (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) @@ -856,7 +663,8 @@ and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error else mk_typed_expression ~expr:GetTarget - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) ~type_:UReal ~loc | ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc | RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc @@ -907,7 +715,7 @@ let check_nrfn loc tenv id es = match Env.find tenv id.name with | {kind= `Variable _; _} :: _ (* variables can shadow stanlib functions, so we have to check this *) - when not (Stan_math_signatures.is_stan_math_function_name id.name) -> + when not (Library.is_stdlib_function_name id.name) -> Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error | [] -> Semantic_error.nonreturning_fn_expected_undeclaredident_found loc id.name @@ -967,9 +775,9 @@ let verify_assignment_non_function loc ut id = | _ -> () let check_assignment_operator loc assop lhs rhs = - let err op = + let err op sigs = Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_ - in + sigs in match assop with | Assign | ArrowAssign -> ( match @@ -977,11 +785,13 @@ let check_assignment_operator loc assop lhs rhs = rhs.emeta.type_ with | Ok p -> Promotion.promote rhs p - | Error _ -> err Operator.Equals |> error ) + | Error _ -> err Operator.Equals [] |> error ) | OperatorAssign op -> ( let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in - let return_type = assignmentoperator_stan_math_return_type op args in - match return_type with Some Void -> rhs | _ -> err op |> error ) + let return_type = assignmentoperator_return_type op args in + match return_type with + | Some Void -> rhs + | _ -> err op (Library.get_assignment_operator_signatures op) |> error ) let check_lvalue cf tenv = function | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> @@ -1454,7 +1264,7 @@ and check_var_decl_initial_value loc cf tenv id init_val_opt = | Ok p -> Some (Promotion.promote rhs p) | Error _ -> Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_ - rhs.emeta.type_ + rhs.emeta.type_ [] |> error ) | None -> None @@ -1749,7 +1559,7 @@ let check_program_exn ; comments } as ast ) = warnings := [] ; (* create a new type environment which has only stan-math functions *) - let tenv = Env.stan_math_environment in + let tenv = std_library_tenv in let tenv, typed_fb = check_toplevel_block Functions tenv fb in verify_functions_have_defn tenv typed_fb ; let tenv, typed_db = check_toplevel_block Data tenv db in diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecker.mli index 07662dfe11..55620d1e12 100644 --- a/src/frontend/Typechecker.mli +++ b/src/frontend/Typechecker.mli @@ -11,37 +11,50 @@ A type environment {!val:Environment.t} is used to hold variables and functions, including Stan math functions. This is a functional map, meaning it is handled immutably. + + This module is parameterized over a Standard Library of function signatures, See + [Std_library_utils.Library]. For the main compiler, this is + [Stan_math_backend.Stan_math_library] *) open Ast +val model_name : string ref +(** A reference to hold the model name. Relevant for checking variable + clashes and used in code generation. *) + +val check_that_all_functions_have_definition : bool ref +(** A switch to determine whether we check that all functions have a definition *) + +val get_arg_types : typed_expression list -> Std_library_utils.fun_arg list +val type_of_expr_typed : typed_expression -> Middle.UnsizedType.t + +val calculate_autodifftype : + Environment.originblock + -> Environment.originblock + -> Middle.UnsizedType.t + -> Middle.UnsizedType.autodifftype + val check_program_exn : untyped_program -> typed_program * Warnings.t list (** - Type check a full Stan program. - Can raise [Errors.SemanticError] -*) + Type check a full Stan program. + Can raise [Errors.SemanticError] + *) val check_program : untyped_program -> (typed_program * Warnings.t list, Semantic_error.t) result (** - The safe version of [check_program_exn]. This catches - all [Errors.SemanticError] exceptions and converts them - into a [Result.t] -*) + The safe version of [check_program_exn]. This catches + all [Errors.SemanticError] exceptions and converts them + into a [Result.t] + *) -val operator_stan_math_return_type : +val operator_return_type : Middle.Operator.t -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> (Middle.UnsizedType.returntype * Promotion.t list) option -val stan_math_return_type : +val library_function_return_type : string -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> Middle.UnsizedType.returntype option - -val model_name : string ref -(** A reference to hold the model name. Relevant for checking variable - clashes and used in code generation. *) - -val check_that_all_functions_have_definition : bool ref -(** A switch to determine whether we check that all functions have a definition *) diff --git a/src/frontend/dune b/src/frontend/dune index 0359a63c1d..ea86e21473 100644 --- a/src/frontend/dune +++ b/src/frontend/dune @@ -3,6 +3,7 @@ (public_name stanc.frontend) (libraries core_kernel re menhirLib fmt middle common yojson) (inline_tests) + (virtual_modules library) (preprocess (pps ppx_jane ppx_deriving.fold ppx_deriving.map))) diff --git a/src/middle/Stan_math_signatures.mli b/src/middle/Stan_math_signatures.mli deleted file mode 100644 index 30b906c4d7..0000000000 --- a/src/middle/Stan_math_signatures.mli +++ /dev/null @@ -1,78 +0,0 @@ -(** This module stores a table of all signatures from the Stan - math C++ library which are exposed to Stan, and some helper - functions for dealing with those signatures. -*) - -open Core_kernel - -(** Function arguments are represented by their type an autodiff - type. This is [AutoDiffable] for everything except arguments - marked with the data keyword *) -type fun_arg = UnsizedType.autodifftype * UnsizedType.t - -(** Signatures consist of a return type, a list of arguments, and a flag - for whether or not those arguments can be Struct of Arrays objects *) -type signature = - UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern - -val stan_math_signatures : (string, signature list) Hashtbl.t -(** Mapping from names to signature(s) of functions *) - -val is_stan_math_function_name : string -> bool -(** Equivalent to [Hashtbl.mem stan_math_signatures s]*) - -(** Pretty printers *) - -val pp_math_sig : signature Fmt.t -val pretty_print_all_math_sigs : unit Fmt.t -val pretty_print_all_math_distributions : unit Fmt.t - -type dimensionality - -type fkind = Lpmf | Lpdf | Rng | Cdf | Ccdf | UnaryVectorized -[@@deriving show {with_path= false}] - -val distributions : - (fkind list * string * dimensionality list * Common.Helpers.mem_pattern) list -(** The distribution {e families} exposed by the math library *) - -val dist_name_suffix : (string * 'a) list -> string -> string - -(** Helpers for dealing with operators as signatures *) - -val operator_to_stan_math_fns : Operator.t -> string list -val string_operator_to_stan_math_fns : string -> string -val pretty_print_math_lib_operator_sigs : Operator.t -> string list -val make_assignmentoperator_stan_math_signatures : Operator.t -> signature list - -(** Special functions for the variadic signatures exposed *) - -(* TODO: We should think of a better encapsulization for these, - this doesn't scale well. -*) - -(* reduce_sum helpers *) -val is_reduce_sum_fn : string -> bool -val reduce_sum_slice_types : UnsizedType.t list - -(* variadic ODE helpers *) -val is_variadic_ode_fn : string -> bool -val is_variadic_ode_nonadjoint_tol_fn : string -> bool -val ode_tolerances_suffix : string -val variadic_ode_adjoint_fn : string -val variadic_ode_mandatory_arg_types : fun_arg list -val variadic_ode_mandatory_fun_args : fun_arg list -val variadic_ode_tol_arg_types : fun_arg list -val variadic_ode_adjoint_ctl_tol_arg_types : fun_arg list -val variadic_ode_fun_return_type : UnsizedType.t -val variadic_ode_return_type : UnsizedType.t - -(* variadic DAE helpers *) -val is_variadic_dae_fn : string -> bool -val is_variadic_dae_tol_fn : string -> bool -val dae_tolerances_suffix : string -val variadic_dae_mandatory_arg_types : fun_arg list -val variadic_dae_mandatory_fun_args : fun_arg list -val variadic_dae_tol_arg_types : fun_arg list -val variadic_dae_fun_return_type : UnsizedType.t -val variadic_dae_return_type : UnsizedType.t diff --git a/src/middle/dune b/src/middle/dune index 65e20b7aba..220ee14c6e 100644 --- a/src/middle/dune +++ b/src/middle/dune @@ -4,9 +4,4 @@ (libraries core_kernel str fmt common re) (inline_tests) (preprocess - (pps - ppx_jane - ppx_deriving.map - ppx_deriving.fold - ppx_deriving.create - ppx_deriving.show))) + (pps ppx_jane ppx_deriving.map ppx_deriving.fold ppx_deriving.create))) diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index 44fda9ddee..526c9c6c6b 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -139,10 +139,8 @@ let variadic_dae_functor_suffix = "_daefunctor__" let functor_suffix_select hof = match hof with | x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix - | x when Stan_math_signatures.is_variadic_ode_fn x -> - variadic_ode_functor_suffix - | x when Stan_math_signatures.is_variadic_dae_fn x -> - variadic_dae_functor_suffix + | x when Stan_math_signatures.is_variadic_ode_fn x -> variadic_ode_functor_suffix + | x when Stan_math_signatures.is_variadic_dae_fn x -> variadic_dae_functor_suffix | _ -> functor_suffix let constraint_to_string = function diff --git a/src/middle/Stan_math_signatures.ml b/src/stan_math_backend/Stan_math_signatures.ml similarity index 89% rename from src/middle/Stan_math_signatures.ml rename to src/stan_math_backend/Stan_math_signatures.ml index e7c929dab3..acd1af2c5f 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/stan_math_backend/Stan_math_signatures.ml @@ -1,7 +1,9 @@ (** The signatures of the Stan Math library, which are used for type checking *) -open Core_kernel +open Core_kernel open Core_kernel.Poly +open Middle +open Frontend.Std_library_utils (** The "dimensionality" (bad name?) is supposed to help us represent the vectorized nature of many Stan functions. It allows us to represent when @@ -51,18 +53,13 @@ let rec expand_arg = function type fkind = Lpmf | Lpdf | Rng | Cdf | Ccdf | UnaryVectorized [@@deriving show {with_path= false}] -type fun_arg = UnsizedType.autodifftype * UnsizedType.t - -type signature = - UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern - let is_primitive = function | UnsizedType.UReal -> true | UInt -> true | _ -> false (** The signatures hash table *) -let (stan_math_signatures : (string, signature list) Hashtbl.t) = +let (function_signatures : (string, signature list) Hashtbl.t) = String.Table.create () (** All of the signatures that are added by hand, rather than the ones @@ -239,6 +236,223 @@ let is_variadic_dae_fn f = Set.mem variadic_dae_fns f let is_variadic_dae_tol_fn f = is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix +let is_variadic_function_name name = + is_reduce_sum_fn name || is_variadic_dae_fn name || is_variadic_ode_fn name + +let variadic_function_returntype name = + if is_reduce_sum_fn name then Some (UnsizedType.ReturnType UReal) + else if is_variadic_ode_fn name then + Some (UnsizedType.ReturnType variadic_ode_return_type) + else if is_variadic_dae_fn name then + Some (UnsizedType.ReturnType variadic_dae_return_type) + else None + +let is_not_overloadable = is_variadic_function_name + +module Variadic_typechecking = struct + (** This module serves as the backend-specific portion + of the typechecker. *) + + open Frontend + open Typechecker + open Ast + + let error e = raise (Errors.SemanticError e) + + let make_function_variable current_block loc id = function + | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | UnsizedType.UFun _ as type_ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | type_ -> + Common.FatalError.fatal_error_msg + [%message + "Attempting to create function variable out of " + (type_ : UnsizedType.t)] + + let check_reduce_sum ~is_cond_dist loc current_block tenv id tes = + let basic_mismatch () = + let mandatory_args = + UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in + let mandatory_fun_args = + UnsizedType. + [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] + in + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal (get_arg_types tes) in + let fail () = + let expected_args, err = + basic_mismatch () |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error in + let matching remaining_es fn = + match fn with + | Environment. + { type_= + UnsizedType.UFun + (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as + ftype + ; _ } + when List.mem reduce_sum_slice_types sliced_arg_fun_type ~equal:( = ) -> + let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in + let mandatory_fun_args = + [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal arg_types + | _ -> basic_mismatch () in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + (* a valid signature exists *) + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error ) + | _ -> fail () + + let check_variadic_ode ~is_cond_dist loc current_block tenv id tes = + let optional_tol_mandatory_args = + if variadic_ode_adjoint_fn = id.name then + variadic_ode_adjoint_ctl_tol_arg_types + else if is_variadic_ode_nonadjoint_tol_fn id.name then + variadic_ode_tol_arg_types + else [] in + let mandatory_arg_types = + variadic_ode_mandatory_arg_types @ optional_tol_mandatory_args in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_ode_mandatory_fun_args variadic_ode_fun_return_type + (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_ode_fun_return_type + |> error in + let matching remaining_es Environment.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_ode_mandatory_fun_args variadic_ode_fun_return_type arg_types + in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:variadic_ode_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_ode_fun_return_type + |> error ) + | _ -> fail () + + let check_variadic_dae ~is_cond_dist loc current_block tenv id tes = + let optional_tol_mandatory_args = + if is_variadic_dae_tol_fn id.name then variadic_dae_tol_arg_types else [] + in + let mandatory_arg_types = + variadic_dae_mandatory_arg_types @ optional_tol_mandatory_args in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_dae_mandatory_fun_args variadic_dae_fun_return_type + (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_dae_fun_return_type + |> error in + let matching remaining_es Environment.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_dae_mandatory_fun_args variadic_dae_fun_return_type arg_types + in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:variadic_dae_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_dae_fun_return_type + |> error ) + | _ -> fail () +end + +let check_variadic_fn id ~is_cond_dist loc current_block tenv tes = + if is_reduce_sum_fn id.Frontend.Ast.name then + Variadic_typechecking.check_reduce_sum ~is_cond_dist loc current_block tenv + id tes + else if is_variadic_ode_fn id.name then + Variadic_typechecking.check_variadic_ode ~is_cond_dist loc current_block + tenv id tes + else if is_variadic_dae_fn id.name then + Variadic_typechecking.check_variadic_dae ~is_cond_dist loc current_block + tenv id tes + else + Common.FatalError.fatal_error_msg + [%message + "Invalid variadic function for Stan Math backend" (id.name : string)] + let distributions = [ ( full_lpmf , "beta_binomial" @@ -312,6 +526,9 @@ let distributions = ; ([Lpdf], "wiener", [DVReal; DVReal; DVReal; DVReal; DVReal], SoA) ; ([Lpdf], "wishart", [DMatrix; DReal; DMatrix], SoA) ] +let distribution_families = + List.map ~f:(fun (_, name, _, _) -> name) distributions + let math_sigs = [ ([UnaryVectorized], "acos", [DDeepVectorized], Common.Helpers.SoA) ; ([UnaryVectorized], "acosh", [DDeepVectorized], SoA) @@ -371,24 +588,11 @@ let all_declarative_sigs = distributions @ math_sigs let declarative_fnsigs = List.concat_map ~f:mk_declarative_sig all_declarative_sigs -let is_stan_math_function_name name = +let is_stdlib_function_name name = let name = Utils.stdlib_distribution_name name in - Hashtbl.mem stan_math_signatures name - -let dist_name_suffix udf_names name = - let is_udf_name s = List.exists ~f:(fun (n, _) -> n = s) udf_names in - match - Utils.distribution_suffices - |> List.filter ~f:(fun sfx -> - is_stan_math_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) - |> List.hd - with - | Some hd -> hd - | None -> - Common.FatalError.fatal_error_msg - [%message "Couldn't find distribution " name] - -let operator_to_stan_math_fns op = + Hashtbl.mem function_signatures name + +let operator_to_function_names op = match op with | Operator.Plus -> ["add"] | PPlus -> ["plus"] @@ -414,21 +618,15 @@ let operator_to_stan_math_fns op = | PNot -> ["logical_negation"] | Transpose -> ["transpose"] -let int_divide_type = - UnsizedType. - ( ReturnType UInt - , [(AutoDiffable, UInt); (AutoDiffable, UInt)] - , Common.Helpers.AoS ) - -let get_sigs name = +let get_signatures name = let name = Utils.stdlib_distribution_name name in - Hashtbl.find_multi stan_math_signatures name |> List.sort ~compare + Hashtbl.find_multi function_signatures name |> List.sort ~compare -let make_assignmentoperator_stan_math_signatures assop = +let get_assignment_operator_signatures assop = ( match assop with | Operator.Divide -> ["divide"] - | assop -> operator_to_stan_math_fns assop ) - |> List.concat_map ~f:get_sigs + | assop -> operator_to_function_names assop ) + |> List.concat_map ~f:get_signatures |> List.concat_map ~f:(function | ReturnType rtype, [(ad1, lhs); (ad2, rhs)], _ when rtype = lhs @@ -441,15 +639,7 @@ let make_assignmentoperator_stan_math_signatures assop = else [(Void, [(ad1, lhs); (ad2, rhs)], SoA)] | _ -> [] ) -let pp_math_sig ppf (rt, args, mem_pattern) = - UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) - -let pp_math_sigs ppf name = - (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf (get_sigs name) - -let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs - -let string_operator_to_stan_math_fns str = +let string_operator_to_function_name str = match str with | "Plus__" -> "add" | "PPlus__" -> "plus" @@ -484,10 +674,10 @@ let pretty_print_all_math_sigs ppf () = (List.map ~f:snd args) UnsizedType.pp_returntype rt in let pp_sigs_for_name ppf name = (list ~sep:cut pp_sig) ppf - (List.map ~f:(fun t -> (name, t)) (get_sigs name)) in + (List.map ~f:(fun t -> (name, t)) (get_signatures name)) in pf ppf "@[%a@]" (list ~sep:cut pp_sigs_for_name) - (List.sort ~compare (Hashtbl.keys stan_math_signatures)) + (List.sort ~compare (Hashtbl.keys function_signatures)) let pretty_print_all_math_distributions ppf () = let open Fmt in @@ -497,10 +687,54 @@ let pretty_print_all_math_distributions ppf () = (List.map ~f:(Fn.compose String.lowercase show_fkind) kinds) in pf ppf "@[%a@]" (list ~sep:cut pp_dist) distributions -let pretty_print_math_lib_operator_sigs op = - if op = Operator.IntDivide then - [Fmt.str "@[@,%a@]" pp_math_sig int_divide_type] - else operator_to_stan_math_fns op |> List.map ~f:pretty_print_math_sigs +let int_divide_type = + UnsizedType. + ( ReturnType UInt + , [(AutoDiffable, UInt); (AutoDiffable, UInt)] + , Common.Helpers.AoS ) + +let get_operator_signatures op = + if op = Operator.IntDivide then [int_divide_type] + else operator_to_function_names op |> List.concat_map ~f:get_signatures + +let deprecated_distributions = + List.concat_map distributions ~f:(fun (fnkinds, name, _, _) -> + List.filter_map fnkinds ~f:(function + | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") + | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") + | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") + | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") + | Rng | UnaryVectorized -> None ) ) + |> List.map ~f:(fun (x, y) -> + ( x + , { replacement= y + ; version= "2.32.0" + ; extra_message= + "This can be automatically changed using the canonicalize flag \ + for stanc" + ; canonicalize_away= true } ) ) + |> String.Map.of_alist_exn + +let deprecated_functions = + let make extra_message version canonicalize_away replacement = + {extra_message; replacement; version; canonicalize_away} in + let ode = + make + "\n\ + The new interface is slightly different, see: \ + https://mc-stan.org/users/documentation/case-studies/convert_odes.html" + "3.0" false in + let std = + make + "This can be automatically changed using the canonicalize flag for stanc" + "2.32.0" true in + String.Map.of_alist_exn + [ ("multiply_log", std "lmultiply") + ; ("binomial_coefficient_log", std "lchoose") + ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) + ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") + ; ("integrate_ode_bdf", ode "ode_bdf") + ; ("integrate_ode_adams", ode "ode_adams") ] (* -- Some helper definitions to populate stan_math_signatures -- *) let bare_types = @@ -517,8 +751,7 @@ let all_vector_types = [UnsizedType.UReal; UArray UReal; UVector; URowVector; UInt; UArray UInt] let add_qualified (name, rt, argts, supports_soa) = - Hashtbl.add_multi stan_math_signatures ~key:name - ~data:(rt, argts, supports_soa) + Hashtbl.add_multi function_signatures ~key:name ~data:(rt, argts, supports_soa) let add_nullary name = add_unqualified (name, UnsizedType.ReturnType UReal, [], AoS) @@ -784,7 +1017,7 @@ let for_vector_types s = List.iter ~f:s vector_types (* -- Start populating stan_math_signaturess -- *) let () = List.iter declarative_fnsigs ~f:(fun (key, rt, args, mem_pattern) -> - Hashtbl.add_multi stan_math_signatures ~key ~data:(rt, args, mem_pattern) ) ; + Hashtbl.add_multi function_signatures ~key ~data:(rt, args, mem_pattern) ) ; add_unqualified ("abs", ReturnType UInt, [UInt], SoA) ; add_unqualified ("abs", ReturnType UReal, [UReal], SoA) ; add_unqualified ("abs", ReturnType UReal, [UComplex], AoS) ; @@ -2462,11 +2695,7 @@ let () = for type-checking *) Hashtbl.iteri manual_stan_math_signatures ~f:(fun ~key ~data -> List.iter data ~f:(fun data -> - Hashtbl.add_multi stan_math_signatures ~key ~data ) ) - -let%expect_test "dist name suffix" = - dist_name_suffix [] "normal" |> print_endline ; - [%expect {| _lpdf |}] + Hashtbl.add_multi function_signatures ~key ~data ) ) let%expect_test "declarative distributions" = let special_suffixes = @@ -2477,7 +2706,7 @@ let%expect_test "declarative distributions" = distributions |> List.map ~f:(function _, n, _, _ -> n) |> String.Set.of_list in - Hashtbl.keys stan_math_signatures + Hashtbl.keys function_signatures |> List.filter ~f:(fun name -> match Utils.split_distribution_suffix name with | Some (name, suffix) diff --git a/src/stan_math_backend/Stan_math_signatures.mli b/src/stan_math_backend/Stan_math_signatures.mli new file mode 100644 index 0000000000..1a4ba2d63c --- /dev/null +++ b/src/stan_math_backend/Stan_math_signatures.mli @@ -0,0 +1,69 @@ +(** This module stores a table of all signatures from the Stan + math C++ library which are exposed to Stan, and some helper + functions for dealing with those signatures. +*) + +open Frontend +open Middle +open Std_library_utils +open Core_kernel + +val function_signatures : (string, signature list) Hashtbl.t +(** Mapping from names to signature(s) of functions *) + +val distribution_families : string list + +val is_stdlib_function_name : string -> bool +(** Equivalent to [Hashtbl.mem function_signatures s]*) + +val get_signatures : string -> signature list +(** Equivalent to [Hashtbl.find_multi function_signatures s]*) + +val get_operator_signatures : Operator.t -> signature list +val get_assignment_operator_signatures : Operator.t -> signature list +val is_not_overloadable : string -> bool +val is_variadic_function_name : string -> bool +val variadic_function_returntype : string -> UnsizedType.returntype option + +val check_variadic_fn : + Ast.identifier + -> is_cond_dist:bool + -> Location_span.t + -> Environment.originblock + -> Environment.t + -> Ast.typed_expression list + -> Ast.typed_expression +(** This function is responsible for typechecking varadic function + calls. It needs to live in the Library since this is usually + bespoke per-function. *) + +val operator_to_function_names : Operator.t -> string list +val string_operator_to_function_name : string -> string +val deprecated_distributions : deprecation_info String.Map.t +val deprecated_functions : deprecation_info String.Map.t +(** These functions are used by the drivers to display + all available functions and distributions. They are + not part of the Library interface since different drivers + for different backends would likely want different behavior + here *) + +val pretty_print_all_math_sigs : unit Fmt.t +val pretty_print_all_math_distributions : unit Fmt.t + +(** These functions related to variadic functions + are specific to this backend and used + during code generation *) + +(* reduce_sum helpers *) +val is_reduce_sum_fn : string -> bool + +(* variadic ODE helpers *) +val is_variadic_ode_fn : string -> bool +val is_variadic_ode_nonadjoint_tol_fn : string -> bool +val ode_tolerances_suffix : string +val variadic_ode_adjoint_fn : string + +(* variadic DAE helpers *) +val is_variadic_dae_fn : string -> bool +val is_variadic_dae_tol_fn : string -> bool +val dae_tolerances_suffix : string diff --git a/src/stan_math_backend/dune b/src/stan_math_backend/dune index d8dd800d2a..1821f261d1 100644 --- a/src/stan_math_backend/dune +++ b/src/stan_math_backend/dune @@ -1,7 +1,7 @@ (library (name stan_math_backend) (public_name stanc.stan_math_backend) - (libraries core_kernel re fmt middle yojson) + (libraries core_kernel re fmt frontend middle yojson) (private_modules mangle cpp_Json @@ -12,4 +12,4 @@ statement_gen) (inline_tests) (preprocess - (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) + (pps ppx_jane ppx_deriving.map ppx_deriving.fold ppx_deriving.show))) diff --git a/src/stan_math_library/Library.ml b/src/stan_math_library/Library.ml new file mode 100644 index 0000000000..b773da97d4 --- /dev/null +++ b/src/stan_math_library/Library.ml @@ -0,0 +1 @@ +include Stan_math_backend.Stan_math_signatures diff --git a/src/stan_math_library/dune b/src/stan_math_library/dune new file mode 100644 index 0000000000..57b5dfc510 --- /dev/null +++ b/src/stan_math_library/dune @@ -0,0 +1,8 @@ +(library + (name stan_math_library) + (public_name stanc.stan_math_library) + (libraries core_kernel middle stan_math_backend) + (implements frontend) + (inline_tests) + (preprocess + (pps ppx_jane ppx_deriving.fold ppx_deriving.map ppx_deriving.show))) diff --git a/src/stanc/dune b/src/stanc/dune index c3e8ab39e9..3763e7cdf7 100644 --- a/src/stanc/dune +++ b/src/stanc/dune @@ -1,6 +1,11 @@ (executable (name stanc) - (libraries frontend middle stan_math_backend analysis_and_optimization) + (libraries + frontend + middle + stan_math_backend + analysis_and_optimization + stan_math_library) (modules Stanc) (public_name stanc) (preprocess diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index 6583ae4f58..529c7fe6da 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -3,9 +3,9 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Stan_math_backend -open Analysis_and_optimization open Middle +open Analysis_and_optimization +open Stan_math_backend (** The main program. *) let version = "%%NAME%%3 %%VERSION%%" diff --git a/src/stancjs/stancjs.ml b/src/stancjs/stancjs.ml index 7277a88e2f..46ee6ef316 100644 --- a/src/stancjs/stancjs.ml +++ b/src/stancjs/stancjs.ml @@ -1,11 +1,19 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Stan_math_backend -open Analysis_and_optimization open Middle +open Analysis_and_optimization +open Stan_math_backend open Js_of_ocaml +(* Initialize functors with Stan Math C++ signatures *) +module Typechecker = Typechecking.Make (Stan_math_library) +module Deprecations = Deprecation_analysis.Make (Stan_math_library) +module Canonicalizer = Canonicalize.Make (Deprecations) +module ModelInfo = Info.Make (Stan_math_library) +module Ast2Mir = Ast_to_Mir.Make (Stan_math_library) +module Optimizer = Optimize.Make (Stan_math_library) + let version = "%%NAME%% %%VERSION%%" let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t) @@ -21,8 +29,8 @@ let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t let stan2cpp model_name model_string is_flag_set flag_val = Common.Gensym.reset_danger_use_cautiously () ; - Typechecker.model_name := model_name ; - Typechecker.check_that_all_functions_have_definition := + Typechecking.model_name := model_name ; + Typechecking.check_that_all_functions_have_definition := not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ; Transform_Mir.use_opencl := is_flag_set "use-opencl" ; Stan_math_code_gen.standalone_functions := @@ -45,7 +53,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = >>| fun (typed_ast, type_warnings) -> let warnings = parser_warnings @ type_warnings in if is_flag_set "info" then - r.return (Result.Ok (Info.info typed_ast), warnings, []) ; + r.return (Result.Ok (ModelInfo.info typed_ast), warnings, []) ; let canonicalizer_settings = if is_flag_set "print-canonical" then Canonicalize.all else @@ -67,7 +75,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = flag_val "max-line-length" |> Option.map ~f:int_of_string |> Option.value ~default:78 in - let mir = Ast_to_Mir.trans_prog model_name typed_ast in + let mir = Ast2Mir.trans_prog model_name typed_ast in let tx_mir = Transform_Mir.trans_prog mir in if is_flag_set "auto-format" || is_flag_set "print-canonical" then r.return @@ -76,7 +84,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = ~bare_functions:(is_flag_set "functions-only") ~line_length ~inline_includes:canonicalizer_settings.inline_includes - (Canonicalize.canonicalize_program typed_ast + (Canonicalizer.canonicalize_program typed_ast canonicalizer_settings ) ) , warnings , [] ) ; @@ -92,7 +100,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = r.return ( Result.Ok (Debug_data_generation.print_data_prog - (Ast_to_Mir.gather_data typed_ast) ) + (Ast2Mir.gather_data typed_ast) ) , warnings , [] ) ; let opt_mir = @@ -102,7 +110,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = else if is_flag_set "Oexperimental" || is_flag_set "O" then Optimize.Oexperimental else Optimize.O0 in - Optimize.optimization_suite + Optimizer.optimization_suite ~settings:(Optimize.level_optimizations opt_lvl) tx_mir in if is_flag_set "debug-optimized-mir" then @@ -181,11 +189,11 @@ let stan2cpp_wrapped name code (flags : Js.string_array Js.t Js.opt) = wrap_result ?printed_filename ~code result ~warnings let dump_stan_math_signatures () = - Js.string @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_sigs () + Js.string @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_sigs () let dump_stan_math_distributions () = Js.string - @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_distributions () + @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_distributions () let () = Js.export "dump_stan_math_signatures" dump_stan_math_signatures ; diff --git a/test/integration/bad/stanc.expected b/test/integration/bad/stanc.expected index a4bfd12395..3263db8341 100644 --- a/test/integration/bad/stanc.expected +++ b/test/integration/bad/stanc.expected @@ -1531,7 +1531,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (vector, real) => vector @@ -1554,7 +1553,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (vector, real) => vector diff --git a/test/unit/Debug_data_generation_tests.ml b/test/unit/Debug_data_generation_tests.ml index 6ab94c65a8..f3d97521e7 100644 --- a/test/unit/Debug_data_generation_tests.ml +++ b/test/unit/Debug_data_generation_tests.ml @@ -1,9 +1,4 @@ -open Analysis_and_optimization open Core_kernel -open Frontend -open Debug_data_generation - -let print_data_prog ast = print_data_prog (Ast_to_Mir.gather_data ast) let%expect_test "whole program data generation check" = let ast = @@ -17,7 +12,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -48,7 +43,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -97,7 +92,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -134,7 +129,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -256,7 +251,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -515,7 +510,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -647,7 +642,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -666,7 +661,7 @@ let%expect_test "Complex numbers program" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| diff --git a/test/unit/Desugar_test.ml b/test/unit/Desugar_test.ml index 525fbfe24b..1bf13ae28b 100644 --- a/test/unit/Desugar_test.ml +++ b/test/unit/Desugar_test.ml @@ -1,6 +1,9 @@ open Core_kernel open Analysis_and_optimization +module Partial_evaluator = + Partial_evaluation.Make (Stan_math_backend.Stan_math_library) + let print_tdata Middle.Program.{prepare_data; _} = Fmt.(str "@[%a@]@," (list ~sep:cut Middle.Stmt.Located.pp) prepare_data) |> print_endline diff --git a/test/unit/Optimize.ml b/test/unit/Optimize.ml index 67154bbbe0..d61c372f82 100644 --- a/test/unit/Optimize.ml +++ b/test/unit/Optimize.ml @@ -1,9 +1,13 @@ open Core_kernel -open Analysis_and_optimization.Optimize open Middle open Common open Analysis_and_optimization.Mir_utils +module Optimizer = + Analysis_and_optimization.Optimize.Make (Stan_math_backend.Stan_math_library) + +open Optimizer + let reset_and_mir_of_string s = Gensym.reset_danger_use_cautiously () ; Test_utils.mir_of_string s diff --git a/test/unit/Test_utils.ml b/test/unit/Test_utils.ml index 5522236e61..03838e7405 100644 --- a/test/unit/Test_utils.ml +++ b/test/unit/Test_utils.ml @@ -1,6 +1,12 @@ open Frontend open Core_kernel +module CppLibrary : Std_library_utils.Library = + Stan_math_backend.Stan_math_library + +module Typechecker = Typechecking.Make (CppLibrary) +module Ast2Mir = Ast_to_Mir.Make (CppLibrary) + let untyped_ast_of_string s = let res, warnings = Parse.parse_string Parser.Incremental.program s in Fmt.epr "%a" (Fmt.list ~sep:Fmt.nop Warnings.pp) warnings ; @@ -15,4 +21,8 @@ let typed_ast_of_string_exn s = |> Result.map_error ~f:Errors.to_string |> Result.ok_or_failwith |> fst -let mir_of_string s = typed_ast_of_string_exn s |> Ast_to_Mir.trans_prog "" +let mir_of_string s = typed_ast_of_string_exn s |> Ast2Mir.trans_prog "" + +let print_data_prog ast = + Analysis_and_optimization.Debug_data_generation.print_data_prog + (Ast2Mir.gather_data ast) From 654810c64e1c0885043c8b9574e76d01a9d4b6a5 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 3 May 2022 12:10:27 -0400 Subject: [PATCH 02/19] Virtual libraries working --- src/analysis_and_optimization/Mir_utils.ml | 16 +- .../Pedantic_analysis.ml | 3 +- src/analysis_and_optimization/dune | 3 +- src/frontend/Ast_to_Mir.ml | 3 +- src/frontend/Info.ml | 4 +- src/frontend/Std_library_utils.ml | 3 +- src/frontend/Typechecker.ml | 17 +- src/frontend/Typechecker.mli | 9 - src/frontend/Typechecker_utils.ml | 21 ++ src/frontend/dune | 4 +- src/stan_math_backend/Expression_gen.ml | 6 +- src/stan_math_backend/Library_utils.ml | 0 src/stan_math_backend/Stan_math_signatures.ml | 315 +----------------- .../Stan_math_signatures.mli | 46 ++- src/stan_math_backend/dune | 2 +- .../stan_math_library/Library.ml | 65 ++++ .../Variadic_typechecking.ml | 256 ++++++++++++++ .../stan_math_library/dune | 1 + src/stan_math_library/Library.ml | 1 - src/stancjs/dune | 3 +- src/stancjs/stancjs.ml | 26 +- test/unit/Desugar_test.ml | 3 - test/unit/Optimize.ml | 6 +- test/unit/Test_utils.ml | 10 +- test/unit/dune | 1 + 25 files changed, 416 insertions(+), 408 deletions(-) create mode 100644 src/frontend/Typechecker_utils.ml create mode 100644 src/stan_math_backend/Library_utils.ml create mode 100644 src/stan_math_backend/stan_math_library/Library.ml create mode 100644 src/stan_math_backend/stan_math_library/Variadic_typechecking.ml rename src/{ => stan_math_backend}/stan_math_library/dune (85%) delete mode 100644 src/stan_math_library/Library.ml diff --git a/src/analysis_and_optimization/Mir_utils.ml b/src/analysis_and_optimization/Mir_utils.ml index 82bf15f570..43bb3e784f 100644 --- a/src/analysis_and_optimization/Mir_utils.ml +++ b/src/analysis_and_optimization/Mir_utils.ml @@ -464,16 +464,16 @@ let cleanup_empty_stmts stmts = (** * Convert a Type.Unsized to a Type.Sized. - * This function is useful in the inlining scheme as - * the Mem_patterns optimization cannot work with decl types - * for unsized types. (Steve: tmk the inline optimization is the only place - * we create Decl's with unsized types.) + * This function is useful in the inlining scheme as + * the Mem_patterns optimization cannot work with decl types + * for unsized types. (Steve: tmk the inline optimization is the only place + * we create Decl's with unsized types.) * * Note that there is no true mapping from Sized types to Unsized types. - * Any sizes are set to 0 and it is assumed that the intent - * of Types.Unsized with inner UFun types is to size the return - * type of the UFun. Any Decl that uses this type should - * have initialize set to false. + * Any sizes are set to 0 and it is assumed that the intent + * of Types.Unsized with inner UFun types is to size the return + * type of the UFun. Any Decl that uses this type should + * have initialize set to false. *) let unsafe_unsized_to_sized_type (rt : Expr.Typed.t Type.t) = match rt with diff --git a/src/analysis_and_optimization/Pedantic_analysis.ml b/src/analysis_and_optimization/Pedantic_analysis.ml index f35b171fe3..96b3a48338 100644 --- a/src/analysis_and_optimization/Pedantic_analysis.ml +++ b/src/analysis_and_optimization/Pedantic_analysis.ml @@ -493,8 +493,7 @@ let settings_constant_prop = let warn_pedantic (mir_unopt : Program.Typed.t) = (* Some warnings will be stronger when constants are propagated *) let mir = - Optimize.optimization_suite ~settings:settings_constant_prop mir_unopt - in + Optimize.optimization_suite ~settings:settings_constant_prop mir_unopt in (* Try to avoid recomputation by pre-building structures *) let distributions_info = list_distributions mir in let factor_graph = prog_factor_graph mir in diff --git a/src/analysis_and_optimization/dune b/src/analysis_and_optimization/dune index c7f4a7dd05..8f6c74f5dd 100644 --- a/src/analysis_and_optimization/dune +++ b/src/analysis_and_optimization/dune @@ -2,6 +2,7 @@ (name analysis_and_optimization) (public_name stanc.analysis) (libraries core_kernel str fmt common middle frontend stan_math_backend) - (inline_tests) + (inline_tests + (libraries stan_math_library)) (preprocess (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index a68157bdae..20d59bf527 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -423,7 +423,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) = | Ast.IncrementLogProb e | Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap | Ast.Tilde {arg; distribution; args; truncation} -> let suffix = - Std_library_utils.dist_name_suffix Library.is_stdlib_function_name ud_dists distribution.name in + Std_library_utils.dist_name_suffix Library.is_stdlib_function_name + ud_dists distribution.name in let name = distribution.name ^ suffix in let kind = let possible_names = diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index 78802bbf84..e82c2e545b 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -52,7 +52,6 @@ let includes_json () = ( List.rev !Preprocessor.included_files |> List.map ~f:(fun str -> `String str) ) ) ] - let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = let acc = match stmt.stmt with @@ -65,7 +64,8 @@ let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = (funs, distrs) else let suffix = - Std_library_utils.dist_name_suffix Library.is_stdlib_function_name ud_dists distribution.name in + Std_library_utils.dist_name_suffix Library.is_stdlib_function_name + ud_dists distribution.name in let name = distribution.name ^ Utils.unnormalized_suffix suffix in (funs, Set.add distrs name) | _ -> (funs, distrs) in diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index f4ea514d3c..6802fd5db5 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -26,6 +26,5 @@ let dist_name_suffix (check : string -> bool) udf_names name = let is_udf_name s = List.exists ~f:(fun (n, _) -> String.equal s n) udf_names in Utils.distribution_suffices - |> List.filter ~f:(fun sfx -> - check (name ^ sfx) || is_udf_name (name ^ sfx) ) + |> List.filter ~f:(fun sfx -> check (name ^ sfx) || is_udf_name (name ^ sfx)) |> List.hd_exn diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index efb692047f..f2ba72df39 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -17,6 +17,7 @@ open Core_kernel open Core_kernel.Poly open Middle open Ast +open Typechecker_utils module Env = Environment (* we only allow errors raised by this function *) @@ -56,22 +57,6 @@ let context block = ; in_udf_dist_def= false ; loop_depth= 0 } -let calculate_autodifftype current_block origin ut = - match origin with - | Env.(Param | TParam | Model | Functions) - when not (UnsizedType.is_int_type ut || current_block = Env.GQuant) -> - UnsizedType.AutoDiffable - | _ -> DataOnly - -let arg_type x = (x.emeta.ad_level, x.emeta.type_) -let get_arg_types = List.map ~f:arg_type -let type_of_expr_typed ue = ue.emeta.type_ -let has_int_type ue = ue.emeta.type_ = UInt -let has_int_array_type ue = ue.emeta.type_ = UArray UInt - -let has_int_or_real_type ue = - match ue.emeta.type_ with UInt | UReal -> true | _ -> false - (* -- General checks ---------------------------------------------- *) let reserved_keywords = [ "for"; "in"; "while"; "repeat"; "until"; "if"; "then"; "else"; "true" diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecker.mli index 55620d1e12..288a3c9d51 100644 --- a/src/frontend/Typechecker.mli +++ b/src/frontend/Typechecker.mli @@ -26,15 +26,6 @@ val model_name : string ref val check_that_all_functions_have_definition : bool ref (** A switch to determine whether we check that all functions have a definition *) -val get_arg_types : typed_expression list -> Std_library_utils.fun_arg list -val type_of_expr_typed : typed_expression -> Middle.UnsizedType.t - -val calculate_autodifftype : - Environment.originblock - -> Environment.originblock - -> Middle.UnsizedType.t - -> Middle.UnsizedType.autodifftype - val check_program_exn : untyped_program -> typed_program * Warnings.t list (** Type check a full Stan program. diff --git a/src/frontend/Typechecker_utils.ml b/src/frontend/Typechecker_utils.ml new file mode 100644 index 0000000000..3a92d9af29 --- /dev/null +++ b/src/frontend/Typechecker_utils.ml @@ -0,0 +1,21 @@ +open Core_kernel +open Core_kernel.Poly +open Middle +open Ast +module Env = Environment + +let calculate_autodifftype current_block origin ut = + match origin with + | Env.(Param | TParam | Model | Functions) + when not (UnsizedType.is_int_type ut || current_block = Env.GQuant) -> + UnsizedType.AutoDiffable + | _ -> DataOnly + +let arg_type x = (x.emeta.ad_level, x.emeta.type_) +let get_arg_types = List.map ~f:arg_type +let type_of_expr_typed ue = ue.emeta.type_ +let has_int_type ue = ue.emeta.type_ = UInt +let has_int_array_type ue = ue.emeta.type_ = UArray UInt + +let has_int_or_real_type ue = + match ue.emeta.type_ with UInt | UReal -> true | _ -> false diff --git a/src/frontend/dune b/src/frontend/dune index ea86e21473..1d51ecbf9a 100644 --- a/src/frontend/dune +++ b/src/frontend/dune @@ -2,8 +2,10 @@ (name frontend) (public_name stanc.frontend) (libraries core_kernel re menhirLib fmt middle common yojson) - (inline_tests) (virtual_modules library) + (inline_tests + ; we need to specify an implementation of the virtual library for testing + (libraries stan_math_library)) (preprocess (pps ppx_jane ppx_deriving.fold ppx_deriving.map))) diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index 526c9c6c6b..44fda9ddee 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -139,8 +139,10 @@ let variadic_dae_functor_suffix = "_daefunctor__" let functor_suffix_select hof = match hof with | x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix - | x when Stan_math_signatures.is_variadic_ode_fn x -> variadic_ode_functor_suffix - | x when Stan_math_signatures.is_variadic_dae_fn x -> variadic_dae_functor_suffix + | x when Stan_math_signatures.is_variadic_ode_fn x -> + variadic_ode_functor_suffix + | x when Stan_math_signatures.is_variadic_dae_fn x -> + variadic_dae_functor_suffix | _ -> functor_suffix let constraint_to_string = function diff --git a/src/stan_math_backend/Library_utils.ml b/src/stan_math_backend/Library_utils.ml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/stan_math_backend/Stan_math_signatures.ml b/src/stan_math_backend/Stan_math_signatures.ml index acd1af2c5f..13846d0d13 100644 --- a/src/stan_math_backend/Stan_math_signatures.ml +++ b/src/stan_math_backend/Stan_math_signatures.ml @@ -3,7 +3,6 @@ open Core_kernel open Core_kernel.Poly open Middle -open Frontend.Std_library_utils (** The "dimensionality" (bad name?) is supposed to help us represent the vectorized nature of many Stan functions. It allows us to represent when @@ -58,6 +57,11 @@ let is_primitive = function | UInt -> true | _ -> false +type fun_arg = UnsizedType.autodifftype * UnsizedType.t + +type signature = + UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern + (** The signatures hash table *) let (function_signatures : (string, signature list) Hashtbl.t) = String.Table.create () @@ -107,64 +111,6 @@ let rec complex_to_real = function | UArray t -> UArray (complex_to_real t) | x -> x -let reduce_sum_allowed_dimensionalities = [1; 2; 3; 4; 5; 6; 7] - -let reduce_sum_slice_types = - let base_slice_type i = - [ bare_array_type (UnsizedType.UReal, i) - ; bare_array_type (UnsizedType.UInt, i) - ; bare_array_type (UnsizedType.UMatrix, i) - ; bare_array_type (UnsizedType.UVector, i) - ; bare_array_type (UnsizedType.URowVector, i) ] in - List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities) - -(* Variadic ODE *) -let variadic_ode_adjoint_ctl_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal) - (* real relative_tolerance_forward *) - ; (DataOnly, UVector) (* vector absolute_tolerance_forward *) - ; (DataOnly, UReal) (* real relative_tolerance_backward *) - ; (DataOnly, UVector) (* real absolute_tolerance_backward *) - ; (DataOnly, UReal) (* real relative_tolerance_quadrature *) - ; (DataOnly, UReal) (* real absolute_tolerance_quadrature *) - ; (DataOnly, UInt) (* int max_num_steps *) - ; (DataOnly, UInt) (* int num_steps_between_checkpoints *) - ; (DataOnly, UInt) (* int interpolation_polynomial *) - ; (DataOnly, UInt) (* int solver_forward *); (DataOnly, UInt) - (* int solver_backward *) ] - -let variadic_ode_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) - ; (DataOnly, UInt) ] - -let variadic_ode_mandatory_arg_types = - [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (AutoDiffable, UReal) - ; (AutoDiffable, UArray UReal) ] - -let variadic_ode_mandatory_fun_args = - [ (UnsizedType.AutoDiffable, UnsizedType.UReal) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] - -let variadic_ode_fun_return_type = UnsizedType.UVector -let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector - -let variadic_dae_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) - ; (DataOnly, UInt) ] - -let variadic_dae_mandatory_arg_types = - [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) - (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) - (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] - -let variadic_dae_mandatory_fun_args = - [ (UnsizedType.AutoDiffable, UnsizedType.UReal) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] - -let variadic_dae_fun_return_type = UnsizedType.UVector -let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector - let mk_declarative_sig (fnkinds, name, args, mem_pattern) = let is_glm = String.is_suffix ~suffix:"_glm" name in let sfxes = function @@ -239,220 +185,8 @@ let is_variadic_dae_tol_fn f = let is_variadic_function_name name = is_reduce_sum_fn name || is_variadic_dae_fn name || is_variadic_ode_fn name -let variadic_function_returntype name = - if is_reduce_sum_fn name then Some (UnsizedType.ReturnType UReal) - else if is_variadic_ode_fn name then - Some (UnsizedType.ReturnType variadic_ode_return_type) - else if is_variadic_dae_fn name then - Some (UnsizedType.ReturnType variadic_dae_return_type) - else None - let is_not_overloadable = is_variadic_function_name -module Variadic_typechecking = struct - (** This module serves as the backend-specific portion - of the typechecker. *) - - open Frontend - open Typechecker - open Ast - - let error e = raise (Errors.SemanticError e) - - let make_function_variable current_block loc id = function - | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype current_block Functions type_) - ~type_ ~loc - | UnsizedType.UFun _ as type_ -> - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype current_block Functions type_) - ~type_ ~loc - | type_ -> - Common.FatalError.fatal_error_msg - [%message - "Attempting to create function variable out of " - (type_ : UnsizedType.t)] - - let check_reduce_sum ~is_cond_dist loc current_block tenv id tes = - let basic_mismatch () = - let mandatory_args = - UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in - let mandatory_fun_args = - UnsizedType. - [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] - in - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal (get_arg_types tes) in - let fail () = - let expected_args, err = - basic_mismatch () |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_fn loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err UReal - |> error in - let matching remaining_es fn = - match fn with - | Environment. - { type_= - UnsizedType.UFun - (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as - ftype - ; _ } - when List.mem reduce_sum_slice_types sliced_arg_fun_type ~equal:( = ) -> - let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in - let mandatory_fun_args = - [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in - let arg_types = - (calculate_autodifftype current_block Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal arg_types - | _ -> basic_mismatch () in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match - SignatureMismatch.find_matching_first_order_fn tenv - (matching remaining_es) fname - with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - (* a valid signature exists *) - let tes = - make_function_variable current_block loc fname ftype :: remaining_es - in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_fn loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err UReal - |> error ) - | _ -> fail () - - let check_variadic_ode ~is_cond_dist loc current_block tenv id tes = - let optional_tol_mandatory_args = - if variadic_ode_adjoint_fn = id.name then - variadic_ode_adjoint_ctl_tol_arg_types - else if is_variadic_ode_nonadjoint_tol_fn id.name then - variadic_ode_tol_arg_types - else [] in - let mandatory_arg_types = - variadic_ode_mandatory_arg_types @ optional_tol_mandatory_args in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - variadic_ode_mandatory_fun_args variadic_ode_fun_return_type - (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_fn loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err variadic_ode_fun_return_type - |> error in - let matching remaining_es Environment.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype current_block Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - variadic_ode_mandatory_fun_args variadic_ode_fun_return_type arg_types - in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match - SignatureMismatch.find_matching_first_order_fn tenv - (matching remaining_es) fname - with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = - make_function_variable current_block loc fname ftype :: remaining_es - in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) ~type_:variadic_ode_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_fn loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err variadic_ode_fun_return_type - |> error ) - | _ -> fail () - - let check_variadic_dae ~is_cond_dist loc current_block tenv id tes = - let optional_tol_mandatory_args = - if is_variadic_dae_tol_fn id.name then variadic_dae_tol_arg_types else [] - in - let mandatory_arg_types = - variadic_dae_mandatory_arg_types @ optional_tol_mandatory_args in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - variadic_dae_mandatory_fun_args variadic_dae_fun_return_type - (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_fn loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err variadic_dae_fun_return_type - |> error in - let matching remaining_es Environment.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype current_block Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - variadic_dae_mandatory_fun_args variadic_dae_fun_return_type arg_types - in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match - SignatureMismatch.find_matching_first_order_fn tenv - (matching remaining_es) fname - with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = - make_function_variable current_block loc fname ftype :: remaining_es - in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) ~type_:variadic_dae_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_fn loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err variadic_dae_fun_return_type - |> error ) - | _ -> fail () -end - -let check_variadic_fn id ~is_cond_dist loc current_block tenv tes = - if is_reduce_sum_fn id.Frontend.Ast.name then - Variadic_typechecking.check_reduce_sum ~is_cond_dist loc current_block tenv - id tes - else if is_variadic_ode_fn id.name then - Variadic_typechecking.check_variadic_ode ~is_cond_dist loc current_block - tenv id tes - else if is_variadic_dae_fn id.name then - Variadic_typechecking.check_variadic_dae ~is_cond_dist loc current_block - tenv id tes - else - Common.FatalError.fatal_error_msg - [%message - "Invalid variadic function for Stan Math backend" (id.name : string)] - let distributions = [ ( full_lpmf , "beta_binomial" @@ -697,45 +431,6 @@ let get_operator_signatures op = if op = Operator.IntDivide then [int_divide_type] else operator_to_function_names op |> List.concat_map ~f:get_signatures -let deprecated_distributions = - List.concat_map distributions ~f:(fun (fnkinds, name, _, _) -> - List.filter_map fnkinds ~f:(function - | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") - | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") - | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") - | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") - | Rng | UnaryVectorized -> None ) ) - |> List.map ~f:(fun (x, y) -> - ( x - , { replacement= y - ; version= "2.32.0" - ; extra_message= - "This can be automatically changed using the canonicalize flag \ - for stanc" - ; canonicalize_away= true } ) ) - |> String.Map.of_alist_exn - -let deprecated_functions = - let make extra_message version canonicalize_away replacement = - {extra_message; replacement; version; canonicalize_away} in - let ode = - make - "\n\ - The new interface is slightly different, see: \ - https://mc-stan.org/users/documentation/case-studies/convert_odes.html" - "3.0" false in - let std = - make - "This can be automatically changed using the canonicalize flag for stanc" - "2.32.0" true in - String.Map.of_alist_exn - [ ("multiply_log", std "lmultiply") - ; ("binomial_coefficient_log", std "lchoose") - ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) - ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") - ; ("integrate_ode_bdf", ode "ode_bdf") - ; ("integrate_ode_adams", ode "ode_adams") ] - (* -- Some helper definitions to populate stan_math_signatures -- *) let bare_types = [ UnsizedType.UInt; UReal; UComplex; UVector; URowVector; UMatrix diff --git a/src/stan_math_backend/Stan_math_signatures.mli b/src/stan_math_backend/Stan_math_signatures.mli index 1a4ba2d63c..a2effe455f 100644 --- a/src/stan_math_backend/Stan_math_signatures.mli +++ b/src/stan_math_backend/Stan_math_signatures.mli @@ -3,14 +3,39 @@ functions for dealing with those signatures. *) -open Frontend open Middle -open Std_library_utils open Core_kernel +type fun_arg = UnsizedType.autodifftype * UnsizedType.t + +type signature = + UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern + +type fkind = Lpmf | Lpdf | Rng | Cdf | Ccdf | UnaryVectorized +[@@deriving show {with_path= false}] + +type dimensionality = + | DInt + | DReal + | DVector + | DMatrix + | DIntArray + (* Vectorizable int *) + | DVInt + (* Vectorizable real *) + | DVReal + (* DEPRECATED; vectorizable ints or reals *) + | DIntAndReals + (* Vectorizable vectors - for multivariate functions *) + | DVectors + | DDeepVectorized + val function_signatures : (string, signature list) Hashtbl.t (** Mapping from names to signature(s) of functions *) +val distributions : + (fkind list * string * dimensionality list * Common.Helpers.mem_pattern) list + val distribution_families : string list val is_stdlib_function_name : string -> bool @@ -23,24 +48,9 @@ val get_operator_signatures : Operator.t -> signature list val get_assignment_operator_signatures : Operator.t -> signature list val is_not_overloadable : string -> bool val is_variadic_function_name : string -> bool -val variadic_function_returntype : string -> UnsizedType.returntype option - -val check_variadic_fn : - Ast.identifier - -> is_cond_dist:bool - -> Location_span.t - -> Environment.originblock - -> Environment.t - -> Ast.typed_expression list - -> Ast.typed_expression -(** This function is responsible for typechecking varadic function - calls. It needs to live in the Library since this is usually - bespoke per-function. *) - val operator_to_function_names : Operator.t -> string list val string_operator_to_function_name : string -> string -val deprecated_distributions : deprecation_info String.Map.t -val deprecated_functions : deprecation_info String.Map.t + (** These functions are used by the drivers to display all available functions and distributions. They are not part of the Library interface since different drivers diff --git a/src/stan_math_backend/dune b/src/stan_math_backend/dune index 1821f261d1..07581843c4 100644 --- a/src/stan_math_backend/dune +++ b/src/stan_math_backend/dune @@ -1,7 +1,7 @@ (library (name stan_math_backend) (public_name stanc.stan_math_backend) - (libraries core_kernel re fmt frontend middle yojson) + (libraries core_kernel re fmt middle yojson) (private_modules mangle cpp_Json diff --git a/src/stan_math_backend/stan_math_library/Library.ml b/src/stan_math_backend/stan_math_library/Library.ml new file mode 100644 index 0000000000..625e0cef23 --- /dev/null +++ b/src/stan_math_backend/stan_math_library/Library.ml @@ -0,0 +1,65 @@ +open Core_kernel +open Std_library_utils + +(** Many of the required functions are exposed in the backend specific file, so we include it *) +include Stan_math_backend.Stan_math_signatures + +let deprecated_distributions = + List.concat_map distributions ~f:(fun (fnkinds, name, _, _) -> + List.filter_map fnkinds ~f:(function + | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") + | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") + | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") + | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") + | Rng | UnaryVectorized -> None ) ) + |> List.map ~f:(fun (x, y) -> + ( x + , { replacement= y + ; version= "2.32.0" + ; extra_message= + "This can be automatically changed using the canonicalize flag \ + for stanc" + ; canonicalize_away= true } ) ) + |> String.Map.of_alist_exn + +let deprecated_functions = + let make extra_message version canonicalize_away replacement = + {extra_message; replacement; version; canonicalize_away} in + let ode = + make + "\n\ + The new interface is slightly different, see: \ + https://mc-stan.org/users/documentation/case-studies/convert_odes.html" + "3.0" false in + let std = + make + "This can be automatically changed using the canonicalize flag for stanc" + "2.32.0" true in + String.Map.of_alist_exn + [ ("multiply_log", std "lmultiply") + ; ("binomial_coefficient_log", std "lchoose") + ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) + ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") + ; ("integrate_ode_bdf", ode "ode_bdf") + ; ("integrate_ode_adams", ode "ode_adams") ] + +(** This function is responsible for typechecking varadic function + calls. It needs to live in the Library since this is usually + bespoke per-function. *) +let check_variadic_fn id ~is_cond_dist loc current_block tenv tes = + if is_reduce_sum_fn id.Ast.name then + Variadic_typechecking.check_reduce_sum ~is_cond_dist loc current_block tenv + id tes + else if is_variadic_ode_fn id.name then + Variadic_typechecking.check_variadic_ode ~is_cond_dist loc current_block + tenv id tes + else if is_variadic_dae_fn id.name then + Variadic_typechecking.check_variadic_dae ~is_cond_dist loc current_block + tenv id tes + else + Common.FatalError.fatal_error_msg + [%message + "Invalid variadic function for Stan Math backend" (id.name : string)] + +let variadic_function_returntype = + Variadic_typechecking.variadic_function_returntype diff --git a/src/stan_math_backend/stan_math_library/Variadic_typechecking.ml b/src/stan_math_backend/stan_math_library/Variadic_typechecking.ml new file mode 100644 index 0000000000..9aa1adc428 --- /dev/null +++ b/src/stan_math_backend/stan_math_library/Variadic_typechecking.ml @@ -0,0 +1,256 @@ +(** This module serves as the backend-specific portion + of the typechecker. *) + +open Core_kernel +open Core_kernel.Poly +open Ast +open Typechecker_utils +open Middle +open Stan_math_backend.Stan_math_signatures + +let error e = raise (Errors.SemanticError e) +let reduce_sum_allowed_dimensionalities = [1; 2; 3; 4; 5; 6; 7] + +let rec bare_array_type (t, i) = + match i with 0 -> t | j -> UnsizedType.UArray (bare_array_type (t, j - 1)) + +let reduce_sum_slice_types = + let base_slice_type i = + [ bare_array_type (UnsizedType.UReal, i) + ; bare_array_type (UnsizedType.UInt, i) + ; bare_array_type (UnsizedType.UMatrix, i) + ; bare_array_type (UnsizedType.UVector, i) + ; bare_array_type (UnsizedType.URowVector, i) ] in + List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities) + +(* Variadic ODE *) +let variadic_ode_adjoint_ctl_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal) + (* real relative_tolerance_forward *) + ; (DataOnly, UVector) (* vector absolute_tolerance_forward *) + ; (DataOnly, UReal) (* real relative_tolerance_backward *) + ; (DataOnly, UVector) (* real absolute_tolerance_backward *) + ; (DataOnly, UReal) (* real relative_tolerance_quadrature *) + ; (DataOnly, UReal) (* real absolute_tolerance_quadrature *) + ; (DataOnly, UInt) (* int max_num_steps *) + ; (DataOnly, UInt) (* int num_steps_between_checkpoints *) + ; (DataOnly, UInt) (* int interpolation_polynomial *) + ; (DataOnly, UInt) (* int solver_forward *); (DataOnly, UInt) + (* int solver_backward *) ] + +let variadic_ode_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_ode_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (AutoDiffable, UReal) + ; (AutoDiffable, UArray UReal) ] + +let variadic_ode_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_ode_fun_return_type = UnsizedType.UVector +let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector + +let variadic_dae_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_dae_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) + (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) + (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] + +let variadic_dae_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_dae_fun_return_type = UnsizedType.UVector +let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector + +let variadic_function_returntype name = + if is_reduce_sum_fn name then Some (UnsizedType.ReturnType UReal) + else if is_variadic_ode_fn name then + Some (UnsizedType.ReturnType variadic_ode_return_type) + else if is_variadic_dae_fn name then + Some (UnsizedType.ReturnType variadic_dae_return_type) + else None + +let make_function_variable current_block loc id = function + | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | UnsizedType.UFun _ as type_ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | type_ -> + Common.FatalError.fatal_error_msg + [%message + "Attempting to create function variable out of " + (type_ : UnsizedType.t)] + +let check_reduce_sum ~is_cond_dist loc current_block tenv id tes = + let basic_mismatch () = + let mandatory_args = + UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in + let mandatory_fun_args = + UnsizedType. + [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in + SignatureMismatch.check_variadic_args true mandatory_args mandatory_fun_args + UReal (get_arg_types tes) in + let fail () = + let expected_args, err = + basic_mismatch () |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error in + let matching remaining_es fn = + match fn with + | Environment. + { type_= + UnsizedType.UFun + (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as + ftype + ; _ } + when List.mem reduce_sum_slice_types sliced_arg_fun_type ~equal:( = ) -> + let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in + let mandatory_fun_args = + [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal arg_types + | _ -> basic_mismatch () in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + (* a valid signature exists *) + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error ) + | _ -> fail () + +let check_variadic_ode ~is_cond_dist loc current_block tenv id tes = + let optional_tol_mandatory_args = + if variadic_ode_adjoint_fn = id.name then + variadic_ode_adjoint_ctl_tol_arg_types + else if is_variadic_ode_nonadjoint_tol_fn id.name then + variadic_ode_tol_arg_types + else [] in + let mandatory_arg_types = + variadic_ode_mandatory_arg_types @ optional_tol_mandatory_args in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_ode_mandatory_fun_args variadic_ode_fun_return_type + (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_ode_fun_return_type + |> error in + let matching remaining_es Environment.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_ode_mandatory_fun_args variadic_ode_fun_return_type arg_types + in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:variadic_ode_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_ode_fun_return_type + |> error ) + | _ -> fail () + +let check_variadic_dae ~is_cond_dist loc current_block tenv id tes = + let optional_tol_mandatory_args = + if is_variadic_dae_tol_fn id.name then variadic_dae_tol_arg_types else [] + in + let mandatory_arg_types = + variadic_dae_mandatory_arg_types @ optional_tol_mandatory_args in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_dae_mandatory_fun_args variadic_dae_fun_return_type + (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_dae_fun_return_type + |> error in + let matching remaining_es Environment.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_dae_mandatory_fun_args variadic_dae_fun_return_type arg_types + in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:variadic_dae_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_dae_fun_return_type + |> error ) + | _ -> fail () diff --git a/src/stan_math_library/dune b/src/stan_math_backend/stan_math_library/dune similarity index 85% rename from src/stan_math_library/dune rename to src/stan_math_backend/stan_math_library/dune index 57b5dfc510..549fdd4069 100644 --- a/src/stan_math_library/dune +++ b/src/stan_math_backend/stan_math_library/dune @@ -3,6 +3,7 @@ (public_name stanc.stan_math_library) (libraries core_kernel middle stan_math_backend) (implements frontend) + (private_modules variadic_typechecking) (inline_tests) (preprocess (pps ppx_jane ppx_deriving.fold ppx_deriving.map ppx_deriving.show))) diff --git a/src/stan_math_library/Library.ml b/src/stan_math_library/Library.ml deleted file mode 100644 index b773da97d4..0000000000 --- a/src/stan_math_library/Library.ml +++ /dev/null @@ -1 +0,0 @@ -include Stan_math_backend.Stan_math_signatures diff --git a/src/stancjs/dune b/src/stancjs/dune index d983673834..e8bafe4d8b 100644 --- a/src/stancjs/dune +++ b/src/stancjs/dune @@ -5,7 +5,8 @@ frontend middle analysis_and_optimization - stan_math_backend) + stan_math_backend + stan_math_library) (preprocess (pps js_of_ocaml-ppx ppx_jane)) (modes js)) diff --git a/src/stancjs/stancjs.ml b/src/stancjs/stancjs.ml index 46ee6ef316..af795a9416 100644 --- a/src/stancjs/stancjs.ml +++ b/src/stancjs/stancjs.ml @@ -6,14 +6,6 @@ open Analysis_and_optimization open Stan_math_backend open Js_of_ocaml -(* Initialize functors with Stan Math C++ signatures *) -module Typechecker = Typechecking.Make (Stan_math_library) -module Deprecations = Deprecation_analysis.Make (Stan_math_library) -module Canonicalizer = Canonicalize.Make (Deprecations) -module ModelInfo = Info.Make (Stan_math_library) -module Ast2Mir = Ast_to_Mir.Make (Stan_math_library) -module Optimizer = Optimize.Make (Stan_math_library) - let version = "%%NAME%% %%VERSION%%" let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t) @@ -29,8 +21,8 @@ let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t let stan2cpp model_name model_string is_flag_set flag_val = Common.Gensym.reset_danger_use_cautiously () ; - Typechecking.model_name := model_name ; - Typechecking.check_that_all_functions_have_definition := + Typechecker.model_name := model_name ; + Typechecker.check_that_all_functions_have_definition := not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ; Transform_Mir.use_opencl := is_flag_set "use-opencl" ; Stan_math_code_gen.standalone_functions := @@ -53,7 +45,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = >>| fun (typed_ast, type_warnings) -> let warnings = parser_warnings @ type_warnings in if is_flag_set "info" then - r.return (Result.Ok (ModelInfo.info typed_ast), warnings, []) ; + r.return (Result.Ok (Info.info typed_ast), warnings, []) ; let canonicalizer_settings = if is_flag_set "print-canonical" then Canonicalize.all else @@ -75,7 +67,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = flag_val "max-line-length" |> Option.map ~f:int_of_string |> Option.value ~default:78 in - let mir = Ast2Mir.trans_prog model_name typed_ast in + let mir = Ast_to_Mir.trans_prog model_name typed_ast in let tx_mir = Transform_Mir.trans_prog mir in if is_flag_set "auto-format" || is_flag_set "print-canonical" then r.return @@ -84,7 +76,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = ~bare_functions:(is_flag_set "functions-only") ~line_length ~inline_includes:canonicalizer_settings.inline_includes - (Canonicalizer.canonicalize_program typed_ast + (Canonicalize.canonicalize_program typed_ast canonicalizer_settings ) ) , warnings , [] ) ; @@ -100,7 +92,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = r.return ( Result.Ok (Debug_data_generation.print_data_prog - (Ast2Mir.gather_data typed_ast) ) + (Ast_to_Mir.gather_data typed_ast) ) , warnings , [] ) ; let opt_mir = @@ -110,7 +102,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = else if is_flag_set "Oexperimental" || is_flag_set "O" then Optimize.Oexperimental else Optimize.O0 in - Optimizer.optimization_suite + Optimize.optimization_suite ~settings:(Optimize.level_optimizations opt_lvl) tx_mir in if is_flag_set "debug-optimized-mir" then @@ -189,11 +181,11 @@ let stan2cpp_wrapped name code (flags : Js.string_array Js.t Js.opt) = wrap_result ?printed_filename ~code result ~warnings let dump_stan_math_signatures () = - Js.string @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_sigs () + Js.string @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_sigs () let dump_stan_math_distributions () = Js.string - @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_distributions () + @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_distributions () let () = Js.export "dump_stan_math_signatures" dump_stan_math_signatures ; diff --git a/test/unit/Desugar_test.ml b/test/unit/Desugar_test.ml index 1bf13ae28b..525fbfe24b 100644 --- a/test/unit/Desugar_test.ml +++ b/test/unit/Desugar_test.ml @@ -1,9 +1,6 @@ open Core_kernel open Analysis_and_optimization -module Partial_evaluator = - Partial_evaluation.Make (Stan_math_backend.Stan_math_library) - let print_tdata Middle.Program.{prepare_data; _} = Fmt.(str "@[%a@]@," (list ~sep:cut Middle.Stmt.Located.pp) prepare_data) |> print_endline diff --git a/test/unit/Optimize.ml b/test/unit/Optimize.ml index d61c372f82..93c8c98bd3 100644 --- a/test/unit/Optimize.ml +++ b/test/unit/Optimize.ml @@ -2,11 +2,7 @@ open Core_kernel open Middle open Common open Analysis_and_optimization.Mir_utils - -module Optimizer = - Analysis_and_optimization.Optimize.Make (Stan_math_backend.Stan_math_library) - -open Optimizer +open Analysis_and_optimization.Optimize let reset_and_mir_of_string s = Gensym.reset_danger_use_cautiously () ; diff --git a/test/unit/Test_utils.ml b/test/unit/Test_utils.ml index 03838e7405..13cd4a4435 100644 --- a/test/unit/Test_utils.ml +++ b/test/unit/Test_utils.ml @@ -1,12 +1,6 @@ open Frontend open Core_kernel -module CppLibrary : Std_library_utils.Library = - Stan_math_backend.Stan_math_library - -module Typechecker = Typechecking.Make (CppLibrary) -module Ast2Mir = Ast_to_Mir.Make (CppLibrary) - let untyped_ast_of_string s = let res, warnings = Parse.parse_string Parser.Incremental.program s in Fmt.epr "%a" (Fmt.list ~sep:Fmt.nop Warnings.pp) warnings ; @@ -21,8 +15,8 @@ let typed_ast_of_string_exn s = |> Result.map_error ~f:Errors.to_string |> Result.ok_or_failwith |> fst -let mir_of_string s = typed_ast_of_string_exn s |> Ast2Mir.trans_prog "" +let mir_of_string s = typed_ast_of_string_exn s |> Ast_to_Mir.trans_prog "" let print_data_prog ast = Analysis_and_optimization.Debug_data_generation.print_data_prog - (Ast2Mir.gather_data ast) + (Ast_to_Mir.gather_data ast) diff --git a/test/unit/dune b/test/unit/dune index 37cffb6770..4fed56676c 100644 --- a/test/unit/dune +++ b/test/unit/dune @@ -4,6 +4,7 @@ core_kernel frontend stan_math_backend + stan_math_library middle analysis_and_optimization) (inline_tests) From 8ffd93eafd587acb3bf2e3dd452453203754aa41 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 3 May 2022 12:24:57 -0400 Subject: [PATCH 03/19] Cleanup --- src/analysis_and_optimization/Optimize.ml | 120 +++++++++--------- src/analysis_and_optimization/Optimize.mli | 52 ++++---- .../Partial_evaluator.mli | 4 - .../Pedantic_analysis.ml | 2 - src/frontend/Ast.ml | 4 - src/frontend/Ast_to_Mir.ml | 8 ++ src/frontend/Canonicalize.ml | 21 ++- src/frontend/Info.ml | 14 +- src/frontend/Library.mli | 5 +- src/frontend/Semantic_error.ml | 6 +- src/frontend/SignatureMismatch.ml | 2 +- src/frontend/Std_library_utils.ml | 8 +- src/frontend/Typechecker.mli | 28 ++-- src/frontend/Typechecker_utils.ml | 4 + src/stanc/stanc.ml | 2 +- src/stancjs/stancjs.ml | 2 +- test/unit/Optimize.ml | 2 +- 17 files changed, 139 insertions(+), 145 deletions(-) delete mode 100644 src/analysis_and_optimization/Partial_evaluator.mli diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index 1f3de62d1e..352d284021 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -6,66 +6,6 @@ open Common open Middle open Mir_utils -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } - -let settings_const b = - { function_inlining= b - ; static_loop_unrolling= b - ; one_step_loop_unrolling= b - ; list_collapsing= b - ; block_fixing= b - ; allow_uninitialized_decls= b - ; constant_propagation= b - ; expression_propagation= b - ; copy_propagation= b - ; dead_code_elimination= b - ; partial_evaluation= b - ; lazy_code_motion= b - ; optimize_ad_levels= b - ; preserve_stability= not b - ; optimize_soa= b } - -let all_optimizations : optimization_settings = settings_const true -let no_optimizations : optimization_settings = settings_const false - -type optimization_level = O0 | O1 | Oexperimental - -let level_optimizations (lvl : optimization_level) : optimization_settings = - match lvl with - | O0 -> no_optimizations - | O1 -> - { function_inlining= true - ; static_loop_unrolling= false - ; one_step_loop_unrolling= false - ; list_collapsing= true - ; block_fixing= true - ; constant_propagation= true - ; expression_propagation= false - ; copy_propagation= true - ; dead_code_elimination= true - ; partial_evaluation= true - ; lazy_code_motion= false - ; allow_uninitialized_decls= true - ; optimize_ad_levels= false - ; preserve_stability= false - ; optimize_soa= true } - | Oexperimental -> all_optimizations - (** Apply the transformation to each function body and to the rest of the program as one block. @@ -1262,6 +1202,66 @@ let optimize_soa (mir : Program.Typed.t) = in {mir with log_prob= transform' mir.log_prob} +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +let settings_const b = + { function_inlining= b + ; static_loop_unrolling= b + ; one_step_loop_unrolling= b + ; list_collapsing= b + ; block_fixing= b + ; allow_uninitialized_decls= b + ; constant_propagation= b + ; expression_propagation= b + ; copy_propagation= b + ; dead_code_elimination= b + ; partial_evaluation= b + ; lazy_code_motion= b + ; optimize_ad_levels= b + ; preserve_stability= not b + ; optimize_soa= b } + +let all_optimizations : optimization_settings = settings_const true +let no_optimizations : optimization_settings = settings_const false + +type optimization_level = O0 | O1 | Oexperimental + +let level_optimizations (lvl : optimization_level) : optimization_settings = + match lvl with + | O0 -> no_optimizations + | O1 -> + { function_inlining= true + ; static_loop_unrolling= false + ; one_step_loop_unrolling= false + ; list_collapsing= true + ; block_fixing= true + ; constant_propagation= true + ; expression_propagation= false + ; copy_propagation= true + ; dead_code_elimination= true + ; partial_evaluation= true + ; lazy_code_motion= false + ; allow_uninitialized_decls= true + ; optimize_ad_levels= false + ; preserve_stability= false + ; optimize_soa= true } + | Oexperimental -> all_optimizations + let optimization_suite ?(settings = all_optimizations) mir = let preserve_stability = settings.preserve_stability in let maybe_optimizations = diff --git a/src/analysis_and_optimization/Optimize.mli b/src/analysis_and_optimization/Optimize.mli index d5ca472a6d..524e5e4e71 100644 --- a/src/analysis_and_optimization/Optimize.mli +++ b/src/analysis_and_optimization/Optimize.mli @@ -2,32 +2,6 @@ open Middle -(** Interface for turning individual optimizations on/off. Useful for testing - and for top-level interface flags. *) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } - -val all_optimizations : optimization_settings -val no_optimizations : optimization_settings - -type optimization_level = O0 | O1 | Oexperimental - -val level_optimizations : optimization_level -> optimization_settings - val function_inlining : Program.Typed.t -> Program.Typed.t (** Inline all functions except for ones with forward declarations (e.g. recursive functions, mutually recursive functions, and @@ -89,6 +63,32 @@ val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t assigns to the full object, allow the object to be constructed but not uninitialized. *) +(** Interface for turning individual optimizations on/off. Useful for testing + and for top-level interface flags. *) +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +val all_optimizations : optimization_settings +val no_optimizations : optimization_settings + +type optimization_level = O0 | O1 | Oexperimental + +val level_optimizations : optimization_level -> optimization_settings + val optimization_suite : ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t (** Perform all optimizations in this module on the MIR in an appropriate order. *) diff --git a/src/analysis_and_optimization/Partial_evaluator.mli b/src/analysis_and_optimization/Partial_evaluator.mli deleted file mode 100644 index ce3b24fa58..0000000000 --- a/src/analysis_and_optimization/Partial_evaluator.mli +++ /dev/null @@ -1,4 +0,0 @@ -open Middle - -val try_eval_expr : Expr.Typed.t -> Expr.Typed.t -val eval_prog : Program.Typed.t -> Program.Typed.t diff --git a/src/analysis_and_optimization/Pedantic_analysis.ml b/src/analysis_and_optimization/Pedantic_analysis.ml index 96b3a48338..19b7777f40 100644 --- a/src/analysis_and_optimization/Pedantic_analysis.ml +++ b/src/analysis_and_optimization/Pedantic_analysis.ml @@ -487,8 +487,6 @@ let settings_constant_prop = ; copy_propagation= true ; partial_evaluation= true } -(** Pedantic mode is only really valid for the Stan Math backend *) - (* Collect all pedantic mode warnings, sorted, to stderr *) let warn_pedantic (mir_unopt : Program.Typed.t) = (* Some warnings will be stronger when constants are propagated *) diff --git a/src/frontend/Ast.ml b/src/frontend/Ast.ml index c47e9a8daa..58502a85a5 100644 --- a/src/frontend/Ast.ml +++ b/src/frontend/Ast.ml @@ -80,10 +80,6 @@ let mk_untyped_expression ~expr ~loc = {expr; emeta= {loc}} let mk_typed_expression ~expr ~loc ~type_ ~ad_level = {expr; emeta= {loc; type_; ad_level}} -let mk_fun_app ~is_cond_dist (kind, id, arguments) = - if is_cond_dist then CondDistApp (kind, id, arguments) - else FunApp (kind, id, arguments) - let expr_loc_lub exprs = match List.map ~f:(fun e -> e.emeta.loc) exprs with | [] -> Location_span.empty diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 20d59bf527..4090b4cf45 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -21,6 +21,14 @@ let drop_leading_zeros s = let format_number s = s |> without_underscores |> drop_leading_zeros +let%expect_test "format_number0" = + format_number "0_000." |> print_endline ; + [%expect "0."] + +let%expect_test "format_number1" = + format_number ".123_456" |> print_endline ; + [%expect ".123456"] + let rec op_to_funapp op args type_ = let loc = Ast.expr_loc_lub args in let adlevel = Ast.expr_ad_lub args in diff --git a/src/frontend/Canonicalize.ml b/src/frontend/Canonicalize.ml index f079db7aea..0d451a7f51 100644 --- a/src/frontend/Canonicalize.ml +++ b/src/frontend/Canonicalize.ml @@ -1,6 +1,6 @@ open Core_kernel open Ast -module Deprecation = Deprecation_analysis +open Deprecation_analysis type canonicalizer_settings = {deprecations: bool; parentheses: bool; braces: bool; inline_includes: bool} @@ -20,8 +20,7 @@ let rec repair_syntax_stmt user_dists {stmt; smeta} = { stmt= Tilde { arg - ; distribution= - {name= Deprecation.without_suffix user_dists name; id_loc} + ; distribution= {name= without_suffix user_dists name; id_loc} ; args ; truncation } ; smeta } @@ -47,10 +46,10 @@ let rec replace_deprecated_expr (replace_deprecated_expr deprecated_userdefined {expr= TernaryIf ({expr= Paren c; emeta= c.emeta}, t, e); emeta} ) | FunApp (StanLib suffix, {name; id_loc}, e) -> - if Deprecation.is_deprecated_distribution name then + if is_deprecated_distribution name then CondDistApp ( StanLib suffix - , {name= Deprecation.rename_deprecated_distribution name; id_loc} + , {name= rename_deprecated_distribution name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) else if String.is_suffix name ~suffix:"_cdf" then CondDistApp @@ -60,14 +59,14 @@ let rec replace_deprecated_expr else FunApp ( StanLib suffix - , {name= Deprecation.rename_deprecated_function name; id_loc} + , {name= rename_deprecated_function name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) | FunApp (UserDefined suffix, {name; id_loc}, e) -> ( match String.Map.find deprecated_userdefined name with | Some type_ -> CondDistApp ( UserDefined suffix - , {name= Deprecation.update_suffix name type_; id_loc} + , {name= update_suffix name type_; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) | None -> if String.is_suffix name ~suffix:"_cdf" then @@ -126,7 +125,7 @@ let rec replace_deprecated_stmt | FunDef {returntype; funname= {name; id_loc}; arguments; body} -> let newname = match String.Map.find deprecated_userdefined name with - | Some type_ -> Deprecation.update_suffix name type_ + | Some type_ -> update_suffix name type_ | None -> name in FunDef { returntype @@ -227,8 +226,7 @@ let repair_syntax program settings = if settings.deprecations then program |> map_program - (repair_syntax_stmt - (Deprecation.userdef_distributions program.functionblock) ) + (repair_syntax_stmt (userdef_distributions program.functionblock)) else program let canonicalize_program program settings : typed_program = @@ -236,8 +234,7 @@ let canonicalize_program program settings : typed_program = if settings.deprecations then program |> map_program - (replace_deprecated_stmt - (Deprecation.collect_userdef_distributions program) ) + (replace_deprecated_stmt (collect_userdef_distributions program)) else program in let program = if settings.parentheses then program |> map_program parens_stmt else program diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index e82c2e545b..b3868434ca 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -45,13 +45,6 @@ let rec get_function_calls_expr (funs, distrs) expr = | _ -> (funs, distrs) in fold_expression get_function_calls_expr (fun acc _ -> acc) acc expr.expr -let includes_json () = - `Assoc - [ ( "included_files" - , `List - ( List.rev !Preprocessor.included_files - |> List.map ~f:(fun str -> `String str) ) ) ] - let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = let acc = match stmt.stmt with @@ -94,6 +87,13 @@ let function_calls_json p = `List (Set.to_list s |> List.map ~f:(fun str -> `String str)) in `Assoc [("functions", set_to_List funs); ("distributions", set_to_List distrs)] +let includes_json () = + `Assoc + [ ( "included_files" + , `List + ( List.rev !Preprocessor.included_files + |> List.map ~f:(fun str -> `String str) ) ) ] + let info_json ast = List.fold ~f:Util.combine ~init:(`Assoc []) [ block_info_json "inputs" ast.datablock diff --git a/src/frontend/Library.mli b/src/frontend/Library.mli index 5f1c399ea3..3841a7fa9e 100644 --- a/src/frontend/Library.mli +++ b/src/frontend/Library.mli @@ -1,5 +1,6 @@ -(** This module is used as a parameter for many functors which - rely on information about a backend-specific Stan library. *) +(** This is a {{:https://dune.readthedocs.io/en/latest/variants.html}virtual module} + which is filled in at link time with a module specifying a backend-specific + a backend-specific Stan library. *) open Middle open Core_kernel diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index e56f5fdb19..bc7c7d939d 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -223,20 +223,20 @@ module TypeError = struct "Ill-typed arguments supplied to infix operator %a. Available \ signatures: @[%a@.@]@[Instead supplied arguments of \ incompatible type: %a, %a.@]" - Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp lt + Operator.pp op Std_library_utils.pp_signatures sigs UnsizedType.pp lt UnsizedType.pp rt | IllTypedPrefixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to prefix operator %a. Available \ signatures: @[%a@.@]@[Instead supplied argument of \ incompatible type: %a.@]" - Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut + Operator.pp op Std_library_utils.pp_signatures sigs UnsizedType.pp ut | IllTypedPostfixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to postfix operator %a. Available \ signatures: @[%a@.@]@[Instead supplied argument of \ incompatible type: %a.@]" - Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut + Operator.pp op Std_library_utils.pp_signatures sigs UnsizedType.pp ut end module IdentifierError = struct diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index 5d4d4e08bf..cd8dc077ee 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -407,7 +407,7 @@ let pp_assignmentoperator_sigs ppf (lt, errors) = | errors, _ -> Some (errors, true) in let pp_sigs ppf (signatures, omitted) = Fmt.pf ppf "@[%a%a@]" - (Fmt.list ~sep:Fmt.cut Std_library_utils.pp_math_sig) + (Fmt.list ~sep:Fmt.cut Std_library_utils.pp_signature) signatures (if omitted then Fmt.pf else Fmt.nop) "@ (Additional signatures omitted)" in diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index 6802fd5db5..ec2296da71 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -14,13 +14,11 @@ type deprecation_info = ; canonicalize_away: bool } [@@deriving sexp] -let pp_math_sig ppf ((rt, args, mem_pattern) : signature) = +let pp_signature ppf ((rt, args, mem_pattern) : signature) = UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) -let pp_math_sigs ppf (sigs : signature list) = - (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs - -let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs +let pp_signatures ppf (sigs : signature list) = + (Fmt.list ~sep:Fmt.cut pp_signature) ppf sigs let dist_name_suffix (check : string -> bool) udf_names name = let is_udf_name s = diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecker.mli index 288a3c9d51..54cadc5bcc 100644 --- a/src/frontend/Typechecker.mli +++ b/src/frontend/Typechecker.mli @@ -11,33 +11,22 @@ A type environment {!val:Environment.t} is used to hold variables and functions, including Stan math functions. This is a functional map, meaning it is handled immutably. - - This module is parameterized over a Standard Library of function signatures, See - [Std_library_utils.Library]. For the main compiler, this is - [Stan_math_backend.Stan_math_library] *) open Ast -val model_name : string ref -(** A reference to hold the model name. Relevant for checking variable - clashes and used in code generation. *) - -val check_that_all_functions_have_definition : bool ref -(** A switch to determine whether we check that all functions have a definition *) - val check_program_exn : untyped_program -> typed_program * Warnings.t list (** - Type check a full Stan program. - Can raise [Errors.SemanticError] + Type check a full Stan program. + Can raise [Errors.SemanticError] *) val check_program : untyped_program -> (typed_program * Warnings.t list, Semantic_error.t) result (** - The safe version of [check_program_exn]. This catches - all [Errors.SemanticError] exceptions and converts them - into a [Result.t] + The safe version of [check_program_exn]. This catches + all [Errors.SemanticError] exceptions and converts them + into a [Result.t] *) val operator_return_type : @@ -49,3 +38,10 @@ val library_function_return_type : string -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> Middle.UnsizedType.returntype option + +val model_name : string ref +(** A reference to hold the model name. Relevant for checking variable + clashes and used in code generation. *) + +val check_that_all_functions_have_definition : bool ref +(** A switch to determine whether we check that all functions have a definition *) diff --git a/src/frontend/Typechecker_utils.ml b/src/frontend/Typechecker_utils.ml index 3a92d9af29..b26269f2c4 100644 --- a/src/frontend/Typechecker_utils.ml +++ b/src/frontend/Typechecker_utils.ml @@ -4,6 +4,10 @@ open Middle open Ast module Env = Environment +let mk_fun_app ~is_cond_dist (kind, id, arguments) = + if is_cond_dist then CondDistApp (kind, id, arguments) + else FunApp (kind, id, arguments) + let calculate_autodifftype current_block origin ut = match origin with | Env.(Param | TParam | Model | Functions) diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index 529c7fe6da..8d4219702e 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -3,9 +3,9 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Middle open Analysis_and_optimization open Stan_math_backend +open Middle (** The main program. *) let version = "%%NAME%%3 %%VERSION%%" diff --git a/src/stancjs/stancjs.ml b/src/stancjs/stancjs.ml index af795a9416..49d7d9b9f2 100644 --- a/src/stancjs/stancjs.ml +++ b/src/stancjs/stancjs.ml @@ -1,9 +1,9 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Middle open Analysis_and_optimization open Stan_math_backend +open Middle open Js_of_ocaml let version = "%%NAME%% %%VERSION%%" diff --git a/test/unit/Optimize.ml b/test/unit/Optimize.ml index 93c8c98bd3..67154bbbe0 100644 --- a/test/unit/Optimize.ml +++ b/test/unit/Optimize.ml @@ -1,8 +1,8 @@ open Core_kernel +open Analysis_and_optimization.Optimize open Middle open Common open Analysis_and_optimization.Mir_utils -open Analysis_and_optimization.Optimize let reset_and_mir_of_string s = Gensym.reset_danger_use_cautiously () ; From 95412b4163a0d925c084887440d12940c802dabb Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 3 May 2022 12:47:26 -0400 Subject: [PATCH 04/19] Update docs --- docs/core_ideas.mld | 28 ++++--------------- docs/exposing_new_functions.mld | 2 +- .../Stan_math_signatures.mli | 7 +++++ .../stan_math_library/Library.ml | 4 +++ 4 files changed, 18 insertions(+), 23 deletions(-) diff --git a/docs/core_ideas.mld b/docs/core_ideas.mld index b05bfec243..2e929ca99d 100644 --- a/docs/core_ideas.mld +++ b/docs/core_ideas.mld @@ -65,30 +65,14 @@ This takes some getting used to, and also can lead to some unhelpful type signat VSCode, because abbreviations are not always used in hover-over text. For example, [Expr.Typed.t], the MIR's typed expression type, actually has a signature of [Expr.Typed.Meta.t Expr.Fixed.t]. -{1 The [Library] interface and functors} +{1 The [Library] virtual module} -Many modules of stanc are modeled as OCaml {{:https://ocaml.org/learn/tutorials/functors.html}functors}, -which take in another module as input and produce a module as output. For the most part, -these functors expect an instance of the [Library] interface defined in -[src/frontend/Std_library_utils.ml]. +The [Frontend] library is a {{:https://dune.readthedocs.io/en/latest/variants.html}virtual library} +where the module [Library] is unimplemented. This allows the rest of the library to operate without +making backend-specific assumptions about any one library. -This module primarily contains signatures for the Stan standard library. For most users, -you can assume this will be filled in with [src/stan_math_backend/Stan_math_library.ml], -the object representing the {{:https://github.com/stan-dev/math}stan-dev/math} C++ library. - -Usages of these functors are rather simple, e.g. in the core stanc driver the line - -{[ -module Typechecker = Typechecking.Make (Stan_math_library) -]} - -defines a module [Typechecker] by supplying the functor [Typechecking.Make] with -the Stan C++ library module. After this, [Typechecker.check_program] will typecheck -an AST against those specific functions. - -As noted in the above tutorial link, the syntax of functors is often the hardest part -of using and understanding them. The functors which accept [Library] are all relatively -simple, and should serve as good examples to beginners with the concept. +This must be supplied when the executable is built in the [dune] file. +For Stanc, we supply the [stan_math_library] instantiation defined in [src/stan_math_backend/stan_math_library]. {1 The [Fmt] library and pretty-printing} diff --git a/docs/exposing_new_functions.mld b/docs/exposing_new_functions.mld index 087d5bffa6..aadfbb50e5 100644 --- a/docs/exposing_new_functions.mld +++ b/docs/exposing_new_functions.mld @@ -7,7 +7,7 @@ For a function to be built into Stan, it has to be included in the Stan Math library and its signature has to be exposed to the compiler. -To do the latter, we have to add a corresponding line in [src/stan_math_backend/Stan_math_library.ml]. +To do the latter, we have to add a corresponding line in [src/stan_math_backend/Stan_math_signatures.ml]. The compiler uses the signatures defined there to do type checking. diff --git a/src/stan_math_backend/Stan_math_signatures.mli b/src/stan_math_backend/Stan_math_signatures.mli index a2effe455f..a601be69f4 100644 --- a/src/stan_math_backend/Stan_math_signatures.mli +++ b/src/stan_math_backend/Stan_math_signatures.mli @@ -1,6 +1,13 @@ (** This module stores a table of all signatures from the Stan math C++ library which are exposed to Stan, and some helper functions for dealing with those signatures. + + It is used by + {!modules: + + Stan_math_library + } + to create a [Library] instance which can be consumed by the frontend of the compiler. *) open Middle diff --git a/src/stan_math_backend/stan_math_library/Library.ml b/src/stan_math_backend/stan_math_library/Library.ml index 625e0cef23..a70d151e98 100644 --- a/src/stan_math_backend/stan_math_library/Library.ml +++ b/src/stan_math_backend/stan_math_library/Library.ml @@ -1,3 +1,7 @@ +(** This is the {e implementation} of the Library virtual module + for the Stan Math C++ backend + *) + open Core_kernel open Std_library_utils From 1a03bf70fa850a77f31daa0949ceae30b125e102 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 4 May 2022 09:30:47 -0400 Subject: [PATCH 05/19] Cleanup --- src/frontend/Library.mli | 2 +- src/frontend/Typechecker.ml | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/frontend/Library.mli b/src/frontend/Library.mli index 3841a7fa9e..38ca7dd15f 100644 --- a/src/frontend/Library.mli +++ b/src/frontend/Library.mli @@ -1,6 +1,6 @@ (** This is a {{:https://dune.readthedocs.io/en/latest/variants.html}virtual module} which is filled in at link time with a module specifying a backend-specific - a backend-specific Stan library. *) + Stan library. *) open Middle open Core_kernel diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index f2ba72df39..2b6fc4ee49 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -70,6 +70,9 @@ let reserved_keywords = let std_library_tenv : Env.t = Env.make_from_library Library.function_signatures +let matching_library_function = + SignatureMismatch.matching_function std_library_tenv + let verify_identifier id : unit = if id.name = !model_name then Semantic_error.ident_is_model_name id.id_loc id.name |> error @@ -179,9 +182,7 @@ let library_function_return_type name arg_tys = match name with | x when Library.is_variadic_function_name x -> Library.variadic_function_returntype x - | _ -> - SignatureMismatch.matching_function std_library_tenv name arg_tys - |> match_to_rt_option + | _ -> matching_library_function name arg_tys |> match_to_rt_option let operator_return_type op arg_tys = match (op, arg_tys) with @@ -191,7 +192,7 @@ let operator_return_type op arg_tys = | _ -> Library.operator_to_function_names op |> List.filter_map ~f:(fun name -> - SignatureMismatch.matching_function std_library_tenv name arg_tys + matching_library_function name arg_tys |> function | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) | _ -> None ) @@ -200,8 +201,7 @@ let operator_return_type op arg_tys = let assignmentoperator_return_type assop arg_tys = ( match assop with | Operator.Divide -> - SignatureMismatch.matching_function std_library_tenv "divide" arg_tys - |> match_to_rt_option + matching_library_function "divide" arg_tys |> match_to_rt_option | Plus | Minus | Times | EltTimes | EltDivide -> operator_return_type assop arg_tys |> Option.map ~f:fst | _ -> None ) From 6af15ae59baa50420fd6fe59c0ec289a606b4bed Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 4 May 2022 15:10:55 -0400 Subject: [PATCH 06/19] Add docstrings to Library interface --- src/frontend/Library.mli | 49 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/src/frontend/Library.mli b/src/frontend/Library.mli index 38ca7dd15f..a020227672 100644 --- a/src/frontend/Library.mli +++ b/src/frontend/Library.mli @@ -7,21 +7,49 @@ open Core_kernel open Std_library_utils val function_signatures : (string, signature list) Hashtbl.t -(** Mapping from names to signature(s) of functions *) +(** Mapping from names to signature(s) of functions + Used in [Environment] to produce the base type environment +*) val distribution_families : string list +(** A list of the families of distribution are available, + e.g. "normal", "bernoulli". Used to produce better + errors in typechecking *) val is_stdlib_function_name : string -> bool -(** Equivalent to [Hashtbl.mem function_signatures s]*) +(** Equivalent to [Hashtbl.mem function_signatures s] *) val get_signatures : string -> signature list (** Equivalent to [Hashtbl.find_multi function_signatures s]*) val get_operator_signatures : Operator.t -> signature list +(* Much like get_signatures, this returns the available + signatures for a given operator, which may represent + multiple internal functions *) + val get_assignment_operator_signatures : Operator.t -> signature list +(* Get the signatures allowed when this operator is used as an infix + assignment, e.g. [a += b]. By convention these have returntype [Void] +*) + val is_not_overloadable : string -> bool +(** This controls which functions the typechecker will not allow + to be overloaded with additional signatures from user defined functions. + In the Stan C++ library, this is equal to [is_variadic_function_name], + but it could be more or less broad, to the limit of disallowing all overloading + by setting it equal to [is_stdlib_function_name] +*) + val is_variadic_function_name : string -> bool +(** Variadic functions are {b not} included in the normal signatures + above, but instead recognized by this function and special-cased during + typechecking +*) + val variadic_function_returntype : string -> UnsizedType.returntype option +(** We currently have the restriction that variadic functions must have the same + return type regardless of their argument types. This function should return that type, + or None if it is given a name that is not a variadic function. *) val check_variadic_fn : Ast.identifier @@ -36,6 +64,23 @@ val check_variadic_fn : bespoke per-function. *) val operator_to_function_names : Operator.t -> string list +(** This should return the name of function(s) (in the above signature table) + which handle the given operator. For example, in Stan's C++ backend, + [Operator.Plus] is represented by the functions [["add"]] +*) + val string_operator_to_function_name : string -> string +(** Serves the same role as [operator_to_function_names], + but the input is the serialized string version used later + in the compiler, e.g. [Operator.PMinus] is ["PMinus__"] + *) + val deprecated_distributions : deprecation_info String.Map.t +(** This should map any deprecated distribution functions, e.g. "normal_log" + to information about their replacements and removal version. +*) + val deprecated_functions : deprecation_info String.Map.t +(** This should map any deprecated distribution functions, e.g. "cov_exp_quad" + to information about their replacements and removal version. +*) From 140f969b88daaa6cd98e2fd822157c61ee9e4fb5 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 5 May 2022 10:18:11 -0400 Subject: [PATCH 07/19] Move if_else deprecation --- src/frontend/Deprecation_analysis.ml | 8 -- .../stan_math_library/Library.ml | 3 +- test/integration/good/warning/pretty.expected | 104 +++++++++--------- 3 files changed, 54 insertions(+), 61 deletions(-) diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index e4a23df9e0..a448bef63b 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -90,14 +90,6 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list) , "Use of the `abs` function with real-valued arguments is \ deprecated; use function `fabs` instead." ) ] ) e - | FunApp (StanLib FnPlain, {name= "if_else"; _}, l) -> - acc - @ [ ( emeta.loc - , "The function `if_else` is deprecated and will be removed in Stan \ - 2.32.0. Use the conditional operator (x ? y : z) instead; this \ - can be automatically changed using the canonicalize flag for \ - stanc" ) ] - @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> let w = match Map.find stan_lib_deprecations name with diff --git a/src/stan_math_backend/stan_math_library/Library.ml b/src/stan_math_backend/stan_math_library/Library.ml index a70d151e98..eb058a1391 100644 --- a/src/stan_math_backend/stan_math_library/Library.ml +++ b/src/stan_math_backend/stan_math_library/Library.ml @@ -45,7 +45,8 @@ let deprecated_functions = ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") ; ("integrate_ode_bdf", ode "ode_bdf") - ; ("integrate_ode_adams", ode "ode_adams") ] + ; ("integrate_ode_adams", ode "ode_adams") + ; ("if_else", std "the conditional operator (x ? y : z)") ] (** This function is responsible for typechecking varadic function calls. It needs to live in the Library since this is usually diff --git a/test/integration/good/warning/pretty.expected b/test/integration/good/warning/pretty.expected index 84a4e467df..82d695eb0c 100644 --- a/test/integration/good/warning/pretty.expected +++ b/test/integration/good/warning/pretty.expected @@ -1712,58 +1712,58 @@ model { y_p ~ normal(0, 1); } -Warning in 'if_else.stan', line 9, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 10, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 11, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 12, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 21, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 22, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 23, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 24, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 26, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 27, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 28, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 29, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 30, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc +Warning in 'if_else.stan', line 9, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 10, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 11, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 12, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 21, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 22, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 23, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 24, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 26, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 27, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 28, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 29, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 30, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc $ ../../../../../install/default/bin/stanc --auto-format increment_log_prob.stan transformed data { int n; From fe7a140c8f34acc2f0e474b4d07311ec93be798f Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 5 May 2022 12:14:32 -0400 Subject: [PATCH 08/19] Remove unused dune stanza --- src/stan_math_backend/stan_math_library/dune | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stan_math_backend/stan_math_library/dune b/src/stan_math_backend/stan_math_library/dune index 549fdd4069..df25df84af 100644 --- a/src/stan_math_backend/stan_math_library/dune +++ b/src/stan_math_backend/stan_math_library/dune @@ -4,6 +4,5 @@ (libraries core_kernel middle stan_math_backend) (implements frontend) (private_modules variadic_typechecking) - (inline_tests) (preprocess (pps ppx_jane ppx_deriving.fold ppx_deriving.map ppx_deriving.show))) From 4c143c2f2c6f800f662cd935e9aba06f96db2ae7 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 10 May 2022 09:14:44 -0400 Subject: [PATCH 09/19] Simplify include path splitting --- src/stanc/stanc.ml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index 13532b9197..c73d32f3d7 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -191,16 +191,12 @@ let options = ) ; ( "--include-paths" , Arg.String - (fun str -> - Preprocessor.include_paths := String.split_on_chars ~on:[','] str ) + (fun str -> Preprocessor.include_paths := String.split ~on:',' str) , " Takes a comma-separated list of directories that may contain a file \ in an #include directive (default = \"\")" ) ; ( "--include_paths" , Arg.String - (fun str -> - Preprocessor.include_paths := - !Preprocessor.include_paths @ String.split_on_chars ~on:[','] str - ) + (fun str -> Preprocessor.include_paths := String.split ~on:',' str) , " Deprecated. Same as --include-paths. Will be removed in Stan 2.32.0" ) ; ( "--use-opencl" From e6d7bf6ecb6e8ab879be66ff9317ceaba33017ee Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 12 May 2022 09:15:50 -0400 Subject: [PATCH 10/19] Update dune project to prevent opam file from changing --- dune-project | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dune-project b/dune-project index 1afc3dfd55..b623f5f940 100644 --- a/dune-project +++ b/dune-project @@ -13,8 +13,8 @@ (ppx_deriving (= 5.2.1)) (fmt (= 0.8.8)) (yojson (= 1.7.0)) - (ocamlformat (and :with-test (= 0.19))) - (merlin (and :with-test (= 4.3.1))) + (ocamlformat (and :with-test (= 0.19.0))) + (merlin :with-test) (utop :with-test) (ocp-indent :with-test) (patdiff :with-test) From ec5ab3b17f7ebc31480f5398054e348fd748e1f6 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 22 Sep 2022 17:42:40 -0400 Subject: [PATCH 11/19] Empty commit From 56efb2da5c0e9bbebe76679077bd2b5a3efe0c0a Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Oct 2022 12:00:42 -0400 Subject: [PATCH 12/19] Empty commit From 37aa03e586cc115a3e54115ab97d5da6ae922e08 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Oct 2022 12:07:47 -0400 Subject: [PATCH 13/19] Force rename try? --- src/stan_math_backend/stan_math_library/Library.ml | 2 +- .../{Special_typechecking.ml => Stan_math_extras.ml} | 0 src/stan_math_backend/stan_math_library/dune | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/stan_math_backend/stan_math_library/{Special_typechecking.ml => Stan_math_extras.ml} (100%) diff --git a/src/stan_math_backend/stan_math_library/Library.ml b/src/stan_math_backend/stan_math_library/Library.ml index 6990c3f464..7933024d6c 100644 --- a/src/stan_math_backend/stan_math_library/Library.ml +++ b/src/stan_math_backend/stan_math_library/Library.ml @@ -8,7 +8,7 @@ open Std_library_utils (** Many of the required functions are exposed in the backend specific file, so we include it *) include Stan_math_backend.Stan_math_signatures -include Special_typechecking +include Stan_math_extras let deprecated_distributions = List.concat_map distributions ~f:(fun (fnkinds, name, _, _) -> diff --git a/src/stan_math_backend/stan_math_library/Special_typechecking.ml b/src/stan_math_backend/stan_math_library/Stan_math_extras.ml similarity index 100% rename from src/stan_math_backend/stan_math_library/Special_typechecking.ml rename to src/stan_math_backend/stan_math_library/Stan_math_extras.ml diff --git a/src/stan_math_backend/stan_math_library/dune b/src/stan_math_backend/stan_math_library/dune index 3d3deb692b..c7aaa7abe0 100644 --- a/src/stan_math_backend/stan_math_library/dune +++ b/src/stan_math_backend/stan_math_library/dune @@ -3,6 +3,6 @@ (public_name stanc.stan_math_library) (libraries core_kernel middle stan_math_backend) (implements frontend) - (private_modules special_typechecking) + (private_modules stan_math_extras) (preprocess (pps ppx_jane ppx_deriving.fold ppx_deriving.map ppx_deriving.show))) From 72b64de8664766453da0a7ccb9c3093d9e745152 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Oct 2022 12:14:25 -0400 Subject: [PATCH 14/19] Try to force different stash --- Jenkinsfile | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 471ee38336..fb3df5eeb4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -113,7 +113,7 @@ pipeline { dune subst """ - stash 'Stanc3Setup' + stash 'Stanc3Setup-debug' def stanMathSigs = ['test/integration/signatures/stan_math_signatures.t'].join(" ") skipExpressionTests = utils.verifyChanges(stanMathSigs, "master") @@ -149,7 +149,7 @@ pipeline { } } steps { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runShell(""" eval \$(opam env) dune build @install @@ -176,7 +176,7 @@ pipeline { } } steps { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" sh """ eval \$(opam env) make format || @@ -209,7 +209,7 @@ pipeline { } } steps { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runShell(""" eval \$(opam env) dune runtest @@ -225,7 +225,7 @@ pipeline { } } steps { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runShell(""" eval \$(opam env) dune build @runjstest @@ -253,7 +253,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/compile-tests-good"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" script { runPerformanceTests("../test/integration/good", params.stanc_flags) } @@ -286,7 +286,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-tests-example"){ script { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runPerformanceTests("example-models", params.stanc_flags) } @@ -322,7 +322,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/compile-good-O1"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" script { runPerformanceTests("../test/integration/good", "--O1") } @@ -360,7 +360,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-example-O1"){ script { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runPerformanceTests("example-models", "--O1") } @@ -397,7 +397,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-end-to-end"){ script { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" unstash 'ubuntu-exe' sh """ git clone --recursive --depth 50 https://github.com/stan-dev/performance-tests-cmdstan @@ -467,7 +467,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-end-to-end-O=1"){ script { - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" unstash 'ubuntu-exe' sh """ git clone --recursive --depth 50 https://github.com/stan-dev/performance-tests-cmdstan @@ -530,7 +530,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/compile-expressions"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" unstash 'ubuntu-exe' script { sh """ @@ -570,7 +570,7 @@ pipeline { agent { label 'osx' } steps { dir("${env.WORKSPACE}/osx"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" withEnv(['SDKROOT=/Library/Developer/CommandLineTools/SDKs/MacOSX10.11.sdk', 'MACOSX_DEPLOYMENT_TARGET=10.11']) { runShell(""" export PATH=/Users/jenkins/brew/bin:\$PATH @@ -604,7 +604,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/stancjs"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runShell(""" eval \$(opam env) dune build --root=. --profile release src/stancjs @@ -634,7 +634,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runShell(""" eval \$(opam env) dune build @install --profile static --root=. @@ -664,7 +664,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-mips64el"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" sh "bash -x scripts/build_multiarch_stanc3.sh mips64el" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-mips64el-stanc" @@ -693,7 +693,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-ppc64el"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" sh "bash -x scripts/build_multiarch_stanc3.sh ppc64el" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-ppc64el-stanc" stash name:'linux-ppc64el-exe', includes:'bin/*' @@ -720,7 +720,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-s390x"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" sh "bash -x scripts/build_multiarch_stanc3.sh s390x" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-s390x-stanc" stash name:'linux-s390x-exe', includes:'bin/*' @@ -747,7 +747,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-arm64"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" sh "bash -x scripts/build_multiarch_stanc3.sh arm64" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-arm64-stanc" stash name:'linux-arm64-exe', includes:'bin/*' @@ -774,7 +774,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-armhf"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" sh "bash -x scripts/build_multiarch_stanc3.sh armhf" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-armhf-stanc" stash name:'linux-armhf-exe', includes:'bin/*' @@ -801,7 +801,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-armel"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" sh "bash -x scripts/build_multiarch_stanc3.sh armel" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-armel-stanc" stash name:'linux-armel-exe', includes:'bin/*' @@ -828,7 +828,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/windows"){ - unstash "Stanc3Setup" + unstash "Stanc3Setup-debug" runShell(""" eval \$(opam env) dune build -x windows --root=. From 438320efc978638105c1a2f2c6d0ad5b915156ef Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Oct 2022 12:20:41 -0400 Subject: [PATCH 15/19] More debug --- Jenkinsfile | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index fb3df5eeb4..1b47da8fdf 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -113,7 +113,7 @@ pipeline { dune subst """ - stash 'Stanc3Setup-debug' + stash 'Stanc3Setup' def stanMathSigs = ['test/integration/signatures/stan_math_signatures.t'].join(" ") skipExpressionTests = utils.verifyChanges(stanMathSigs, "master") @@ -149,7 +149,7 @@ pipeline { } } steps { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runShell(""" eval \$(opam env) dune build @install @@ -176,7 +176,7 @@ pipeline { } } steps { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" sh """ eval \$(opam env) make format || @@ -209,7 +209,7 @@ pipeline { } } steps { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runShell(""" eval \$(opam env) dune runtest @@ -225,8 +225,9 @@ pipeline { } } steps { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runShell(""" + rm src/stan_math_backend/stan_math_library/*_typechecking.ml eval \$(opam env) dune build @runjstest """) @@ -253,7 +254,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/compile-tests-good"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" script { runPerformanceTests("../test/integration/good", params.stanc_flags) } @@ -286,7 +287,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-tests-example"){ script { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runPerformanceTests("example-models", params.stanc_flags) } @@ -322,7 +323,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/compile-good-O1"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" script { runPerformanceTests("../test/integration/good", "--O1") } @@ -360,7 +361,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-example-O1"){ script { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runPerformanceTests("example-models", "--O1") } @@ -397,7 +398,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-end-to-end"){ script { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" unstash 'ubuntu-exe' sh """ git clone --recursive --depth 50 https://github.com/stan-dev/performance-tests-cmdstan @@ -467,7 +468,7 @@ pipeline { steps { dir("${env.WORKSPACE}/compile-end-to-end-O=1"){ script { - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" unstash 'ubuntu-exe' sh """ git clone --recursive --depth 50 https://github.com/stan-dev/performance-tests-cmdstan @@ -530,7 +531,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/compile-expressions"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" unstash 'ubuntu-exe' script { sh """ @@ -570,7 +571,7 @@ pipeline { agent { label 'osx' } steps { dir("${env.WORKSPACE}/osx"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" withEnv(['SDKROOT=/Library/Developer/CommandLineTools/SDKs/MacOSX10.11.sdk', 'MACOSX_DEPLOYMENT_TARGET=10.11']) { runShell(""" export PATH=/Users/jenkins/brew/bin:\$PATH @@ -604,7 +605,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/stancjs"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runShell(""" eval \$(opam env) dune build --root=. --profile release src/stancjs @@ -634,7 +635,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runShell(""" eval \$(opam env) dune build @install --profile static --root=. @@ -664,7 +665,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-mips64el"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" sh "bash -x scripts/build_multiarch_stanc3.sh mips64el" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-mips64el-stanc" @@ -693,7 +694,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-ppc64el"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" sh "bash -x scripts/build_multiarch_stanc3.sh ppc64el" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-ppc64el-stanc" stash name:'linux-ppc64el-exe', includes:'bin/*' @@ -720,7 +721,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-s390x"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" sh "bash -x scripts/build_multiarch_stanc3.sh s390x" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-s390x-stanc" stash name:'linux-s390x-exe', includes:'bin/*' @@ -747,7 +748,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-arm64"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" sh "bash -x scripts/build_multiarch_stanc3.sh arm64" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-arm64-stanc" stash name:'linux-arm64-exe', includes:'bin/*' @@ -774,7 +775,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-armhf"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" sh "bash -x scripts/build_multiarch_stanc3.sh armhf" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-armhf-stanc" stash name:'linux-armhf-exe', includes:'bin/*' @@ -801,7 +802,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/linux-armel"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" sh "bash -x scripts/build_multiarch_stanc3.sh armel" sh "mkdir -p bin && mv `find _build -name stanc.exe` bin/linux-armel-stanc" stash name:'linux-armel-exe', includes:'bin/*' @@ -828,7 +829,7 @@ pipeline { } steps { dir("${env.WORKSPACE}/windows"){ - unstash "Stanc3Setup-debug" + unstash "Stanc3Setup" runShell(""" eval \$(opam env) dune build -x windows --root=. From 58b3e030c11f85512e801b9fc94800df8240da65 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Oct 2022 12:24:59 -0400 Subject: [PATCH 16/19] Delete files after ocaml tests --- Jenkinsfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 1b47da8fdf..b0cf8f8cce 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -215,6 +215,7 @@ pipeline { dune runtest """) } + post { always { runShell("rm -rf ./*") }} } stage("stancjs tests") { agent { @@ -227,11 +228,11 @@ pipeline { steps { unstash "Stanc3Setup" runShell(""" - rm src/stan_math_backend/stan_math_library/*_typechecking.ml eval \$(opam env) dune build @runjstest """) } + post { always { runShell("rm -rf ./*") }} } } } From 0f44e26bae1b89424d1dce04e5e14c5250b186dc Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 14 Oct 2022 12:34:10 -0400 Subject: [PATCH 17/19] Restore jenkinsfile --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index b0cf8f8cce..471ee38336 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -215,7 +215,6 @@ pipeline { dune runtest """) } - post { always { runShell("rm -rf ./*") }} } stage("stancjs tests") { agent { @@ -232,7 +231,6 @@ pipeline { dune build @runjstest """) } - post { always { runShell("rm -rf ./*") }} } } } From 20dcf0a0e57724385c33f8665414caa09be888cc Mon Sep 17 00:00:00 2001 From: Nicusor Serban <48496524+serban-nicusor-toptal@users.noreply.github.com> Date: Fri, 14 Oct 2022 18:37:46 +0200 Subject: [PATCH 18/19] Ensure ocaml tests proper cleanup --- Jenkinsfile | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 471ee38336..352e3e9ce9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -209,12 +209,15 @@ pipeline { } } steps { - unstash "Stanc3Setup" - runShell(""" - eval \$(opam env) - dune runtest - """) + dir("${env.WORKSPACE}/dune-tests"){ + unstash "Stanc3Setup" + runShell(""" + eval \$(opam env) + dune --root=. runtest + """) + } } + post { always { runShell("rm -rf ${env.WORKSPACE}/dune-tests/*") }} } stage("stancjs tests") { agent { @@ -225,12 +228,15 @@ pipeline { } } steps { - unstash "Stanc3Setup" - runShell(""" - eval \$(opam env) - dune build @runjstest - """) + dir("${env.WORKSPACE}/stancjs-tests"){ + unstash "Stanc3Setup" + runShell(""" + eval \$(opam env) + dune --root=. build @runjstest + """) + } } + post { always { runShell("rm -rf ${env.WORKSPACE}/stancjs-tests/*") }} } } } From f30e7d9e9d22b905924417a81e6a26fbe9eeb55b Mon Sep 17 00:00:00 2001 From: Nicusor Serban <48496524+serban-nicusor-toptal@users.noreply.github.com> Date: Fri, 14 Oct 2022 18:41:37 +0200 Subject: [PATCH 19/19] Change order of --root in dune calls --- Jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 352e3e9ce9..1fb56649cc 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -213,7 +213,7 @@ pipeline { unstash "Stanc3Setup" runShell(""" eval \$(opam env) - dune --root=. runtest + dune runtest --root=. """) } } @@ -232,7 +232,7 @@ pipeline { unstash "Stanc3Setup" runShell(""" eval \$(opam env) - dune --root=. build @runjstest + dune build @runjstest --root=. """) } }