55#include < stan/math/prim/functor/operands_and_partials.hpp>
66#include < stan/math/prim/meta.hpp>
77#include < stan/math/rev/core/var.hpp>
8+ #include < stan/math/opencl/rev/arena_matrix_cl.hpp>
89#include < stan/math/opencl/kernel_generator.hpp>
910#include < stan/math/opencl/rev/arena_type.hpp>
1011#include < stan/math/opencl/rev/to_arena.hpp>
@@ -17,7 +18,7 @@ template <typename Op>
1718class ops_partials_edge <double , var_value<Op>,
1819 require_kernel_expression_lhs_t <Op>> {
1920 public:
20- using partials_t = plain_type_t <Op >;
21+ using partials_t = arena_matrix_cl< value_type_t <Op> >;
2122 partials_t partials_; // For univariate use-cases
2223 broadcast_array<partials_t > partials_vec_; // For multivariate
2324 explicit ops_partials_edge (const var_value<Op>& ops)
@@ -28,17 +29,10 @@ class ops_partials_edge<double, var_value<Op>,
2829 private:
2930 template <typename , typename , typename , typename , typename , typename >
3031 friend class stan ::math::operands_and_partials;
31- const var_value<Op>& operands_;
32-
33- void dump_operands (vari** varis) {}
34- void dump_partials (double * partials) {}
35- int size () { return 0 ; }
36- std::tuple<var_value<Op>> container_operands () {
37- return std::make_tuple (operands_);
38- }
39- std::tuple<partials_t > container_partials () {
40- return std::make_tuple (partials_);
41- }
32+ var_value<Op> operands_;
33+ static constexpr int size () noexcept { return 0 ; }
34+ inline auto & operand () noexcept { return this ->operands_ ; }
35+ inline auto & partial () noexcept { return this ->partials_ ; }
4236};
4337
4438} // namespace internal
0 commit comments