Skip to content

Commit 29dda9c

Browse files
authored
Merge pull request #1115 from WardBrian/dedupe-unsized-signaturemismatch
Deduplicate logic in UnsizedType and Signaturemismatch
2 parents ba71588 + 2302027 commit 29dda9c

31 files changed

Lines changed: 383 additions & 303 deletions

src/frontend/Debug_data_generation.ml renamed to src/analysis_and_optimization/Debug_data_generation.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
open Core_kernel
22
open Middle
3+
open Frontend
34
open Ast
45

56
let rec transpose = function
@@ -32,8 +33,8 @@ let rec vect_to_mat l m =
3233
let unwrap_num_exn m e =
3334
let e = Ast_to_Mir.trans_expr e in
3435
let m = Map.Poly.map m ~f:Ast_to_Mir.trans_expr in
35-
let e = Analysis_and_optimization.Mir_utils.subst_expr m e in
36-
let e = Analysis_and_optimization.Partial_evaluator.eval_expr e in
36+
let e = Mir_utils.subst_expr m e in
37+
let e = Partial_evaluator.eval_expr e in
3738
let rec strip_promotions (e : Middle.Expr.Typed.t) =
3839
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
3940
in
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
val print_data_prog : Frontend.Ast.typed_program -> string

src/analysis_and_optimization/Mem_pattern.ml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
open Core_kernel
2+
open Core_kernel.Poly
23
open Middle
34
open Middle.Expr
45

@@ -111,10 +112,33 @@ let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t)
111112
| None -> false )
112113
| _ -> false )
113114

115+
let query_stan_math_mem_pattern_support (name : string)
116+
(args : (UnsizedType.autodifftype * UnsizedType.t) list) =
117+
let open Stan_math_signatures in
118+
match name with
119+
| x when is_reduce_sum_fn x -> false
120+
| x when is_variadic_ode_fn x -> false
121+
| x when is_variadic_dae_fn x -> false
122+
| _ ->
123+
let name =
124+
string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name)
125+
in
126+
let namematches = Hashtbl.find_multi stan_math_signatures name in
127+
let filteredmatches =
128+
List.filter
129+
~f:(fun x ->
130+
Frontend.SignatureMismatch.check_compatible_arguments_mod_conv
131+
(snd3 x) args
132+
|> Result.is_ok )
133+
namematches in
134+
let is_soa ((_ : UnsizedType.returntype), _, mem) =
135+
mem = Common.Helpers.SoA in
136+
List.exists ~f:is_soa filteredmatches
137+
114138
(*Validate whether a function can support SoA matrices*)
115139
let is_fun_soa_supported name exprs =
116140
let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in
117-
Stan_math_signatures.query_stan_math_mem_pattern_support name fun_args
141+
query_stan_math_mem_pattern_support name fun_args
118142

119143
(**
120144
* Query to find the initial set of objects that cannot be SoA.

src/analysis_and_optimization/Partial_evaluator.ml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ open Middle
66

77
exception Rejected of Location_span.t * string
88

9-
let is_int i Expr.Fixed.{pattern; _} =
10-
let nums = List.map ~f:(fun s -> string_of_int i ^ s) [""; "."; ".0"] in
9+
let rec is_int query Expr.Fixed.{pattern; _} =
1110
match pattern with
12-
| (Lit (Int, i) | Lit (Real, i)) when List.mem nums i ~equal:String.equal ->
13-
true
11+
| Lit (Int, i) | Lit (Real, i) -> float_of_string i = float_of_int query
12+
| Promotion (e, _, _) -> is_int query e
1413
| _ -> false
1514

1615
let apply_prefix_operator_int (op : string) i =
@@ -107,10 +106,11 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
107106
Operator.of_string_opt name
108107
|> Option.value_map
109108
~f:(fun op ->
110-
Stan_math_signatures.operator_stan_math_return_type op
111-
argument_types )
109+
Frontend.Typechecker.operator_stan_math_return_type op
110+
argument_types
111+
|> Option.map ~f:fst )
112112
~default:
113-
(Stan_math_signatures.stan_math_returntype name
113+
(Frontend.Typechecker.stan_math_return_type name
114114
argument_types ) in
115115
let try_partially_evaluate_stanlib e =
116116
Expr.Fixed.Pattern.(

src/analysis_and_optimization/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
(library
22
(name analysis_and_optimization)
33
(public_name stanc.analysis)
4-
(libraries core_kernel str fmt common middle)
4+
(libraries core_kernel str fmt common middle frontend)
55
(inline_tests)
66
;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation
77
(modules_without_implementation monotone_framework_sigs)

src/frontend/Ast_to_Mir.ml

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,6 @@ open Core_kernel
22
open Core_kernel.Poly
33
open Middle
44

5-
(* XXX fix exn *)
6-
let unwrap_return_exn = function
7-
| Some (UnsizedType.ReturnType ut) -> ut
8-
| x ->
9-
Common.FatalError.fatal_error_msg
10-
[%message "Unexpected return type " (x : UnsizedType.returntype option)]
11-
125
let trans_fn_kind kind name =
136
let fname = Utils.stdlib_distribution_name name in
147
match kind with
@@ -36,14 +29,9 @@ let%expect_test "format_number1" =
3629
format_number ".123_456" |> print_endline ;
3730
[%expect ".123456"]
3831

39-
let rec op_to_funapp op args =
40-
let argtypes =
41-
List.map ~f:(fun x -> (x.Ast.emeta.Ast.ad_level, x.emeta.type_)) args in
42-
let type_ =
43-
Stan_math_signatures.operator_stan_math_return_type op argtypes
44-
|> unwrap_return_exn
45-
and loc = Ast.expr_loc_lub args
46-
and adlevel = Ast.expr_ad_lub args in
32+
let rec op_to_funapp op args type_ =
33+
let loc = Ast.expr_loc_lub args in
34+
let adlevel = Ast.expr_ad_lub args in
4735
Expr.
4836
{ Fixed.pattern=
4937
FunApp (StanLib (Operator.to_string op, FnPlain, AoS), trans_exprs args)
@@ -61,8 +49,8 @@ and trans_expr {Ast.expr; Ast.emeta} =
6149
| Ast.Paren x -> trans_expr x
6250
| BinOp (lhs, And, rhs) -> EAnd (trans_expr lhs, trans_expr rhs) |> ewrap
6351
| BinOp (lhs, Or, rhs) -> EOr (trans_expr lhs, trans_expr rhs) |> ewrap
64-
| BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs]
65-
| PrefixOp (op, e) | Ast.PostfixOp (e, op) -> op_to_funapp op [e]
52+
| BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs] emeta.type_
53+
| PrefixOp (op, e) | Ast.PostfixOp (e, op) -> op_to_funapp op [e] emeta.type_
6654
| Ast.TernaryIf (cond, ifb, elseb) ->
6755
Expr.Fixed.Pattern.TernaryIf
6856
(trans_expr cond, trans_expr ifb, trans_expr elseb)
@@ -127,12 +115,12 @@ let truncate_dist ud_dists (id : Ast.identifier) ast_obs ast_args t =
127115
{ Stmt.Fixed.meta= smeta
128116
; pattern=
129117
IfElse
130-
( op_to_funapp cond_op [ast_obs; x]
118+
( op_to_funapp cond_op [ast_obs; x] UInt
131119
, {Stmt.Fixed.meta= smeta; pattern= TargetPE neg_inf}
132120
, Some y ) } in
133121
let targetme loc e =
134-
{Stmt.Fixed.meta= loc; pattern= TargetPE (op_to_funapp Operator.PMinus [e])}
135-
in
122+
{ Stmt.Fixed.meta= loc
123+
; pattern= TargetPE (op_to_funapp Operator.PMinus [e] e.emeta.type_) } in
136124
let funapp meta kind name args =
137125
{ Ast.emeta= meta
138126
; expr= Ast.FunApp (kind, {name; id_loc= Location_span.empty}, args) } in
@@ -418,7 +406,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
418406
let rhs =
419407
match assign_op with
420408
| Ast.Assign | Ast.ArrowAssign -> trans_expr assign_rhs
421-
| Ast.OperatorAssign op -> op_to_funapp op [assignee; assign_rhs] in
409+
| Ast.OperatorAssign op ->
410+
op_to_funapp op [assignee; assign_rhs] assignee.emeta.type_ in
422411
Assignment
423412
( ( assign_identifier.Ast.name
424413
, id_type_

src/frontend/Canonicalize.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ let rec no_parens {expr; emeta} =
163163

164164
and keep_parens {expr; emeta} =
165165
match expr with
166-
| Paren {expr= Paren e; _} -> keep_parens e
166+
| Promotion (e, ut, ad) -> {expr= Promotion (keep_parens e, ut, ad); emeta}
167+
| Paren ({expr= Paren _; _} as e) -> keep_parens e
167168
| Paren ({expr= BinOp _; _} as e)
168169
|Paren ({expr= PrefixOp _; _} as e)
169170
|Paren ({expr= PostfixOp _; _} as e)

src/frontend/Debug_data_generation.mli

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/frontend/Environment.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ type info =
2727

2828
type t = info list String.Map.t
2929

30-
let create () =
30+
let stan_math_environment =
3131
let functions =
3232
Hashtbl.to_alist Stan_math_signatures.stan_math_signatures
3333
|> List.map ~f:(fun (key, values) ->

src/frontend/Environment.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ type info =
2929

3030
type t
3131

32-
val create : unit -> t
33-
(** Return a new type environment which contains the Stan math library functions
32+
val stan_math_environment : t
33+
(** A type environment which contains the Stan math library functions
3434
*)
3535

3636
val find : t -> string -> info list

0 commit comments

Comments
 (0)