|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <mysql.h> |
| 4 | + |
| 5 | +#include <cstring> |
| 6 | +#include <expected> |
| 7 | +#include <memory> |
| 8 | +#include <optional> |
| 9 | +#include <string> |
| 10 | +#include <tuple> |
| 11 | +#include <vector> |
| 12 | + |
| 13 | +#include "ds_mysql/prepared_statement.hpp" |
| 14 | + |
| 15 | +namespace ds_mysql { |
| 16 | + |
| 17 | +// =================================================================== |
| 18 | +// server_cursor — streaming result set via MySQL server-side cursors. |
| 19 | +// |
| 20 | +// Unlike prepared_statement::query() which loads all rows into memory, |
| 21 | +// a server_cursor fetches rows one-at-a-time (or in prefetch batches) |
| 22 | +// from the server. Ideal for large result sets. |
| 23 | +// |
| 24 | +// Created via mysql_connection::open_cursor(): |
| 25 | +// auto cursor = conn.open_cursor<RowType>("SELECT * FROM t WHERE x = ?", 42u); |
| 26 | +// while (auto row = cursor->fetch()) { |
| 27 | +// process(*row); |
| 28 | +// } |
| 29 | +// |
| 30 | +// Prefetch hint (number of rows to buffer at once): |
| 31 | +// auto cursor = conn.open_cursor<RowType>(sql, prefetch_rows{100}, 42u); |
| 32 | +// =================================================================== |
| 33 | + |
| 34 | +struct prefetch_rows { |
| 35 | + unsigned long value = 1; |
| 36 | +}; |
| 37 | + |
| 38 | +template <typename RowType> |
| 39 | +class server_cursor { |
| 40 | +public: |
| 41 | + server_cursor(server_cursor const&) = delete; |
| 42 | + server_cursor& operator=(server_cursor const&) = delete; |
| 43 | + |
| 44 | + server_cursor(server_cursor&& other) noexcept |
| 45 | + : stmt_(std::move(other.stmt_)), |
| 46 | + result_binds_(std::move(other.result_binds_)), |
| 47 | + lengths_(std::move(other.lengths_)), |
| 48 | + nulls_(std::move(other.nulls_)), |
| 49 | + errors_(std::move(other.errors_)), |
| 50 | + string_bufs_(std::move(other.string_bufs_)), |
| 51 | + done_(other.done_) { |
| 52 | + other.done_ = true; |
| 53 | + } |
| 54 | + |
| 55 | + server_cursor& operator=(server_cursor&& other) noexcept { |
| 56 | + if (this != &other) { |
| 57 | + close(); |
| 58 | + stmt_ = std::move(other.stmt_); |
| 59 | + result_binds_ = std::move(other.result_binds_); |
| 60 | + lengths_ = std::move(other.lengths_); |
| 61 | + nulls_ = std::move(other.nulls_); |
| 62 | + errors_ = std::move(other.errors_); |
| 63 | + string_bufs_ = std::move(other.string_bufs_); |
| 64 | + done_ = other.done_; |
| 65 | + other.done_ = true; |
| 66 | + } |
| 67 | + return *this; |
| 68 | + } |
| 69 | + |
| 70 | + ~server_cursor() { |
| 71 | + close(); |
| 72 | + } |
| 73 | + |
| 74 | + /// Fetch the next row. Returns std::nullopt when no more rows. |
| 75 | + [[nodiscard]] std::expected<std::optional<RowType>, std::string> fetch() { |
| 76 | + if (done_) { |
| 77 | + return std::optional<RowType>{std::nullopt}; |
| 78 | + } |
| 79 | + |
| 80 | + int const status = mysql_stmt_fetch(stmt_.get()); |
| 81 | + if (status == MYSQL_NO_DATA) { |
| 82 | + done_ = true; |
| 83 | + return std::optional<RowType>{std::nullopt}; |
| 84 | + } |
| 85 | + if (status != 0 && status != MYSQL_DATA_TRUNCATED) { |
| 86 | + return std::unexpected(std::string(mysql_stmt_error(stmt_.get()))); |
| 87 | + } |
| 88 | + |
| 89 | + return extract_row(std::make_index_sequence<std::tuple_size_v<RowType>>{}, status == MYSQL_DATA_TRUNCATED); |
| 90 | + } |
| 91 | + |
| 92 | + /// True when all rows have been fetched. |
| 93 | + [[nodiscard]] bool done() const noexcept { |
| 94 | + return done_; |
| 95 | + } |
| 96 | + |
| 97 | +private: |
| 98 | + friend class mysql_connection; |
| 99 | + |
| 100 | + struct stmt_deleter { |
| 101 | + void operator()(MYSQL_STMT* s) const noexcept { |
| 102 | + if (s) |
| 103 | + mysql_stmt_close(s); |
| 104 | + } |
| 105 | + }; |
| 106 | + |
| 107 | + static constexpr auto N = std::tuple_size_v<RowType>; |
| 108 | + static constexpr std::size_t initial_string_buf_size = 256; |
| 109 | + |
| 110 | + explicit server_cursor(std::unique_ptr<MYSQL_STMT, stmt_deleter> stmt) |
| 111 | + : stmt_(std::move(stmt)), |
| 112 | + result_binds_(N), |
| 113 | + lengths_(N), |
| 114 | + nulls_(std::make_unique<bool[]>(N)), |
| 115 | + errors_(std::make_unique<bool[]>(N)), |
| 116 | + string_bufs_(N) { |
| 117 | + std::fill_n(nulls_.get(), N, false); |
| 118 | + std::fill_n(errors_.get(), N, false); |
| 119 | + } |
| 120 | + |
| 121 | + void close() { |
| 122 | + if (stmt_) { |
| 123 | + mysql_stmt_free_result(stmt_.get()); |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + std::expected<void, std::string> bind_results() { |
| 128 | + [this]<std::size_t... Is>(std::index_sequence<Is...>) { |
| 129 | + (setup_bind<std::tuple_element_t<Is, RowType>>(Is), ...); |
| 130 | + }(std::make_index_sequence<N>{}); |
| 131 | + |
| 132 | + if (mysql_stmt_bind_result(stmt_.get(), result_binds_.data())) { |
| 133 | + return std::unexpected(std::string(mysql_stmt_error(stmt_.get()))); |
| 134 | + } |
| 135 | + return {}; |
| 136 | + } |
| 137 | + |
| 138 | + template <typename T> |
| 139 | + void setup_bind(std::size_t idx) { |
| 140 | + auto& bind = result_binds_[idx]; |
| 141 | + std::memset(&bind, 0, sizeof(MYSQL_BIND)); |
| 142 | + bind.length = &lengths_[idx]; |
| 143 | + bind.is_null = &nulls_.get()[idx]; |
| 144 | + bind.error = &errors_.get()[idx]; |
| 145 | + |
| 146 | + using raw = stmt_detail::unwrap_param_type_t<T>; |
| 147 | + if constexpr (std::same_as<raw, std::string>) { |
| 148 | + string_bufs_[idx].resize(initial_string_buf_size); |
| 149 | + bind.buffer_type = MYSQL_TYPE_STRING; |
| 150 | + bind.buffer = string_bufs_[idx].data(); |
| 151 | + bind.buffer_length = static_cast<unsigned long>(string_bufs_[idx].size()); |
| 152 | + } else { |
| 153 | + bind.buffer_type = stmt_detail::mysql_type_traits<raw>::field_type; |
| 154 | + bind.is_unsigned = stmt_detail::mysql_type_traits<raw>::is_unsigned; |
| 155 | + string_bufs_[idx].resize(sizeof(raw)); |
| 156 | + bind.buffer = string_bufs_[idx].data(); |
| 157 | + bind.buffer_length = sizeof(raw); |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + template <std::size_t... Is> |
| 162 | + std::expected<std::optional<RowType>, std::string> extract_row(std::index_sequence<Is...>, bool truncated) { |
| 163 | + if (truncated) { |
| 164 | + (refetch_truncated<std::tuple_element_t<Is, RowType>>(Is), ...); |
| 165 | + } |
| 166 | + return std::optional<RowType>{ |
| 167 | + RowType{extract_value<std::tuple_element_t<Is, RowType>>(Is)...}}; |
| 168 | + } |
| 169 | + |
| 170 | + template <typename T> |
| 171 | + void refetch_truncated(std::size_t idx) { |
| 172 | + using raw = stmt_detail::unwrap_param_type_t<T>; |
| 173 | + if constexpr (std::same_as<raw, std::string>) { |
| 174 | + if (lengths_[idx] > result_binds_[idx].buffer_length) { |
| 175 | + string_bufs_[idx].resize(lengths_[idx]); |
| 176 | + result_binds_[idx].buffer = string_bufs_[idx].data(); |
| 177 | + result_binds_[idx].buffer_length = static_cast<unsigned long>(string_bufs_[idx].size()); |
| 178 | + mysql_stmt_fetch_column(stmt_.get(), &result_binds_[idx], static_cast<unsigned int>(idx), 0); |
| 179 | + } |
| 180 | + } |
| 181 | + } |
| 182 | + |
| 183 | + template <typename T> |
| 184 | + T extract_value(std::size_t idx) const { |
| 185 | + if constexpr (is_optional_v<T>) { |
| 186 | + if (nulls_.get()[idx]) |
| 187 | + return std::nullopt; |
| 188 | + using inner = unwrap_optional_t<T>; |
| 189 | + return extract_value<inner>(idx); |
| 190 | + } else if constexpr (std::same_as<stmt_detail::unwrap_param_type_t<T>, std::string>) { |
| 191 | + return std::string(string_bufs_[idx].data(), lengths_[idx]); |
| 192 | + } else { |
| 193 | + using raw = stmt_detail::unwrap_param_type_t<T>; |
| 194 | + raw val{}; |
| 195 | + std::memcpy(&val, string_bufs_[idx].data(), sizeof(raw)); |
| 196 | + return static_cast<T>(val); |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + std::unique_ptr<MYSQL_STMT, stmt_deleter> stmt_; |
| 201 | + std::vector<MYSQL_BIND> result_binds_; |
| 202 | + std::vector<unsigned long> lengths_; |
| 203 | + std::unique_ptr<bool[]> nulls_; |
| 204 | + std::unique_ptr<bool[]> errors_; |
| 205 | + std::vector<std::vector<char>> string_bufs_; |
| 206 | + bool done_ = false; |
| 207 | +}; |
| 208 | + |
| 209 | +} // namespace ds_mysql |
0 commit comments