3434
3535#include < sstream>
3636#include < cstring>
37+ #include " ur_client_library/comm/socket_t.h"
3738#include < fcntl.h>
3839
3940namespace urcl
@@ -80,11 +81,23 @@ void TCPServer::init()
8081
8182void TCPServer::shutdown ()
8283{
84+ std::unique_lock<std::mutex> listen_lk (listen_fd_mutex_, std::try_to_lock);
8385 if (listen_fd_ == INVALID_SOCKET )
8486 {
85- URCL_LOG_DEBUG ( " TCPServer is already shut down ." );
87+ URCL_LOG_INFO ( " Listen FD already closed by another thread. Nothing to do here ." );
8688 return ;
8789 }
90+ if (!listen_lk.owns_lock ())
91+ {
92+ URCL_LOG_WARN (" Could not acquire lock for listen FD when shutting down. Is there another thread shutting the "
93+ " server down already? Waiting for lock to be released." );
94+ listen_lk.lock ();
95+ if (listen_fd_ == INVALID_SOCKET )
96+ {
97+ URCL_LOG_INFO (" Listen FD already closed by another thread. Nothing to do here." );
98+ return ;
99+ }
100+ }
88101 keep_running_ = false ;
89102
90103 socket_t shutdown_socket = ::socket (AF_INET , SOCK_STREAM , 0 );
@@ -147,9 +160,9 @@ void TCPServer::bind(const size_t max_num_tries, const std::chrono::milliseconds
147160 err = ::bind (listen_fd_, (struct sockaddr *)&server_addr, sizeof (server_addr));
148161 if (err == -1 )
149162 {
163+ auto error_code = getLastSocketErrorCode ();
150164 std::ostringstream ss;
151- ss << " Failed to bind socket for port " << port_
152- << " to address. Reason: " << strerrorPortable (getLastSocketError ());
165+ ss << " Failed to bind socket for port " << port_ << " to address. Reason: " << error_code.message ();
153166
154167 if (connection_counter++ < max_num_tries || max_num_tries == 0 )
155168 {
@@ -160,7 +173,7 @@ void TCPServer::bind(const size_t max_num_tries, const std::chrono::milliseconds
160173 }
161174 else
162175 {
163- throw makeSocketError ( ss.str ());
176+ throw std::system_error (error_code, ss.str ());
164177 }
165178 }
166179 } while (err == -1 && (connection_counter <= max_num_tries || max_num_tries == 0 ));
@@ -185,7 +198,7 @@ void TCPServer::startListen()
185198 if (getsockname (listen_fd_, (struct sockaddr *)&sin, &len) == -1 )
186199 {
187200 URCL_LOG_ERROR (" getsockname() failed to get port number for listening socket: %s" ,
188- strerrorPortable ( getLastSocketError () ).c_str ());
201+ getLastSocketErrorCode (). message ( ).c_str ());
189202 }
190203
191204 else
@@ -202,12 +215,26 @@ void TCPServer::handleConnect()
202215 socket_t client_fd = accept (listen_fd_, (struct sockaddr *)&client_addr, &addrlen);
203216 if (client_fd == INVALID_SOCKET )
204217 {
205- std::ostringstream ss;
206- ss << " Failed to accept connection request on port " << port_ ;
207- throw makeSocketError (ss. str ()) ;
218+ URCL_LOG_ERROR ( " Failed to accept connection request on port %d. Reason: %s " , port_,
219+ getLastSocketErrorCode (). message (). c_str ()) ;
220+ return ;
208221 }
209222
210- if (client_fd >= FD_SETSIZE )
223+ auto set_size_exceeded = [this , &client_fd]() {
224+ #ifdef _WIN32
225+ (void )client_fd; // Avoid unused variable warning, since on Windows we only check the number of
226+ // clients, not the client FD itself.
227+ return client_fds_.size () >= FD_SETSIZE - 1 ; // -1 because listen_fd_ also occupies one
228+ // slot in masterfds_
229+ #else
230+ (void )this ; // Avoid unused variable warning, since on Unix-like systems we only check the
231+ // client FD itself, not the number of clients.
232+ return client_fd >= FD_SETSIZE ; // On Unix-like systems, the client FD itself must be less than
233+ // FD_SETSIZE, otherwise it cannot be added to the fd_set.
234+ #endif
235+ };
236+
237+ if (set_size_exceeded ())
211238 {
212239 URCL_LOG_ERROR (" Accepted client FD %d exceeds FD_SETSIZE (%d). Closing connection." , (int )client_fd, FD_SETSIZE );
213240 ur_close (client_fd);
@@ -267,14 +294,17 @@ void TCPServer::spin()
267294
268295 std::vector<socket_t > disconnected_clients;
269296
270- for (const auto & client_fd : client_fds_)
271297 {
272- if (FD_ISSET (client_fd, &tempfds_))
298+ std::lock_guard<std::mutex> lk (clients_mutex_);
299+ for (const auto & client_fd : client_fds_)
273300 {
274- URCL_LOG_DEBUG (" Activity on client FD %d" , (int )client_fd);
275- if (!readData (client_fd))
301+ if (FD_ISSET (client_fd, &tempfds_))
276302 {
277- disconnected_clients.push_back (client_fd);
303+ URCL_LOG_DEBUG (" Activity on client FD %d" , (int )client_fd);
304+ if (!readData (client_fd))
305+ {
306+ disconnected_clients.push_back (client_fd);
307+ }
278308 }
279309 }
280310 }
@@ -388,13 +418,20 @@ bool TCPServer::write(const socket_t fd, const uint8_t* buf, const size_t buf_le
388418 }
389419 }
390420
421+ // We don't use a lock around the send call here, since writing on a closed socket would raise
422+ // an error anyway, and the client FD is only removed from client_fds_ after the socket is
423+ // closed. Thus, even if the client gets disconnected right after the check, the send call will
424+ // just fail and return false, which is the expected behavior.
425+ return writeUnchecked (fd, buf, buf_len, written);
426+ }
427+
428+ bool TCPServer::writeUnchecked (const socket_t fd, const uint8_t * buf, const size_t buf_len, size_t & written)
429+ {
391430 size_t remaining = buf_len;
392431
393432 // handle partial sends
394433 while (written < buf_len)
395434 {
396- // We don't use a lock around the send call here, since writing on a closed socket would raise
397- // an error anyway.
398435 ssize_t sent =
399436 ::send (fd, reinterpret_cast <const char *>(buf + written), static_cast<socklen_t>(remaining), MSG_NOSIGNAL);
400437
0 commit comments