Skip to content

Commit d899108

Browse files
committed
Further improve thread safety
1 parent 9a12588 commit d899108

5 files changed

Lines changed: 318 additions & 43 deletions

File tree

include/ur_client_library/comm/socket_t.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,24 @@ typedef int socket_t;
7171
# define MSG_NOSIGNAL 0
7272
#endif
7373

74-
inline std::system_error makeSocketError(const std::string& message)
74+
/*!
75+
* \brief Get the last socket error as an std::error_code
76+
*
77+
* On Windows, this will use WSAGetLastError and the system category, while on other platforms it
78+
* will use errno and the generic category.
79+
*
80+
* \return The last socket error
81+
*/
82+
inline std::error_code getLastSocketErrorCode()
7583
{
7684
#ifdef _WIN32
77-
return std::system_error(std::error_code(WSAGetLastError(), std::system_category()), message);
85+
return std::error_code(WSAGetLastError(), std::system_category());
7886
#else
79-
return std::system_error(std::error_code(errno, std::generic_category()), message);
87+
return std::error_code(errno, std::generic_category());
8088
#endif
8189
}
8290

83-
inline int getLastSocketError()
91+
inline std::system_error makeSocketError(const std::string& message)
8492
{
85-
#ifdef _WIN32
86-
return WSAGetLastError();
87-
#else
88-
return errno;
89-
#endif
93+
return std::system_error(getLastSocketErrorCode(), message);
9094
}

include/ur_client_library/comm/tcp_server.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ class TCPServer
8282
* \note: For thread safety, this will block when there is an active callback around connection
8383
* (connection, disconnection, shutdown). Thus, trying to call setConnectCallback e.g. by a
8484
* handler function registered as disconnectCallback will result in a deadlock.
85+
*
86+
* \note: Calling write() from within the callback will cause a deadlock. If you want to write from the connection
87+
* callback, you can use writeUnchecked(), as the callback will be triggered from a thread-protected context.
8588
*/
8689
void setConnectCallback(std::function<void(const socket_t)> func)
8790
{
@@ -98,6 +101,9 @@ class TCPServer
98101
* \note: For thread safety, this will block when there is an active callback around connection
99102
* (connection, disconnection, shutdown). Thus, trying to call setDisconnectCallback e.g. by a
100103
* handler function registered as connectCallback will result in a deadlock.
104+
*
105+
* \note: The socket will already be closed when the disconnect callback is triggered, thus
106+
* trying to write to the socket from the disconnect callback will fail.
101107
*/
102108
void setDisconnectCallback(std::function<void(const socket_t)> func)
103109
{
@@ -113,6 +119,9 @@ class TCPServer
113119
*
114120
* \note: For thread safety, this will block when there is an active message callback. Thus, trying to call
115121
* setMessageCallback e.g. from a handler function registered as messageCallback will result in a deadlock.
122+
*
123+
* \note: Calling write() from within the callback will cause a deadlock. If you want to write from the message
124+
* callback, you can use writeUnchecked(), as the callback will be triggered from a thread-protected context.
116125
*/
117126
void setMessageCallback(std::function<void(const socket_t, char*, int)> func)
118127
{
@@ -129,9 +138,7 @@ class TCPServer
129138
void start();
130139

131140
/*!
132-
* \brief Shut down the event listener thread. After calling this, no events will be handled
133-
* anymore, but the socket will remain open and bound to the port. Call start() in order to
134-
* restart event handling.
141+
* \brief Shutdown the server and close all client connections.
135142
*/
136143
void shutdown();
137144

@@ -148,6 +155,21 @@ class TCPServer
148155
*/
149156
bool write(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written);
150157

158+
/*!
159+
* \brief Writes to a filedescriptor without verifying that it is a client or even a valid
160+
* filedescriptor. It is the caller's responsibility to ensure that the filedescriptor is valid
161+
* and belongs to a client.
162+
*
163+
* \param[in] fd File descriptor belonging to the client the data should be sent to. The file
164+
* descriptor will be given from the connection callback.
165+
* \param[in] buf Buffer of bytes to write
166+
* \param[in] buf_len Number of bytes in the buffer
167+
* \param[out] written Number of bytes actually written
168+
*
169+
* \returns True on success, false otherwise
170+
*/
171+
bool writeUnchecked(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written);
172+
151173
/*!
152174
* \brief Get the maximum number of clients allowed to connect to this server
153175
*
@@ -217,6 +239,7 @@ class TCPServer
217239
std::vector<socket_t> client_fds_;
218240
std::mutex clients_mutex_;
219241
std::mutex message_mutex_;
242+
std::mutex listen_fd_mutex_;
220243

221244
static const int INPUT_BUFFER_SIZE = 4096;
222245
char input_buffer_[INPUT_BUFFER_SIZE];

include/ur_client_library/log.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,4 @@ void setLogLevel(LogLevel level);
8989
* \param fmt Format string
9090
*/
9191
void log(const char* file, int line, LogLevel level, const char* fmt, ...);
92-
93-
/*!
94-
* \brief Cross-platform replacement for strerror.
95-
*
96-
* On MSVC, strerror triggers C4996 (deprecated). This function uses std::error_code instead,
97-
* which is portable and thread-safe.
98-
*
99-
* \param errnum Error number (on POSIX typically errno)
100-
* \returns Human-readable error message
101-
*/
102-
inline std::string strerrorPortable(int errnum)
103-
{
104-
return std::error_code(errnum, std::system_category()).message();
105-
}
106-
10792
} // namespace urcl

src/comm/tcp_server.cpp

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#include <sstream>
3636
#include <cstring>
37+
#include "ur_client_library/comm/socket_t.h"
3738
#include <fcntl.h>
3839

3940
namespace urcl
@@ -80,11 +81,23 @@ void TCPServer::init()
8081

8182
void TCPServer::shutdown()
8283
{
84+
std::unique_lock<std::mutex> listen_lk(listen_fd_mutex_, std::try_to_lock);
8385
if (listen_fd_ == INVALID_SOCKET)
8486
{
85-
URCL_LOG_DEBUG("TCPServer is already shut down.");
87+
URCL_LOG_INFO("Listen FD already closed by another thread. Nothing to do here.");
8688
return;
8789
}
90+
if (!listen_lk.owns_lock())
91+
{
92+
URCL_LOG_WARN("Could not acquire lock for listen FD when shutting down. Is there another thread shutting the "
93+
"server down already? Waiting for lock to be released.");
94+
listen_lk.lock();
95+
if (listen_fd_ == INVALID_SOCKET)
96+
{
97+
URCL_LOG_INFO("Listen FD already closed by another thread. Nothing to do here.");
98+
return;
99+
}
100+
}
88101
keep_running_ = false;
89102

90103
socket_t shutdown_socket = ::socket(AF_INET, SOCK_STREAM, 0);
@@ -147,9 +160,9 @@ void TCPServer::bind(const size_t max_num_tries, const std::chrono::milliseconds
147160
err = ::bind(listen_fd_, (struct sockaddr*)&server_addr, sizeof(server_addr));
148161
if (err == -1)
149162
{
163+
auto error_code = getLastSocketErrorCode();
150164
std::ostringstream ss;
151-
ss << "Failed to bind socket for port " << port_
152-
<< " to address. Reason: " << strerrorPortable(getLastSocketError());
165+
ss << "Failed to bind socket for port " << port_ << " to address. Reason: " << error_code.message();
153166

154167
if (connection_counter++ < max_num_tries || max_num_tries == 0)
155168
{
@@ -160,7 +173,7 @@ void TCPServer::bind(const size_t max_num_tries, const std::chrono::milliseconds
160173
}
161174
else
162175
{
163-
throw makeSocketError(ss.str());
176+
throw std::system_error(error_code, ss.str());
164177
}
165178
}
166179
} while (err == -1 && (connection_counter <= max_num_tries || max_num_tries == 0));
@@ -185,7 +198,7 @@ void TCPServer::startListen()
185198
if (getsockname(listen_fd_, (struct sockaddr*)&sin, &len) == -1)
186199
{
187200
URCL_LOG_ERROR("getsockname() failed to get port number for listening socket: %s",
188-
strerrorPortable(getLastSocketError()).c_str());
201+
getLastSocketErrorCode().message().c_str());
189202
}
190203

191204
else
@@ -202,12 +215,26 @@ void TCPServer::handleConnect()
202215
socket_t client_fd = accept(listen_fd_, (struct sockaddr*)&client_addr, &addrlen);
203216
if (client_fd == INVALID_SOCKET)
204217
{
205-
std::ostringstream ss;
206-
ss << "Failed to accept connection request on port " << port_;
207-
throw makeSocketError(ss.str());
218+
URCL_LOG_ERROR("Failed to accept connection request on port %d. Reason: %s", port_,
219+
getLastSocketErrorCode().message().c_str());
220+
return;
208221
}
209222

210-
if (client_fd >= FD_SETSIZE)
223+
auto set_size_exceeded = [this, &client_fd]() {
224+
#ifdef _WIN32
225+
(void)client_fd; // Avoid unused variable warning, since on Windows we only check the number of
226+
// clients, not the client FD itself.
227+
return client_fds_.size() >= FD_SETSIZE - 1; // -1 because listen_fd_ also occupies one
228+
// slot in masterfds_
229+
#else
230+
(void)this; // Avoid unused variable warning, since on Unix-like systems we only check the
231+
// client FD itself, not the number of clients.
232+
return client_fd >= FD_SETSIZE; // On Unix-like systems, the client FD itself must be less than
233+
// FD_SETSIZE, otherwise it cannot be added to the fd_set.
234+
#endif
235+
};
236+
237+
if (set_size_exceeded())
211238
{
212239
URCL_LOG_ERROR("Accepted client FD %d exceeds FD_SETSIZE (%d). Closing connection.", (int)client_fd, FD_SETSIZE);
213240
ur_close(client_fd);
@@ -267,14 +294,17 @@ void TCPServer::spin()
267294

268295
std::vector<socket_t> disconnected_clients;
269296

270-
for (const auto& client_fd : client_fds_)
271297
{
272-
if (FD_ISSET(client_fd, &tempfds_))
298+
std::lock_guard<std::mutex> lk(clients_mutex_);
299+
for (const auto& client_fd : client_fds_)
273300
{
274-
URCL_LOG_DEBUG("Activity on client FD %d", (int)client_fd);
275-
if (!readData(client_fd))
301+
if (FD_ISSET(client_fd, &tempfds_))
276302
{
277-
disconnected_clients.push_back(client_fd);
303+
URCL_LOG_DEBUG("Activity on client FD %d", (int)client_fd);
304+
if (!readData(client_fd))
305+
{
306+
disconnected_clients.push_back(client_fd);
307+
}
278308
}
279309
}
280310
}
@@ -388,13 +418,20 @@ bool TCPServer::write(const socket_t fd, const uint8_t* buf, const size_t buf_le
388418
}
389419
}
390420

421+
// We don't use a lock around the send call here, since writing on a closed socket would raise
422+
// an error anyway, and the client FD is only removed from client_fds_ after the socket is
423+
// closed. Thus, even if the client gets disconnected right after the check, the send call will
424+
// just fail and return false, which is the expected behavior.
425+
return writeUnchecked(fd, buf, buf_len, written);
426+
}
427+
428+
bool TCPServer::writeUnchecked(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written)
429+
{
391430
size_t remaining = buf_len;
392431

393432
// handle partial sends
394433
while (written < buf_len)
395434
{
396-
// We don't use a lock around the send call here, since writing on a closed socket would raise
397-
// an error anyway.
398435
ssize_t sent =
399436
::send(fd, reinterpret_cast<const char*>(buf + written), static_cast<socklen_t>(remaining), MSG_NOSIGNAL);
400437

0 commit comments

Comments
 (0)