Skip to content

Commit 1ac4aec

Browse files
author
Jachym.Barvinek
committed
minor polishing
1 parent dd4c98d commit 1ac4aec

2 files changed

Lines changed: 3 additions & 4 deletions

File tree

stan/math/rev/fun/log_softmax.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ inline auto log_softmax(T&& x) {
2929
= return_var_matrix_t<plain_type_t<decltype(x_arena.val())>, T>;
3030
arena_t<return_t> res = log_softmax(x_arena.val());
3131
reverse_pass_callback([x_arena, res]() mutable {
32-
const auto s = res.val().array().exp().eval();
3332
const auto& res_adj = to_ref(res.adj());
34-
x_arena.adj().array() += res_adj.array() - res_adj.sum() * s;
33+
x_arena.adj().array()
34+
+= res_adj.array() - res_adj.sum() * res.val().array().exp();
3535
});
3636
return return_t(res);
3737
}

stan/math/rev/fun/softmax.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ inline auto softmax(T&& x) {
2727
if (x_arena.size() == 0) {
2828
return return_t(x_arena);
2929
}
30-
auto res_val = to_arena(softmax(x_arena.val()));
31-
arena_t<return_t> res = res_val;
30+
arena_t<return_t> res = softmax(x_arena.val());
3231
reverse_pass_callback([x_arena, res]() mutable {
3332
const auto& s = to_ref(res.val());
3433
const auto& res_adj = to_ref(res.adj());

0 commit comments

Comments
 (0)