Skip to content

Commit 31323d4

Browse files
committed
Fix WebSocket/MQTT protocol bugs and Logger circular buffer/mutex issues
1 parent f9cec75 commit 31323d4

2 files changed

Lines changed: 112 additions & 31 deletions

File tree

Source/IO/Protocol.cpp

Lines changed: 103 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -698,9 +698,16 @@ namespace Protocol
698698
int rv = select(sock + 1, &readfds, nullptr, nullptr, &tv);
699699
if (rv < 0)
700700
{
701+
#ifdef _WIN32
702+
int sock_err = WSAGetLastError();
703+
if (sock_err == WSAEINTR)
704+
continue;
705+
Error() << "TLS select failed. Error code: " << sock_err;
706+
#else
701707
if (errno == EINTR)
702708
continue;
703-
perror("TLS select");
709+
Error() << "TLS select failed: " << strerror(errno);
710+
#endif
704711
if (stats)
705712
stats->bytes_in += BIO_number_read(SSL_get_rbio(ssl)) - before;
706713

@@ -959,7 +966,7 @@ namespace Protocol
959966
if (shift >= 28)
960967
return -1;
961968

962-
if (prev->read((char *)&b, 1) != 1)
969+
if (prev->read((char *)&b, 1, 5, true) != 1)
963970
return -1;
964971

965972
length |= (b & 127) << shift;
@@ -1071,7 +1078,8 @@ namespace Protocol
10711078

10721079
performHandshake();
10731080

1074-
ProtocolBase::onConnect();
1081+
if (connected)
1082+
ProtocolBase::onConnect();
10751083
}
10761084

10771085
void MQTT::onDisconnect()
@@ -1210,12 +1218,6 @@ namespace Protocol
12101218

12111219
case PacketType::PUBLISH:
12121220
{
1213-
if (data_len < length)
1214-
{
1215-
Warning() << "MQTT: Buffer too small for incoming message. Required: " << length << ", provided: " << data_len;
1216-
break;
1217-
}
1218-
12191221
// Validate minimum length for topic length field
12201222
if (length < 2)
12211223
{
@@ -1236,15 +1238,22 @@ namespace Protocol
12361238
return -1;
12371239
}
12381240

1241+
int payload_len = length - header_size;
1242+
if (data_len < payload_len)
1243+
{
1244+
Warning() << "MQTT: Buffer too small for incoming message. Required: " << payload_len << ", provided: " << data_len;
1245+
break;
1246+
}
1247+
12391248
if (q == 0)
12401249
{
1241-
data_returned = length - 2 - topic_len;
1250+
data_returned = payload_len;
12421251
memcpy(data, buffer.data() + i + 2 + topic_len, data_returned);
12431252
}
12441253
else if (q == 1 || q == 2)
12451254
{
12461255
uint16_t packet_id = (buffer[i + 2 + topic_len] << 8) + buffer[i + 2 + topic_len + 1];
1247-
data_returned = length - 4 - topic_len;
1256+
data_returned = payload_len;
12481257
memcpy(data, buffer.data() + i + 4 + topic_len, data_returned);
12491258

12501259
createPacket((q == 1) ? PacketType::PUBACK : PacketType::PUBREC, 0);
@@ -1390,16 +1399,51 @@ namespace Protocol
13901399
return false;
13911400
}
13921401

1393-
// Read the response
1394-
char response[2048];
1395-
int len = prev->read(response, sizeof(response), 5, true);
1396-
if (len <= 0)
1402+
// Read the response — headers may arrive across multiple TCP segments,
1403+
// so accumulate until we see the end-of-headers marker.
1404+
std::string response_str;
1405+
const size_t MAX_HEADER_SIZE = 16384;
1406+
const std::string HEADER_END = "\r\n\r\n";
1407+
size_t header_end_pos = std::string::npos;
1408+
1409+
while (response_str.size() < MAX_HEADER_SIZE)
13971410
{
1398-
Error() << "WebSocket: No response to handshake request.";
1411+
char chunk[2048];
1412+
int len = prev->read(chunk, sizeof(chunk), 5, true);
1413+
if (len <= 0)
1414+
{
1415+
Error() << "WebSocket: No response to handshake request.";
1416+
return false;
1417+
}
1418+
1419+
// Track scan start so we don't re-scan the entire string each iteration,
1420+
// while still catching the marker if it straddles two reads.
1421+
size_t scan_from = response_str.size() >= 3 ? response_str.size() - 3 : 0;
1422+
response_str.append(chunk, len);
1423+
1424+
header_end_pos = response_str.find(HEADER_END, scan_from);
1425+
if (header_end_pos != std::string::npos)
1426+
break;
1427+
}
1428+
1429+
if (header_end_pos == std::string::npos)
1430+
{
1431+
Error() << "WebSocket: Handshake response headers exceeded " << MAX_HEADER_SIZE << " bytes.";
13991432
return false;
14001433
}
14011434

1402-
std::string response_str(response, len);
1435+
// Preserve any frame bytes that arrived after the headers for the frame parser.
1436+
size_t body_start = header_end_pos + HEADER_END.size();
1437+
if (body_start < response_str.size())
1438+
{
1439+
size_t extra = response_str.size() - body_start;
1440+
if (extra > buffer.size())
1441+
buffer.resize(extra);
1442+
memcpy(buffer.data(), response_str.data() + body_start, extra);
1443+
buffer_ptr = (int)extra;
1444+
response_str.resize(body_start);
1445+
}
1446+
14031447
std::string rs_lower(response_str);
14041448

14051449
// Parse the response
@@ -1495,17 +1539,20 @@ namespace Protocol
14951539
}
14961540
// Generate a random masking key
14971541
uint8_t masking_key[4];
1542+
bool key_ok = false;
14981543
#ifdef HASOPENSSL
14991544
// Use cryptographically secure random from OpenSSL
1500-
RAND_bytes(masking_key, 4);
1501-
#else
1502-
// Fallback to std::random_device
1503-
std::random_device rd;
1504-
std::mt19937 gen(rd());
1505-
std::uniform_int_distribution<> dis(0, 255);
1506-
for (int i = 0; i < 4; ++i)
1507-
masking_key[i] = dis(gen);
1545+
key_ok = (RAND_bytes(masking_key, 4) == 1);
15081546
#endif
1547+
if (!key_ok)
1548+
{
1549+
// Fallback to std::random_device (also used when OpenSSL is unavailable)
1550+
std::random_device rd;
1551+
std::mt19937 gen(rd());
1552+
std::uniform_int_distribution<int> dis(0, 255);
1553+
for (int i = 0; i < 4; ++i)
1554+
masking_key[i] = (uint8_t)dis(gen);
1555+
}
15091556

15101557
frame.insert(frame.end(), masking_key, masking_key + 4);
15111558

@@ -1646,10 +1693,40 @@ namespace Protocol
16461693
break;
16471694
}
16481695

1696+
// Unmask the incoming PING payload before echoing it back
1697+
if (mask)
1698+
{
1699+
for (int i = 0; i < length; ++i)
1700+
buffer[ptr + i] ^= masking_key[i % 4];
1701+
}
1702+
1703+
// Client-to-server frames must be masked (RFC 6455 §5.3)
1704+
uint8_t pong_key[4];
1705+
#ifdef HASOPENSSL
1706+
if (RAND_bytes(pong_key, 4) != 1)
1707+
{
1708+
std::random_device rd;
1709+
std::mt19937 gen(rd());
1710+
std::uniform_int_distribution<int> dis(0, 255);
1711+
for (int i = 0; i < 4; ++i)
1712+
pong_key[i] = (uint8_t)dis(gen);
1713+
}
1714+
#else
1715+
{
1716+
std::random_device rd;
1717+
std::mt19937 gen(rd());
1718+
std::uniform_int_distribution<int> dis(0, 255);
1719+
for (int i = 0; i < 4; ++i)
1720+
pong_key[i] = (uint8_t)dis(gen);
1721+
}
1722+
#endif
1723+
16491724
frame.resize(0);
16501725
frame.push_back(0x80 | (uint8_t)OPCODE::PONG);
16511726
frame.push_back(0x80 | length);
1652-
frame.insert(frame.end(), buffer.data() + ptr, buffer.data() + ptr + length);
1727+
frame.insert(frame.end(), pong_key, pong_key + 4);
1728+
for (int i = 0; i < length; ++i)
1729+
frame.push_back(buffer[ptr + i] ^ pong_key[i % 4]);
16531730

16541731
if (prev->send(frame.data(), frame.size()) != (int)frame.size())
16551732
{

Source/Library/Logger.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ std::vector<LogMessage> Logger::getLastMessages(int n)
216216
return result;
217217
}
218218

219-
int ptr = (buffer_position_ + 1 + message_buffer_.size() - n) % message_buffer_.size();
219+
int ptr = (buffer_position_ + message_buffer_.size() - n) % message_buffer_.size();
220220

221221
while (ptr != buffer_position_)
222222
{
@@ -257,16 +257,20 @@ void Logger::removeLogListener(int id)
257257

258258
void Logger::notifyListeners(const LogMessage &msg)
259259
{
260+
// Prevent unbounded recursion if a callback itself calls log().
260261
thread_local bool in_notify = false;
261-
262262
if (in_notify) return;
263263

264-
in_notify = true;
265-
std::lock_guard<std::mutex> lock(mutex_);
264+
std::vector<LogListener> snapshot;
265+
{
266+
std::lock_guard<std::mutex> lock(mutex_);
267+
snapshot = log_listeners_;
268+
}
266269

270+
in_notify = true;
267271
try
268272
{
269-
for (const auto &listener : log_listeners_)
273+
for (const auto &listener : snapshot)
270274
{
271275
listener.callback(msg);
272276
}

0 commit comments

Comments
 (0)