Skip to content

Commit 51c0851

Browse files
authored
Merge pull request #1261 from nhuurre/bugfix/reject-array
Fix reject() codegen for container types
2 parents cf8fd0a + 5de6039 commit 51c0851

6 files changed

Lines changed: 67 additions & 34 deletions

File tree

src/stan_math_backend/Statement_gen.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ let rec pp_statement (ppf : Format.formatter) Stmt.Fixed.{pattern; meta} =
363363
pf ppf "if (pstream__) %a" pp_block (list ~sep:cut pp_arg, args)
364364
| NRFunApp (CompilerInternal FnReject, args) ->
365365
let err_strm = "errmsg_stream__" in
366-
let add_to_string ppf e = pf ppf "%s << %a;" err_strm pp_expr e in
366+
let add_to_string ppf e =
367+
pf ppf "stan::math::stan_print(&%s, %a);" err_strm pp_expr e in
367368
pf ppf "std::stringstream %s;@," err_strm ;
368369
pf ppf "%a@," (list ~sep:cut add_to_string) args ;
369370
pf ppf "throw std::domain_error(%s.str());" err_strm

test/integration/good/code-gen/complex_numbers/cpp.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8857,7 +8857,7 @@ template <typename T0__,
88578857
try {
88588858
current_statement__ = 17;
88598859
std::stringstream errmsg_stream__;
8860-
errmsg_stream__ << "called the wrong foo";
8860+
stan::math::stan_print(&errmsg_stream__, "called the wrong foo");
88618861
throw std::domain_error(errmsg_stream__.str());
88628862
} catch (const std::exception& e) {
88638863
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
@@ -8897,7 +8897,7 @@ template <typename T0__,
88978897
try {
88988898
current_statement__ = 21;
88998899
std::stringstream errmsg_stream__;
8900-
errmsg_stream__ << "called the wrong foo";
8900+
stan::math::stan_print(&errmsg_stream__, "called the wrong foo");
89018901
throw std::domain_error(errmsg_stream__.str());
89028902
} catch (const std::exception& e) {
89038903
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
@@ -8937,7 +8937,7 @@ template <typename T0__,
89378937
try {
89388938
current_statement__ = 25;
89398939
std::stringstream errmsg_stream__;
8940-
errmsg_stream__ << "called the wrong foo";
8940+
stan::math::stan_print(&errmsg_stream__, "called the wrong foo");
89418941
throw std::domain_error(errmsg_stream__.str());
89428942
} catch (const std::exception& e) {
89438943
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
@@ -8975,7 +8975,7 @@ template <typename T0__,
89758975
try {
89768976
current_statement__ = 29;
89778977
std::stringstream errmsg_stream__;
8978-
errmsg_stream__ << "called the wrong foo";
8978+
stan::math::stan_print(&errmsg_stream__, "called the wrong foo");
89798979
throw std::domain_error(errmsg_stream__.str());
89808980
} catch (const std::exception& e) {
89818981
stan::lang::rethrow_located(e, locations_array__[current_statement__]);

test/integration/good/code-gen/cpp.expected

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3518,8 +3518,8 @@ static constexpr std::array<const char*, 787> locations_array__ =
35183518
" (in 'mother.stan', line 182, column 36 to line 184, column 3)",
35193519
" (in 'mother.stan', line 187, column 4 to column 24)",
35203520
" (in 'mother.stan', line 186, column 22 to line 188, column 3)",
3521-
" (in 'mother.stan', line 191, column 4 to column 42)",
3522-
" (in 'mother.stan', line 190, column 21 to line 192, column 3)",
3521+
" (in 'mother.stan', line 191, column 4 to column 54)",
3522+
" (in 'mother.stan', line 190, column 29 to line 192, column 3)",
35233523
" (in 'mother.stan', line 195, column 4 to column 18)",
35243524
" (in 'mother.stan', line 196, column 4 to column 19)",
35253525
" (in 'mother.stan', line 197, column 4 to column 26)",
@@ -3749,7 +3749,7 @@ struct foo_4_functor__ {
37493749
template <typename T0__,
37503750
stan::require_all_t<stan::is_stan_scalar<T0__>>* = nullptr>
37513751
void
3752-
operator()(const T0__& x, std::ostream* pstream__) const;
3752+
operator()(const std::vector<T0__>& x, std::ostream* pstream__) const;
37533753
};
37543754
struct f7_functor__ {
37553755
template <typename T3__, typename T4__, typename T5__, typename T6__,
@@ -4643,7 +4643,7 @@ template <bool propto__, typename T0__, typename T_lp__,
46434643
}
46444644
template <typename T0__,
46454645
stan::require_all_t<stan::is_stan_scalar<T0__>>* = nullptr> void
4646-
foo_4(const T0__& x, std::ostream* pstream__) {
4646+
foo_4(const std::vector<T0__>& x, std::ostream* pstream__) {
46474647
using local_scalar_t__ = stan::promote_args_t<T0__>;
46484648
int current_statement__ = 0;
46494649
static constexpr bool propto__ = true;
@@ -4653,8 +4653,11 @@ template <typename T0__,
46534653
try {
46544654
current_statement__ = 710;
46554655
std::stringstream errmsg_stream__;
4656-
errmsg_stream__ << "user-specified rejection";
4657-
errmsg_stream__ << x;
4656+
stan::math::stan_print(&errmsg_stream__, "user-specified rejection");
4657+
stan::math::stan_print(&errmsg_stream__, stan::model::rvalue(x, "x",
4658+
stan::model::index_uni(1)));
4659+
stan::math::stan_print(&errmsg_stream__, "; ");
4660+
stan::math::stan_print(&errmsg_stream__, x);
46584661
throw std::domain_error(errmsg_stream__.str());
46594662
} catch (const std::exception& e) {
46604663
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
@@ -4685,24 +4688,24 @@ template <typename T0__, typename T1__, typename T2__, typename T3__,
46854688
if (stan::math::logical_gt((abs_diff / avg_scale), max_)) {
46864689
current_statement__ = 716;
46874690
std::stringstream errmsg_stream__;
4688-
errmsg_stream__ << "user-specified rejection, difference above ";
4689-
errmsg_stream__ << max_;
4690-
errmsg_stream__ << " x:";
4691-
errmsg_stream__ << x;
4692-
errmsg_stream__ << " y:";
4693-
errmsg_stream__ << y;
4691+
stan::math::stan_print(&errmsg_stream__, "user-specified rejection, difference above ");
4692+
stan::math::stan_print(&errmsg_stream__, max_);
4693+
stan::math::stan_print(&errmsg_stream__, " x:");
4694+
stan::math::stan_print(&errmsg_stream__, x);
4695+
stan::math::stan_print(&errmsg_stream__, " y:");
4696+
stan::math::stan_print(&errmsg_stream__, y);
46944697
throw std::domain_error(errmsg_stream__.str());
46954698
}
46964699
current_statement__ = 719;
46974700
if (stan::math::logical_lt((abs_diff / avg_scale), min_)) {
46984701
current_statement__ = 718;
46994702
std::stringstream errmsg_stream__;
4700-
errmsg_stream__ << "user-specified rejection, difference below ";
4701-
errmsg_stream__ << min_;
4702-
errmsg_stream__ << " x:";
4703-
errmsg_stream__ << x;
4704-
errmsg_stream__ << " y:";
4705-
errmsg_stream__ << y;
4703+
stan::math::stan_print(&errmsg_stream__, "user-specified rejection, difference below ");
4704+
stan::math::stan_print(&errmsg_stream__, min_);
4705+
stan::math::stan_print(&errmsg_stream__, " x:");
4706+
stan::math::stan_print(&errmsg_stream__, x);
4707+
stan::math::stan_print(&errmsg_stream__, " y:");
4708+
stan::math::stan_print(&errmsg_stream__, y);
47064709
throw std::domain_error(errmsg_stream__.str());
47074710
}
47084711
current_statement__ = 720;
@@ -5783,8 +5786,9 @@ f3_functor__::operator()(const int& a1, const std::vector<int>& a2,
57835786
}
57845787

57855788
template <typename T0__, stan::require_all_t<stan::is_stan_scalar<T0__>>*>
5786-
void foo_4_functor__::operator()(const T0__& x, std::ostream* pstream__)
5787-
const
5789+
void
5790+
foo_4_functor__::operator()(const std::vector<T0__>& x,
5791+
std::ostream* pstream__) const
57885792
{
57895793
return foo_4(x, pstream__);
57905794
}
@@ -10139,7 +10143,7 @@ int foo_functor__::operator()(const int& n, std::ostream* pstream__) const
1013910143
stan::model::index_uni(1)))) {
1014010144
current_statement__ = 137;
1014110145
std::stringstream errmsg_stream__;
10142-
errmsg_stream__ << "indexing test 1 failed";
10146+
stan::math::stan_print(&errmsg_stream__, "indexing test 1 failed");
1014310147
throw std::domain_error(errmsg_stream__.str());
1014410148
}
1014510149
current_statement__ = 141;
@@ -10175,7 +10179,7 @@ int foo_functor__::operator()(const int& n, std::ostream* pstream__) const
1017510179
stan::model::index_uni(1)))) {
1017610180
current_statement__ = 143;
1017710181
std::stringstream errmsg_stream__;
10178-
errmsg_stream__ << "indexing test 2 failed";
10182+
stan::math::stan_print(&errmsg_stream__, "indexing test 2 failed");
1017910183
throw std::domain_error(errmsg_stream__.str());
1018010184
}
1018110185
current_statement__ = 148;
@@ -10220,7 +10224,7 @@ int foo_functor__::operator()(const int& n, std::ostream* pstream__) const
1022010224
stan::model::index_uni(1)))) {
1022110225
current_statement__ = 150;
1022210226
std::stringstream errmsg_stream__;
10223-
errmsg_stream__ << "indexing test 3 failed";
10227+
stan::math::stan_print(&errmsg_stream__, "indexing test 3 failed");
1022410228
throw std::domain_error(errmsg_stream__.str());
1022510229
}
1022610230
current_statement__ = 152;

test/integration/good/code-gen/mir.expected

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,16 +1255,30 @@
12551255
(meta <opaque>))))
12561256
(fdloc <opaque>))
12571257
((fdrt ()) (fdname foo_4) (fdsuffix FnPlain)
1258-
(fdargs ((AutoDiffable x UReal)))
1258+
(fdargs ((AutoDiffable x (UArray UReal))))
12591259
(fdbody
12601260
(((pattern
12611261
(Block
12621262
(((pattern
12631263
(NRFunApp (CompilerInternal FnReject)
12641264
(((pattern (Lit Str "user-specified rejection"))
12651265
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
1266+
((pattern
1267+
(Indexed
1268+
((pattern (Var x))
1269+
(meta
1270+
((type_ (UArray UReal)) (loc <opaque>)
1271+
(adlevel AutoDiffable))))
1272+
((Single
1273+
((pattern (Lit Int 1))
1274+
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))))))
1275+
(meta ((type_ UReal) (loc <opaque>) (adlevel AutoDiffable))))
1276+
((pattern (Lit Str "; "))
1277+
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
12661278
((pattern (Var x))
1267-
(meta ((type_ UReal) (loc <opaque>) (adlevel AutoDiffable)))))))
1279+
(meta
1280+
((type_ (UArray UReal)) (loc <opaque>)
1281+
(adlevel AutoDiffable)))))))
12681282
(meta <opaque>)))))
12691283
(meta <opaque>))))
12701284
(fdloc <opaque>))

test/integration/good/code-gen/mother.stan

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ functions {
187187
return x + target();
188188
}
189189

190-
void foo_4(real x) {
191-
reject("user-specified rejection", x);
190+
void foo_4(array[] real x) {
191+
reject("user-specified rejection", x[1], "; ", x);
192192
}
193193

194194
real relative_diff(real x, real y, real max_, real min_) {

test/integration/good/code-gen/transformed_mir.expected

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,16 +1273,30 @@
12731273
(meta <opaque>))))
12741274
(fdloc <opaque>))
12751275
((fdrt ()) (fdname foo_4) (fdsuffix FnPlain)
1276-
(fdargs ((AutoDiffable x UReal)))
1276+
(fdargs ((AutoDiffable x (UArray UReal))))
12771277
(fdbody
12781278
(((pattern
12791279
(Block
12801280
(((pattern
12811281
(NRFunApp (CompilerInternal FnReject)
12821282
(((pattern (Lit Str "user-specified rejection"))
12831283
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
1284+
((pattern
1285+
(Indexed
1286+
((pattern (Var x))
1287+
(meta
1288+
((type_ (UArray UReal)) (loc <opaque>)
1289+
(adlevel AutoDiffable))))
1290+
((Single
1291+
((pattern (Lit Int 1))
1292+
(meta ((type_ UInt) (loc <opaque>) (adlevel DataOnly))))))))
1293+
(meta ((type_ UReal) (loc <opaque>) (adlevel AutoDiffable))))
1294+
((pattern (Lit Str "; "))
1295+
(meta ((type_ UReal) (loc <opaque>) (adlevel DataOnly))))
12841296
((pattern (Var x))
1285-
(meta ((type_ UReal) (loc <opaque>) (adlevel AutoDiffable)))))))
1297+
(meta
1298+
((type_ (UArray UReal)) (loc <opaque>)
1299+
(adlevel AutoDiffable)))))))
12861300
(meta <opaque>)))))
12871301
(meta <opaque>))))
12881302
(fdloc <opaque>))

0 commit comments

Comments
 (0)