Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/dpp/socketengine.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ struct DPP_EXPORT socket_engine_base {
/**
* @brief Default destructor
*/
virtual ~socket_engine_base() = default;
virtual ~socket_engine_base();

/**
* @brief Should be called repeatedly in a loop.
Expand Down
6 changes: 6 additions & 0 deletions src/dpp/socketengine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ socket_engine_base::socket_engine_base(cluster* creator) : owner(creator) {
#endif
}

socket_engine_base::~socket_engine_base() {
#ifdef _WIN32
WSACleanup();
#endif
}

time_t last_time = time(nullptr);

socket_events* socket_engine_base::get_fd(dpp::socket fd) {
Expand Down
91 changes: 82 additions & 9 deletions src/dpp/socketengines/poll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
#include <dpp/compat.h>
#include <dpp/socketengine.h>
#include <dpp/exception.h>
#include <dpp/socket.h>
#include <dpp/sslconnection.h>
#include <vector>
#include <shared_mutex>
#include <memory>
#include <cerrno>

namespace dpp {

Expand All @@ -45,6 +48,8 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base {
const int poll_delay = 1000;

prune();
/* Save count of tracked sockets while mutex is held, just in case */
size_t fd_count = 0;
{
std::shared_lock lock(poll_set_mutex);
if (poll_set.empty()) {
Expand All @@ -55,6 +60,7 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base {
if (poll_set.size() > FD_SETSIZE) {
throw dpp::connection_exception("poll() does not support more than FD_SETSIZE active sockets at once!");
}
fd_count = poll_set.size();
/**
* We must make a copy of the poll_set, because it would cause thread locking/contention
* issues if we had it locked for read during poll/iteration of the returned set.
Expand All @@ -63,17 +69,24 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base {
}
}

int i = dpp::compat::poll(out_set, static_cast<unsigned int>(poll_set.size()), poll_delay);
int i = dpp::compat::poll(out_set, static_cast<unsigned int>(fd_count), poll_delay);
int processed = 0;

for (size_t index = 0; index < poll_set.size() && processed < i; index++) {
const int fd = out_set[index].fd;
for (size_t index = 0; index < fd_count && processed < i; index++) {
const dpp::socket fd = out_set[index].fd;
const short revents = out_set[index].revents;

if (revents > 0) {
processed++;
}

if (fd == wake_read.fd) {
if ((revents & POLLIN) != 0) {
drain_wakeup();
}
continue;
}

socket_events *eh = get_fd(fd);
if (eh == nullptr) {
continue;
Expand Down Expand Up @@ -123,12 +136,6 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base {
}
}

#if _WIN32
~socket_engine_poll() override {
WSACleanup();
}
#endif

bool register_socket(const socket_events& e) final {
bool r = socket_engine_base::register_socket(e);
if (r) {
Expand All @@ -144,6 +151,7 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base {
}
poll_set.push_back(fd_info);
}
refresh_poll();
return r;
}

Expand All @@ -166,15 +174,21 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base {
break;
}
}
refresh_poll();
return r;
}

explicit socket_engine_poll(cluster* creator) : socket_engine_base(creator) {
stats.engine_type = "poll";
init_wakeup_socket();
};

protected:

/* Poll wakeup mechanism: UDP loopback socket pair */
dpp::raii_socket wake_read{dpp::rst_udp};
dpp::raii_socket wake_write{dpp::rst_udp};

bool remove_socket(dpp::socket fd) final {
std::unique_lock lock(poll_set_mutex);
for (auto i = poll_set.begin(); i != poll_set.end(); ++i) {
Expand All @@ -185,11 +199,70 @@ struct DPP_EXPORT socket_engine_poll : public socket_engine_base {
event.fd = fd;
owner->on_socket_close.call(event);
}
refresh_poll();
return true;
}
}
return false;
}

void init_wakeup_socket() {
if (!wake_read.bind(dpp::address_t("127.0.0.1", 0))) {
throw dpp::connection_exception("Failed to bind refresh_poll read socket");
}

if (!set_nonblocking(wake_read.fd, true)) {
throw dpp::connection_exception("Failed to set refresh_poll read socket non-blocking");
}

dpp::address_t tmp;
uint16_t port = tmp.get_port(wake_read.fd);
dpp::address_t dest("127.0.0.1", port);
if (::connect(wake_write.fd, dest.get_socket_address(), static_cast<int>(dest.size())) != 0) {
throw dpp::connection_exception("Failed to connect refresh_poll write socket");
}

{
std::unique_lock lock(poll_set_mutex);
pollfd fd_info{};
fd_info.fd = wake_read.fd;
fd_info.events = POLLIN;
poll_set.push_back(fd_info);
}
}

void drain_wakeup() {
char buf[256];
while (true) {
#if _WIN32
int r = ::recv(wake_read.fd, buf, sizeof(buf), 0);
if (r <= 0) {
int e = WSAGetLastError();
if (e == WSAEWOULDBLOCK || e == WSAEINTR) {
break;
}
break;
}
#else
ssize_t r = ::recv(wake_read.fd, buf, sizeof(buf), MSG_DONTWAIT);
if (r < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
break;
}
break;
}
if (r == 0) {
break;
}
#endif
}
}

void refresh_poll() const {
if (wake_write.fd == INVALID_SOCKET) return;
static const char one = 1;
(void)::send(wake_write.fd, &one, 1, 0);
}
};

DPP_EXPORT std::unique_ptr<socket_engine_base> create_socket_engine(cluster* creator) {
Expand Down
Loading