Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
open Core_kernel
open Middle
open Frontend
open Ast

let rec transpose = function
Expand Down Expand Up @@ -32,8 +33,8 @@ let rec vect_to_mat l m =
let unwrap_num_exn m e =
let e = Ast_to_Mir.trans_expr e in
let m = Map.Poly.map m ~f:Ast_to_Mir.trans_expr in
let e = Analysis_and_optimization.Mir_utils.subst_expr m e in
let e = Analysis_and_optimization.Partial_evaluator.eval_expr e in
let e = Mir_utils.subst_expr m e in
let e = Partial_evaluator.eval_expr e in
let rec strip_promotions (e : Middle.Expr.Typed.t) =
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
in
Expand Down
1 change: 1 addition & 0 deletions src/analysis_and_optimization/Debug_data_generation.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
val print_data_prog : Frontend.Ast.typed_program -> string
26 changes: 25 additions & 1 deletion src/analysis_and_optimization/Mem_pattern.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
open Core_kernel
open Core_kernel.Poly
open Middle
open Middle.Expr

Expand Down Expand Up @@ -111,10 +112,33 @@ let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t)
| None -> false )
| _ -> false )

let query_stan_math_mem_pattern_support (name : string)
(args : (UnsizedType.autodifftype * UnsizedType.t) list) =
let open Stan_math_signatures in
match name with
| x when is_reduce_sum_fn x -> false
| x when is_variadic_ode_fn x -> false
| x when is_variadic_dae_fn x -> false
| _ ->
let name =
string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name)
in
let namematches = Hashtbl.find_multi stan_math_signatures name in
let filteredmatches =
List.filter
~f:(fun x ->
Frontend.SignatureMismatch.check_compatible_arguments_mod_conv
(snd3 x) args
|> Result.is_ok )
namematches in
let is_soa ((_ : UnsizedType.returntype), _, mem) =
mem = Common.Helpers.SoA in
List.exists ~f:is_soa filteredmatches

(*Validate whether a function can support SoA matrices*)
let is_fun_soa_supported name exprs =
let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in
Stan_math_signatures.query_stan_math_mem_pattern_support name fun_args
query_stan_math_mem_pattern_support name fun_args

(**
* Query to find the initial set of objects that cannot be SoA.
Expand Down
14 changes: 7 additions & 7 deletions src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ open Middle

exception Rejected of Location_span.t * string

let is_int i Expr.Fixed.{pattern; _} =
let nums = List.map ~f:(fun s -> string_of_int i ^ s) [""; "."; ".0"] in
let rec is_int query Expr.Fixed.{pattern; _} =
match pattern with
| (Lit (Int, i) | Lit (Real, i)) when List.mem nums i ~equal:String.equal ->
true
| Lit (Int, i) | Lit (Real, i) -> float_of_string i = float_of_int query
| Promotion (e, _, _) -> is_int query e
| _ -> false

let apply_prefix_operator_int (op : string) i =
Expand Down Expand Up @@ -107,10 +106,11 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
Operator.of_string_opt name
|> Option.value_map
~f:(fun op ->
Stan_math_signatures.operator_stan_math_return_type op
argument_types )
Frontend.Typechecker.operator_stan_math_return_type op
argument_types
|> Option.map ~f:fst )
~default:
(Stan_math_signatures.stan_math_returntype name
(Frontend.Typechecker.stan_math_return_type name
argument_types ) in
let try_partially_evaluate_stanlib e =
Expr.Fixed.Pattern.(
Expand Down
2 changes: 1 addition & 1 deletion src/analysis_and_optimization/dune
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
(library
(name analysis_and_optimization)
(public_name stanc.analysis)
(libraries core_kernel str fmt common middle)
(libraries core_kernel str fmt common middle frontend)
Comment thread
nhuurre marked this conversation as resolved.
(inline_tests)
;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation
(modules_without_implementation monotone_framework_sigs)
Expand Down
31 changes: 10 additions & 21 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@ open Core_kernel
open Core_kernel.Poly
open Middle

(* XXX fix exn *)
let unwrap_return_exn = function
| Some (UnsizedType.ReturnType ut) -> ut
| x ->
Common.FatalError.fatal_error_msg
[%message "Unexpected return type " (x : UnsizedType.returntype option)]

let trans_fn_kind kind name =
let fname = Utils.stdlib_distribution_name name in
match kind with
Expand Down Expand Up @@ -36,14 +29,9 @@ let%expect_test "format_number1" =
format_number ".123_456" |> print_endline ;
[%expect ".123456"]

let rec op_to_funapp op args =
let argtypes =
List.map ~f:(fun x -> (x.Ast.emeta.Ast.ad_level, x.emeta.type_)) args in
let type_ =
Stan_math_signatures.operator_stan_math_return_type op argtypes
|> unwrap_return_exn
and loc = Ast.expr_loc_lub args
and adlevel = Ast.expr_ad_lub args in
let rec op_to_funapp op args type_ =
let loc = Ast.expr_loc_lub args in
let adlevel = Ast.expr_ad_lub args in
Expr.
{ Fixed.pattern=
FunApp (StanLib (Operator.to_string op, FnPlain, AoS), trans_exprs args)
Expand All @@ -61,8 +49,8 @@ and trans_expr {Ast.expr; Ast.emeta} =
| Ast.Paren x -> trans_expr x
| BinOp (lhs, And, rhs) -> EAnd (trans_expr lhs, trans_expr rhs) |> ewrap
| BinOp (lhs, Or, rhs) -> EOr (trans_expr lhs, trans_expr rhs) |> ewrap
| BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs]
| PrefixOp (op, e) | Ast.PostfixOp (e, op) -> op_to_funapp op [e]
| BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs] emeta.type_
| PrefixOp (op, e) | Ast.PostfixOp (e, op) -> op_to_funapp op [e] emeta.type_
| Ast.TernaryIf (cond, ifb, elseb) ->
Expr.Fixed.Pattern.TernaryIf
(trans_expr cond, trans_expr ifb, trans_expr elseb)
Expand Down Expand Up @@ -127,12 +115,12 @@ let truncate_dist ud_dists (id : Ast.identifier) ast_obs ast_args t =
{ Stmt.Fixed.meta= smeta
; pattern=
IfElse
( op_to_funapp cond_op [ast_obs; x]
( op_to_funapp cond_op [ast_obs; x] UInt
, {Stmt.Fixed.meta= smeta; pattern= TargetPE neg_inf}
, Some y ) } in
let targetme loc e =
{Stmt.Fixed.meta= loc; pattern= TargetPE (op_to_funapp Operator.PMinus [e])}
in
{ Stmt.Fixed.meta= loc
; pattern= TargetPE (op_to_funapp Operator.PMinus [e] e.emeta.type_) } in
let funapp meta kind name args =
{ Ast.emeta= meta
; expr= Ast.FunApp (kind, {name; id_loc= Location_span.empty}, args) } in
Expand Down Expand Up @@ -418,7 +406,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
let rhs =
match assign_op with
| Ast.Assign | Ast.ArrowAssign -> trans_expr assign_rhs
| Ast.OperatorAssign op -> op_to_funapp op [assignee; assign_rhs] in
| Ast.OperatorAssign op ->
op_to_funapp op [assignee; assign_rhs] assignee.emeta.type_ in
Assignment
( ( assign_identifier.Ast.name
, id_type_
Expand Down
3 changes: 2 additions & 1 deletion src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ let rec no_parens {expr; emeta} =

and keep_parens {expr; emeta} =
match expr with
| Paren {expr= Paren e; _} -> keep_parens e
| Promotion (e, ut, ad) -> {expr= Promotion (keep_parens e, ut, ad); emeta}
| Paren ({expr= Paren _; _} as e) -> keep_parens e
| Paren ({expr= BinOp _; _} as e)
|Paren ({expr= PrefixOp _; _} as e)
|Paren ({expr= PostfixOp _; _} as e)
Expand Down
1 change: 0 additions & 1 deletion src/frontend/Debug_data_generation.mli

This file was deleted.

2 changes: 1 addition & 1 deletion src/frontend/Environment.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type info =

type t = info list String.Map.t

let create () =
let stan_math_environment =
let functions =
Hashtbl.to_alist Stan_math_signatures.stan_math_signatures
|> List.map ~f:(fun (key, values) ->
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/Environment.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ type info =

type t

val create : unit -> t
(** Return a new type environment which contains the Stan math library functions
val stan_math_environment : t
(** A type environment which contains the Stan math library functions
*)

val find : t -> string -> info list
Expand Down
20 changes: 13 additions & 7 deletions src/frontend/SignatureMismatch.ml
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,8 @@ let unique_minimum_promotion promotion_options =
| [] -> Error None )
| None -> Error None

let matching_function env name args =
let find_compatible_rt function_types args =
(* NB: Variadic arguments are special-cased in the typechecker and not handled here *)
let name = Utils.stdlib_distribution_name name in
let function_types =
Environment.find env name
|> List.filter_map ~f:extract_function_types
|> List.sort ~compare:(fun (ret1, _, _, _) (ret2, _, _, _) ->
UnsizedType.compare_returntype ret1 ret2 ) in
let matches, errors =
List.partition_map function_types
~f:(fun (rt, tys, funkind_constructor, _) ->
Expand All @@ -225,6 +219,18 @@ let matching_function env name args =
let errors, omitted = List.split_n errors max_n_errors in
SignatureErrors (errors, not (List.is_empty omitted))

let matching_function env name args =
let name = Utils.stdlib_distribution_name name in
let function_types =
Environment.find env name
|> List.filter_map ~f:extract_function_types
|> List.sort ~compare:(fun (ret1, _, _, _) (ret2, _, _, _) ->
UnsizedType.compare_returntype ret1 ret2 ) in
find_compatible_rt function_types args

let matching_stanlib_function =
matching_function Environment.stan_math_environment

let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
fun_return args =
let minimal_func_type =
Expand Down
6 changes: 6 additions & 0 deletions src/frontend/SignatureMismatch.mli
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ val matching_function :
Requires a unique minimum option under type promotion
*)

val matching_stanlib_function :
string -> (UnsizedType.autodifftype * UnsizedType.t) list -> match_result
(** Same as [matching_function] but requires specifically that the function
be from StanMath (uses [Environment.stan_math_environment])
*)

val check_variadic_args :
bool
-> (UnsizedType.autodifftype * UnsizedType.t) list
Expand Down
Loading