@@ -397,12 +397,7 @@ static int header_value_cb(http_parser* parser, const char* at, size_t length) {
397
397
auto inspector = static_cast <InspectorSocket*>(parser->data );
398
398
auto state = inspector->http_parsing_state ;
399
399
state->parsing_value = true ;
400
- if (state->current_header .size () == sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 &&
401
- node::StringEqualNoCaseN (state->current_header .data (),
402
- SEC_WEBSOCKET_KEY_HEADER,
403
- sizeof (SEC_WEBSOCKET_KEY_HEADER) - 1 )) {
404
- state->ws_key .append (at, length);
405
- }
400
+ state->headers [state->current_header ].append (at, length);
406
401
return 0 ;
407
402
}
408
403
@@ -475,10 +470,59 @@ static void handshake_failed(InspectorSocket* inspector) {
475
470
// init_handshake references message_complete_cb
476
471
static void init_handshake (InspectorSocket* socket);
477
472
473
+ static std::string TrimPort (const std::string& host) {
474
+ size_t last_colon_pos = host.rfind (" :" );
475
+ if (last_colon_pos == std::string::npos)
476
+ return host;
477
+ size_t bracket = host.rfind (" ]" );
478
+ if (bracket == std::string::npos || last_colon_pos > bracket)
479
+ return host.substr (0 , last_colon_pos);
480
+ return host;
481
+ }
482
+
483
+ static bool IsIPAddress (const std::string& host) {
484
+ if (host.length () >= 4 && host.front () == ' [' && host.back () == ' ]' )
485
+ return true ;
486
+ int quads = 0 ;
487
+ for (char c : host) {
488
+ if (c == ' .' )
489
+ quads++;
490
+ else if (!isdigit (c))
491
+ return false ;
492
+ }
493
+ return quads == 3 ;
494
+ }
495
+
496
+ static std::string HeaderValue (const struct http_parsing_state_s * state,
497
+ const std::string& header) {
498
+ bool header_found = false ;
499
+ std::string value;
500
+ for (const auto & header_value : state->headers ) {
501
+ if (node::StringEqualNoCaseN (header_value.first .data (), header.data (),
502
+ header.length ())) {
503
+ if (header_found)
504
+ return " " ;
505
+ value = header_value.second ;
506
+ header_found = true ;
507
+ }
508
+ }
509
+ return value;
510
+ }
511
+
512
+ static bool IsAllowedHost (const std::string& host_with_port) {
513
+ std::string host = TrimPort (host_with_port);
514
+ return host.empty () || IsIPAddress (host)
515
+ || node::StringEqualNoCase (host.data (), " localhost" )
516
+ || node::StringEqualNoCase (host.data (), " localhost6" );
517
+ }
518
+
478
519
static int message_complete_cb (http_parser* parser) {
479
520
InspectorSocket* inspector = static_cast <InspectorSocket*>(parser->data );
480
521
struct http_parsing_state_s * state = inspector->http_parsing_state ;
481
- if (parser->method != HTTP_GET) {
522
+ state->ws_key = HeaderValue (state, " Sec-WebSocket-Key" );
523
+
524
+ if (!IsAllowedHost (HeaderValue (state, " Host" )) ||
525
+ parser->method != HTTP_GET) {
482
526
handshake_failed (inspector);
483
527
} else if (!parser->upgrade ) {
484
528
if (state->callback (inspector, kInspectorHandshakeHttpGet , state->path )) {
0 commit comments