@@ -58,39 +58,40 @@ class log_softmax_elt_vari : public vari {
5858 */
5959template <typename T, require_container_st<is_var, T>* = nullptr >
6060inline auto log_softmax (const T& x) {
61- return apply_vector_unary<ref_type_t <T>>::apply (to_ref (x), [&](const auto & alpha) {
62- const int a_size = alpha.size ();
61+ return apply_vector_unary<ref_type_t <T>>::apply (
62+ to_ref (x), [&](const auto & alpha) {
63+ const int a_size = alpha.size ();
6364
64- check_nonzero_size (" log_softmax" , " alpha" , alpha);
65+ check_nonzero_size (" log_softmax" , " alpha" , alpha);
6566
66- vari** alpha_vi_array
67- = ChainableStack::instance_->memalloc_ .alloc_array <vari*>(a_size);
68- Eigen::Map<vector_vi>(alpha_vi_array, a_size) = alpha.vi ();
67+ vari** alpha_vi_array
68+ = ChainableStack::instance_->memalloc_ .alloc_array <vari*>(a_size);
69+ Eigen::Map<vector_vi>(alpha_vi_array, a_size) = alpha.vi ();
6970
70- vector_d alpha_d = alpha.val ();
71+ vector_d alpha_d = alpha.val ();
7172
72- // fold logic of math::softmax() and math::log_softmax()
73- // to save computations
73+ // fold logic of math::softmax() and math::log_softmax()
74+ // to save computations
7475
75- vector_d diff = (alpha_d.array () - alpha_d.maxCoeff ());
76- vector_d softmax_alpha_d = diff.array ().exp ();
77- double sum = softmax_alpha_d.sum ();
78- vector_d log_softmax_alpha_d = diff.array () - std::log (sum);
76+ vector_d diff = (alpha_d.array () - alpha_d.maxCoeff ());
77+ vector_d softmax_alpha_d = diff.array ().exp ();
78+ double sum = softmax_alpha_d.sum ();
79+ vector_d log_softmax_alpha_d = diff.array () - std::log (sum);
7980
80- // end fold
81- double * softmax_alpha_d_array
82- = ChainableStack::instance_->memalloc_ .alloc_array <double >(a_size);
83- Eigen::Map<vector_d>(softmax_alpha_d_array, a_size)
84- = softmax_alpha_d.array () / sum;
81+ // end fold
82+ double * softmax_alpha_d_array
83+ = ChainableStack::instance_->memalloc_ .alloc_array <double >(a_size);
84+ Eigen::Map<vector_d>(softmax_alpha_d_array, a_size)
85+ = softmax_alpha_d.array () / sum;
8586
86- vector_v log_softmax_alpha (a_size);
87- for (int k = 0 ; k < a_size; ++k) {
88- log_softmax_alpha (k) = var (new internal::log_softmax_elt_vari (
89- log_softmax_alpha_d[k], alpha_vi_array, softmax_alpha_d_array, a_size ,
90- k));
91- }
92- return log_softmax_alpha;
93- });
87+ vector_v log_softmax_alpha (a_size);
88+ for (int k = 0 ; k < a_size; ++k) {
89+ log_softmax_alpha (k) = var (new internal::log_softmax_elt_vari (
90+ log_softmax_alpha_d[k], alpha_vi_array, softmax_alpha_d_array,
91+ a_size, k));
92+ }
93+ return log_softmax_alpha;
94+ });
9495}
9596
9697} // namespace math
0 commit comments