@@ -698,9 +698,16 @@ namespace Protocol
698698 int rv = select (sock + 1 , &readfds, nullptr , nullptr , &tv);
699699 if (rv < 0 )
700700 {
701+ #ifdef _WIN32
702+ int sock_err = WSAGetLastError ();
703+ if (sock_err == WSAEINTR)
704+ continue ;
705+ Error () << " TLS select failed. Error code: " << sock_err;
706+ #else
701707 if (errno == EINTR)
702708 continue ;
703- perror (" TLS select" );
709+ Error () << " TLS select failed: " << strerror (errno);
710+ #endif
704711 if (stats)
705712 stats->bytes_in += BIO_number_read (SSL_get_rbio (ssl)) - before;
706713
@@ -959,7 +966,7 @@ namespace Protocol
959966 if (shift >= 28 )
960967 return -1 ;
961968
962- if (prev->read ((char *)&b, 1 ) != 1 )
969+ if (prev->read ((char *)&b, 1 , 5 , true ) != 1 )
963970 return -1 ;
964971
965972 length |= (b & 127 ) << shift;
@@ -1071,7 +1078,8 @@ namespace Protocol
10711078
10721079 performHandshake ();
10731080
1074- ProtocolBase::onConnect ();
1081+ if (connected)
1082+ ProtocolBase::onConnect ();
10751083 }
10761084
10771085 void MQTT::onDisconnect ()
@@ -1210,12 +1218,6 @@ namespace Protocol
12101218
12111219 case PacketType::PUBLISH:
12121220 {
1213- if (data_len < length)
1214- {
1215- Warning () << " MQTT: Buffer too small for incoming message. Required: " << length << " , provided: " << data_len;
1216- break ;
1217- }
1218-
12191221 // Validate minimum length for topic length field
12201222 if (length < 2 )
12211223 {
@@ -1236,15 +1238,22 @@ namespace Protocol
12361238 return -1 ;
12371239 }
12381240
1241+ int payload_len = length - header_size;
1242+ if (data_len < payload_len)
1243+ {
1244+ Warning () << " MQTT: Buffer too small for incoming message. Required: " << payload_len << " , provided: " << data_len;
1245+ break ;
1246+ }
1247+
12391248 if (q == 0 )
12401249 {
1241- data_returned = length - 2 - topic_len ;
1250+ data_returned = payload_len ;
12421251 memcpy (data, buffer.data () + i + 2 + topic_len, data_returned);
12431252 }
12441253 else if (q == 1 || q == 2 )
12451254 {
12461255 uint16_t packet_id = (buffer[i + 2 + topic_len] << 8 ) + buffer[i + 2 + topic_len + 1 ];
1247- data_returned = length - 4 - topic_len ;
1256+ data_returned = payload_len ;
12481257 memcpy (data, buffer.data () + i + 4 + topic_len, data_returned);
12491258
12501259 createPacket ((q == 1 ) ? PacketType::PUBACK : PacketType::PUBREC, 0 );
@@ -1390,16 +1399,51 @@ namespace Protocol
13901399 return false ;
13911400 }
13921401
1393- // Read the response
1394- char response[2048 ];
1395- int len = prev->read (response, sizeof (response), 5 , true );
1396- if (len <= 0 )
1402+ // Read the response — headers may arrive across multiple TCP segments,
1403+ // so accumulate until we see the end-of-headers marker.
1404+ std::string response_str;
1405+ const size_t MAX_HEADER_SIZE = 16384 ;
1406+ const std::string HEADER_END = " \r\n\r\n " ;
1407+ size_t header_end_pos = std::string::npos;
1408+
1409+ while (response_str.size () < MAX_HEADER_SIZE)
13971410 {
1398- Error () << " WebSocket: No response to handshake request." ;
1411+ char chunk[2048 ];
1412+ int len = prev->read (chunk, sizeof (chunk), 5 , true );
1413+ if (len <= 0 )
1414+ {
1415+ Error () << " WebSocket: No response to handshake request." ;
1416+ return false ;
1417+ }
1418+
1419+ // Track scan start so we don't re-scan the entire string each iteration,
1420+ // while still catching the marker if it straddles two reads.
1421+ size_t scan_from = response_str.size () >= 3 ? response_str.size () - 3 : 0 ;
1422+ response_str.append (chunk, len);
1423+
1424+ header_end_pos = response_str.find (HEADER_END, scan_from);
1425+ if (header_end_pos != std::string::npos)
1426+ break ;
1427+ }
1428+
1429+ if (header_end_pos == std::string::npos)
1430+ {
1431+ Error () << " WebSocket: Handshake response headers exceeded " << MAX_HEADER_SIZE << " bytes." ;
13991432 return false ;
14001433 }
14011434
1402- std::string response_str (response, len);
1435+ // Preserve any frame bytes that arrived after the headers for the frame parser.
1436+ size_t body_start = header_end_pos + HEADER_END.size ();
1437+ if (body_start < response_str.size ())
1438+ {
1439+ size_t extra = response_str.size () - body_start;
1440+ if (extra > buffer.size ())
1441+ buffer.resize (extra);
1442+ memcpy (buffer.data (), response_str.data () + body_start, extra);
1443+ buffer_ptr = (int )extra;
1444+ response_str.resize (body_start);
1445+ }
1446+
14031447 std::string rs_lower (response_str);
14041448
14051449 // Parse the response
@@ -1495,17 +1539,20 @@ namespace Protocol
14951539 }
14961540 // Generate a random masking key
14971541 uint8_t masking_key[4 ];
1542+ bool key_ok = false ;
14981543#ifdef HASOPENSSL
14991544 // Use cryptographically secure random from OpenSSL
1500- RAND_bytes (masking_key, 4 );
1501- #else
1502- // Fallback to std::random_device
1503- std::random_device rd;
1504- std::mt19937 gen (rd ());
1505- std::uniform_int_distribution<> dis (0 , 255 );
1506- for (int i = 0 ; i < 4 ; ++i)
1507- masking_key[i] = dis (gen);
1545+ key_ok = (RAND_bytes (masking_key, 4 ) == 1 );
15081546#endif
1547+ if (!key_ok)
1548+ {
1549+ // Fallback to std::random_device (also used when OpenSSL is unavailable)
1550+ std::random_device rd;
1551+ std::mt19937 gen (rd ());
1552+ std::uniform_int_distribution<int > dis (0 , 255 );
1553+ for (int i = 0 ; i < 4 ; ++i)
1554+ masking_key[i] = (uint8_t )dis (gen);
1555+ }
15091556
15101557 frame.insert (frame.end (), masking_key, masking_key + 4 );
15111558
@@ -1646,10 +1693,40 @@ namespace Protocol
16461693 break ;
16471694 }
16481695
1696+ // Unmask the incoming PING payload before echoing it back
1697+ if (mask)
1698+ {
1699+ for (int i = 0 ; i < length; ++i)
1700+ buffer[ptr + i] ^= masking_key[i % 4 ];
1701+ }
1702+
1703+ // Client-to-server frames must be masked (RFC 6455 §5.3)
1704+ uint8_t pong_key[4 ];
1705+ #ifdef HASOPENSSL
1706+ if (RAND_bytes (pong_key, 4 ) != 1 )
1707+ {
1708+ std::random_device rd;
1709+ std::mt19937 gen (rd ());
1710+ std::uniform_int_distribution<int > dis (0 , 255 );
1711+ for (int i = 0 ; i < 4 ; ++i)
1712+ pong_key[i] = (uint8_t )dis (gen);
1713+ }
1714+ #else
1715+ {
1716+ std::random_device rd;
1717+ std::mt19937 gen (rd ());
1718+ std::uniform_int_distribution<int > dis (0 , 255 );
1719+ for (int i = 0 ; i < 4 ; ++i)
1720+ pong_key[i] = (uint8_t )dis (gen);
1721+ }
1722+ #endif
1723+
16491724 frame.resize (0 );
16501725 frame.push_back (0x80 | (uint8_t )OPCODE::PONG);
16511726 frame.push_back (0x80 | length);
1652- frame.insert (frame.end (), buffer.data () + ptr, buffer.data () + ptr + length);
1727+ frame.insert (frame.end (), pong_key, pong_key + 4 );
1728+ for (int i = 0 ; i < length; ++i)
1729+ frame.push_back (buffer[ptr + i] ^ pong_key[i % 4 ]);
16531730
16541731 if (prev->send (frame.data (), frame.size ()) != (int )frame.size ())
16551732 {
0 commit comments