diff --git a/include/dpp/socketengine.h b/include/dpp/socketengine.h index 9579bbd5a7..b7f831bbb2 100644 --- a/include/dpp/socketengine.h +++ b/include/dpp/socketengine.h @@ -230,7 +230,7 @@ struct DPP_EXPORT socket_engine_base { /** * @brief Default destructor */ - virtual ~socket_engine_base() = default; + virtual ~socket_engine_base(); /** * @brief Should be called repeatedly in a loop. diff --git a/src/dpp/socketengine.cpp b/src/dpp/socketengine.cpp index a16f8fb84a..60a1a77c8e 100644 --- a/src/dpp/socketengine.cpp +++ b/src/dpp/socketengine.cpp @@ -75,6 +75,12 @@ socket_engine_base::socket_engine_base(cluster* creator) : owner(creator) { #endif } +socket_engine_base::~socket_engine_base() { +#ifdef _WIN32 + WSACleanup(); +#endif +} + time_t last_time = time(nullptr); socket_events* socket_engine_base::get_fd(dpp::socket fd) { diff --git a/src/dpp/socketengines/poll.cpp b/src/dpp/socketengines/poll.cpp index 763e5d304b..2256712c32 100644 --- a/src/dpp/socketengines/poll.cpp +++ b/src/dpp/socketengines/poll.cpp @@ -23,9 +23,12 @@ #include #include #include +#include +#include #include #include #include +#include namespace dpp { @@ -45,6 +48,8 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { const int poll_delay = 1000; prune(); + /* Save count of tracked sockets while mutex is held, just in case */ + size_t fd_count = 0; { std::shared_lock lock(poll_set_mutex); if (poll_set.empty()) { @@ -55,6 +60,7 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { if (poll_set.size() > FD_SETSIZE) { throw dpp::connection_exception("poll() does not support more than FD_SETSIZE active sockets at once!"); } + fd_count = poll_set.size(); /** * We must make a copy of the poll_set, because it would cause thread locking/contention * issues if we had it locked for read during poll/iteration of the returned set. @@ -63,17 +69,24 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { } } - int i = dpp::compat::poll(out_set, static_cast(poll_set.size()), poll_delay); + int i = dpp::compat::poll(out_set, static_cast(fd_count), poll_delay); int processed = 0; - for (size_t index = 0; index < poll_set.size() && processed < i; index++) { - const int fd = out_set[index].fd; + for (size_t index = 0; index < fd_count && processed < i; index++) { + const dpp::socket fd = out_set[index].fd; const short revents = out_set[index].revents; if (revents > 0) { processed++; } + if (fd == wake_read.fd) { + if ((revents & POLLIN) != 0) { + drain_wakeup(); + } + continue; + } + socket_events *eh = get_fd(fd); if (eh == nullptr) { continue; @@ -123,12 +136,6 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { } } -#if _WIN32 - ~socket_engine_poll() override { - WSACleanup(); - } -#endif - bool register_socket(const socket_events& e) final { bool r = socket_engine_base::register_socket(e); if (r) { @@ -144,6 +151,7 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { } poll_set.push_back(fd_info); } + refresh_poll(); return r; } @@ -166,15 +174,21 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { break; } } + refresh_poll(); return r; } explicit socket_engine_poll(cluster* creator) : socket_engine_base(creator) { stats.engine_type = "poll"; + init_wakeup_socket(); }; protected: + /* Poll wakeup mechanism: UDP loopback socket pair */ + dpp::raii_socket wake_read{dpp::rst_udp}; + dpp::raii_socket wake_write{dpp::rst_udp}; + bool remove_socket(dpp::socket fd) final { std::unique_lock lock(poll_set_mutex); for (auto i = poll_set.begin(); i != poll_set.end(); ++i) { @@ -185,11 +199,70 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base { event.fd = fd; owner->on_socket_close.call(event); } + refresh_poll(); return true; } } return false; } + + void init_wakeup_socket() { + if (!wake_read.bind(dpp::address_t("127.0.0.1", 0))) { + throw dpp::connection_exception("Failed to bind refresh_poll read socket"); + } + + if (!set_nonblocking(wake_read.fd, true)) { + throw dpp::connection_exception("Failed to set refresh_poll read socket non-blocking"); + } + + dpp::address_t tmp; + uint16_t port = tmp.get_port(wake_read.fd); + dpp::address_t dest("127.0.0.1", port); + if (::connect(wake_write.fd, dest.get_socket_address(), static_cast(dest.size())) != 0) { + throw dpp::connection_exception("Failed to connect refresh_poll write socket"); + } + + { + std::unique_lock lock(poll_set_mutex); + pollfd fd_info{}; + fd_info.fd = wake_read.fd; + fd_info.events = POLLIN; + poll_set.push_back(fd_info); + } + } + + void drain_wakeup() { + char buf[256]; + while (true) { +#if _WIN32 + int r = ::recv(wake_read.fd, buf, sizeof(buf), 0); + if (r <= 0) { + int e = WSAGetLastError(); + if (e == WSAEWOULDBLOCK || e == WSAEINTR) { + break; + } + break; + } +#else + ssize_t r = ::recv(wake_read.fd, buf, sizeof(buf), MSG_DONTWAIT); + if (r < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; + } + break; + } + if (r == 0) { + break; + } +#endif + } + } + + void refresh_poll() const { + if (wake_write.fd == INVALID_SOCKET) return; + static const char one = 1; + (void)::send(wake_write.fd, &one, 1, 0); + } }; DPP_EXPORT std::unique_ptr create_socket_engine(cluster* creator) {