Skip to content

Commit 4610023

Browse files
authored
Robustify TCPServer implementation for thread safety and windows support (#458)
This PR hardens the TCPServer implementation in the following points: * Add thread safety between connect / disconnect, write / read and shutdown events. * Improved windows-specific error case handling
1 parent 0922712 commit 4610023

7 files changed

Lines changed: 497 additions & 68 deletions

File tree

include/ur_client_library/comm/socket_t.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#pragma once
1818

19+
#include <system_error>
1920
#ifdef _WIN32
2021

2122
# define NOMINMAX
@@ -33,6 +34,7 @@
3334

3435
typedef SOCKET socket_t;
3536
typedef SSIZE_T ssize_t;
37+
typedef int socklen_t;
3638

3739
static inline int ur_setsockopt(socket_t s, int level, int optname, const void* optval, unsigned int optlen)
3840
{
@@ -64,3 +66,29 @@ typedef int socket_t;
6466
# define ur_close close
6567

6668
#endif // _WIN32
69+
70+
#ifndef MSG_NOSIGNAL
71+
# define MSG_NOSIGNAL 0
72+
#endif
73+
74+
/*!
75+
* \brief Get the last socket error as an std::error_code
76+
*
77+
* On Windows, this will use WSAGetLastError and the system category, while on other platforms it
78+
* will use errno and the generic category.
79+
*
80+
* \return The last socket error
81+
*/
82+
inline std::error_code getLastSocketErrorCode()
83+
{
84+
#ifdef _WIN32
85+
return std::error_code(WSAGetLastError(), std::system_category());
86+
#else
87+
return std::error_code(errno, std::generic_category());
88+
#endif
89+
}
90+
91+
inline std::system_error makeSocketError(const std::string& message)
92+
{
93+
return std::system_error(getLastSocketErrorCode(), message);
94+
}

include/ur_client_library/comm/tcp_server.h

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <cstddef>
3535
#include <cstdint>
3636
#include <functional>
37+
#include <mutex>
3738
#include <thread>
3839
#include <vector>
3940

@@ -79,9 +80,13 @@ class TCPServer
7980
*
8081
* \param func Function handling the event information. The file descriptor created by the
8182
* connection event will be passed to the function.
83+
*
84+
* \note: The connection callback will be triggered with the socket being accepted. Hence, it
85+
* is possible to send data from the connection callback directly.
8286
*/
8387
void setConnectCallback(std::function<void(const socket_t)> func)
8488
{
89+
std::lock_guard<std::mutex> lk(callback_mutex_);
8590
new_connection_callback_ = func;
8691
}
8792

@@ -90,9 +95,13 @@ class TCPServer
9095
*
9196
* \param func Function handling the event information. The file descriptor created by the
9297
* connection event will be passed to the function.
98+
*
99+
* \note: The socket will already be closed when the disconnect callback is triggered, thus
100+
* trying to interact with the socket from the disconnect callback will fail.
93101
*/
94102
void setDisconnectCallback(std::function<void(const socket_t)> func)
95103
{
104+
std::lock_guard<std::mutex> lk(callback_mutex_);
96105
disconnect_callback_ = func;
97106
}
98107

@@ -104,6 +113,7 @@ class TCPServer
104113
*/
105114
void setMessageCallback(std::function<void(const socket_t, char*, int)> func)
106115
{
116+
std::lock_guard<std::mutex> lk(message_mutex_);
107117
message_callback_ = func;
108118
}
109119

@@ -116,9 +126,11 @@ class TCPServer
116126
void start();
117127

118128
/*!
119-
* \brief Shut down the event listener thread. After calling this, no events will be handled
120-
* anymore, but the socket will remain open and bound to the port. Call start() in order to
121-
* restart event handling.
129+
* \brief Shutdown the server and close all client connections.
130+
*
131+
* \note: This should not be called from within any of the registered callback functions, as
132+
* it will cause a deadlock. If you want to shutdown the server from a callback, you can e.g.
133+
* start a new thread that calls shutdown() from there.
122134
*/
123135
void shutdown();
124136

@@ -135,6 +147,21 @@ class TCPServer
135147
*/
136148
bool write(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written);
137149

150+
/*!
151+
* \brief Writes to a filedescriptor without verifying that it is a client or even a valid
152+
* filedescriptor. It is the caller's responsibility to ensure that the filedescriptor is valid
153+
* and belongs to a client.
154+
*
155+
* \param[in] fd File descriptor belonging to the client the data should be sent to. The file
156+
* descriptor will be given from the connection callback.
157+
* \param[in] buf Buffer of bytes to write
158+
* \param[in] buf_len Number of bytes in the buffer
159+
* \param[out] written Number of bytes actually written
160+
*
161+
* \returns True on success, false otherwise
162+
*/
163+
bool writeUnchecked(const socket_t fd, const uint8_t* buf, const size_t buf_len, size_t& written);
164+
138165
/*!
139166
* \brief Get the maximum number of clients allowed to connect to this server
140167
*
@@ -182,15 +209,15 @@ class TCPServer
182209
void handleDisconnect(const socket_t fd);
183210

184211
//! read data from socket
185-
void readData(const socket_t fd);
212+
bool readData(const socket_t fd);
186213

187214
//! Event handler. Blocks until activity on any client or connection attempt
188215
void spin();
189216

190217
//! Runs spin() as long as keep_running_ is set to true.
191218
void worker();
192219

193-
std::atomic<bool> keep_running_;
220+
std::atomic<bool> keep_running_{ false };
194221
std::thread worker_thread_;
195222

196223
std::atomic<socket_t> listen_fd_;
@@ -202,6 +229,10 @@ class TCPServer
202229

203230
uint32_t max_clients_allowed_;
204231
std::vector<socket_t> client_fds_;
232+
std::mutex clients_mutex_;
233+
std::mutex message_mutex_;
234+
std::mutex listen_fd_mutex_;
235+
std::mutex callback_mutex_;
205236

206237
static const int INPUT_BUFFER_SIZE = 4096;
207238
char input_buffer_[INPUT_BUFFER_SIZE];

include/ur_client_library/log.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,4 @@ void setLogLevel(LogLevel level);
8888
* \param fmt Format string
8989
*/
9090
void log(const char* file, int line, LogLevel level, const char* fmt, ...);
91-
9291
} // namespace urcl

0 commit comments

Comments
 (0)