diff --git a/docs/core_ideas.mld b/docs/core_ideas.mld index 111cd474c0..2e929ca99d 100644 --- a/docs/core_ideas.mld +++ b/docs/core_ideas.mld @@ -65,6 +65,15 @@ This takes some getting used to, and also can lead to some unhelpful type signat VSCode, because abbreviations are not always used in hover-over text. For example, [Expr.Typed.t], the MIR's typed expression type, actually has a signature of [Expr.Typed.Meta.t Expr.Fixed.t]. +{1 The [Library] virtual module} + +The [Frontend] library is a {{:https://dune.readthedocs.io/en/latest/variants.html}virtual library} +where the module [Library] is unimplemented. This allows the rest of the library to operate without +making backend-specific assumptions about any one library. + +This must be supplied when the executable is built in the [dune] file. +For Stanc, we supply the [stan_math_library] instantiation defined in [src/stan_math_backend/stan_math_library]. + {1 The [Fmt] library and pretty-printing} We extensively use the {{:https://erratique.ch/software/fmt}Fmt} library for our pretty-printing and code diff --git a/docs/exposing_new_functions.mld b/docs/exposing_new_functions.mld index 7194ccc03c..89be15ef49 100644 --- a/docs/exposing_new_functions.mld +++ b/docs/exposing_new_functions.mld @@ -7,7 +7,7 @@ For a function to be built into Stan, it has to be included in the Stan Math library and its signature has to be exposed to the compiler. -To do the latter, we have to add a corresponding line in [src/middle/Stan_math_signatures.ml]. +To do the latter, we have to add a corresponding line in [src/stan_math_backend/Stan_math_signatures.ml]. The compiler uses the signatures defined there to do type checking. @@ -130,12 +130,13 @@ For example, the following line defines the signature [add(real, matrix) => matr Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length list of arguments, are {b NOT} added to this list. -"Nice" variadic functions are added to the hashtable [Stan_math_signatures.stan_math_variadic_signatures]. +"Nice" variadic functions are added to the hashtable [Library.variadic_signatures]. This is probably sufficient for most variadic functions, e.g. all the ODE solvers and DAE solvers are done via this method. [reduce_sum] is not "nice", since it is both variadic and {e polymorphic}, requiring certain arguments to have the same (but {e not predetermined}) type. Therefore, [reduce_sum] is treated as special case in the [Typechecker] -module in the frontend folder. +module in the frontend folder. These are instead handled by special functions like [is_special_function_name]. They +must also be given custom typechecking rules in the private sub-module [Special_typechecking]. Note that higher-order functions also usually require changes to the C++ code generation to work properly. It is best to consult an existing example of how these are done before proceeding. diff --git a/src/analysis_and_optimization/Debug_data_generation.ml b/src/analysis_and_optimization/Debug_data_generation.ml index e3166aa0e4..a6858b4837 100644 --- a/src/analysis_and_optimization/Debug_data_generation.ml +++ b/src/analysis_and_optimization/Debug_data_generation.ml @@ -30,7 +30,7 @@ let rec vect_to_mat l m = let eval_expr m e = let e = Mir_utils.subst_expr m e in - let e = Partial_evaluator.eval_expr e in + let e = Partial_evaluator.try_eval_expr e in let rec strip_promotions (e : Middle.Expr.Typed.t) = match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e in diff --git a/src/analysis_and_optimization/Dependence_analysis.ml b/src/analysis_and_optimization/Dependence_analysis.ml index 0bc86aeeb8..7b4a28951b 100644 --- a/src/analysis_and_optimization/Dependence_analysis.ml +++ b/src/analysis_and_optimization/Dependence_analysis.ml @@ -4,7 +4,7 @@ open Middle open Dataflow_types open Mir_utils open Dataflow_utils -open Monotone_framework_sigs +open Monotone_framework_intf open Monotone_framework (***********************************) @@ -119,9 +119,8 @@ let mir_reaching_definitions (mir : Program.Typed.t) (stmt : Stmt.Located.t) : Map.Poly.map rd_map ~f:(fun {entry; exit} -> {entry= to_rd_set entry; exit= to_rd_set exit} ) -let all_labels - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - : int Set.Poly.t = +let all_labels (module Flowgraph : FLOWGRAPH with type labels = int) : + int Set.Poly.t = let step set = Set.Poly.union set (union_map set ~f:(fun l -> Map.Poly.find_exn Flowgraph.successors l)) diff --git a/src/analysis_and_optimization/Memory_patterns.ml b/src/analysis_and_optimization/Memory_patterns.ml index 1b795a25e1..c368f13153 100644 --- a/src/analysis_and_optimization/Memory_patterns.ml +++ b/src/analysis_and_optimization/Memory_patterns.ml @@ -111,15 +111,14 @@ let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t) let query_stan_math_mem_pattern_support (name : string) (args : (UnsizedType.autodifftype * UnsizedType.t) list) = - let open Stan_math_signatures in match name with - | x when is_stan_math_variadic_function_name x -> false - | x when is_reduce_sum_fn x -> false + | x when Frontend.Library.is_variadic_function_name x -> false + | x when Frontend.Library.is_special_function_name x -> false | _ -> let name = - string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name) - in - let namematches = Hashtbl.find_multi stan_math_signatures name in + Frontend.Library.string_operator_to_function_name + (Utils.stdlib_distribution_name name) in + let namematches = Frontend.Library.get_signatures name in let filteredmatches = List.filter ~f:(fun x -> diff --git a/src/analysis_and_optimization/Monotone_framework.ml b/src/analysis_and_optimization/Monotone_framework.ml index 442781b210..83c231c694 100644 --- a/src/analysis_and_optimization/Monotone_framework.ml +++ b/src/analysis_and_optimization/Monotone_framework.ml @@ -2,7 +2,7 @@ open Core_kernel open Core_kernel.Poly -open Monotone_framework_sigs +open Monotone_framework_intf open Mir_utils open Middle @@ -864,7 +864,7 @@ let rec declared_variables_stmt (List.map ~f:(fun x -> declared_variables_stmt x.pattern) l) let propagation_mfp (prog : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (propagation_transfer : (int, Stmt.Located.Non_recursive.t) Map.Poly.t @@ -897,7 +897,7 @@ let propagation_mfp (prog : Program.Typed.t) Mf.mfp () let reaching_definitions_mfp (mir : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let variables = ( module struct @@ -918,7 +918,7 @@ let reaching_definitions_mfp (mir : Program.Typed.t) Mf.mfp () let initialized_vars_mfp (total : string Set.Poly.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let (module Lattice) = dual_powerset_lattice_empty_initial @@ -949,8 +949,7 @@ let globals (prog : Program.Typed.t) = (** Monotone framework instance for live_variables analysis. Expects reverse flowgraph. *) let live_variables_mfp (prog : Program.Typed.t) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let never_kill = globals prog in let variables = @@ -970,10 +969,8 @@ let live_variables_mfp (prog : Program.Typed.t) (** Instantiate all four instances of the monotone framework for lazy code motion, reusing code between them *) -let lazy_expressions_mfp - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) +let lazy_expressions_mfp (module Flowgraph : FLOWGRAPH with type labels = int) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let all_expressions = used_subexpressions_stmt @@ -1031,8 +1028,7 @@ let lazy_expressions_mfp * *) let minimal_variables_mfp - (module Circular_Fwd_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Circular_Fwd_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (initial_variables : string Set.Poly.t) (gen_variable : diff --git a/src/analysis_and_optimization/Monotone_framework_sigs.mli b/src/analysis_and_optimization/Monotone_framework_intf.ml similarity index 100% rename from src/analysis_and_optimization/Monotone_framework_sigs.mli rename to src/analysis_and_optimization/Monotone_framework_intf.ml diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index 0f039b97f3..c45510ffec 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -628,7 +628,7 @@ let list_collapsing (mir : Program.Typed.t) = let propagation (propagation_transfer : (int, Stmt.Located.Non_recursive.t) Map.Poly.t - -> (module Monotone_framework_sigs.TRANSFER_FUNCTION + -> (module Monotone_framework_intf.TRANSFER_FUNCTION with type labels = int and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t option ) ) (mir : Program.Typed.t) = @@ -711,7 +711,7 @@ let dead_code_elimination (mir : Program.Typed.t) = let dead_code_elim_stmt_base i stmt = (* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *) let live_variables_s = - (Map.find_exn live_variables i).Monotone_framework_sigs.entry in + (Map.find_exn live_variables i).Monotone_framework_intf.entry in match stmt with | Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) -> if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then stmt @@ -1183,11 +1183,6 @@ let optimize_soa (mir : Program.Typed.t) = List.fold ~init:Set.Poly.empty ~f:(Memory_patterns.query_initial_demotable_stmt false) mir.log_prob in - (* - let print_set s = - Set.Poly.iter ~f:print_endline s in - let () = print_set initial_variables in - *) let mod_exprs aos_exits mod_expr = Mir_utils.map_rec_expr (Memory_patterns.modify_expr_pattern aos_exits) @@ -1207,8 +1202,6 @@ let optimize_soa (mir : Program.Typed.t) = in {mir with log_prob= transform' mir.log_prob} -(* Apparently you need to completely copy/paste type definitions between - ml and mli files?*) type optimization_settings = { function_inlining: bool ; static_loop_unrolling: bool diff --git a/src/analysis_and_optimization/Optimize.mli b/src/analysis_and_optimization/Optimize.mli index 2885ce9b44..524e5e4e71 100644 --- a/src/analysis_and_optimization/Optimize.mli +++ b/src/analysis_and_optimization/Optimize.mli @@ -1,4 +1,5 @@ (* Code for optimization passes on the MIR *) + open Middle val function_inlining : Program.Typed.t -> Program.Typed.t @@ -59,7 +60,7 @@ val optimize_ad_levels : Program.Typed.t -> Program.Typed.t val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t (** Marks Decl types such that, if the first assignment after the decl - assigns to the full object, allow the object to be constructed but + assigns to the full object, allow the object to be constructed but not uninitialized. *) (** Interface for turning individual optimizations on/off. Useful for testing diff --git a/src/analysis_and_optimization/Partial_evaluator.ml b/src/analysis_and_optimization/Partial_evaluator.ml index 2a9941c5e6..70afeb5a0b 100644 --- a/src/analysis_and_optimization/Partial_evaluator.ml +++ b/src/analysis_and_optimization/Partial_evaluator.ml @@ -106,11 +106,11 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = Operator.of_string_opt name |> Option.value_map ~f:(fun op -> - Frontend.Typechecker.operator_stan_math_return_type op + Frontend.Typechecker.operator_return_type op argument_types |> Option.map ~f:fst ) ~default: - (Frontend.Typechecker.stan_math_return_type name + (Frontend.Typechecker.library_function_return_type name argument_types ) in let try_partially_evaluate_stanlib e = Expr.Fixed.Pattern.( diff --git a/src/analysis_and_optimization/dune b/src/analysis_and_optimization/dune index b656409cb0..8f6c74f5dd 100644 --- a/src/analysis_and_optimization/dune +++ b/src/analysis_and_optimization/dune @@ -1,9 +1,8 @@ (library (name analysis_and_optimization) (public_name stanc.analysis) - (libraries core_kernel str fmt common middle frontend) - (inline_tests) - ;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation - (modules_without_implementation monotone_framework_sigs) + (libraries core_kernel str fmt common middle frontend stan_math_backend) + (inline_tests + (libraries stan_math_library)) (preprocess (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 1ea33e888d..7446a0b100 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -107,8 +107,8 @@ let truncate_dist ud_dists (id : Ast.identifier) ast_obs ast_args t = | None -> ( Ast.StanLib FnPlain , Set.to_list possible_names |> List.hd_exn - , if Stan_math_signatures.is_stan_math_function_name (id.name ^ "_lpmf") - then UnsizedType.UInt + , if Library.is_stdlib_function_name (id.name ^ "_lpmf") then + UnsizedType.UInt else UnsizedType.UReal (* close enough *) ) in let trunc cond_op (x : Ast.typed_expression) y = let smeta = x.Ast.emeta.loc in @@ -429,7 +429,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) = | Ast.IncrementLogProb e | Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap | Ast.Tilde {arg; distribution; args; truncation} -> let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name in + Std_library_utils.dist_name_suffix Library.is_stdlib_function_name + ud_dists distribution.name in let name = distribution.name ^ suffix in let kind = let possible_names = diff --git a/src/frontend/Canonicalize.ml b/src/frontend/Canonicalize.ml index 1e84a62681..ab0f5a1aa4 100644 --- a/src/frontend/Canonicalize.ml +++ b/src/frontend/Canonicalize.ml @@ -43,7 +43,7 @@ let rec replace_deprecated_expr if is_deprecated_distribution name then CondDistApp ( StanLib suffix - , {name= rename_deprecated deprecated_distributions name; id_loc} + , {name= rename_deprecated_distribution name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) else if String.is_suffix name ~suffix:"_cdf" then CondDistApp @@ -53,7 +53,7 @@ let rec replace_deprecated_expr else FunApp ( StanLib suffix - , {name= rename_deprecated deprecated_functions name; id_loc} + , {name= rename_deprecated_function name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) | FunApp (UserDefined suffix, {name; id_loc}, e) -> ( match String.Map.find deprecated_userdefined name with diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index 0df3f30f1e..17dd35f772 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -2,48 +2,29 @@ open Core_kernel open Ast open Middle -let deprecated_functions = - String.Map.of_alist_exn - [ ("multiply_log", ("lmultiply", "2.32.0")) - ; ("binomial_coefficient_log", ("lchoose", "2.32.0")) - ; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) - ; ("fabs", ("abs", "2.33.0")) ] - -let deprecated_odes = - String.Map.of_alist_exn - [ ("integrate_ode", ("ode_rk45", "3.0")) - ; ("integrate_ode_rk45", ("ode_rk45", "3.0")) - ; ("integrate_ode_bdf", ("ode_bdf", "3.0")) - ; ("integrate_ode_adams", ("ode_adams", "3.0")) ] - -let deprecated_distributions = - String.Map.of_alist_exn - (List.map - ~f:(fun (x, y) -> (x, (y, "2.32.0"))) - (List.concat_map Middle.Stan_math_signatures.distributions - ~f:(fun (fnkinds, name, _, _) -> - List.filter_map fnkinds ~f:(function - | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") - | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") - | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") - | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") - | Rng | Log | UnaryVectorized _ -> None ) ) ) ) - let stan_lib_deprecations = - Map.merge_skewed deprecated_distributions deprecated_functions + Map.merge_skewed Library.deprecated_distributions Library.deprecated_functions ~combine:(fun ~key x y -> Common.FatalError.fatal_error_msg [%message "Common key in deprecation map" (key : string) - (x : string * string) - (y : string * string)] ) + (x : Std_library_utils.deprecation_info) + (y : Std_library_utils.deprecation_info)] ) let is_deprecated_distribution name = - Option.is_some (Map.find deprecated_distributions name) + Map.mem Library.deprecated_distributions name let rename_deprecated map name = - Map.find map name |> Option.map ~f:fst |> Option.value ~default:name + Map.find map name + |> Option.map ~f:(fun Std_library_utils.{replacement; canonicalize_away; _} -> + if canonicalize_away then replacement else name ) + |> Option.value ~default:name + +let rename_deprecated_distribution = + rename_deprecated Library.deprecated_distributions + +let rename_deprecated_function = rename_deprecated Library.deprecated_functions let distribution_suffix name = let open String in @@ -101,23 +82,13 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list) ({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) : (Location_span.t * string) list = match expr with - | FunApp (StanLib FnPlain, {name= "if_else"; _}, l) -> - acc - @ [ ( emeta.loc - , "The function `if_else` is deprecated and will be removed in Stan \ - 2.32.0. Use the conditional operator (x ? y : z) instead; this \ - can be automatically changed using the canonicalize flag for \ - stanc" ) ] - @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> let w = match Map.find stan_lib_deprecations name with - | Some (rename, version) -> + | Some {replacement; version; extra_message; _} -> [ ( emeta.loc , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. This can be automatically changed using the \ - canonicalize flag for stanc" ) ] + ^ ". Use " ^ replacement ^ " instead. " ^ extra_message ) ] | _ when String.is_suffix name ~suffix:"_cdf" -> [ ( emeta.loc , "Use of " ^ name @@ -125,17 +96,7 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list) of a CDF is deprecated and will be removed in Stan 2.32.0. \ This can be automatically changed using the canonicalize \ flag for stanc" ) ] - | _ -> ( - match Map.find deprecated_odes name with - | Some (rename, version) -> - [ ( emeta.loc - , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. \n\ - The new interface is slightly different, see: \ - https://mc-stan.org/users/documentation/case-studies/convert_odes.html" - ) ] - | _ -> [] ) in + | _ -> [] in acc @ w @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) | PrefixOp (PNot, ({emeta= {type_= UReal; loc; _}; _} as e)) -> let acc = diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index 205e400afd..4a4daf5cf0 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -16,8 +16,7 @@ val collect_userdef_distributions : val distribution_suffix : string -> bool val without_suffix : string list -> string -> string val is_deprecated_distribution : string -> bool -val deprecated_distributions : (string * string) String.Map.t -val deprecated_functions : (string * string) String.Map.t -val rename_deprecated : (string * string) String.Map.t -> string -> string +val rename_deprecated_distribution : string -> string +val rename_deprecated_function : string -> string val userdef_distributions : untyped_statement block option -> string list val collect_warnings : typed_program -> Warnings.t list diff --git a/src/frontend/Environment.ml b/src/frontend/Environment.ml index 5267d6f277..d1674cc433 100644 --- a/src/frontend/Environment.ml +++ b/src/frontend/Environment.ml @@ -27,9 +27,9 @@ type info = type t = info list String.Map.t -let stan_math_environment = +let make_from_library signatures : t = let functions = - Hashtbl.to_alist Stan_math_signatures.stan_math_signatures + Hashtbl.to_alist signatures |> List.map ~f:(fun (key, values) -> ( key , List.map values ~f:(fun (rt, args, mem) -> diff --git a/src/frontend/Environment.mli b/src/frontend/Environment.mli index 0679805867..864e281c8c 100644 --- a/src/frontend/Environment.mli +++ b/src/frontend/Environment.mli @@ -29,8 +29,16 @@ type info = type t -val stan_math_environment : t -(** A type environment which contains the Stan math library functions +val make_from_library : + ( string + , ( UnsizedType.returntype + * (UnsizedType.autodifftype * UnsizedType.t) list + * Mem_pattern.t ) + list ) + Core_kernel.Hashtbl.t + -> t +(** Make a type environment from a hashtable of functions like those from + [Std_library_utils] *) val find : t -> string -> info list diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index f5bb82fedc..93ce165518 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -63,8 +63,8 @@ let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = (funs, distrs) else let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name - in + Std_library_utils.dist_name_suffix Library.is_stdlib_function_name + ud_dists distribution.name in let name = distribution.name ^ Utils.unnormalized_suffix suffix in (funs, Set.add distrs name) | _ -> (funs, distrs) in diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index a21aa8f717..ca0447be43 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -10,7 +10,7 @@ - [type]: the base type of the variable (["int"] or ["real"]). - [dimensions]: the number of dimensions ([0] for a scalar, [1] for a vector or row vector, etc.). - + The JSON object also have the fields [stanlib_calls] and [distributions] containing the name of the standard library functions called and distributions used. diff --git a/src/frontend/Library.mli b/src/frontend/Library.mli new file mode 100644 index 0000000000..d5bec52cbb --- /dev/null +++ b/src/frontend/Library.mli @@ -0,0 +1,97 @@ +(** This is a {{:https://dune.readthedocs.io/en/latest/variants.html}virtual module} + which is filled in at link time with a module specifying a backend-specific + Stan library. *) + +open Middle +open Core_kernel +open Std_library_utils + +val function_signatures : (string, signature list) Hashtbl.t +(** Mapping from names to signature(s) of functions + Used in [Environment] to produce the base type environment +*) + +val variadic_signatures : (string, variadic_signature) Hashtbl.t +(** Mapping from names to description of a variadic function. +Note that these function names cannot be overloaded, and usually require +customized code-gen in the backend. +*) + +val distribution_families : string list +(** A list of the families of distribution are available, + e.g. "normal", "bernoulli". Used to produce better + errors in typechecking *) + +val is_stdlib_function_name : string -> bool +(** Equivalent to [Hashtbl.mem function_signatures s] *) + +val get_signatures : string -> signature list +(** Equivalent to [Hashtbl.find_multi function_signatures s]*) + +val get_operator_signatures : Operator.t -> signature list +(* Much like get_signatures, this returns the available + signatures for a given operator, which may represent + multiple internal functions *) + +val get_assignment_operator_signatures : Operator.t -> signature list +(* Get the signatures allowed when this operator is used as an infix + assignment, e.g. [a += b]. By convention these have returntype [Void] +*) + +val is_not_overloadable : string -> bool +(** This controls which functions the typechecker will not allow + to be overloaded with additional signatures from user defined functions. + In the Stan C++ library, this is equal to [is_variadic_function_name], + but it could be more or less broad, to the limit of disallowing all overloading + by setting it equal to [is_stdlib_function_name] +*) + +val is_variadic_function_name : string -> bool +(** Variadic functions are handled as generally as possible + using the above hashtable +*) + +val is_special_function_name : string -> bool +(** Special functions like [reduce_sum] are {b not} included in the normal signatures + above, but instead recognized by this function and special-cased during + typechecking +*) + +val special_function_returntype : string -> UnsizedType.returntype option +(** We currently have the restriction that variadic functions must have the same + return type regardless of their argument types. This function should return that type, + or None if it is given a name that is not a variadic function. *) + +val check_special_fn : + is_cond_dist:bool + -> Location_span.t + -> Environment.originblock + -> Environment.t + -> Ast.identifier + -> Ast.typed_expression list + -> Ast.typed_expression +(** This function is responsible for typechecking varadic function + calls. It needs to live in the Library since this is usually + bespoke per-function. *) + +val operator_to_function_names : Operator.t -> string list +(** This should return the name of function(s) (in the above signature table) + which handle the given operator. For example, in Stan's C++ backend, + [Operator.Plus] is represented by the functions [["add"]] +*) + +val string_operator_to_function_name : string -> string +(** Serves the same role as [operator_to_function_names], + but the input is the serialized string version used later + in the compiler, e.g. [Operator.PMinus] is ["PMinus__"] + *) + +val deprecated_distributions : deprecation_info String.Map.t +(** This should map any deprecated distribution functions, e.g. "normal_log" + to information about their replacements and removal version. +*) + +val deprecated_functions : deprecation_info String.Map.t +(** This should map any deprecated distribution functions, e.g. "cov_exp_quad" + to information about their replacements and removal version. +*) diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index 8a089e880c..2655038dbe 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -14,19 +14,13 @@ module TypeError = struct | IntIntArrayOrRangeExpected of UnsizedType.t | IntOrRealContainerExpected of UnsizedType.t | ArrayVectorRowVectorMatrixExpected of UnsizedType.t - | IllTypedAssignment of Operator.t * UnsizedType.t * UnsizedType.t + | IllTypedAssignment of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list | IllTypedTernaryIf of UnsizedType.t * UnsizedType.t * UnsizedType.t - | IllTypedReduceSum of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedReduceSumGeneric of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedVariadic of + | IllTypedVariadicFn of string * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list @@ -50,9 +44,15 @@ module TypeError = struct string * UnsizedType.t list * (SignatureMismatch.signature_error list * bool) - | IllTypedBinaryOperator of Operator.t * UnsizedType.t * UnsizedType.t - | IllTypedPrefixOperator of Operator.t * UnsizedType.t - | IllTypedPostfixOperator of Operator.t * UnsizedType.t + | IllTypedBinaryOperator of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list + | IllTypedPrefixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list + | IllTypedPostfixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list | NotIndexable of UnsizedType.t * int let pp ppf = function @@ -101,18 +101,18 @@ module TypeError = struct "Foreach-loop must be over array, vector, row_vector or matrix. \ Instead found expression of type %a." UnsizedType.pp ut - | IllTypedAssignment (Operator.Equals, lt, rt) -> + | IllTypedAssignment (Operator.Equals, lt, rt, _) -> Fmt.pf ppf "Ill-typed arguments supplied to assignment operator =: lhs has type \ %a and rhs has type %a" UnsizedType.pp lt UnsizedType.pp rt - | IllTypedAssignment (op, lt, rt) -> + | IllTypedAssignment (op, lt, rt, sigs) -> Fmt.pf ppf "@[Ill-typed arguments supplied to assignment operator %a=: lhs \ has type %a and rhs has type %a.@ Available signatures for given \ lhs:@]@ %a" Operator.pp op UnsizedType.pp lt UnsizedType.pp rt - SignatureMismatch.pp_math_lib_assignmentoperator_sigs (lt, op) + SignatureMismatch.pp_assignmentoperator_sigs (lt, sigs) | IllTypedTernaryIf (UInt, ut, _) when UnsizedType.is_fun_type ut -> Fmt.pf ppf "Ternary expression cannot have a function type: %a" UnsizedType.pp ut @@ -125,13 +125,7 @@ module TypeError = struct Fmt.pf ppf "Condition in ternary expression must be primitive int; found type=%a" UnsizedType.pp ut1 - | IllTypedReduceSum (name, arg_tys, expected_args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedVariadic (name, arg_tys, args, error, return_type) -> + | IllTypedVariadicFn (name, arg_tys, args, error, return_type) -> SignatureMismatch.pp_signature_mismatch ppf ( name , arg_tys @@ -224,32 +218,25 @@ module TypeError = struct prefix suffix prefix prefix newsuffix | IllTypedFunctionApp (name, arg_tys, errors) -> SignatureMismatch.pp_signature_mismatch ppf (name, arg_tys, errors) - | IllTypedBinaryOperator (op, lt, rt) -> + | IllTypedBinaryOperator (op, lt, rt, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to infix operator %a. Available \ - signatures: %s@[Instead supplied arguments of incompatible type: \ - %a, %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp lt UnsizedType.pp rt - | IllTypedPrefixOperator (op, ut) -> + signatures: @[%a@.@]@[Instead supplied arguments of \ + incompatible type: %a, %a.@]" + Operator.pp op Std_library_utils.pp_signatures sigs UnsizedType.pp lt + UnsizedType.pp rt + | IllTypedPrefixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to prefix operator %a. Available \ - signatures: %s@[Instead supplied argument of incompatible type: \ - %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut - | IllTypedPostfixOperator (op, ut) -> + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" + Operator.pp op Std_library_utils.pp_signatures sigs UnsizedType.pp ut + | IllTypedPostfixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to postfix operator %a. Available \ - signatures: %s\n\ - Instead supplied argument of incompatible type: %a." Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" + Operator.pp op Std_library_utils.pp_signatures sigs UnsizedType.pp ut end module IdentifierError = struct @@ -532,8 +519,8 @@ let int_or_real_container_expected loc ut = let array_vector_rowvector_matrix_expected loc ut = TypeError (loc, TypeError.ArrayVectorRowVectorMatrixExpected ut) -let illtyped_assignment loc assignop lt rt = - TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt)) +let illtyped_assignment loc assignop lt rt sigs = + TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt, sigs)) let illtyped_ternary_if loc predt lt rt = TypeError (loc, TypeError.IllTypedTernaryIf (predt, lt, rt)) @@ -541,17 +528,9 @@ let illtyped_ternary_if loc predt lt rt = let returning_fn_expected_nonreturning_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonReturningFound name) -let illtyped_reduce_sum loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedReduceSum (name, arg_tys, args, error)) - -let illtyped_reduce_sum_generic loc name arg_tys expected_args error = +let illtyped_variadic_fn loc name arg_tys args error return_type = TypeError - ( loc - , TypeError.IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) - ) - -let illtyped_variadic loc name arg_tys args fn_rt error = - TypeError (loc, TypeError.IllTypedVariadic (name, arg_tys, args, error, fn_rt)) + (loc, TypeError.IllTypedVariadicFn (name, arg_tys, args, error, return_type)) let ambiguous_function_promotion loc name arg_tys signatures = TypeError @@ -585,14 +564,14 @@ let nonreturning_fn_expected_undeclaredident_found loc name sug = let illtyped_fn_app loc name errors arg_tys = TypeError (loc, TypeError.IllTypedFunctionApp (name, arg_tys, errors)) -let illtyped_binary_op loc op lt rt = - TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt)) +let illtyped_binary_op loc op lt rt sigs = + TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt, sigs)) -let illtyped_prefix_op loc op ut = - TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut)) +let illtyped_prefix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut, sigs)) -let illtyped_postfix_op loc op ut = - TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut)) +let illtyped_postfix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut, sigs)) let not_indexable loc ut nidcs = TypeError (loc, TypeError.NotIndexable (ut, nidcs)) diff --git a/src/frontend/Semantic_error.mli b/src/frontend/Semantic_error.mli index 34cd59eb16..7252a3d503 100644 --- a/src/frontend/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -25,7 +25,12 @@ val array_vector_rowvector_matrix_expected : Location_span.t -> UnsizedType.t -> t val illtyped_assignment : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t val illtyped_ternary_if : Location_span.t -> UnsizedType.t -> UnsizedType.t -> UnsizedType.t -> t @@ -42,20 +47,13 @@ val returning_fn_expected_undeclared_dist_suffix_found : val returning_fn_expected_wrong_dist_suffix_found : Location_span.t -> string * string -> t -val illtyped_reduce_sum : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - -val illtyped_reduce_sum_generic : +val illtyped_variadic_fn : Location_span.t -> string -> UnsizedType.t list -> (UnsizedType.autodifftype * UnsizedType.t) list -> SignatureMismatch.function_mismatch + -> UnsizedType.t -> t val ambiguous_function_promotion : @@ -66,15 +64,6 @@ val ambiguous_function_promotion : list -> t -val illtyped_variadic : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> UnsizedType.t - -> SignatureMismatch.function_mismatch - -> t - val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t @@ -89,10 +78,27 @@ val illtyped_fn_app : -> t val illtyped_binary_op : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_prefix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_postfix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t -val illtyped_prefix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t -val illtyped_postfix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t val not_indexable : Location_span.t -> UnsizedType.t -> int -> t val ident_is_keyword : Location_span.t -> string -> t val ident_is_model_name : Location_span.t -> string -> t diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index 2719703798..91b68bad95 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -245,9 +245,6 @@ let matching_function env name args = UnsizedType.compare_returntype ret1 ret2 ) in find_compatible_rt function_types args -let matching_stanlib_function = - matching_function Environment.stan_math_environment - let check_variadic_args ~allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys fun_return args = let minimal_func_type = @@ -286,6 +283,20 @@ let check_variadic_args ~allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys | (_, x) :: _ -> TypeMismatch (minimal_func_type, x, None) |> wrap_err | [] -> Error ([], ArgNumMismatch (List.length mandatory_arg_tys, 0)) +let find_matching_first_order_fn tenv matches (fname : Ast.identifier) = + let candidates = + Utils.stdlib_distribution_name fname.name + |> Environment.find tenv |> List.map ~f:matches in + let ok, errs = List.partition_map candidates ~f:Result.to_either in + match unique_minimum_promotion ok with + | Ok a -> UniqueMatch a + | Error (Some promotions) -> + List.filter_map promotions ~f:(function + | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) + | _ -> None ) + |> AmbiguousMatch + | Error None -> SignatureErrors (List.hd_exn errs) + let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) = let open Fmt in let ctx = ref TypeMap.empty in @@ -383,10 +394,8 @@ let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) = (list ~sep:cut pp_signature) sigs pp_omitted () -let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = +let pp_assignmentoperator_sigs ppf (lt, errors) = let signatures = - let errors = - Stan_math_signatures.make_assignmentoperator_stan_math_signatures op in let errors = List.filter ~f:(fun (_, args, _) -> @@ -398,7 +407,7 @@ let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = | errors, _ -> Some (errors, true) in let pp_sigs ppf (signatures, omitted) = Fmt.pf ppf "@[%a%a@]" - (Fmt.list ~sep:Fmt.cut Stan_math_signatures.pp_math_sig) + (Fmt.list ~sep:Fmt.cut Std_library_utils.pp_signature) signatures (if omitted then Fmt.pf else Fmt.nop) "@ (Additional signatures omitted)" in diff --git a/src/frontend/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli index 68fb9b7c76..53027485da 100644 --- a/src/frontend/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -55,12 +55,6 @@ val matching_function : Requires a unique minimum option under type promotion *) -val matching_stanlib_function : - string -> (UnsizedType.autodifftype * UnsizedType.t) list -> match_result -(** Same as [matching_function] but requires specifically that the function - be from StanMath (uses [Environment.stan_math_environment]) -*) - val check_variadic_args : allow_lpdf:bool -> (UnsizedType.autodifftype * UnsizedType.t) list @@ -75,6 +69,15 @@ val check_variadic_args : If none is found, returns [Error] of the list of args and a function_mismatch. *) +val find_matching_first_order_fn : + Environment.t + -> (Environment.info -> (UnsizedType.t * Promotion.t list, 'a) result) + -> Ast.identifier + -> (UnsizedType.t * Promotion.t list, 'a) generic_match_result +(** Given a constraint function [matches], find any signature which exists + Returns the first [Ok] if any exist, or else [Error] +*) + val pp_signature_mismatch : Format.formatter -> string @@ -86,8 +89,8 @@ val pp_signature_mismatch : * bool ) -> unit -val pp_math_lib_assignmentoperator_sigs : - Format.formatter -> UnsizedType.t * Operator.t -> unit +val pp_assignmentoperator_sigs : + Format.formatter -> UnsizedType.t * Std_library_utils.signature list -> unit val compare_errors : function_mismatch -> function_mismatch -> int val compare_match_results : match_result -> match_result -> int diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml new file mode 100644 index 0000000000..3c6256b0a2 --- /dev/null +++ b/src/frontend/Std_library_utils.ml @@ -0,0 +1,33 @@ +open Middle +open Core_kernel + +(* Types for the module representing the standard library *) +type fun_arg = UnsizedType.autodifftype * UnsizedType.t +type signature = UnsizedType.returntype * fun_arg list * Mem_pattern.t + +type variadic_signature = + { return_type: UnsizedType.t + ; control_args: fun_arg list + ; required_fn_rt: UnsizedType.t + ; required_fn_args: fun_arg list } +[@@deriving create] + +type deprecation_info = + { replacement: string + ; version: string + ; extra_message: string + ; canonicalize_away: bool } +[@@deriving sexp] + +let pp_signature ppf ((rt, args, mem_pattern) : signature) = + UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) + +let pp_signatures ppf (sigs : signature list) = + (Fmt.list ~sep:Fmt.cut pp_signature) ppf sigs + +let dist_name_suffix (check : string -> bool) udf_names name = + let is_udf_name s = + List.exists ~f:(fun (n, _) -> String.equal s n) udf_names in + Utils.distribution_suffices + |> List.filter ~f:(fun sfx -> check (name ^ sfx) || is_udf_name (name ^ sfx)) + |> List.hd_exn diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index 3bd84d521c..7636d355cc 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -17,6 +17,7 @@ open Core_kernel open Core_kernel.Poly open Middle open Ast +open Typechecker_utils module Env = Environment (* we only allow errors raised by this function *) @@ -56,22 +57,6 @@ let context block = ; in_udf_dist_def= false ; loop_depth= 0 } -let calculate_autodifftype cf origin ut = - match origin with - | Env.(Param | TParam | Model | Functions) - when not (UnsizedType.is_int_type ut || cf.current_block = GQuant) -> - UnsizedType.AutoDiffable - | _ -> DataOnly - -let arg_type x = (x.emeta.ad_level, x.emeta.type_) -let get_arg_types = List.map ~f:arg_type -let type_of_expr_typed ue = ue.emeta.type_ -let has_int_type ue = ue.emeta.type_ = UInt -let has_int_array_type ue = ue.emeta.type_ = UArray UInt - -let has_int_or_real_type ue = - match ue.emeta.type_ with UInt | UReal -> true | _ -> false - (* -- General checks ---------------------------------------------- *) let reserved_keywords = [ "for"; "in"; "while"; "repeat"; "until"; "if"; "then"; "else"; "true" @@ -83,6 +68,11 @@ let reserved_keywords = ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" ; "static"; "auto" ] +let std_library_tenv : Env.t = Env.make_from_library Library.function_signatures + +let matching_library_function = + SignatureMismatch.matching_function std_library_tenv + let verify_identifier id : unit = if id.name = !model_name then Semantic_error.ident_is_model_name id.id_loc id.name |> error @@ -121,8 +111,7 @@ let verify_name_fresh_udf loc tenv name = if (* variadic functions are currently not in math sigs and aren't overloadable due to their separate typechecking *) - Stan_math_signatures.is_reduce_sum_fn name - || Stan_math_signatures.is_stan_math_variadic_function_name name + Library.is_not_overloadable name then Semantic_error.ident_is_stanmath_name loc name |> error else if Utils.is_unnormalized_distribution name then Semantic_error.udf_is_unnormalized_fn loc name |> error @@ -189,38 +178,33 @@ let match_to_rt_option = function | SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt | _ -> None -let stan_math_return_type name arg_tys = - match - Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures name - with +let library_function_return_type name arg_tys = + match Hashtbl.find Library.variadic_signatures name with | Some {return_type; _} -> Some (UnsizedType.ReturnType return_type) - | None when Stan_math_signatures.is_reduce_sum_fn name -> - Some (UnsizedType.ReturnType UReal) - | None -> - SignatureMismatch.matching_stanlib_function name arg_tys - |> match_to_rt_option + | None when Library.is_special_function_name name -> + Library.special_function_returntype name + | None -> matching_library_function name arg_tys |> match_to_rt_option -let operator_stan_math_return_type op arg_tys = +let operator_return_type op arg_tys = match (op, arg_tys) with | Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] -> Some (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion]) | IntDivide, _ -> None | _ -> - Stan_math_signatures.operator_to_stan_math_fns op + Library.operator_to_function_names op |> List.filter_map ~f:(fun name -> - SignatureMismatch.matching_stanlib_function name arg_tys + matching_library_function name arg_tys |> function | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) | _ -> None ) |> List.hd -let assignmentoperator_stan_math_return_type assop arg_tys = +let assignmentoperator_return_type assop arg_tys = ( match assop with | Operator.Divide -> - SignatureMismatch.matching_stanlib_function "divide" arg_tys - |> match_to_rt_option + matching_library_function "divide" arg_tys |> match_to_rt_option | Plus | Minus | Times | EltTimes | EltDivide -> - operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst + operator_return_type assop arg_tys |> Option.map ~f:fst | _ -> None ) |> Option.bind ~f:(function | ReturnType rtype @@ -232,7 +216,7 @@ let assignmentoperator_stan_math_return_type assop arg_tys = | _ -> None ) let check_binop loc op le re = - let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in + let rt = [le; re] |> get_arg_types |> operator_return_type op in match rt with | Some (ReturnType type_, [p1; p2]) -> mk_typed_expression @@ -241,27 +225,34 @@ let check_binop loc op le re = ~type_ ~loc | _ -> Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ + (Library.get_operator_signatures op) |> error let check_prefixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in + let rt = operator_return_type op [arg_type te] in match rt with | Some (ReturnType type_, _) -> mk_typed_expression ~expr:(PrefixOp (op, te)) ~ad_level:(expr_ad_lub [te]) ~type_ ~loc - | _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error + | _ -> + Semantic_error.illtyped_prefix_op loc op te.emeta.type_ + (Library.get_operator_signatures op) + |> error let check_postfixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in + let rt = operator_return_type op [arg_type te] in match rt with | Some (ReturnType type_, _) -> mk_typed_expression ~expr:(PostfixOp (te, op)) ~ad_level:(expr_ad_lub [te]) ~type_ ~loc - | _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error + | _ -> + Semantic_error.illtyped_postfix_op loc op te.emeta.type_ + (Library.get_operator_signatures op) + |> error let check_id cf loc tenv id = match Env.find tenv (Utils.stdlib_distribution_name id.name) with @@ -270,7 +261,7 @@ let check_id cf loc tenv id = (Env.nearest_ident tenv id.name) |> error | {kind= `StanMath; _} :: _ -> - ( calculate_autodifftype cf MathLibrary UMathLibraryFunction + ( calculate_autodifftype cf.current_block MathLibrary UMathLibraryFunction , UnsizedType.UMathLibraryFunction ) | {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _ when cf.in_toplevel_decl -> @@ -282,16 +273,16 @@ let check_id cf loc tenv id = || cf.current_block = Model ) -> Semantic_error.invalid_unnormalized_fn loc |> error | {kind= `Variable {origin; _}; type_} :: _ -> - (calculate_autodifftype cf origin type_, type_) + (calculate_autodifftype cf.current_block origin type_, type_) | { kind= `UserDefined | `UserDeclared _ ; type_= UFun (args, rt, FnLpdf _, mem_pattern) } :: _ -> let type_ = UnsizedType.UFun (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - (calculate_autodifftype cf Functions type_, type_) + (calculate_autodifftype cf.current_block Functions type_, type_) | {kind= `UserDefined | `UserDeclared _; type_} :: _ -> - (calculate_autodifftype cf Functions type_, type_) + (calculate_autodifftype cf.current_block Functions type_, type_) let check_variable cf loc tenv id = let ad_level, type_ = check_id cf loc tenv id in @@ -409,7 +400,6 @@ let inferred_ad_type_of_indexed at uindices = uindices ) (* function checking *) - let verify_conddist_name loc id = if List.exists @@ -458,24 +448,17 @@ let verify_unnormalized cf loc id = && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) then Semantic_error.invalid_unnormalized_fn loc |> error -let mk_fun_app ~is_cond_dist (x, y, z) = - if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z) - let check_normal_fn ~is_cond_dist loc tenv id es = match Env.find tenv (Utils.normalized_name id.name) with | {kind= `Variable _; _} :: _ (* variables can sometimes shadow stanlib functions, so we have to check this *) - when not - (Stan_math_signatures.is_stan_math_function_name - (Utils.normalized_name id.name) ) -> + when not (Library.is_stdlib_function_name (Utils.normalized_name id.name)) + -> Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error | [] -> ( match Utils.split_distribution_suffix id.name with | Some (prefix, suffix) -> ( - let known_families = - List.map - ~f:(fun (_, y, _, _) -> y) - Stan_math_signatures.distributions in + let known_families = Library.distribution_families in let is_known_family s = List.mem known_families s ~equal:String.equal in match suffix with @@ -532,113 +515,18 @@ let check_normal_fn ~is_cond_dist loc tenv id es = |> Semantic_error.illtyped_fn_app loc id.name (l, b) |> error ) -(** Given a constraint function [matches], find any signature which exists - Returns the first [Ok] if any exist, or else [Error] -*) -let find_matching_first_order_fn tenv matches fname = - let candidates = - Utils.stdlib_distribution_name fname.name - |> Env.find tenv |> List.map ~f:matches in - let ok, errs = List.partition_map candidates ~f:Result.to_either in - match SignatureMismatch.unique_minimum_promotion ok with - | Ok a -> SignatureMismatch.UniqueMatch a - | Error (Some promotions) -> - List.filter_map promotions ~f:(function - | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) - | _ -> None ) - |> AmbiguousMatch - | Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs) - -let make_function_variable cf loc id = function - | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | UnsizedType.UFun _ as type_ -> - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | type_ -> - Common.FatalError.fatal_error_msg - [%message - "Attempting to create function variable out of " - (type_ : UnsizedType.t)] - let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list) = - if Stan_math_signatures.is_stan_math_variadic_function_name id.name then - check_variadic ~is_cond_dist loc cf tenv id tes - else if Stan_math_signatures.is_reduce_sum_fn id.name then - check_reduce_sum ~is_cond_dist loc cf tenv id tes + if Library.is_special_function_name id.name then + Library.check_special_fn ~is_cond_dist loc cf.current_block tenv id tes + else if Library.is_variadic_function_name id.name then + check_variadic ~is_cond_dist loc cf.current_block tenv id tes else check_normal_fn ~is_cond_dist loc tenv id tes -(** Reduce sum is a special case, even compared to the other - variadic functions, because it is polymorphic in the type of the - first argument. The first, fourth, and fifth arguments must agree, - which is too complicated to be captured declaratively. *) -and check_reduce_sum ~is_cond_dist loc cf tenv id tes = - let basic_mismatch () = - let mandatory_args = - UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in - let mandatory_fun_args = - UnsizedType. - [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in - SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args - mandatory_fun_args UReal (get_arg_types tes) in - let matching remaining_es fn = - match fn with - | Env. - { type_= - UnsizedType.UFun - (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as - ftype - ; _ } - when List.mem Stan_math_signatures.reduce_sum_slice_types - sliced_arg_fun_type ~equal:( = ) -> - let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in - let mandatory_fun_args = - [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args - mandatory_fun_args UReal arg_types - | _ -> basic_mismatch () in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - (* a valid signature exists *) - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_reduce_sum loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> - let expected_args, err = - basic_mismatch () |> Result.error |> Option.value_exn in - Semantic_error.illtyped_reduce_sum_generic loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error - and check_variadic ~is_cond_dist loc cf tenv id tes = - let Stan_math_signatures. + let Std_library_utils. {control_args; required_fn_args; required_fn_rt; return_type} = - Hashtbl.find_exn Stan_math_signatures.stan_math_variadic_signatures id.name - in + Hashtbl.find_exn Library.variadic_signatures id.name in let matching remaining_es Env.{type_= ftype; _} = let arg_types = (calculate_autodifftype cf Functions ftype, ftype) @@ -647,7 +535,10 @@ and check_variadic ~is_cond_dist loc cf tenv id tes = required_fn_args required_fn_rt arg_types in match tes with | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with | SignatureMismatch.UniqueMatch (ftype, promotions) -> let tes = make_function_variable cf loc fname ftype :: remaining_es in mk_typed_expression @@ -659,18 +550,18 @@ and check_variadic ~is_cond_dist loc cf tenv id tes = Semantic_error.ambiguous_function_promotion loc fname.name None ps |> error | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic loc id.name + Semantic_error.illtyped_variadic_fn loc id.name (List.map ~f:type_of_expr_typed tes) - expected_args required_fn_rt err + expected_args err required_fn_rt |> error ) | _ -> let expected_args, err = SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args required_fn_args required_fn_rt (get_arg_types tes) |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic loc id.name + Semantic_error.illtyped_variadic_fn loc id.name (List.map ~f:type_of_expr_typed tes) - expected_args required_fn_rt err + expected_args err required_fn_rt |> error and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) = @@ -811,7 +702,8 @@ and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : Semantic_error.target_plusequals_outside_model_or_logprob loc |> error else mk_typed_expression ~expr:GetLP - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) ~type_:UReal ~loc | GetTarget -> (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) @@ -823,7 +715,8 @@ and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : Semantic_error.target_plusequals_outside_model_or_logprob loc |> error else mk_typed_expression ~expr:GetTarget - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) ~type_:UReal ~loc | ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc | RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc @@ -874,7 +767,7 @@ let check_nrfn loc tenv id es = match Env.find tenv id.name with | {kind= `Variable _; _} :: _ (* variables can shadow stanlib functions, so we have to check this *) - when not (Stan_math_signatures.is_stan_math_function_name id.name) -> + when not (Library.is_stdlib_function_name id.name) -> Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error | [] -> Semantic_error.nonreturning_fn_expected_undeclaredident_found loc id.name @@ -934,9 +827,9 @@ let verify_assignment_non_function loc ut id = | _ -> () let check_assignment_operator loc assop lhs rhs = - let err op = + let err op sigs = Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_ - in + sigs in match assop with | Assign | ArrowAssign -> ( match @@ -944,11 +837,13 @@ let check_assignment_operator loc assop lhs rhs = rhs.emeta.type_ with | Ok p -> Promotion.promote rhs p - | Error _ -> err Operator.Equals |> error ) + | Error _ -> err Operator.Equals [] |> error ) | OperatorAssign op -> ( let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in - let return_type = assignmentoperator_stan_math_return_type op args in - match return_type with Some Void -> rhs | _ -> err op |> error ) + let return_type = assignmentoperator_return_type op args in + match return_type with + | Some Void -> rhs + | _ -> err op (Library.get_assignment_operator_signatures op) |> error ) let check_lvalue cf tenv = function | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> @@ -1422,7 +1317,7 @@ and check_var_decl_initial_value loc cf tenv {identifier; initial_value} = | Ok p -> Ast.{identifier; initial_value= Some (Promotion.promote rhs p)} | Error _ -> Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_ - rhs.emeta.type_ + rhs.emeta.type_ [] |> error ) | None -> Ast.{identifier; initial_value= None} @@ -1715,7 +1610,7 @@ let check_program_exn ; comments } as ast ) = warnings := [] ; (* create a new type environment which has only stan-math functions *) - let tenv = Env.stan_math_environment in + let tenv = std_library_tenv in let tenv, typed_fb = check_toplevel_block Functions tenv fb in verify_functions_have_defn tenv typed_fb ; let tenv, typed_db = check_toplevel_block Data tenv db in diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecker.mli index 07662dfe11..54cadc5bcc 100644 --- a/src/frontend/Typechecker.mli +++ b/src/frontend/Typechecker.mli @@ -19,7 +19,7 @@ val check_program_exn : untyped_program -> typed_program * Warnings.t list (** Type check a full Stan program. Can raise [Errors.SemanticError] -*) + *) val check_program : untyped_program -> (typed_program * Warnings.t list, Semantic_error.t) result @@ -27,21 +27,21 @@ val check_program : The safe version of [check_program_exn]. This catches all [Errors.SemanticError] exceptions and converts them into a [Result.t] -*) + *) -val operator_stan_math_return_type : +val operator_return_type : Middle.Operator.t -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> (Middle.UnsizedType.returntype * Promotion.t list) option -val stan_math_return_type : +val library_function_return_type : string -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> Middle.UnsizedType.returntype option val model_name : string ref (** A reference to hold the model name. Relevant for checking variable - clashes and used in code generation. *) + clashes and used in code generation. *) val check_that_all_functions_have_definition : bool ref (** A switch to determine whether we check that all functions have a definition *) diff --git a/src/frontend/Typechecker_utils.ml b/src/frontend/Typechecker_utils.ml new file mode 100644 index 0000000000..d0161fdba3 --- /dev/null +++ b/src/frontend/Typechecker_utils.ml @@ -0,0 +1,43 @@ +open Core_kernel +open Core_kernel.Poly +open Middle +open Ast +module Env = Environment + +let mk_fun_app ~is_cond_dist (kind, id, arguments) = + if is_cond_dist then CondDistApp (kind, id, arguments) + else FunApp (kind, id, arguments) + +let calculate_autodifftype current_block origin ut = + match origin with + | Env.(Param | TParam | Model | Functions) + when not (UnsizedType.is_int_type ut || current_block = Env.GQuant) -> + UnsizedType.AutoDiffable + | _ -> DataOnly + +let make_function_variable current_block loc id = function + | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | UnsizedType.UFun _ as type_ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | type_ -> + Common.FatalError.fatal_error_msg + [%message + "Attempting to create function variable out of " + (type_ : UnsizedType.t)] + +let arg_type x = (x.emeta.ad_level, x.emeta.type_) +let get_arg_types = List.map ~f:arg_type +let type_of_expr_typed ue = ue.emeta.type_ +let has_int_type ue = ue.emeta.type_ = UInt +let has_int_array_type ue = ue.emeta.type_ = UArray UInt + +let has_int_or_real_type ue = + match ue.emeta.type_ with UInt | UReal -> true | _ -> false diff --git a/src/frontend/dune b/src/frontend/dune index 0359a63c1d..8341170ebb 100644 --- a/src/frontend/dune +++ b/src/frontend/dune @@ -2,9 +2,12 @@ (name frontend) (public_name stanc.frontend) (libraries core_kernel re menhirLib fmt middle common yojson) - (inline_tests) + (virtual_modules library) + (inline_tests + ; we need to specify an implementation of the virtual library for testing + (libraries stan_math_library)) (preprocess - (pps ppx_jane ppx_deriving.fold ppx_deriving.map))) + (pps ppx_jane ppx_deriving.fold ppx_deriving.map ppx_deriving.create))) (ocamllex lexer) diff --git a/src/middle/Stan_math_signatures.mli b/src/middle/Stan_math_signatures.mli deleted file mode 100644 index a1444bf08c..0000000000 --- a/src/middle/Stan_math_signatures.mli +++ /dev/null @@ -1,86 +0,0 @@ -(** This module stores a table of all signatures from the Stan - math C++ library which are exposed to Stan, and some helper - functions for dealing with those signatures. -*) - -open Core_kernel - -(** Function arguments are represented by their type an autodiff - type. This is [AutoDiffable] for everything except arguments - marked with the data keyword *) -type fun_arg = UnsizedType.autodifftype * UnsizedType.t - -(** Signatures consist of a return type, a list of arguments, and a flag - for whether or not those arguments can be Struct of Arrays objects *) -type signature = UnsizedType.returntype * fun_arg list * Mem_pattern.t - -val stan_math_signatures : (string, signature list) Hashtbl.t -(** Mapping from names to signature(s) of functions *) - -val is_stan_math_function_name : string -> bool -(** Equivalent to [Hashtbl.mem stan_math_signatures s]*) - -type variadic_signature = - { return_type: UnsizedType.t - ; control_args: fun_arg list - ; required_fn_rt: UnsizedType.t - ; required_fn_args: fun_arg list } - -val stan_math_variadic_signatures : (string, variadic_signature) Hashtbl.t -(** Mapping from names to description of a variadic function. - - Note that these function names cannot be overloaded, and usually require - customized code-gen in the backend. -*) - -val is_stan_math_variadic_function_name : string -> bool -(** Equivalent to [Hashtbl.mem stan_math_variadic_signatures s]*) - -(** Pretty printers *) - -val pp_math_sig : signature Fmt.t -val pretty_print_all_math_sigs : unit Fmt.t -val pretty_print_all_math_distributions : unit Fmt.t - -type dimensionality -type return_behavior - -type fkind = private - | Lpmf - | Lpdf - | Log - | Rng - | Cdf - | Ccdf - | UnaryVectorized of return_behavior -[@@deriving show {with_path= false}] - -val distributions : - (fkind list * string * dimensionality list * Mem_pattern.t) list -(** The distribution {e families} exposed by the math library *) - -val dist_name_suffix : (string * 'a) list -> string -> string - -(** Helpers for dealing with operators as signatures *) - -val operator_to_stan_math_fns : Operator.t -> string list -val string_operator_to_stan_math_fns : string -> string -val pretty_print_math_lib_operator_sigs : Operator.t -> string list -val make_assignmentoperator_stan_math_signatures : Operator.t -> signature list - -(** Special functions for the variadic signatures exposed *) - -(* reduce_sum helpers *) -val is_reduce_sum_fn : string -> bool -val reduce_sum_slice_types : UnsizedType.t list - -(** These are only used in code-gen, typing is done via [stan_math_variadic_signatures] *) - -(* variadic ODE helpers *) -val is_variadic_ode_fn : string -> bool -val ode_tolerances_suffix : string -val variadic_ode_adjoint_fn : string - -(* variadic DAE helpers *) -val is_variadic_dae_fn : string -> bool -val dae_tolerances_suffix : string diff --git a/src/middle/dune b/src/middle/dune index 65e20b7aba..220ee14c6e 100644 --- a/src/middle/dune +++ b/src/middle/dune @@ -4,9 +4,4 @@ (libraries core_kernel str fmt common re) (inline_tests) (preprocess - (pps - ppx_jane - ppx_deriving.map - ppx_deriving.fold - ppx_deriving.create - ppx_deriving.show))) + (pps ppx_jane ppx_deriving.map ppx_deriving.fold ppx_deriving.create))) diff --git a/src/stan_math_backend/Library_utils.ml b/src/stan_math_backend/Library_utils.ml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/middle/Stan_math_signatures.ml b/src/stan_math_backend/Stan_math_signatures.ml similarity index 94% rename from src/middle/Stan_math_signatures.ml rename to src/stan_math_backend/Stan_math_signatures.ml index 90a68b7ddc..b19fce54a7 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/stan_math_backend/Stan_math_signatures.ml @@ -1,7 +1,8 @@ (** The signatures of the Stan Math library, which are used for type checking *) -open Core_kernel +open Core_kernel open Core_kernel.Poly +open Middle (** The "dimensionality" (bad name?) is supposed to help us represent the vectorized nature of many Stan functions. It allows us to represent when @@ -90,20 +91,13 @@ type fkind = type fun_arg = UnsizedType.autodifftype * UnsizedType.t type signature = UnsizedType.returntype * fun_arg list * Mem_pattern.t -type variadic_signature = - { return_type: UnsizedType.t - ; control_args: fun_arg list - ; required_fn_rt: UnsizedType.t - ; required_fn_args: fun_arg list } -[@@deriving create] - let is_primitive = function | UnsizedType.UReal -> true | UInt -> true | _ -> false (** The signatures hash table *) -let (stan_math_signatures : (string, signature list) Hashtbl.t) = +let (function_signatures : (string, signature list) Hashtbl.t) = String.Table.create () (** All of the signatures that are added by hand, rather than the ones @@ -111,13 +105,6 @@ let (stan_math_signatures : (string, signature list) Hashtbl.t) = let (manual_stan_math_signatures : (string, signature list) Hashtbl.t) = String.Table.create () -(** The variadic signatures hash table - - These functions cannot be overloaded. -*) -let (stan_math_variadic_signatures : (string, variadic_signature) Hashtbl.t) = - String.Table.create () - (* XXX The correct word here isn't combination - what is it? *) let all_combinations xx = List.fold_right xx ~init:[[]] ~f:(fun x accum -> @@ -297,6 +284,9 @@ let distributions = ; ([Lpdf], "wishart_cholesky", [DMatrix; DReal; DMatrix], SoA) ; ([Lpdf; Log], "wishart", [DMatrix; DReal; DMatrix], SoA) ] +let distribution_families = + List.map ~f:(fun (_, name, _, _) -> name) distributions + let basic_vectorized = UnaryVectorized IntsToReals let math_sigs = @@ -366,27 +356,11 @@ let all_declarative_sigs = distributions @ math_sigs let declarative_fnsigs = List.concat_map ~f:mk_declarative_sig all_declarative_sigs -let is_stan_math_function_name name = +let is_stdlib_function_name name = let name = Utils.stdlib_distribution_name name in - Hashtbl.mem stan_math_signatures name - -let is_stan_math_variadic_function_name name = - Hashtbl.mem stan_math_variadic_signatures name + Hashtbl.mem function_signatures name -let dist_name_suffix udf_names name = - let is_udf_name s = List.exists ~f:(fun (n, _) -> n = s) udf_names in - match - Utils.distribution_suffices - |> List.filter ~f:(fun sfx -> - is_stan_math_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) - |> List.hd - with - | Some hd -> hd - | None -> - Common.FatalError.fatal_error_msg - [%message "Couldn't find distribution " name] - -let operator_to_stan_math_fns op = +let operator_to_function_names op = match op with | Operator.Plus -> ["add"] | PPlus -> ["plus"] @@ -412,21 +386,15 @@ let operator_to_stan_math_fns op = | PNot -> ["logical_negation"] | Transpose -> ["transpose"] -let int_divide_type = - UnsizedType. - ( ReturnType UInt - , [(AutoDiffable, UInt); (AutoDiffable, UInt)] - , Mem_pattern.AoS ) - -let get_sigs name = +let get_signatures name = let name = Utils.stdlib_distribution_name name in - Hashtbl.find_multi stan_math_signatures name |> List.sort ~compare + Hashtbl.find_multi function_signatures name |> List.sort ~compare -let make_assignmentoperator_stan_math_signatures assop = +let get_assignment_operator_signatures assop = ( match assop with | Operator.Divide -> ["divide"] - | assop -> operator_to_stan_math_fns assop ) - |> List.concat_map ~f:get_sigs + | assop -> operator_to_function_names assop ) + |> List.concat_map ~f:get_signatures |> List.concat_map ~f:(function | ReturnType rtype, [(ad1, lhs); (ad2, rhs)], _ when rtype = lhs @@ -439,15 +407,7 @@ let make_assignmentoperator_stan_math_signatures assop = else [(Void, [(ad1, lhs); (ad2, rhs)], SoA)] | _ -> [] ) -let pp_math_sig ppf (rt, args, mem_pattern) = - UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) - -let pp_math_sigs ppf name = - (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf (get_sigs name) - -let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs - -let string_operator_to_stan_math_fns str = +let string_operator_to_function_name str = match str with | "Plus__" -> "add" | "PPlus__" -> "plus" @@ -482,10 +442,10 @@ let pretty_print_all_math_sigs ppf () = (List.map ~f:snd args) UnsizedType.pp_returntype rt in let pp_sigs_for_name ppf name = (list ~sep:cut pp_sig) ppf - (List.map ~f:(fun t -> (name, t)) (get_sigs name)) in + (List.map ~f:(fun t -> (name, t)) (get_signatures name)) in pf ppf "@[%a@]" (list ~sep:cut pp_sigs_for_name) - (List.sort ~compare (Hashtbl.keys stan_math_signatures)) + (List.sort ~compare (Hashtbl.keys function_signatures)) let pretty_print_all_math_distributions ppf () = let open Fmt in @@ -495,16 +455,20 @@ let pretty_print_all_math_distributions ppf () = (List.map ~f:(Fn.compose String.lowercase show_fkind) kinds) in pf ppf "@[%a@]" (list ~sep:cut pp_dist) distributions -let pretty_print_math_lib_operator_sigs op = - if op = Operator.IntDivide then - [Fmt.str "@[@,%a@]" pp_math_sig int_divide_type] - else operator_to_stan_math_fns op |> List.map ~f:pretty_print_math_sigs +let int_divide_type = + UnsizedType. + ( ReturnType UInt + , [(AutoDiffable, UInt); (AutoDiffable, UInt)] + , Mem_pattern.AoS ) + +let get_operator_signatures op = + if op = Operator.IntDivide then [int_divide_type] + else operator_to_function_names op |> List.concat_map ~f:get_signatures (* -- Some helper definitions to populate stan_math_signatures -- *) let add_qualified (name, rt, argts, supports_soa) = - Hashtbl.add_multi stan_math_signatures ~key:name - ~data:(rt, argts, supports_soa) + Hashtbl.add_multi function_signatures ~key:name ~data:(rt, argts, supports_soa) let add_nullary name = add_unqualified (name, UnsizedType.ReturnType UReal, [], AoS) @@ -748,7 +712,7 @@ let for_vector_types s = List.iter ~f:s vector_types (* -- Start populating stan_math_signaturess -- *) let () = List.iter declarative_fnsigs ~f:(fun (key, rt, args, mem_pattern) -> - Hashtbl.add_multi stan_math_signatures ~key ~data:(rt, args, mem_pattern) ) ; + Hashtbl.add_multi function_signatures ~key ~data:(rt, args, mem_pattern) ) ; add_unqualified ("acos", ReturnType UComplex, [UComplex], AoS) ; add_unqualified ("acosh", ReturnType UComplex, [UComplex], AoS) ; List.iter @@ -2558,64 +2522,7 @@ let () = for type-checking *) Hashtbl.iteri manual_stan_math_signatures ~f:(fun ~key ~data -> List.iter data ~f:(fun data -> - Hashtbl.add_multi stan_math_signatures ~key ~data ) ) - -(* variadics *) - -let reduce_sum_allowed_dimensionalities = [1; 2; 3; 4; 5; 6; 7] - -let reduce_sum_slice_types = - let base_slice_type i = - [ bare_array_type (UnsizedType.UReal, i) - ; bare_array_type (UnsizedType.UInt, i) - ; bare_array_type (UnsizedType.UMatrix, i) - ; bare_array_type (UnsizedType.UVector, i) - ; bare_array_type (UnsizedType.URowVector, i) ] in - List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities) - -(* Variadic ODE *) -let variadic_ode_adjoint_ctl_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal) - (* real relative_tolerance_forward *) - ; (DataOnly, UVector) (* vector absolute_tolerance_forward *) - ; (DataOnly, UReal) (* real relative_tolerance_backward *) - ; (DataOnly, UVector) (* real absolute_tolerance_backward *) - ; (DataOnly, UReal) (* real relative_tolerance_quadrature *) - ; (DataOnly, UReal) (* real absolute_tolerance_quadrature *) - ; (DataOnly, UInt) (* int max_num_steps *) - ; (DataOnly, UInt) (* int num_steps_between_checkpoints *) - ; (DataOnly, UInt) (* int interpolation_polynomial *) - ; (DataOnly, UInt) (* int solver_forward *); (DataOnly, UInt) - (* int solver_backward *) ] - -let variadic_ode_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) - ; (DataOnly, UInt) ] - -let variadic_ode_mandatory_arg_types = - [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (AutoDiffable, UReal) - ; (AutoDiffable, UArray UReal) ] - -let variadic_ode_mandatory_fun_args = - [ (UnsizedType.AutoDiffable, UnsizedType.UReal) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] - -let variadic_ode_fun_return_type = UnsizedType.UVector -let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector - -let variadic_dae_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) - ; (DataOnly, UInt) ] - -let variadic_dae_mandatory_arg_types = - [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) - (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) - (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] - -let variadic_dae_mandatory_fun_args = - [ (UnsizedType.AutoDiffable, UnsizedType.UReal) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + Hashtbl.add_multi function_signatures ~key ~data ) ) let reduce_sum_functions = String.Set.of_list ["reduce_sum"; "reduce_sum_static"] let variadic_ode_adjoint_fn = "ode_adjoint_tol_ctl" @@ -2634,46 +2541,7 @@ let is_variadic_ode_fn f = let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"] let dae_tolerances_suffix = "_tol" let is_variadic_dae_fn f = Set.mem variadic_dae_fns f -let variadic_dae_fun_return_type = UnsizedType.UVector -let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector - -let add_variadic_fn name ~return_type ?control_args ~required_fn_rt - ?required_fn_args () = - Hashtbl.add_exn stan_math_variadic_signatures ~key:name - ~data: - (create_variadic_signature ~return_type ?control_args ?required_fn_args - ~required_fn_rt () ) - -let () = - (* DAEs *) - add_variadic_fn "dae" ~return_type:variadic_dae_return_type - ~control_args:variadic_dae_mandatory_arg_types - ~required_fn_args:variadic_dae_mandatory_fun_args - ~required_fn_rt:variadic_dae_fun_return_type () ; - add_variadic_fn "dae_tol" ~return_type:variadic_dae_return_type - ~control_args:(variadic_dae_mandatory_arg_types @ variadic_dae_tol_arg_types) - ~required_fn_args:variadic_dae_mandatory_fun_args - ~required_fn_rt:variadic_dae_fun_return_type () ; - (* non-adjoint ODES - same for all *) - let add_ode name = - add_variadic_fn name ~return_type:variadic_ode_return_type - ~control_args: - ( if String.is_suffix name ~suffix:ode_tolerances_suffix then - variadic_ode_mandatory_arg_types @ variadic_ode_tol_arg_types - else variadic_ode_mandatory_arg_types ) - ~required_fn_rt:variadic_ode_fun_return_type - ~required_fn_args:variadic_ode_mandatory_fun_args () in - Set.iter ~f:add_ode variadic_ode_nonadjoint_fns ; - (* Adjoint ODE function *) - add_variadic_fn variadic_ode_adjoint_fn ~return_type:variadic_ode_return_type - ~control_args: - (variadic_ode_mandatory_arg_types @ variadic_ode_adjoint_ctl_tol_arg_types) - ~required_fn_rt:variadic_ode_fun_return_type - ~required_fn_args:variadic_ode_mandatory_fun_args () - -let%expect_test "dist name suffix" = - dist_name_suffix [] "normal" |> print_endline ; - [%expect {| _lpdf |}] +let is_special_function_name = is_reduce_sum_fn let%expect_test "declarative distributions" = let special_suffixes = @@ -2684,7 +2552,7 @@ let%expect_test "declarative distributions" = distributions |> List.map ~f:(function _, n, _, _ -> n) |> String.Set.of_list in - Hashtbl.keys stan_math_signatures + Hashtbl.keys function_signatures |> List.filter ~f:(fun name -> match Utils.split_distribution_suffix name with | Some (name, suffix) diff --git a/src/stan_math_backend/Stan_math_signatures.mli b/src/stan_math_backend/Stan_math_signatures.mli new file mode 100644 index 0000000000..f3a320d6a7 --- /dev/null +++ b/src/stan_math_backend/Stan_math_signatures.mli @@ -0,0 +1,75 @@ +(** This module stores a table of all signatures from the Stan + math C++ library which are exposed to Stan, and some helper + functions for dealing with those signatures. + + It is used by + {!modules: + + Stan_math_library + } + to create a [Library] instance which can be consumed by the frontend of the compiler. +*) + +open Middle +open Core_kernel + +type fun_arg = UnsizedType.autodifftype * UnsizedType.t +type signature = UnsizedType.returntype * fun_arg list * Mem_pattern.t +type dimensionality +type return_behavior + +type fkind = private + | Lpmf + | Lpdf + | Log + | Rng + | Cdf + | Ccdf + | UnaryVectorized of return_behavior +[@@deriving show {with_path= false}] + +val function_signatures : (string, signature list) Hashtbl.t +(** Mapping from names to signature(s) of functions *) + +val distributions : + (fkind list * string * dimensionality list * Mem_pattern.t) list + +val distribution_families : string list + +val is_stdlib_function_name : string -> bool +(** Equivalent to [Hashtbl.mem function_signatures s]*) + +val get_signatures : string -> signature list +(** Equivalent to [Hashtbl.find_multi function_signatures s]*) + +val get_operator_signatures : Operator.t -> signature list +val get_assignment_operator_signatures : Operator.t -> signature list +val operator_to_function_names : Operator.t -> string list +val string_operator_to_function_name : string -> string + +(** These functions are used by the drivers to display + all available functions and distributions. They are + not part of the Library interface since different drivers + for different backends would likely want different behavior + here *) + +val pretty_print_all_math_sigs : unit Fmt.t +val pretty_print_all_math_distributions : unit Fmt.t + +(** These functions related to variadic functions + are specific to this backend and used + during code generation *) + +(* reduce_sum helpers *) +val is_reduce_sum_fn : string -> bool +val is_special_function_name : string -> bool + +(* variadic ODE helpers *) +val is_variadic_ode_fn : string -> bool +val ode_tolerances_suffix : string +val variadic_ode_adjoint_fn : string +val variadic_ode_nonadjoint_fns : String.Set.t + +(* variadic DAE helpers *) +val is_variadic_dae_fn : string -> bool +val dae_tolerances_suffix : string diff --git a/src/stan_math_backend/dune b/src/stan_math_backend/dune index d8dd800d2a..3878d06d27 100644 --- a/src/stan_math_backend/dune +++ b/src/stan_math_backend/dune @@ -12,4 +12,9 @@ statement_gen) (inline_tests) (preprocess - (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) + (pps + ppx_jane + ppx_deriving.map + ppx_deriving.fold + ppx_deriving.show + ppx_deriving.create))) diff --git a/src/stan_math_backend/stan_math_library/Library.ml b/src/stan_math_backend/stan_math_library/Library.ml new file mode 100644 index 0000000000..7933024d6c --- /dev/null +++ b/src/stan_math_backend/stan_math_library/Library.ml @@ -0,0 +1,62 @@ +(** This is the {e implementation} of the Library virtual module + for the Stan Math C++ backend + *) + +open Core_kernel +open Std_library_utils + +(** Many of the required functions are exposed in the backend specific file, so we include it *) +include Stan_math_backend.Stan_math_signatures + +include Stan_math_extras + +let deprecated_distributions = + List.concat_map distributions ~f:(fun (fnkinds, name, _, _) -> + List.filter_map fnkinds ~f:(function + | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") + | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") + | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") + | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") + | Rng | Log | UnaryVectorized _ -> None ) ) + |> List.map ~f:(fun (x, y) -> + ( x + , { replacement= y + ; version= "2.32.0" + ; extra_message= + "This can be automatically changed using the canonicalize flag \ + for stanc" + ; canonicalize_away= true } ) ) + |> String.Map.of_alist_exn + +let deprecated_functions = + let make extra_message version canonicalize_away replacement = + {extra_message; replacement; version; canonicalize_away} in + let ode = + make + "\n\ + The new interface is slightly different, see: \ + https://mc-stan.org/users/documentation/case-studies/convert_odes.html" + "3.0" false in + let std = + make + "This can be automatically changed using the canonicalize flag for stanc" + "2.32.0" true in + String.Map.of_alist_exn + [ ("multiply_log", std "lmultiply") + ; ("binomial_coefficient_log", std "lchoose") + ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) + ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") + ; ("integrate_ode_bdf", ode "ode_bdf") + ; ("integrate_ode_adams", ode "ode_adams") + ; ("if_else", std "the conditional operator (x ? y : z)") ] + +(** This function is responsible for typechecking varadic function + calls. It needs to live in the Library since this is usually + bespoke per-function. *) +let check_special_fn = check_reduce_sum + +let is_special_function_name = is_reduce_sum_fn + +let special_function_returntype name = + if is_reduce_sum_fn name then Some (Middle.UnsizedType.ReturnType UReal) + else None diff --git a/src/stan_math_backend/stan_math_library/Stan_math_extras.ml b/src/stan_math_backend/stan_math_library/Stan_math_extras.ml new file mode 100644 index 0000000000..2b51ce3039 --- /dev/null +++ b/src/stan_math_backend/stan_math_library/Stan_math_extras.ml @@ -0,0 +1,182 @@ +(** This module serves as the backend-specific portion + of the typechecker. *) + +open Core_kernel +open Core_kernel.Poly +open Ast +open Typechecker_utils +open Middle +open Std_library_utils +open Stan_math_backend.Stan_math_signatures + +let error e = raise (Errors.SemanticError e) +let reduce_sum_allowed_dimensionalities = [1; 2; 3; 4; 5; 6; 7] + +let rec bare_array_type (t, i) = + match i with 0 -> t | j -> UnsizedType.UArray (bare_array_type (t, j - 1)) + +let reduce_sum_slice_types = + let base_slice_type i = + [ bare_array_type (UnsizedType.UReal, i) + ; bare_array_type (UnsizedType.UInt, i) + ; bare_array_type (UnsizedType.UMatrix, i) + ; bare_array_type (UnsizedType.UVector, i) + ; bare_array_type (UnsizedType.URowVector, i) ] in + List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities) + +(* Variadic ODE *) +let variadic_ode_adjoint_ctl_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal) + (* real relative_tolerance_forward *) + ; (DataOnly, UVector) (* vector absolute_tolerance_forward *) + ; (DataOnly, UReal) (* real relative_tolerance_backward *) + ; (DataOnly, UVector) (* real absolute_tolerance_backward *) + ; (DataOnly, UReal) (* real relative_tolerance_quadrature *) + ; (DataOnly, UReal) (* real absolute_tolerance_quadrature *) + ; (DataOnly, UInt) (* int max_num_steps *) + ; (DataOnly, UInt) (* int num_steps_between_checkpoints *) + ; (DataOnly, UInt) (* int interpolation_polynomial *) + ; (DataOnly, UInt) (* int solver_forward *); (DataOnly, UInt) + (* int solver_backward *) ] + +let variadic_ode_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_ode_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (AutoDiffable, UReal) + ; (AutoDiffable, UArray UReal) ] + +let variadic_ode_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_ode_fun_return_type = UnsizedType.UVector +let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector + +let variadic_dae_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_dae_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) + (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) + (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] + +let variadic_dae_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_dae_fun_return_type = UnsizedType.UVector +let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector + +let check_reduce_sum ~is_cond_dist loc current_block tenv id tes = + let basic_mismatch () = + let mandatory_args = + UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in + let mandatory_fun_args = + UnsizedType. + [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in + SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args + mandatory_fun_args UReal (get_arg_types tes) in + let fail () = + let expected_args, err = + basic_mismatch () |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error in + let matching remaining_es fn = + match fn with + | Environment. + { type_= + UnsizedType.UFun + (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as + ftype + ; _ } + when List.mem reduce_sum_slice_types sliced_arg_fun_type ~equal:( = ) -> + let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in + let mandatory_fun_args = + [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args + mandatory_fun_args UReal arg_types + | _ -> basic_mismatch () in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + (* a valid signature exists *) + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error ) + | _ -> fail () + +(** The variadic signatures hash table + These functions cannot be overloaded. +*) +let (variadic_signatures : (string, variadic_signature) Hashtbl.t) = + let t = String.Table.create () in + let add_variadic_fn name ~return_type ?control_args ~required_fn_rt + ?required_fn_args () = + Hashtbl.add_exn t ~key:name + ~data: + (create_variadic_signature ~return_type ?control_args ?required_fn_args + ~required_fn_rt () ) in + let () = + (* DAEs *) + add_variadic_fn "dae" ~return_type:variadic_dae_return_type + ~control_args:variadic_dae_mandatory_arg_types + ~required_fn_args:variadic_dae_mandatory_fun_args + ~required_fn_rt:variadic_dae_fun_return_type () ; + add_variadic_fn "dae_tol" ~return_type:variadic_dae_return_type + ~control_args: + (variadic_dae_mandatory_arg_types @ variadic_dae_tol_arg_types) + ~required_fn_args:variadic_dae_mandatory_fun_args + ~required_fn_rt:variadic_dae_fun_return_type () ; + (* non-adjoint ODES - same for all *) + let add_ode name = + add_variadic_fn name ~return_type:variadic_ode_return_type + ~control_args: + ( if String.is_suffix name ~suffix:ode_tolerances_suffix then + variadic_ode_mandatory_arg_types @ variadic_ode_tol_arg_types + else variadic_ode_mandatory_arg_types ) + ~required_fn_rt:variadic_ode_fun_return_type + ~required_fn_args:variadic_ode_mandatory_fun_args () in + Set.iter ~f:add_ode variadic_ode_nonadjoint_fns ; + (* Adjoint ODE function *) + add_variadic_fn variadic_ode_adjoint_fn + ~return_type:variadic_ode_return_type + ~control_args: + ( variadic_ode_mandatory_arg_types + @ variadic_ode_adjoint_ctl_tol_arg_types ) + ~required_fn_rt:variadic_ode_fun_return_type + ~required_fn_args:variadic_ode_mandatory_fun_args () in + t + +let is_variadic_function_name name = Hashtbl.mem variadic_signatures name + +let is_not_overloadable name = + is_variadic_function_name name || is_special_function_name name + +let variadic_dae_fun_return_type = UnsizedType.UVector +let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector diff --git a/src/stan_math_backend/stan_math_library/dune b/src/stan_math_backend/stan_math_library/dune new file mode 100644 index 0000000000..c7aaa7abe0 --- /dev/null +++ b/src/stan_math_backend/stan_math_library/dune @@ -0,0 +1,8 @@ +(library + (name stan_math_library) + (public_name stanc.stan_math_library) + (libraries core_kernel middle stan_math_backend) + (implements frontend) + (private_modules stan_math_extras) + (preprocess + (pps ppx_jane ppx_deriving.fold ppx_deriving.map ppx_deriving.show))) diff --git a/src/stanc/dune b/src/stanc/dune index c3e8ab39e9..3763e7cdf7 100644 --- a/src/stanc/dune +++ b/src/stanc/dune @@ -1,6 +1,11 @@ (executable (name stanc) - (libraries frontend middle stan_math_backend analysis_and_optimization) + (libraries + frontend + middle + stan_math_backend + analysis_and_optimization + stan_math_library) (modules Stanc) (public_name stanc) (preprocess diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index 4fa81128d5..19362a57ec 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -3,8 +3,8 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Stan_math_backend open Analysis_and_optimization +open Stan_math_backend open Middle (** The main program. *) @@ -191,16 +191,12 @@ let options = ) ; ( "--include-paths" , Arg.String - (fun str -> - Preprocessor.include_paths := String.split_on_chars ~on:[','] str ) + (fun str -> Preprocessor.include_paths := String.split ~on:',' str) , " Takes a comma-separated list of directories that may contain a file \ in an #include directive (default = \"\")" ) ; ( "--include_paths" , Arg.String - (fun str -> - Preprocessor.include_paths := - !Preprocessor.include_paths @ String.split_on_chars ~on:[','] str - ) + (fun str -> Preprocessor.include_paths := String.split ~on:',' str) , " Deprecated. Same as --include-paths. Will be removed in Stan 2.32.0" ) ; ( "--use-opencl" diff --git a/src/stancjs/dune b/src/stancjs/dune index d983673834..e8bafe4d8b 100644 --- a/src/stancjs/dune +++ b/src/stancjs/dune @@ -5,7 +5,8 @@ frontend middle analysis_and_optimization - stan_math_backend) + stan_math_backend + stan_math_library) (preprocess (pps js_of_ocaml-ppx ppx_jane)) (modes js)) diff --git a/src/stancjs/stancjs.ml b/src/stancjs/stancjs.ml index 84f0ab6168..c5becf8f86 100644 --- a/src/stancjs/stancjs.ml +++ b/src/stancjs/stancjs.ml @@ -1,8 +1,8 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Stan_math_backend open Analysis_and_optimization +open Stan_math_backend open Middle open Js_of_ocaml diff --git a/test/integration/bad/stanc.expected b/test/integration/bad/stanc.expected index 4e7be8d0f8..7f7aaafafa 100644 --- a/test/integration/bad/stanc.expected +++ b/test/integration/bad/stanc.expected @@ -1638,7 +1638,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (real, vector) => vector @@ -1670,7 +1669,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (real, vector) => vector diff --git a/test/integration/good/warning/pretty.expected b/test/integration/good/warning/pretty.expected index f23ca7735b..4a39e03849 100644 --- a/test/integration/good/warning/pretty.expected +++ b/test/integration/good/warning/pretty.expected @@ -10,9 +10,6 @@ model { y ~ normal(mu, 1); } -Warning in 'abs-deprecate.stan', line 3, column 7: fabs is deprecated and - will be removed in Stan 2.33.0. Use abs instead. This can be - automatically changed using the canonicalize flag for stanc $ ../../../../../install/default/bin/stanc --auto-format binomial_coefficient_log.stan data { int d_int; @@ -1750,58 +1747,58 @@ model { y_p ~ normal(0, 1); } -Warning in 'if_else.stan', line 9, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 10, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 11, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 12, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 21, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 22, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 23, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 24, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 26, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 27, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 28, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 29, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 30, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc +Warning in 'if_else.stan', line 9, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 10, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 11, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 12, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 21, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 22, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 23, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 24, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 26, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 27, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 28, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 29, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 30, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc $ ../../../../../install/default/bin/stanc --auto-format increment_log_prob.stan transformed data { int n; diff --git a/test/unit/Debug_data_generation_tests.ml b/test/unit/Debug_data_generation_tests.ml index 6ab94c65a8..f3d97521e7 100644 --- a/test/unit/Debug_data_generation_tests.ml +++ b/test/unit/Debug_data_generation_tests.ml @@ -1,9 +1,4 @@ -open Analysis_and_optimization open Core_kernel -open Frontend -open Debug_data_generation - -let print_data_prog ast = print_data_prog (Ast_to_Mir.gather_data ast) let%expect_test "whole program data generation check" = let ast = @@ -17,7 +12,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -48,7 +43,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -97,7 +92,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -134,7 +129,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -256,7 +251,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -515,7 +510,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -647,7 +642,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -666,7 +661,7 @@ let%expect_test "Complex numbers program" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| diff --git a/test/unit/Test_utils.ml b/test/unit/Test_utils.ml index 5522236e61..13cd4a4435 100644 --- a/test/unit/Test_utils.ml +++ b/test/unit/Test_utils.ml @@ -16,3 +16,7 @@ let typed_ast_of_string_exn s = |> Result.ok_or_failwith |> fst let mir_of_string s = typed_ast_of_string_exn s |> Ast_to_Mir.trans_prog "" + +let print_data_prog ast = + Analysis_and_optimization.Debug_data_generation.print_data_prog + (Ast_to_Mir.gather_data ast) diff --git a/test/unit/dune b/test/unit/dune index 37cffb6770..4fed56676c 100644 --- a/test/unit/dune +++ b/test/unit/dune @@ -4,6 +4,7 @@ core_kernel frontend stan_math_backend + stan_math_library middle analysis_and_optimization) (inline_tests)