Skip to content

Commit 36a15e2

Browse files
committed
Merge commit 'c02589c1513c57593661ec7b21f17923b55cd9f3' into HEAD
2 parents cb602ad + c02589c commit 36a15e2

3 files changed

Lines changed: 97 additions & 43 deletions

File tree

stan/math/rev/core/grad.hpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,12 @@ namespace math {
2828
* derivative propagation.
2929
*/
3030
static void grad(vari* vi) {
31-
// simple reference implementation (intended as doc):
32-
// vi->init_dependent();
33-
// size_t end = var_stack_.size();
34-
// size_t begin = empty_nested() ? 0 : end - nested_size();
35-
// for (size_t i = end; --i > begin; )
36-
// var_stack_[i]->chain();
37-
38-
using it_t = std::vector<vari*>::reverse_iterator;
3931
vi->init_dependent();
40-
it_t begin = ChainableStack::instance_->var_stack_.rbegin();
41-
it_t end = empty_nested() ? ChainableStack::instance_->var_stack_.rend()
42-
: begin + nested_size();
43-
for (it_t it = begin; it < end; ++it) {
44-
(*it)->chain();
32+
std::vector<vari*>& var_stack = ChainableStack::instance_->var_stack_;
33+
size_t end = var_stack.size();
34+
size_t beginning = empty_nested() ? 0 : end - nested_size();
35+
for (size_t i = end; i-- > beginning;) {
36+
var_stack[i]->chain();
4537
}
4638
}
4739

test/unit/math/rev/core/agrad_test.cpp

Lines changed: 0 additions & 30 deletions
This file was deleted.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#include <gtest/gtest.h>
2+
#include <test/unit/math/rev/fun/util.hpp>
3+
#include <stan/math/rev/core.hpp>
4+
#include <stan/math/rev/fun/sin.hpp>
5+
#include <vector>
6+
7+
TEST(AgradRev, multiple_grads) {
8+
for (int i = 0; i < 100; ++i) {
9+
AVAR a = 2.0;
10+
AVAR b = 3.0 * a;
11+
AVAR c = sin(a) * b;
12+
// fixes warning regarding unused variable
13+
c = c;
14+
15+
AVAR nothing;
16+
}
17+
18+
AVAR d = 2.0;
19+
AVAR e = 3.0;
20+
AVAR f = d * e;
21+
22+
AVEC x = createAVEC(d, e);
23+
VEC grad_f;
24+
f.grad(x, grad_f);
25+
26+
EXPECT_FLOAT_EQ(3.0, d.adj());
27+
EXPECT_FLOAT_EQ(2.0, e.adj());
28+
29+
EXPECT_FLOAT_EQ(3.0, grad_f[0]);
30+
EXPECT_FLOAT_EQ(2.0, grad_f[1]);
31+
}
32+
33+
TEST(AgradRev, ensure_first_vari_chained) {
34+
using stan::math::var;
35+
36+
// Make sure there aren't any varis on stack
37+
stan::math::recover_memory();
38+
39+
int N = 10;
40+
std::vector<var> vars;
41+
42+
var total = 0.0;
43+
for (int i = 0; i < N; ++i) {
44+
vars.push_back(0.0);
45+
total += vars.back();
46+
}
47+
48+
total.grad();
49+
50+
EXPECT_FLOAT_EQ(0.0, total.val());
51+
for (int i = 0; i < N; ++i) {
52+
EXPECT_FLOAT_EQ(1.0, vars[i].adj());
53+
}
54+
}
55+
56+
namespace stan {
57+
namespace math {
58+
59+
class test_vari : public vari {
60+
public:
61+
test_vari() : vari(0.0) {}
62+
63+
virtual void chain() {
64+
stan::math::nested_rev_autodiff nested;
65+
66+
// Add enough vars to make the the var_stack_ vector reallocate
67+
int N_new_vars = ChainableStack::instance_->var_stack_.capacity() + 1;
68+
69+
var total = 0.0;
70+
for (int i = 0; i < N_new_vars; ++i) {
71+
total += i;
72+
}
73+
74+
total.grad();
75+
}
76+
};
77+
78+
} // namespace math
79+
} // namespace stan
80+
81+
TEST(AgradRev, nested_grad_during_chain) {
82+
using stan::math::var;
83+
84+
var total = 0.0;
85+
for (int i = 0; i < 2; ++i) {
86+
total += i;
87+
}
88+
89+
var test_var(new stan::math::test_vari());
90+
91+
test_var.grad();
92+
}

0 commit comments

Comments
 (0)