Skip to content
141 changes: 141 additions & 0 deletions stan/math/opencl/kernels/multinomial_logit_glm_lpmf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#ifndef STAN_MATH_OPENCL_KERNELS_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#define STAN_MATH_OPENCL_KERNELS_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_cl.hpp>
#include <string>

namespace stan {
namespace math {
namespace opencl_kernels {

// \cond
static constexpr const char* multinomial_logit_glm_kernel_code = STRINGIFY(
// \endcond
/** \ingroup opencl_kernels
* GPU kernel for the Generalized Linear Model (GLM) with multinomial
* distribution and softmax (logit) link function.
*
* Must be run with at least N_instances threads and local size LOCAL_SIZE_.
* Each thread handles one instance n = gid.
*
* The kernel performs two passes over the K classes for instance n:
* 1. find max(eta[n,:]) for numerical stability,
* 2. accumulate sum_exp, S_n, and logp using shifted eta
* (eta[n,k] - max) to avoid catastrophic cancellation; skips
* skips y_nk=0 terms to implement the 0*log(0)=0 convention;
* if need_delta, stash exp(eta[n,k] - max) into delta_global.
* A final loop normalizes delta (if need_delta) and subtracts
* lgamma(y_nk+1) terms (if need_logp_gamma), reading only y_global
* and delta_global, without re-reading x_beta_global or alpha_global.
*
* @param[out] logp_global partial logp sums, one per work group
* @param[out] delta_global residual matrix N_instances x N_classes
* (col-major)
* @param[in] y_global outcome counts, N_instances x N_classes (col-major)
* @param[in] x_beta_global product x*beta, N_instances x N_classes
* (col-major)
* @param[in] alpha_global intercepts: K values if is_alpha_vector, else
* N_instances x N_classes (col-major)
* @param N_instances number of instances
* @param N_classes number of outcome classes
* @param is_alpha_vector 1 if alpha is shared 1xK row, 0 if NxK
* @param need_delta 1 if delta_global should be computed and written
* @param need_logp_gamma 1 if lgamma terms should be included in logp
*/
__kernel void multinomial_logit_glm(
__global double* logp_global, __global double* delta_global,
const __global int* y_global, const __global double* x_beta_global,
const __global double* alpha_global, const int N_instances,
const int N_classes, const int is_alpha_vector, const int need_delta,
const int need_logp_gamma) {
const int gid = get_global_id(0);
const int lid = get_local_id(0);
const int lsize = get_local_size(0);
const int wg_id = get_group_id(0);

__local double local_storage[LOCAL_SIZE_];

double logp = 0;
if (gid < N_instances) {
// Pass 1: row-wise max of eta for numerical stability.
double eta_max = -INFINITY;
for (int k = 0; k < N_classes; k++) {
int nk = k * N_instances + gid;
int alpha_idx = is_alpha_vector ? k : nk;
double eta_k = x_beta_global[nk] + alpha_global[alpha_idx];
if (eta_k > eta_max)
eta_max = eta_k;
}

// Pass 2: sum_exp, S_n, logp; if need_delta stash exp_k in
// delta_global.
double sum_exp = 0;
int S_n = 0;
for (int k = 0; k < N_classes; k++) {
int nk = k * N_instances + gid;
int alpha_idx = is_alpha_vector ? k : nk;
double shifted_eta_k
= x_beta_global[nk] + alpha_global[alpha_idx] - eta_max;
double exp_k = exp(shifted_eta_k);
sum_exp += exp_k;
int y_nk = y_global[nk];
S_n += y_nk;
if (y_nk != 0)
logp += y_nk * shifted_eta_k;
if (need_delta)
delta_global[nk] = exp_k;
}
logp -= S_n * log(sum_exp);

if (need_logp_gamma)
logp += lgamma(S_n + 1.0);

// Normalize delta and/or subtract lgamma(y_nk+1) in one pass.
if (need_delta || need_logp_gamma) {
for (int k = 0; k < N_classes; k++) {
int nk = k * N_instances + gid;
int y_nk = y_global[nk];
if (need_logp_gamma)
logp -= lgamma(y_nk + 1.0);
if (need_delta)
delta_global[nk] = y_nk - S_n * delta_global[nk] / sum_exp;
}
}
}

// Work-group reduction of logp.
local_storage[lid] = logp;
barrier(CLK_LOCAL_MEM_FENCE);
for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
step /= REDUCTION_STEP_SIZE) {
if (lid < step) {
for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
local_storage[lid] += local_storage[lid + step * i];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (lid == 0) {
logp_global[wg_id] = local_storage[0];
}
}
// \cond
);
// \endcond

/** \ingroup opencl_kernels
* See the docs for \link kernels/multinomial_logit_glm_lpmf.hpp
* multinomial_logit_glm() \endlink
*/
const kernel_cl<out_buffer, out_buffer, in_buffer, in_buffer, in_buffer, int,
int, int, int, int>
multinomial_logit_glm("multinomial_logit_glm",
{multinomial_logit_glm_kernel_code},
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});

} // namespace opencl_kernels
} // namespace math
} // namespace stan
#endif
#endif
1 change: 1 addition & 0 deletions stan/math/opencl/prim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
#include <stan/math/opencl/prim/neg_binomial_2_lpmf.hpp>
#include <stan/math/opencl/prim/neg_binomial_2_log_lpmf.hpp>
#include <stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp>
#include <stan/math/opencl/prim/multinomial_logit_glm_lpmf.hpp>
#include <stan/math/opencl/prim/normal_id_glm_lpdf.hpp>
#include <stan/math/opencl/prim/normal_cdf.hpp>
#include <stan/math/opencl/prim/normal_lccdf.hpp>
Expand Down
146 changes: 146 additions & 0 deletions stan/math/opencl/prim/multinomial_logit_glm_lpmf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#ifndef STAN_MATH_OPENCL_PRIM_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#define STAN_MATH_OPENCL_PRIM_MULTINOMIAL_LOGIT_GLM_LPMF_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/prim/size.hpp>
#include <stan/math/opencl/rev/operands_and_partials.hpp>
#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/copy.hpp>
#include <stan/math/opencl/prim/multiply.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/opencl/kernels/multinomial_logit_glm_lpmf.hpp>

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/eval.hpp>
#include <stan/math/prim/fun/sum.hpp>
#include <stan/math/prim/fun/Eigen.hpp>

#include <vector>

namespace stan {
namespace math {

/** \ingroup opencl
* Returns the log PMF of the Generalized Linear Model (GLM)
* with multinomial distribution and softmax (logit) link function.
* This is an OpenCL overload of
* `prim/prob/multinomial_logit_glm_lpmf.hpp`.
* Alpha can be either a shared 1×K row vector or an N×K per-instance matrix.
*
* @tparam T_x type of the design matrix (N×M kernel expression)
* @tparam T_alpha type of the intercept (1×K or N×K kernel expression)
* @tparam T_beta type of the weight matrix (M×K kernel expression)
* @param y outcome count vectors: `y[n]` is a length-K vector of non-negative
* integer counts for instance n
* @param x design matrix (N×M) on OpenCL device
* @param alpha intercept: 1×K broadcast row or N×K per-instance matrix
* @param beta weight matrix (M×K) on OpenCL device
* @return log sum of multinomial log PMFs over all N instances
* @throw std::domain_error if any element of x or beta is infinite or NaN,
* or if alpha contains `+inf` or NaN (`-inf` forces the corresponding softmax
* probability to zero and is allowed)
* @throw std::domain_error if any count in y is negative
* @throw std::invalid_argument if container sizes mismatch
*/
template <bool propto, typename T_x, typename T_alpha, typename T_beta,
require_all_prim_or_rev_kernel_expression_t<T_x, T_alpha,
T_beta>* = nullptr>
inline return_type_t<T_x, T_alpha, T_beta> multinomial_logit_glm_lpmf(
const std::vector<std::vector<int>>& y, const T_x& x, const T_alpha& alpha,
const T_beta& beta) {
using T_partials_return = partials_return_t<T_x, T_alpha, T_beta>;
static constexpr const char* function = "multinomial_logit_glm_lpmf";

const int N_instances = x.rows();
const int N_classes = beta.cols();

check_size_match(function, "Rows of", "x", N_instances, "size of", "y",
y.size());
check_size_match(function, "Columns of", "beta", N_classes, "columns of",
"alpha", alpha.cols());
check_size_match(function, "Columns of", "x", x.cols(), "rows of", "beta",
beta.rows());

const int alpha_rows = alpha.rows();
const bool is_alpha_vector = alpha_rows == 1;
if (!is_alpha_vector) {
check_size_match(function, "Rows of", "alpha", alpha_rows, "rows of", "x",
N_instances);
}

if (N_instances == 0) {
return 0;
}
for (int n = 0; n < N_instances; ++n) {
check_size_match(function, "Size of outcome vector", y[n].size(),
"number of classes", N_classes);
check_nonnegative(function, "outcome counts", y[n]);
}

if constexpr (!include_summand<propto, T_x, T_alpha, T_beta>::value) {
return 0;
}

// Flatten nested y into an N×K matrix for upload.
Eigen::MatrixXi y_mat(N_instances, N_classes);
for (int n = 0; n < N_instances; ++n)
for (int k = 0; k < N_classes; ++k)
y_mat(n, k) = y[n][k];
matrix_cl<int> y_cl(y_mat);

const auto& x_val = eval(value_of(x));
const auto& alpha_val = eval(value_of(alpha));
const auto& beta_val = eval(value_of(beta));

matrix_cl<double> x_beta_cl = x_val * beta_val;

const int local_size
= opencl_kernels::multinomial_logit_glm.get_option("LOCAL_SIZE_");
const int wgs = (N_instances + local_size - 1) / local_size;

constexpr bool need_delta = is_any_autodiff_v<T_x, T_alpha, T_beta>;

matrix_cl<double> logp_cl(wgs, 1);
matrix_cl<double> delta_cl(0, 0);
if constexpr (need_delta)
delta_cl = matrix_cl<double>(N_instances, N_classes);

try {
opencl_kernels::multinomial_logit_glm(
cl::NDRange(local_size * wgs), cl::NDRange(local_size), logp_cl,
delta_cl, y_cl, x_beta_cl, alpha_val, N_instances, N_classes,
is_alpha_vector, need_delta, !propto);
} catch (const cl::Error& e) {
check_opencl_error(function, e);
}

T_partials_return logp = sum(from_matrix_cl(logp_cl));

if (!std::isfinite(logp)) {
check_cl(function, "Design matrix", x_val, "finite") = isfinite(x_val);
check_cl(function, "Intercept", alpha_val, "finite") = isfinite(alpha_val);
check_cl(function, "Weight matrix", beta_val, "finite")
= isfinite(beta_val);
}

auto ops_partials = make_partials_propagator(x, alpha, beta);
if constexpr (need_delta) {
// dlogp/deta[n,k] = delta[n,k]; chain rule: dx = delta*beta^T,
// dalpha = delta (or colwise_sum when 1xK), dbeta = x^T*delta.
if constexpr (is_autodiff_v<T_x>)
partials<0>(ops_partials) = delta_cl * transpose(beta_val);
if constexpr (is_autodiff_v<T_alpha>)
partials<1>(ops_partials)
= is_alpha_vector ? colwise_sum(delta_cl) : delta_cl;
if constexpr (is_autodiff_v<T_beta>)
partials<2>(ops_partials) = transpose(x_val) * delta_cl;
}
return ops_partials.build(logp);
}

} // namespace math
} // namespace stan

#endif
#endif
1 change: 1 addition & 0 deletions stan/math/prim/prob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
#include <stan/math/prim/prob/multi_student_t_cholesky_rng.hpp>
#include <stan/math/prim/prob/multi_student_t_lpdf.hpp>
#include <stan/math/prim/prob/multi_student_t_rng.hpp>
#include <stan/math/prim/prob/multinomial_logit_glm_lpmf.hpp>
#include <stan/math/prim/prob/multinomial_logit_lpmf.hpp>
#include <stan/math/prim/prob/multinomial_logit_rng.hpp>
#include <stan/math/prim/prob/multinomial_lpmf.hpp>
Expand Down
Loading
Loading