@@ -393,12 +393,7 @@ static int header_value_cb(http_parser* parser, const char* at, size_t length) {
393393 auto inspector = static_cast <InspectorSocket*>(parser->data );
394394 auto state = inspector->http_parsing_state ;
395395 state->parsing_value = true ;
396- if (state->current_header .size () == sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 &&
397- node::StringEqualNoCaseN (state->current_header .data (),
398- SEC_WEBSOCKET_KEY_HEADER,
399- sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 )) {
400- state->ws_key .append (at, length);
401- }
396+ state->headers [state->current_header ].append (at, length);
402397 return 0 ;
403398}
404399
@@ -471,10 +466,59 @@ static void handshake_failed(InspectorSocket* inspector) {
471466// init_handshake references message_complete_cb
472467static void init_handshake (InspectorSocket* inspector);
473468
469+ static std::string TrimPort (const std::string& host) {
470+ size_t last_colon_pos = host.rfind (" :" );
471+ if (last_colon_pos == std::string::npos)
472+ return host;
473+ size_t bracket = host.rfind (" ]" );
474+ if (bracket == std::string::npos || last_colon_pos > bracket)
475+ return host.substr (0 , last_colon_pos);
476+ return host;
477+ }
478+
479+ static bool IsIPAddress (const std::string& host) {
480+ if (host.length () >= 4 && host[0 ] == ' [' && host[host.size () - 1 ] == ' ]' )
481+ return true ;
482+ int quads = 0 ;
483+ for (char c : host) {
484+ if (c == ' .' )
485+ quads++;
486+ else if (!isdigit (c))
487+ return false ;
488+ }
489+ return quads == 3 ;
490+ }
491+
492+ static std::string HeaderValue (const struct http_parsing_state_s * state,
493+ const std::string& header) {
494+ bool header_found = false ;
495+ std::string value;
496+ for (const auto & header_value : state->headers ) {
497+ if (node::StringEqualNoCaseN (header_value.first .data (), header.data (),
498+ header.length ())) {
499+ if (header_found)
500+ return " " ;
501+ value = header_value.second ;
502+ header_found = true ;
503+ }
504+ }
505+ return value;
506+ }
507+
508+ static bool IsAllowedHost (const std::string& host_with_port) {
509+ std::string host = TrimPort (host_with_port);
510+ return host.empty () || IsIPAddress (host)
511+ || node::StringEqualNoCase (host.data (), " localhost" )
512+ || node::StringEqualNoCase (host.data (), " localhost6" );
513+ }
514+
474515static int message_complete_cb (http_parser* parser) {
475516 InspectorSocket* inspector = static_cast <InspectorSocket*>(parser->data );
476517 struct http_parsing_state_s * state = inspector->http_parsing_state ;
477- if (parser->method != HTTP_GET) {
518+ state->ws_key = HeaderValue (state, " Sec-WebSocket-Key" );
519+
520+ if (!IsAllowedHost (HeaderValue (state, " Host" )) ||
521+ parser->method != HTTP_GET) {
478522 handshake_failed (inspector);
479523 } else if (!parser->upgrade ) {
480524 if (state->callback (inspector, kInspectorHandshakeHttpGet , state->path )) {
0 commit comments