Skip to content

Commit 704dc9b

Browse files
committed
Codegen without promotion for operators on scalars
1 parent 689a0de commit 704dc9b

4 files changed

Lines changed: 55 additions & 57 deletions

File tree

src/stan_math_backend/Expression_gen.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ and pp_scalar_binary ppf op fn es =
230230

231231
and gen_operator_app op ppf es =
232232
let remove_basic_promotion (e : 'a Expr.Fixed.t) =
233-
match e.pattern with Promotion (e, UReal, _) -> e | _ -> e in
233+
match e.pattern with Promotion (e, _, _) when is_scalar e -> e | _ -> e
234+
in
234235
let es = List.map ~f:remove_basic_promotion es in
235236
match op with
236237
| Operator.Plus -> pp_scalar_binary ppf "+" "stan::math::add" es

test/integration/good/code-gen/cpp.expected

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,124 +2570,123 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
25702570
current_statement__ = 125;
25712571
gq_complex = (z + y);
25722572
current_statement__ = 126;
2573-
gq_complex = (z + stan::math::to_complex(gq_r, 0));
2573+
gq_complex = (z + gq_r);
25742574
current_statement__ = 127;
2575-
gq_complex = (stan::math::to_complex(gq_r, 0) + z);
2575+
gq_complex = (gq_r + z);
25762576
current_statement__ = 128;
2577-
gq_complex = (z + stan::math::to_complex(gq_i, 0));
2577+
gq_complex = (z + gq_i);
25782578
current_statement__ = 129;
2579-
gq_complex = (stan::math::to_complex(gq_i, 0) + z);
2579+
gq_complex = (gq_i + z);
25802580
current_statement__ = 130;
25812581
gq_complex = (d_complex + p_complex);
25822582
current_statement__ = 131;
2583-
gq_complex = (d_complex + stan::math::to_complex(d_r, 0));
2583+
gq_complex = (d_complex + d_r);
25842584
current_statement__ = 132;
2585-
gq_complex = (d_complex + stan::math::to_complex(p_r, 0));
2585+
gq_complex = (d_complex + p_r);
25862586
current_statement__ = 133;
2587-
gq_complex = (stan::math::to_complex(d_r, 0) + p_complex);
2587+
gq_complex = (d_r + p_complex);
25882588
current_statement__ = 134;
2589-
gq_complex = (p_complex + stan::math::to_complex(p_r, 0));
2589+
gq_complex = (p_complex + p_r);
25902590
current_statement__ = 135;
2591-
gq_complex = (d_complex + stan::math::to_complex(gq_i, 0));
2591+
gq_complex = (d_complex + gq_i);
25922592
current_statement__ = 136;
2593-
gq_complex = (stan::math::to_complex(gq_i, 0) + p_complex);
2593+
gq_complex = (gq_i + p_complex);
25942594
current_statement__ = 137;
25952595
gq_complex = (z - y);
25962596
current_statement__ = 138;
2597-
gq_complex = (z - stan::math::to_complex(gq_r, 0));
2597+
gq_complex = (z - gq_r);
25982598
current_statement__ = 139;
2599-
gq_complex = (stan::math::to_complex(gq_r, 0) - z);
2599+
gq_complex = (gq_r - z);
26002600
current_statement__ = 140;
2601-
gq_complex = (z - stan::math::to_complex(gq_i, 0));
2601+
gq_complex = (z - gq_i);
26022602
current_statement__ = 141;
2603-
gq_complex = (stan::math::to_complex(gq_i, 0) - z);
2603+
gq_complex = (gq_i - z);
26042604
current_statement__ = 142;
26052605
gq_complex = (d_complex - p_complex);
26062606
current_statement__ = 143;
2607-
gq_complex = (d_complex - stan::math::to_complex(d_r, 0));
2607+
gq_complex = (d_complex - d_r);
26082608
current_statement__ = 144;
2609-
gq_complex = (d_complex - stan::math::to_complex(p_r, 0));
2609+
gq_complex = (d_complex - p_r);
26102610
current_statement__ = 145;
2611-
gq_complex = (stan::math::to_complex(d_r, 0) - p_complex);
2611+
gq_complex = (d_r - p_complex);
26122612
current_statement__ = 146;
2613-
gq_complex = (p_complex - stan::math::to_complex(p_r, 0));
2613+
gq_complex = (p_complex - p_r);
26142614
current_statement__ = 147;
2615-
gq_complex = (d_complex - stan::math::to_complex(gq_i, 0));
2615+
gq_complex = (d_complex - gq_i);
26162616
current_statement__ = 148;
2617-
gq_complex = (stan::math::to_complex(gq_i, 0) - p_complex);
2617+
gq_complex = (gq_i - p_complex);
26182618
current_statement__ = 149;
26192619
gq_complex = (z * y);
26202620
current_statement__ = 150;
2621-
gq_complex = (z * stan::math::to_complex(gq_r, 0));
2621+
gq_complex = (z * gq_r);
26222622
current_statement__ = 151;
2623-
gq_complex = (stan::math::to_complex(gq_r, 0) * z);
2623+
gq_complex = (gq_r * z);
26242624
current_statement__ = 152;
2625-
gq_complex = (z * stan::math::to_complex(gq_i, 0));
2625+
gq_complex = (z * gq_i);
26262626
current_statement__ = 153;
2627-
gq_complex = (stan::math::to_complex(gq_i, 0) * z);
2627+
gq_complex = (gq_i * z);
26282628
current_statement__ = 154;
26292629
gq_complex = (d_complex * p_complex);
26302630
current_statement__ = 155;
2631-
gq_complex = (d_complex * stan::math::to_complex(d_r, 0));
2631+
gq_complex = (d_complex * d_r);
26322632
current_statement__ = 156;
2633-
gq_complex = (d_complex * stan::math::to_complex(p_r, 0));
2633+
gq_complex = (d_complex * p_r);
26342634
current_statement__ = 157;
2635-
gq_complex = (stan::math::to_complex(d_r, 0) * p_complex);
2635+
gq_complex = (d_r * p_complex);
26362636
current_statement__ = 158;
2637-
gq_complex = (p_complex * stan::math::to_complex(p_r, 0));
2637+
gq_complex = (p_complex * p_r);
26382638
current_statement__ = 159;
2639-
gq_complex = (d_complex * stan::math::to_complex(gq_i, 0));
2639+
gq_complex = (d_complex * gq_i);
26402640
current_statement__ = 160;
2641-
gq_complex = (stan::math::to_complex(gq_i, 0) * p_complex);
2641+
gq_complex = (gq_i * p_complex);
26422642
current_statement__ = 161;
26432643
gq_complex = (z / y);
26442644
current_statement__ = 162;
2645-
gq_complex = (z / stan::math::to_complex(gq_r, 0));
2645+
gq_complex = (z / gq_r);
26462646
current_statement__ = 163;
2647-
gq_complex = (stan::math::to_complex(gq_r, 0) / z);
2647+
gq_complex = (gq_r / z);
26482648
current_statement__ = 164;
2649-
gq_complex = (z / stan::math::to_complex(gq_i, 0));
2649+
gq_complex = (z / gq_i);
26502650
current_statement__ = 165;
2651-
gq_complex = (stan::math::to_complex(gq_i, 0) / z);
2651+
gq_complex = (gq_i / z);
26522652
current_statement__ = 166;
26532653
gq_complex = (d_complex / p_complex);
26542654
current_statement__ = 167;
2655-
gq_complex = (d_complex / stan::math::to_complex(d_r, 0));
2655+
gq_complex = (d_complex / d_r);
26562656
current_statement__ = 168;
2657-
gq_complex = (d_complex / stan::math::to_complex(p_r, 0));
2657+
gq_complex = (d_complex / p_r);
26582658
current_statement__ = 169;
2659-
gq_complex = (stan::math::to_complex(d_r, 0) / p_complex);
2659+
gq_complex = (d_r / p_complex);
26602660
current_statement__ = 170;
2661-
gq_complex = (p_complex / stan::math::to_complex(p_r, 0));
2661+
gq_complex = (p_complex / p_r);
26622662
current_statement__ = 171;
2663-
gq_complex = (d_complex / stan::math::to_complex(gq_i, 0));
2663+
gq_complex = (d_complex / gq_i);
26642664
current_statement__ = 172;
2665-
gq_complex = (stan::math::to_complex(gq_i, 0) / p_complex);
2665+
gq_complex = (gq_i / p_complex);
26662666
current_statement__ = 173;
26672667
gq_complex = stan::math::pow(z, y);
26682668
current_statement__ = 174;
26692669
gq_complex = stan::math::pow(z, gq_r);
26702670
current_statement__ = 175;
2671-
gq_complex = stan::math::pow(stan::math::to_complex(gq_r, 0), z);
2671+
gq_complex = stan::math::pow(gq_r, z);
26722672
current_statement__ = 176;
26732673
gq_complex = stan::math::pow(z, gq_i);
26742674
current_statement__ = 177;
2675-
gq_complex = stan::math::pow(stan::math::to_complex(gq_i, 0), z);
2675+
gq_complex = stan::math::pow(gq_i, z);
26762676
current_statement__ = 178;
26772677
gq_complex = stan::math::pow(d_complex, p_complex);
26782678
current_statement__ = 179;
26792679
gq_complex = stan::math::pow(d_complex, d_r);
26802680
current_statement__ = 180;
26812681
gq_complex = stan::math::pow(d_complex, p_r);
26822682
current_statement__ = 181;
2683-
gq_complex = stan::math::pow(stan::math::to_complex(d_r, 0), p_complex);
2683+
gq_complex = stan::math::pow(d_r, p_complex);
26842684
current_statement__ = 182;
26852685
gq_complex = stan::math::pow(p_complex, p_r);
26862686
current_statement__ = 183;
26872687
gq_complex = stan::math::pow(d_complex, gq_i);
26882688
current_statement__ = 184;
2689-
gq_complex = stan::math::pow(stan::math::to_complex(gq_i, 0),
2690-
p_complex);
2689+
gq_complex = stan::math::pow(gq_i, p_complex);
26912690
current_statement__ = 185;
26922691
gq_complex = -z;
26932692
current_statement__ = 186;
@@ -2725,8 +2724,7 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
27252724
current_statement__ = 202;
27262725
gq_i = stan::math::logical_eq(d_complex, d_r);
27272726
current_statement__ = 203;
2728-
gq_i = stan::math::logical_eq(stan::math::to_complex(p_r, 0),
2729-
d_complex);
2727+
gq_i = stan::math::logical_eq(p_r, d_complex);
27302728
current_statement__ = 204;
27312729
gq_i = stan::math::logical_eq(p_complex, d_r);
27322730
current_statement__ = 205;
@@ -2746,8 +2744,7 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
27462744
current_statement__ = 212;
27472745
gq_i = stan::math::logical_neq(d_complex, d_r);
27482746
current_statement__ = 213;
2749-
gq_i = stan::math::logical_neq(stan::math::to_complex(p_r, 0),
2750-
d_complex);
2747+
gq_i = stan::math::logical_neq(p_r, d_complex);
27512748
current_statement__ = 214;
27522749
gq_i = stan::math::logical_neq(p_complex, d_r);
27532750
current_statement__ = 215;
@@ -3020,7 +3017,7 @@ class complex_scalar_model final : public model_base_crtp<complex_scalar_model>
30203017
std::complex<double>(std::numeric_limits<double>::quiet_NaN(),
30213018
std::numeric_limits<double>::quiet_NaN());
30223019
current_statement__ = 336;
3023-
zi = (stan::math::to_complex(1, 0) + stan::math::to_complex(0, 3.14));
3020+
zi = (1 + stan::math::to_complex(0, 3.14));
30243021
current_statement__ = 337;
30253022
zi = (zi * stan::math::to_complex(0, 0));
30263023
std::complex<double> yi =

test/integration/good/compiler-optimizations/cpp.expected

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2139,8 +2139,8 @@ class ad_level_failing_model final : public model_base_crtp<ad_level_failing_mod
21392139
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
21402140
(void) DUMMY_VAR__; // suppress unused var warning
21412141
try {
2142-
2143-
2142+
2143+
21442144
int pos__;
21452145
pos__ = 1;
21462146
current_statement__ = 16;
@@ -5982,7 +5982,7 @@ class expr_prop_experiment2_model final : public model_base_crtp<expr_prop_exper
59825982
(void) DUMMY_VAR__; // suppress unused var warning
59835983
try {
59845984

5985-
5985+
59865986
int pos__;
59875987
pos__ = 1;
59885988
current_statement__ = 1;
@@ -20763,7 +20763,7 @@ class lcm_experiment_model final : public model_base_crtp<lcm_experiment_model>
2076320763
(void) DUMMY_VAR__; // suppress unused var warning
2076420764
try {
2076520765

20766-
20766+
2076720767
int pos__;
2076820768
pos__ = 1;
2076920769
current_statement__ = 1;
@@ -25588,7 +25588,7 @@ using namespace stan::math;
2558825588

2558925589

2559025590
stan::math::profile_map profiles__;
25591-
static constexpr std::array<const char*, 91> locations_array__ =
25591+
static constexpr std::array<const char*, 91> locations_array__ =
2559225592
{" (found before start of program)",
2559325593
" (in 'optimizations.stan', line 20, column 4 to column 15)",
2559425594
" (in 'optimizations.stan', line 21, column 4 to column 13)",

test/integration/good/compiler-optimizations/cppO1.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16669,7 +16669,7 @@ using namespace stan::math;
1666916669

1667016670

1667116671
stan::math::profile_map profiles__;
16672-
static constexpr std::array<const char*, 109> locations_array__ =
16672+
static constexpr std::array<const char*, 109> locations_array__ =
1667316673
{" (found before start of program)",
1667416674
" (in 'optimizations.stan', line 20, column 4 to column 15)",
1667516675
" (in 'optimizations.stan', line 21, column 4 to column 13)",

0 commit comments

Comments
 (0)