Skip to content

Commit 36e1f04

Browse files
committed
fix several bugs in Protocol.cpp (WS handshake, Win socket check, WS max frame drop, MQTT PUBREL, MQTT buffer growth)
1 parent 2bcdbc3 commit 36e1f04

1 file changed

Lines changed: 34 additions & 12 deletions

File tree

Source/IO/Protocol.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,11 @@ namespace Protocol
670670
else
671671
{
672672
SOCKET sock = getSocket();
673+
#ifdef _WIN32
674+
if (sock == INVALID_SOCKET)
675+
#else
673676
if (sock < 0)
677+
#endif
674678
{
675679
disconnect();
676680
return -1;
@@ -1194,6 +1198,17 @@ namespace Protocol
11941198

11951199
} while ((buffer[i++] & 128) != 0);
11961200

1201+
const int MAX_MQTT_PACKET_SIZE = 1024 * 1024;
1202+
if (length > MAX_MQTT_PACKET_SIZE)
1203+
{
1204+
Error() << "MQTT: Packet size " << length << " exceeds maximum " << MAX_MQTT_PACKET_SIZE;
1205+
disconnect();
1206+
return -1;
1207+
}
1208+
1209+
if ((int)buffer.size() < length + i)
1210+
buffer.resize(length + i);
1211+
11971212
if (buffer_ptr < length + i)
11981213
return 0;
11991214

@@ -1265,7 +1280,22 @@ namespace Protocol
12651280
pushVariableLength(2);
12661281
pushByte(buffer[i]);
12671282
pushByte(buffer[i + 1]);
1268-
if (prev->send(packet.data(), packet.size()) != packet.size())
1283+
if (prev->send(packet.data(), packet.size()) != (int)packet.size())
1284+
return -1;
1285+
1286+
break;
1287+
case PacketType::PUBREL:
1288+
if (length < 2)
1289+
{
1290+
Error() << "MQTT: PUBREL packet too short";
1291+
disconnect();
1292+
return -1;
1293+
}
1294+
createPacket(PacketType::PUBCOMP, 0);
1295+
pushVariableLength(2);
1296+
pushByte(buffer[i]);
1297+
pushByte(buffer[i + 1]);
1298+
if (prev->send(packet.data(), packet.size()) != (int)packet.size())
12691299
return -1;
12701300

12711301
break;
@@ -1406,7 +1436,6 @@ namespace Protocol
14061436
while (std::getline(resp_stream, line))
14071437
{
14081438
std::string lower(line);
1409-
;
14101439
Util::Convert::toLower(lower);
14111440

14121441
if (lower.find("sec-websocket-accept:") != std::string::npos)
@@ -1425,7 +1454,7 @@ namespace Protocol
14251454
}
14261455

14271456
// Compute the expected accept key
1428-
std::string expected_accept_key = Util::Convert::BASE64toString(sha1Hash(secWebSocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").substr(0, 20).c_str());
1457+
std::string expected_accept_key = Util::Convert::BASE64toString(sha1Hash(secWebSocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
14291458

14301459
accept_key.erase(std::remove_if(accept_key.begin(), accept_key.end(), ::isspace), accept_key.end());
14311460
expected_accept_key.erase(std::remove_if(expected_accept_key.begin(), expected_accept_key.end(), ::isspace), expected_accept_key.end());
@@ -1604,7 +1633,7 @@ namespace Protocol
16041633
buffer[ptr + i] ^= masking_key[i % 4];
16051634
}
16061635

1607-
if (received_ptr + length < received.size() && length < MAX_PACKET_SIZE)
1636+
if (received_ptr + length <= received.size() && length <= MAX_PACKET_SIZE)
16081637
{
16091638
memmove(received.data() + received_ptr, buffer.data() + ptr, length);
16101639
received_ptr += length;
@@ -1653,14 +1682,7 @@ namespace Protocol
16531682

16541683
bool WebSocket::isConnected()
16551684
{
1656-
1657-
if (connected)
1658-
return true;
1659-
1660-
if (prev)
1661-
connected = prev->isConnected();
1662-
1663-
return false;
1685+
return connected && prev && prev->isConnected();
16641686
}
16651687

16661688
void WebSocket::onDisconnect()

0 commit comments

Comments
 (0)