Skip to content

Commit 872dca4

Browse files
committed
Wire binops up to promotion
1 parent 4dcff5b commit 872dca4

12 files changed

Lines changed: 253 additions & 157 deletions

File tree

src/analysis_and_optimization/Partial_evaluator.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ open Middle
66

77
exception Rejected of Location_span.t * string
88

9-
let is_int i Expr.Fixed.{pattern; _} =
9+
let rec is_int i Expr.Fixed.{pattern; _} =
1010
let nums = List.map ~f:(fun s -> string_of_int i ^ s) [""; "."; ".0"] in
1111
match pattern with
1212
| (Lit (Int, i) | Lit (Real, i)) when List.mem nums i ~equal:String.equal ->
1313
true
14+
| Promotion (e, _, _) -> is_int i e
1415
| _ -> false
1516

1617
let apply_prefix_operator_int (op : string) i =
@@ -108,7 +109,8 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
108109
|> Option.value_map
109110
~f:(fun op ->
110111
Frontend.Typechecker.operator_stan_math_return_type op
111-
argument_types )
112+
argument_types
113+
|> Option.map ~f:fst )
112114
~default:
113115
(Frontend.Typechecker.stan_math_return_type name
114116
argument_types ) in

src/frontend/Canonicalize.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ let rec no_parens {expr; emeta} =
146146
| Variable _ | IntNumeral _ | RealNumeral _ | ImagNumeral _ | GetLP
147147
|GetTarget ->
148148
{expr; emeta}
149-
| TernaryIf _ | BinOp _ | PrefixOp _ | PostfixOp _ ->
149+
| TernaryIf _ | BinOp _ | PrefixOp _ | PostfixOp _ | Promotion _ ->
150150
{expr= map_expression keep_parens ident expr; emeta}
151151
| Indexed (e, l) ->
152152
{ expr=
@@ -158,7 +158,7 @@ let rec no_parens {expr; emeta} =
158158
| i -> map_index keep_parens i )
159159
l )
160160
; emeta }
161-
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ ->
161+
| ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ ->
162162
{expr= map_expression no_parens ident expr; emeta}
163163

164164
and keep_parens {expr; emeta} =

src/frontend/Typechecker.ml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,15 @@ let stan_math_return_type name arg_tys =
205205
let operator_stan_math_return_type op arg_tys =
206206
match (op, arg_tys) with
207207
| Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] ->
208-
Some UnsizedType.(ReturnType UInt)
208+
Some (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion])
209209
| IntDivide, _ -> None
210210
| _ ->
211211
Stan_math_signatures.operator_to_stan_math_fns op
212212
|> List.filter_map ~f:(fun name ->
213213
SignatureMismatch.matching_stanlib_function name arg_tys
214-
|> match_to_rt_option )
214+
|> function
215+
| SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p)
216+
| _ -> None )
215217
|> List.hd
216218

217219
let assignmentoperator_stan_math_return_type assop arg_tys =
@@ -220,7 +222,7 @@ let assignmentoperator_stan_math_return_type assop arg_tys =
220222
SignatureMismatch.matching_stanlib_function "divide" arg_tys
221223
|> match_to_rt_option
222224
| Plus | Minus | Times | EltTimes | EltDivide ->
223-
operator_stan_math_return_type assop arg_tys
225+
operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst
224226
| _ -> None )
225227
|> Option.bind ~f:(function
226228
| ReturnType rtype
@@ -234,9 +236,9 @@ let assignmentoperator_stan_math_return_type assop arg_tys =
234236
let check_binop loc op le re =
235237
let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in
236238
match rt with
237-
| Some (ReturnType type_) ->
239+
| Some (ReturnType type_, [p1; p2]) ->
238240
mk_typed_expression
239-
~expr:(BinOp (le, op, re))
241+
~expr:(BinOp (Promotion.promote le p1, op, Promotion.promote re p2))
240242
~ad_level:(expr_ad_lub [le; re])
241243
~type_ ~loc
242244
| _ ->
@@ -246,7 +248,7 @@ let check_binop loc op le re =
246248
let check_prefixop loc op te =
247249
let rt = operator_stan_math_return_type op [arg_type te] in
248250
match rt with
249-
| Some (ReturnType type_) ->
251+
| Some (ReturnType type_, _) ->
250252
mk_typed_expression
251253
~expr:(PrefixOp (op, te))
252254
~ad_level:(expr_ad_lub [te])
@@ -256,7 +258,7 @@ let check_prefixop loc op te =
256258
let check_postfixop loc op te =
257259
let rt = operator_stan_math_return_type op [arg_type te] in
258260
match rt with
259-
| Some (ReturnType type_) ->
261+
| Some (ReturnType type_, _) ->
260262
mk_typed_expression
261263
~expr:(PostfixOp (te, op))
262264
~ad_level:(expr_ad_lub [te])

src/frontend/Typechecker.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ val check_program :
3232
val operator_stan_math_return_type :
3333
Middle.Operator.t
3434
-> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list
35-
-> Middle.UnsizedType.returntype option
35+
-> (Middle.UnsizedType.returntype * Promotion.t list) option
3636

3737
val stan_math_return_type :
3838
string

src/stan_math_backend/Expression_gen.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ and pp_scalar_binary ppf op fn es =
229229
else pp_binary_f ppf fn es
230230

231231
and gen_operator_app op ppf es =
232+
let remove_basic_promotion (e : 'a Expr.Fixed.t) =
233+
match e.pattern with Promotion (e, UReal, _) -> e | _ -> e in
234+
let es = List.map ~f:remove_basic_promotion es in
232235
match op with
233236
| Operator.Plus -> pp_scalar_binary ppf "+" "stan::math::add" es
234237
| PMinus ->

src/stan_math_backend/Statement_gen.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ let pp_bool_expr ppf expr =
305305
| _ -> pp_expr ppf expr
306306

307307
let rec pp_statement (ppf : Format.formatter) Stmt.Fixed.{pattern; meta} =
308-
(* ({stmt; smeta} : (mtype_loc_ad, 'a) stmt_with) = *)
309308
let remove_promotions (e : 'a Expr.Fixed.t) =
310309
(* assignment handles one level of promotion internally, don't do it twice *)
311310
match e.pattern with Promotion (e, _, _) -> e | _ -> e in

test/integration/good/code-gen/cpp.expected

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,123 +2570,124 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
25702570
current_statement__ = 125;
25712571
gq_complex = (z + y);
25722572
current_statement__ = 126;
2573-
gq_complex = (z + gq_r);
2573+
gq_complex = (z + stan::math::to_complex(gq_r, 0));
25742574
current_statement__ = 127;
2575-
gq_complex = (gq_r + z);
2575+
gq_complex = (stan::math::to_complex(gq_r, 0) + z);
25762576
current_statement__ = 128;
2577-
gq_complex = (z + gq_i);
2577+
gq_complex = (z + stan::math::to_complex(gq_i, 0));
25782578
current_statement__ = 129;
2579-
gq_complex = (gq_i + z);
2579+
gq_complex = (stan::math::to_complex(gq_i, 0) + z);
25802580
current_statement__ = 130;
25812581
gq_complex = (d_complex + p_complex);
25822582
current_statement__ = 131;
2583-
gq_complex = (d_complex + d_r);
2583+
gq_complex = (d_complex + stan::math::to_complex(d_r, 0));
25842584
current_statement__ = 132;
2585-
gq_complex = (d_complex + p_r);
2585+
gq_complex = (d_complex + stan::math::to_complex(p_r, 0));
25862586
current_statement__ = 133;
2587-
gq_complex = (d_r + p_complex);
2587+
gq_complex = (stan::math::to_complex(d_r, 0) + p_complex);
25882588
current_statement__ = 134;
2589-
gq_complex = (p_complex + p_r);
2589+
gq_complex = (p_complex + stan::math::to_complex(p_r, 0));
25902590
current_statement__ = 135;
2591-
gq_complex = (d_complex + gq_i);
2591+
gq_complex = (d_complex + stan::math::to_complex(gq_i, 0));
25922592
current_statement__ = 136;
2593-
gq_complex = (gq_i + p_complex);
2593+
gq_complex = (stan::math::to_complex(gq_i, 0) + p_complex);
25942594
current_statement__ = 137;
25952595
gq_complex = (z - y);
25962596
current_statement__ = 138;
2597-
gq_complex = (z - gq_r);
2597+
gq_complex = (z - stan::math::to_complex(gq_r, 0));
25982598
current_statement__ = 139;
2599-
gq_complex = (gq_r - z);
2599+
gq_complex = (stan::math::to_complex(gq_r, 0) - z);
26002600
current_statement__ = 140;
2601-
gq_complex = (z - gq_i);
2601+
gq_complex = (z - stan::math::to_complex(gq_i, 0));
26022602
current_statement__ = 141;
2603-
gq_complex = (gq_i - z);
2603+
gq_complex = (stan::math::to_complex(gq_i, 0) - z);
26042604
current_statement__ = 142;
26052605
gq_complex = (d_complex - p_complex);
26062606
current_statement__ = 143;
2607-
gq_complex = (d_complex - d_r);
2607+
gq_complex = (d_complex - stan::math::to_complex(d_r, 0));
26082608
current_statement__ = 144;
2609-
gq_complex = (d_complex - p_r);
2609+
gq_complex = (d_complex - stan::math::to_complex(p_r, 0));
26102610
current_statement__ = 145;
2611-
gq_complex = (d_r - p_complex);
2611+
gq_complex = (stan::math::to_complex(d_r, 0) - p_complex);
26122612
current_statement__ = 146;
2613-
gq_complex = (p_complex - p_r);
2613+
gq_complex = (p_complex - stan::math::to_complex(p_r, 0));
26142614
current_statement__ = 147;
2615-
gq_complex = (d_complex - gq_i);
2615+
gq_complex = (d_complex - stan::math::to_complex(gq_i, 0));
26162616
current_statement__ = 148;
2617-
gq_complex = (gq_i - p_complex);
2617+
gq_complex = (stan::math::to_complex(gq_i, 0) - p_complex);
26182618
current_statement__ = 149;
26192619
gq_complex = (z * y);
26202620
current_statement__ = 150;
2621-
gq_complex = (z * gq_r);
2621+
gq_complex = (z * stan::math::to_complex(gq_r, 0));
26222622
current_statement__ = 151;
2623-
gq_complex = (gq_r * z);
2623+
gq_complex = (stan::math::to_complex(gq_r, 0) * z);
26242624
current_statement__ = 152;
2625-
gq_complex = (z * gq_i);
2625+
gq_complex = (z * stan::math::to_complex(gq_i, 0));
26262626
current_statement__ = 153;
2627-
gq_complex = (gq_i * z);
2627+
gq_complex = (stan::math::to_complex(gq_i, 0) * z);
26282628
current_statement__ = 154;
26292629
gq_complex = (d_complex * p_complex);
26302630
current_statement__ = 155;
2631-
gq_complex = (d_complex * d_r);
2631+
gq_complex = (d_complex * stan::math::to_complex(d_r, 0));
26322632
current_statement__ = 156;
2633-
gq_complex = (d_complex * p_r);
2633+
gq_complex = (d_complex * stan::math::to_complex(p_r, 0));
26342634
current_statement__ = 157;
2635-
gq_complex = (d_r * p_complex);
2635+
gq_complex = (stan::math::to_complex(d_r, 0) * p_complex);
26362636
current_statement__ = 158;
2637-
gq_complex = (p_complex * p_r);
2637+
gq_complex = (p_complex * stan::math::to_complex(p_r, 0));
26382638
current_statement__ = 159;
2639-
gq_complex = (d_complex * gq_i);
2639+
gq_complex = (d_complex * stan::math::to_complex(gq_i, 0));
26402640
current_statement__ = 160;
2641-
gq_complex = (gq_i * p_complex);
2641+
gq_complex = (stan::math::to_complex(gq_i, 0) * p_complex);
26422642
current_statement__ = 161;
26432643
gq_complex = (z / y);
26442644
current_statement__ = 162;
2645-
gq_complex = (z / gq_r);
2645+
gq_complex = (z / stan::math::to_complex(gq_r, 0));
26462646
current_statement__ = 163;
2647-
gq_complex = (gq_r / z);
2647+
gq_complex = (stan::math::to_complex(gq_r, 0) / z);
26482648
current_statement__ = 164;
2649-
gq_complex = (z / gq_i);
2649+
gq_complex = (z / stan::math::to_complex(gq_i, 0));
26502650
current_statement__ = 165;
2651-
gq_complex = (gq_i / z);
2651+
gq_complex = (stan::math::to_complex(gq_i, 0) / z);
26522652
current_statement__ = 166;
26532653
gq_complex = (d_complex / p_complex);
26542654
current_statement__ = 167;
2655-
gq_complex = (d_complex / d_r);
2655+
gq_complex = (d_complex / stan::math::to_complex(d_r, 0));
26562656
current_statement__ = 168;
2657-
gq_complex = (d_complex / p_r);
2657+
gq_complex = (d_complex / stan::math::to_complex(p_r, 0));
26582658
current_statement__ = 169;
2659-
gq_complex = (d_r / p_complex);
2659+
gq_complex = (stan::math::to_complex(d_r, 0) / p_complex);
26602660
current_statement__ = 170;
2661-
gq_complex = (p_complex / p_r);
2661+
gq_complex = (p_complex / stan::math::to_complex(p_r, 0));
26622662
current_statement__ = 171;
2663-
gq_complex = (d_complex / gq_i);
2663+
gq_complex = (d_complex / stan::math::to_complex(gq_i, 0));
26642664
current_statement__ = 172;
2665-
gq_complex = (gq_i / p_complex);
2665+
gq_complex = (stan::math::to_complex(gq_i, 0) / p_complex);
26662666
current_statement__ = 173;
26672667
gq_complex = stan::math::pow(z, y);
26682668
current_statement__ = 174;
26692669
gq_complex = stan::math::pow(z, gq_r);
26702670
current_statement__ = 175;
2671-
gq_complex = stan::math::pow(gq_r, z);
2671+
gq_complex = stan::math::pow(stan::math::to_complex(gq_r, 0), z);
26722672
current_statement__ = 176;
26732673
gq_complex = stan::math::pow(z, gq_i);
26742674
current_statement__ = 177;
2675-
gq_complex = stan::math::pow(gq_i, z);
2675+
gq_complex = stan::math::pow(stan::math::to_complex(gq_i, 0), z);
26762676
current_statement__ = 178;
26772677
gq_complex = stan::math::pow(d_complex, p_complex);
26782678
current_statement__ = 179;
26792679
gq_complex = stan::math::pow(d_complex, d_r);
26802680
current_statement__ = 180;
26812681
gq_complex = stan::math::pow(d_complex, p_r);
26822682
current_statement__ = 181;
2683-
gq_complex = stan::math::pow(d_r, p_complex);
2683+
gq_complex = stan::math::pow(stan::math::to_complex(d_r, 0), p_complex);
26842684
current_statement__ = 182;
26852685
gq_complex = stan::math::pow(p_complex, p_r);
26862686
current_statement__ = 183;
26872687
gq_complex = stan::math::pow(d_complex, gq_i);
26882688
current_statement__ = 184;
2689-
gq_complex = stan::math::pow(gq_i, p_complex);
2689+
gq_complex = stan::math::pow(stan::math::to_complex(gq_i, 0),
2690+
p_complex);
26902691
current_statement__ = 185;
26912692
gq_complex = -z;
26922693
current_statement__ = 186;
@@ -2724,7 +2725,8 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
27242725
current_statement__ = 202;
27252726
gq_i = stan::math::logical_eq(d_complex, d_r);
27262727
current_statement__ = 203;
2727-
gq_i = stan::math::logical_eq(p_r, d_complex);
2728+
gq_i = stan::math::logical_eq(stan::math::to_complex(p_r, 0),
2729+
d_complex);
27282730
current_statement__ = 204;
27292731
gq_i = stan::math::logical_eq(p_complex, d_r);
27302732
current_statement__ = 205;
@@ -2744,7 +2746,8 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
27442746
current_statement__ = 212;
27452747
gq_i = stan::math::logical_neq(d_complex, d_r);
27462748
current_statement__ = 213;
2747-
gq_i = stan::math::logical_neq(p_r, d_complex);
2749+
gq_i = stan::math::logical_neq(stan::math::to_complex(p_r, 0),
2750+
d_complex);
27482751
current_statement__ = 214;
27492752
gq_i = stan::math::logical_neq(p_complex, d_r);
27502753
current_statement__ = 215;
@@ -3017,7 +3020,7 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
30173020
std::complex<double>(std::numeric_limits<double>::quiet_NaN(),
30183021
std::numeric_limits<double>::quiet_NaN());
30193022
current_statement__ = 336;
3020-
zi = (1 + stan::math::to_complex(0, 3.14));
3023+
zi = (stan::math::to_complex(1, 0) + stan::math::to_complex(0, 3.14));
30213024
current_statement__ = 337;
30223025
zi = (zi * stan::math::to_complex(0, 0));
30233026
std::complex<double> yi =

0 commit comments

Comments
 (0)