Skip to content

Commit a37a37c

Browse files
authored
Merge pull request #3211 from stan-dev/student-t-qf
Add quantile function for Student-T distribution
2 parents 105bfcc + 3eac00e commit a37a37c

12 files changed

Lines changed: 502 additions & 23 deletions

File tree

stan/math/fwd/prob.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
#include <stan/math/fwd/fun/Eigen_NumTraits.hpp>
66

77
#include <stan/math/fwd/prob/std_normal_log_qf.hpp>
8+
#include <stan/math/fwd/prob/student_t_qf.hpp>
89

910
#endif
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#ifndef STAN_MATH_FWD_PROB_STUDENT_T_QF_HPP
2+
#define STAN_MATH_FWD_PROB_STUDENT_T_QF_HPP
3+
4+
#include <stan/math/fwd/meta.hpp>
5+
#include <stan/math/fwd/fun/digamma.hpp>
6+
#include <stan/math/fwd/fun/exp.hpp>
7+
#include <stan/math/fwd/fun/hypergeometric_2F1.hpp>
8+
#include <stan/math/fwd/fun/hypergeometric_pFq.hpp>
9+
#include <stan/math/fwd/fun/inv_inc_beta.hpp>
10+
#include <stan/math/fwd/fun/log.hpp>
11+
#include <stan/math/fwd/fun/sqrt.hpp>
12+
#include <stan/math/fwd/fun/value_of.hpp>
13+
#include <stan/math/fwd/fun/value_of_rec.hpp>
14+
#include <stan/math/prim/meta.hpp>
15+
#include <stan/math/prim/prob/student_t_lpdf.hpp>
16+
17+
namespace stan {
18+
namespace math {
19+
20+
template <typename T_p, typename T_nu, typename T_mu, typename T_sigma,
21+
require_all_stan_scalar_t<T_p, T_mu, T_sigma, T_nu>* = nullptr,
22+
require_any_fvar_t<T_p, T_nu, T_mu, T_sigma>* = nullptr>
23+
inline auto student_t_qf(const T_p& p, const T_nu& nu, const T_mu& mu,
24+
const T_sigma& sigma) {
25+
static constexpr const char* function = "student_t_qf";
26+
using FvarT = return_type_t<T_p, T_mu, T_sigma, T_nu>;
27+
using T_partials = partials_type_t<FvarT>;
28+
29+
auto p_val = value_of(p);
30+
auto nu_val = value_of(nu);
31+
auto mu_val = value_of(mu);
32+
auto sigma_val = value_of(sigma);
33+
34+
check_nonnegative(function, "Degrees of freedom parameter", nu_val);
35+
check_positive(function, "Scale parameter", sigma_val);
36+
check_bounded(function, "Probability parameter", p_val, 0.0, 1.0);
37+
38+
if (unlikely(p_val == 0.0)) {
39+
return FvarT{NEGATIVE_INFTY, 0.0};
40+
} else if (unlikely(p_val == 1.0)) {
41+
return FvarT{INFTY, 0.0};
42+
} else if (unlikely(p_val == 0.5)) {
43+
return FvarT{mu_val, 0.0};
44+
}
45+
46+
const auto p_val_flip = p_val < 0.5 ? p_val : 1.0 - p_val;
47+
const double p_sign = value_of_rec(p_val) < 0.5 ? -1.0 : 1.0;
48+
auto sqrt_nu_val = sqrt(nu_val);
49+
auto ibeta_arg = inv_inc_beta(0.5 * nu_val, 0.5, 2.0 * p_val_flip);
50+
auto rtn_val
51+
= mu_val
52+
+ p_sign * sigma_val * sqrt_nu_val * sqrt(-1.0 + 1.0 / ibeta_arg);
53+
54+
FvarT rtn(rtn_val, 0.0);
55+
56+
if constexpr (is_autodiff_v<T_p>) {
57+
rtn.d_ += p.d_ * exp(-student_t_lpdf(rtn_val, nu_val, mu_val, sigma_val));
58+
}
59+
60+
if constexpr (is_autodiff_v<T_nu>) {
61+
const auto half_nu = nu_val / 2.0;
62+
Eigen::Matrix<T_partials, -1, 1> hyper_arg_a{{0.5, half_nu, half_nu}};
63+
Eigen::Matrix<T_partials, -1, 1> hyper_arg_b{
64+
{1.0 + half_nu, 1.0 + half_nu}};
65+
const auto hyper_arg
66+
= hypergeometric_pFq(hyper_arg_a, hyper_arg_b, ibeta_arg);
67+
const auto hyper2f1 = hypergeometric_2F1(1.0, (1.0 + nu_val) / 2.0,
68+
(2.0 + nu_val) / 2.0, ibeta_arg);
69+
const auto digamma_a1 = digamma(half_nu);
70+
const auto digamma_a2 = digamma((1.0 + nu_val) / 2.0);
71+
72+
const auto arg_1 = (4.0 * hyper_arg * sqrt(1.0 - ibeta_arg)) / nu_val;
73+
const auto arg_2 = -2.0 * hyper2f1 * (-1.0 + ibeta_arg)
74+
* (log(ibeta_arg) - digamma_a1 + digamma_a2);
75+
76+
const auto num1 = sigma_val * (-2.0 + (2.0 - arg_1 + arg_2) / ibeta_arg);
77+
const auto den1 = 4.0 * sqrt_nu_val * sqrt(-1.0 + 1.0 / ibeta_arg);
78+
rtn.d_ += nu.d_ * p_sign * num1 / den1;
79+
}
80+
81+
if constexpr (is_autodiff_v<T_mu>) {
82+
rtn.d_ += mu.d_;
83+
}
84+
85+
if constexpr (is_autodiff_v<T_sigma>) {
86+
rtn.d_ += sigma.d_ * p_sign * sqrt_nu_val * sqrt(-1.0 + 1.0 / ibeta_arg);
87+
}
88+
89+
return rtn;
90+
}
91+
} // namespace math
92+
} // namespace stan
93+
94+
#endif

stan/math/prim/fun/hypergeometric_pFq.hpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err/check_not_nan.hpp>
66
#include <stan/math/prim/err/check_finite.hpp>
7+
#include <stan/math/prim/fun/to_ref.hpp>
78
#include <stan/math/prim/fun/to_row_vector.hpp>
89
#include <boost/math/special_functions/hypergeometric_pFq.hpp>
910

@@ -25,34 +26,53 @@ namespace math {
2526
template <typename Ta, typename Tb, typename Tz,
2627
require_all_vector_st<std::is_arithmetic, Ta, Tb>* = nullptr,
2728
require_arithmetic_t<Tz>* = nullptr>
28-
inline return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
29-
const Tz& z) {
30-
plain_type_t<Ta> a_ref = a;
31-
plain_type_t<Tb> b_ref = b;
29+
inline return_type_t<Ta, Tb, Tz> hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
30+
decltype(auto) a_ref = to_ref(std::forward<Ta>(a));
31+
decltype(auto) b_ref = to_ref(std::forward<Tb>(b));
3232
check_finite("hypergeometric_pFq", "a", a_ref);
3333
check_finite("hypergeometric_pFq", "b", b_ref);
3434
check_finite("hypergeometric_pFq", "z", z);
35-
3635
check_not_nan("hypergeometric_pFq", "a", a_ref);
3736
check_not_nan("hypergeometric_pFq", "b", b_ref);
3837
check_not_nan("hypergeometric_pFq", "z", z);
3938

40-
bool condition_1 = (a_ref.size() > (b_ref.size() + 1)) && (z != 0);
41-
bool condition_2 = (a_ref.size() == (b_ref.size() + 1)) && (std::fabs(z) > 1);
39+
const bool condition_1 = (a_ref.size() > (b_ref.size() + 1)) && (z != 0);
40+
const bool condition_2
41+
= (a_ref.size() == (b_ref.size() + 1)) && (std::fabs(z) > 1);
4242

4343
if (condition_1 || condition_2) {
44-
std::stringstream msg;
45-
msg << "hypergeometric function pFq does not meet convergence "
46-
<< "conditions with given arguments. "
47-
<< "a: " << to_row_vector(a_ref) << ", "
48-
<< "b: " << to_row_vector(b_ref) << ", "
49-
<< "z: " << z;
50-
throw std::domain_error(msg.str());
44+
[&]() STAN_COLD_PATH {
45+
std::stringstream msg;
46+
msg << "hypergeometric function pFq does not meet convergence "
47+
"conditions with given arguments. "
48+
"a: "
49+
<< to_row_vector(a_ref) << ", "
50+
<< "b: " << to_row_vector(b_ref) << ", "
51+
<< "z: " << z;
52+
throw std::domain_error(msg.str());
53+
}();
54+
}
55+
// For plain vectors, we can use Eigen's Map to avoid unnecessary copies
56+
using a_ref_t = decltype(a_ref);
57+
using b_ref_t = decltype(b_ref);
58+
constexpr bool is_a_plain_vec
59+
= std::is_same_v<std::decay_t<a_ref_t>, plain_type_t<a_ref_t>>;
60+
constexpr bool is_b_plain_vec
61+
= std::is_same_v<std::decay_t<b_ref_t>, plain_type_t<b_ref_t>>;
62+
if constexpr (is_a_plain_vec && is_b_plain_vec) {
63+
// We use type erasure not do a hard copy here
64+
using map_t = Eigen::Map<Eigen::VectorXd>;
65+
auto map_a = map_t(const_cast<double*>(a_ref.data()), a_ref.size());
66+
auto map_b = map_t(const_cast<double*>(b_ref.data()), b_ref.size());
67+
return boost::math::hypergeometric_pFq(map_a, map_b, z);
68+
} else {
69+
// We need pointers to `a` and `b`'s data here so we hard evaluate.
70+
decltype(auto) a_eval = eval(a_ref);
71+
decltype(auto) b_eval = eval(b_ref);
72+
return boost::math::hypergeometric_pFq(
73+
std::vector<double>(a_eval.data(), a_eval.data() + a_eval.size()),
74+
std::vector<double>(b_eval.data(), b_eval.data() + b_eval.size()), z);
5175
}
52-
53-
return boost::math::hypergeometric_pFq(
54-
std::vector<double>(a_ref.data(), a_ref.data() + a_ref.size()),
55-
std::vector<double>(b_ref.data(), b_ref.data() + b_ref.size()), z);
5676
}
5777
} // namespace math
5878
} // namespace stan

stan/math/prim/meta.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
#include <stan/math/prim/meta/ad_promotable.hpp>
7272
#include <stan/math/prim/meta/append_return_type.hpp>
7373
#include <stan/math/prim/meta/base_type.hpp>
74+
#include <stan/math/prim/meta/common_container_type.hpp>
7475
#include <stan/math/prim/meta/contains_std_vector.hpp>
7576
#include <stan/math/prim/meta/contains_tuple.hpp>
7677
#include <stan/math/prim/meta/error_index.hpp>
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#ifndef STAN_MATH_PRIM_META_COMMON_CONTAINER_TYPE_HPP
2+
#define STAN_MATH_PRIM_META_COMMON_CONTAINER_TYPE_HPP
3+
4+
#include <stan/math/prim/meta/is_container.hpp>
5+
#include <stan/math/prim/meta/is_tuple.hpp>
6+
#include <stan/math/prim/meta/is_detected.hpp>
7+
#include <stan/math/prim/meta/is_stan_scalar.hpp>
8+
#include <stan/math/prim/meta/is_var_matrix.hpp>
9+
#include <stan/math/prim/meta/plain_type.hpp>
10+
#include <stan/math/prim/meta/return_type.hpp>
11+
#include <stan/math/prim/meta/promote_scalar_type.hpp>
12+
#include <type_traits>
13+
14+
namespace stan {
15+
namespace internal {
16+
template <typename T1, typename T2, typename = void, typename = void>
17+
struct common_container_type_impl;
18+
19+
template <typename T1, typename T2>
20+
struct common_container_type_impl<T1, T2, require_stan_scalar_t<T1>,
21+
require_stan_scalar_t<T2>> {
22+
using type = return_type_t<T1, T2>;
23+
};
24+
25+
template <typename T1, typename T2>
26+
struct common_container_type_impl<T1, T2, require_container_t<T1>,
27+
require_container_t<T2>> {
28+
using return_t = return_type_t<T1, T2>;
29+
using container_type_1 = math::promote_scalar_t<return_t, plain_type_t<T1>>;
30+
using container_type_2 = math::promote_scalar_t<return_t, plain_type_t<T2>>;
31+
using type = std::conditional_t<
32+
std::is_same<container_type_1, container_type_2>::value, container_type_1,
33+
void>;
34+
};
35+
36+
template <typename T1, typename T2>
37+
struct common_container_type_impl<T1, T2, require_stan_scalar_t<T1>,
38+
require_container_t<T2>> {
39+
using type = math::promote_scalar_t<return_type_t<T1, T2>, plain_type_t<T2>>;
40+
};
41+
42+
template <typename T1, typename T2>
43+
struct common_container_type_impl<T1, T2, require_container_t<T1>,
44+
require_stan_scalar_t<T2>> {
45+
using type = math::promote_scalar_t<return_type_t<T1, T2>, plain_type_t<T1>>;
46+
};
47+
} // namespace internal
48+
49+
template <typename... Ts>
50+
struct common_container_type;
51+
52+
template <typename T>
53+
struct common_container_type<T> {
54+
using type = typename internal::common_container_type_impl<
55+
T, double>::type; // Use double for base case
56+
};
57+
58+
/**
59+
* Determine the common container type for a variadic list of types.
60+
* If all types are scalars, then the common scalar type is returned.
61+
* If all container types the same, but not necessarily the same scalar type,
62+
* the common container type with the common scalar type is returned.
63+
*
64+
* If different container types are present, the result is `void`.
65+
*/
66+
template <typename T1, typename... Ts>
67+
struct common_container_type<T1, Ts...> {
68+
using type = typename internal::common_container_type_impl<
69+
T1, typename common_container_type<Ts...>::type>::type;
70+
};
71+
72+
template <typename... Ts>
73+
using common_container_t = typename common_container_type<Ts...>::type;
74+
75+
} // namespace stan
76+
77+
#endif // STAN_MATH_PRIM_META_PLAIN_TYPE_HPP

stan/math/prim/prob.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@
287287
#include <stan/math/prim/prob/student_t_lccdf.hpp>
288288
#include <stan/math/prim/prob/student_t_lcdf.hpp>
289289
#include <stan/math/prim/prob/student_t_lpdf.hpp>
290+
#include <stan/math/prim/prob/student_t_qf.hpp>
290291
#include <stan/math/prim/prob/student_t_rng.hpp>
291292
#include <stan/math/prim/prob/uniform_ccdf_log.hpp>
292293
#include <stan/math/prim/prob/uniform_cdf.hpp>
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#ifndef STAN_MATH_PRIM_PROB_STUDENT_T_QF_HPP
2+
#define STAN_MATH_PRIM_PROB_STUDENT_T_QF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/sqrt.hpp>
7+
#include <stan/math/prim/fun/inv_inc_beta.hpp>
8+
#include <stan/math/prim/fun/max_size.hpp>
9+
10+
namespace stan {
11+
namespace math {
12+
13+
/**
14+
* The quantile function of the Student's t-distribution.
15+
*
16+
* @tparam T_p type of the probability parameter
17+
* @tparam T_nu type of the degrees of freedom parameter
18+
* @tparam T_mu type of the location parameter
19+
* @tparam T_sigma type of the scale parameter
20+
* @param p Probability in the range [0, 1].
21+
* @param nu Degrees of freedom, must be non-negative.
22+
* @param mu Location parameter.
23+
* @param sigma Scale parameter, must be positive.
24+
* @return Quantile function value.
25+
* @throw std::domain_error if `nu` is negative or `sigma` is not positive,
26+
* or if `p` is not in [0, 1].
27+
*/
28+
template <typename T_p, typename T_nu, typename T_mu, typename T_sigma,
29+
require_all_stan_scalar_t<T_p, T_nu, T_mu, T_sigma>* = nullptr,
30+
require_all_arithmetic_t<T_p, T_nu, T_mu, T_sigma>* = nullptr>
31+
inline double student_t_qf(const T_p& p, const T_nu& nu, const T_mu& mu,
32+
const T_sigma& sigma) {
33+
static constexpr const char* function = "student_t_qf";
34+
check_nonnegative(function, "Degrees of freedom parameter", nu);
35+
check_positive(function, "Scale parameter", sigma);
36+
check_bounded(function, "Probability parameter", p, 0.0, 1.0);
37+
38+
if (p == 0.0) {
39+
return NEGATIVE_INFTY;
40+
} else if (p == 1.0) {
41+
return INFTY;
42+
} else if (p == 0.5) {
43+
return mu;
44+
}
45+
46+
const double p_val_flip = p < 0.5 ? p : 1.0 - p;
47+
const double p_sign = p < 0.5 ? -1.0 : 1.0;
48+
const auto ibeta_arg = inv_inc_beta(0.5 * nu, 0.5, 2 * p_val_flip);
49+
50+
return mu + p_sign * sigma * sqrt(nu) * sqrt(-1.0 + 1.0 / ibeta_arg);
51+
}
52+
53+
/**
54+
* A vectorized version of the Student's t quantile function that accepts
55+
* std::vectors, Eigen Matrix/Array objects, or expressions, and containers of
56+
* these.
57+
*
58+
* @tparam T_p type of the probability parameter
59+
* @tparam T_nu type of the degrees of freedom parameter
60+
* @tparam T_mu type of the location parameter
61+
* @tparam T_sigma type of the scale parameter
62+
* @tparam T_container type of the container to hold results
63+
* @param p Probability in the range [0, 1].
64+
* @param nu Degrees of freedom, must be non-negative.
65+
* @param mu Location parameter.
66+
* @param sigma Scale parameter, must be positive.
67+
* @return Container with quantile function values for each input.
68+
*/
69+
template <typename T_p, typename T_nu, typename T_mu, typename T_sigma,
70+
require_any_vector_t<T_p, T_nu, T_mu, T_sigma>* = nullptr>
71+
inline auto student_t_qf(const T_p& p, const T_nu& nu, const T_mu& mu,
72+
const T_sigma& sigma) {
73+
using T_container = common_container_t<T_p, T_nu, T_mu, T_sigma>;
74+
static constexpr const char* function = "student_t_qf";
75+
const size_t max_size_all = max_size(p, nu, mu, sigma);
76+
T_container result(max_size_all);
77+
78+
ref_type_t<T_p> p_ref = p;
79+
ref_type_t<T_nu> nu_ref = nu;
80+
ref_type_t<T_mu> mu_ref = mu;
81+
ref_type_t<T_sigma> sigma_ref = sigma;
82+
83+
scalar_seq_view<ref_type_t<T_p>> p_vec(p_ref);
84+
scalar_seq_view<ref_type_t<T_nu>> nu_vec(nu_ref);
85+
scalar_seq_view<ref_type_t<T_mu>> mu_vec(mu_ref);
86+
scalar_seq_view<ref_type_t<T_sigma>> sigma_vec(sigma_ref);
87+
88+
for (size_t i = 0; i < max_size_all; ++i) {
89+
result[i] = student_t_qf(p_vec[i], nu_vec[i], mu_vec[i], sigma_vec[i]);
90+
}
91+
92+
return result;
93+
}
94+
95+
} // namespace math
96+
} // namespace stan
97+
98+
#endif

stan/math/rev/fun/hypergeometric_pFq.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ namespace math {
2323
* @return Generalized hypergeometric function
2424
*/
2525
template <typename Ta, typename Tb, typename Tz,
26-
bool grad_a = is_autodiff_v<Ta>, bool grad_b = is_autodiff_v<Tb>,
27-
bool grad_z = is_autodiff_v<Tz>,
2826
require_all_vector_t<Ta, Tb>* = nullptr,
2927
require_return_type_t<is_var, Ta, Tb, Tz>* = nullptr>
3028
inline var hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
31-
auto&& arena_a = to_arena(as_column_vector_or_scalar(std::forward<Ta>(a)));
32-
auto&& arena_b = to_arena(as_column_vector_or_scalar(std::forward<Tb>(b)));
33-
auto pfq_val = hypergeometric_pFq(arena_a.val(), arena_b.val(), value_of(z));
29+
constexpr bool grad_a = is_autodiff_v<Ta>;
30+
constexpr bool grad_b = is_autodiff_v<Tb>;
31+
constexpr bool grad_z = is_autodiff_v<Tz>;
32+
auto arena_a = to_arena(as_column_vector_or_scalar(std::forward<Ta>(a)));
33+
auto arena_b = to_arena(as_column_vector_or_scalar(std::forward<Tb>(b)));
34+
auto pfq_val
35+
= hypergeometric_pFq(value_of(arena_a), value_of(arena_b), value_of(z));
3436
return make_callback_var(
3537
pfq_val, [arena_a, arena_b, z, pfq_val](auto& vi) mutable {
3638
auto grad_tuple = grad_pFq<grad_a, grad_b, grad_z>(

0 commit comments

Comments
 (0)