File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -29,9 +29,9 @@ inline auto log_softmax(T&& x) {
2929 = return_var_matrix_t <plain_type_t <decltype (x_arena.val ())>, T>;
3030 arena_t <return_t > res = log_softmax (x_arena.val ());
3131 reverse_pass_callback ([x_arena, res]() mutable {
32- const auto s = res.val ().array ().exp ().eval ();
3332 const auto & res_adj = to_ref (res.adj ());
34- x_arena.adj ().array () += res_adj.array () - res_adj.sum () * s;
33+ x_arena.adj ().array ()
34+ += res_adj.array () - res_adj.sum () * res.val ().array ().exp ();
3535 });
3636 return return_t (res);
3737}
Original file line number Diff line number Diff line change @@ -27,8 +27,7 @@ inline auto softmax(T&& x) {
2727 if (x_arena.size () == 0 ) {
2828 return return_t (x_arena);
2929 }
30- auto res_val = to_arena (softmax (x_arena.val ()));
31- arena_t <return_t > res = res_val;
30+ arena_t <return_t > res = softmax (x_arena.val ());
3231 reverse_pass_callback ([x_arena, res]() mutable {
3332 const auto & s = to_ref (res.val ());
3433 const auto & res_adj = to_ref (res.adj ());
You can’t perform that action at this time.
0 commit comments