@@ -761,6 +761,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
761761 . method ( Method :: POST )
762762 . header ( "Accept" , "application/json, text/event-stream" )
763763 . header ( CONTENT_TYPE , "application/json" )
764+ . header ( "Host" , "localhost:8080" )
764765 . body ( Full :: new ( Bytes :: from ( init_body. to_string ( ) ) ) )
765766 . unwrap ( ) ;
766767
@@ -785,6 +786,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
785786 . method ( Method :: POST )
786787 . header ( "Accept" , "application/json, text/event-stream" )
787788 . header ( CONTENT_TYPE , "application/json" )
789+ . header ( "Host" , "localhost:8080" )
788790 . header ( "mcp-session-id" , & session_id)
789791 . header ( "mcp-protocol-version" , "2025-03-26" )
790792 . body ( Full :: new ( Bytes :: from ( initialized_body. to_string ( ) ) ) )
@@ -802,6 +804,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
802804 . method ( Method :: POST )
803805 . header ( "Accept" , "application/json, text/event-stream" )
804806 . header ( CONTENT_TYPE , "application/json" )
807+ . header ( "Host" , "localhost:8080" )
805808 . header ( "mcp-session-id" , & session_id)
806809 . header ( "mcp-protocol-version" , "2025-03-26" )
807810 . body ( Full :: new ( Bytes :: from ( valid_body. to_string ( ) ) ) )
@@ -823,6 +826,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
823826 . method ( Method :: POST )
824827 . header ( "Accept" , "application/json, text/event-stream" )
825828 . header ( CONTENT_TYPE , "application/json" )
829+ . header ( "Host" , "localhost:8080" )
826830 . header ( "mcp-session-id" , & session_id)
827831 . header ( "mcp-protocol-version" , "9999-01-01" )
828832 . body ( Full :: new ( Bytes :: from ( invalid_body. to_string ( ) ) ) )
@@ -844,6 +848,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
844848 . method ( Method :: POST )
845849 . header ( "Accept" , "application/json, text/event-stream" )
846850 . header ( CONTENT_TYPE , "application/json" )
851+ . header ( "Host" , "localhost:8080" )
847852 . header ( "mcp-session-id" , & session_id)
848853 . body ( Full :: new ( Bytes :: from ( no_version_body. to_string ( ) ) ) )
849854 . unwrap ( ) ;
@@ -870,3 +875,156 @@ fn test_protocol_version_utilities() {
870875 assert ! ( ProtocolVersion :: KNOWN_VERSIONS . contains( & ProtocolVersion :: V_2025_03_26 ) ) ;
871876 assert ! ( ProtocolVersion :: KNOWN_VERSIONS . contains( & ProtocolVersion :: V_2025_06_18 ) ) ;
872877}
878+
879+ /// Integration test: Verify server validates only the Host header for DNS rebinding protection
880+ #[ tokio:: test]
881+ #[ cfg( all( feature = "transport-streamable-http-server" , feature = "server" , ) ) ]
882+ async fn test_server_validates_host_header_for_dns_rebinding_protection ( ) {
883+ use std:: sync:: Arc ;
884+
885+ use bytes:: Bytes ;
886+ use http:: { Method , Request , header:: CONTENT_TYPE } ;
887+ use http_body_util:: Full ;
888+ use rmcp:: {
889+ handler:: server:: ServerHandler ,
890+ model:: { ServerCapabilities , ServerInfo } ,
891+ transport:: streamable_http_server:: {
892+ StreamableHttpServerConfig , StreamableHttpService , session:: local:: LocalSessionManager ,
893+ } ,
894+ } ;
895+ use serde_json:: json;
896+
897+ #[ derive( Clone ) ]
898+ struct TestHandler ;
899+
900+ impl ServerHandler for TestHandler {
901+ fn get_info ( & self ) -> ServerInfo {
902+ ServerInfo :: new ( ServerCapabilities :: builder ( ) . build ( ) )
903+ }
904+ }
905+
906+ let service = StreamableHttpService :: new (
907+ || Ok ( TestHandler ) ,
908+ Arc :: new ( LocalSessionManager :: default ( ) ) ,
909+ StreamableHttpServerConfig :: default ( ) ,
910+ ) ;
911+
912+ let init_body = json ! ( {
913+ "jsonrpc" : "2.0" ,
914+ "id" : 1 ,
915+ "method" : "initialize" ,
916+ "params" : {
917+ "protocolVersion" : "2025-03-26" ,
918+ "capabilities" : { } ,
919+ "clientInfo" : {
920+ "name" : "test-client" ,
921+ "version" : "1.0.0"
922+ }
923+ }
924+ } ) ;
925+
926+ let allowed_request = Request :: builder ( )
927+ . method ( Method :: POST )
928+ . header ( "Accept" , "application/json, text/event-stream" )
929+ . header ( CONTENT_TYPE , "application/json" )
930+ . header ( "Host" , "localhost:8080" )
931+ . header ( "Origin" , "http://localhost:8080" )
932+ . body ( Full :: new ( Bytes :: from ( init_body. to_string ( ) ) ) )
933+ . unwrap ( ) ;
934+
935+ let response = service. handle ( allowed_request) . await ;
936+ assert_eq ! ( response. status( ) , http:: StatusCode :: OK ) ;
937+
938+ let bad_host_request = Request :: builder ( )
939+ . method ( Method :: POST )
940+ . header ( "Accept" , "application/json, text/event-stream" )
941+ . header ( CONTENT_TYPE , "application/json" )
942+ . header ( "Host" , "attacker.example" )
943+ . body ( Full :: new ( Bytes :: from ( init_body. to_string ( ) ) ) )
944+ . unwrap ( ) ;
945+
946+ let response = service. handle ( bad_host_request) . await ;
947+ assert_eq ! ( response. status( ) , http:: StatusCode :: FORBIDDEN ) ;
948+
949+ let ignored_origin_request = Request :: builder ( )
950+ . method ( Method :: POST )
951+ . header ( "Accept" , "application/json, text/event-stream" )
952+ . header ( CONTENT_TYPE , "application/json" )
953+ . header ( "Host" , "localhost:8080" )
954+ . header ( "Origin" , "http://attacker.example" )
955+ . body ( Full :: new ( Bytes :: from ( init_body. to_string ( ) ) ) )
956+ . unwrap ( ) ;
957+
958+ let response = service. handle ( ignored_origin_request) . await ;
959+ assert_eq ! ( response. status( ) , http:: StatusCode :: OK ) ;
960+ }
961+
962+ /// Integration test: Verify server can enforce an allowed Host port when configured
963+ #[ tokio:: test]
964+ #[ cfg( all( feature = "transport-streamable-http-server" , feature = "server" , ) ) ]
965+ async fn test_server_validates_host_header_port_for_dns_rebinding_protection ( ) {
966+ use std:: sync:: Arc ;
967+
968+ use bytes:: Bytes ;
969+ use http:: { Method , Request , header:: CONTENT_TYPE } ;
970+ use http_body_util:: Full ;
971+ use rmcp:: {
972+ handler:: server:: ServerHandler ,
973+ model:: { ServerCapabilities , ServerInfo } ,
974+ transport:: streamable_http_server:: {
975+ StreamableHttpServerConfig , StreamableHttpService , session:: local:: LocalSessionManager ,
976+ } ,
977+ } ;
978+ use serde_json:: json;
979+
980+ #[ derive( Clone ) ]
981+ struct TestHandler ;
982+
983+ impl ServerHandler for TestHandler {
984+ fn get_info ( & self ) -> ServerInfo {
985+ ServerInfo :: new ( ServerCapabilities :: builder ( ) . build ( ) )
986+ }
987+ }
988+
989+ let service = StreamableHttpService :: new (
990+ || Ok ( TestHandler ) ,
991+ Arc :: new ( LocalSessionManager :: default ( ) ) ,
992+ StreamableHttpServerConfig :: default ( ) . with_allowed_hosts ( [ "localhost:8080" ] ) ,
993+ ) ;
994+
995+ let init_body = json ! ( {
996+ "jsonrpc" : "2.0" ,
997+ "id" : 1 ,
998+ "method" : "initialize" ,
999+ "params" : {
1000+ "protocolVersion" : "2025-03-26" ,
1001+ "capabilities" : { } ,
1002+ "clientInfo" : {
1003+ "name" : "test-client" ,
1004+ "version" : "1.0.0"
1005+ }
1006+ }
1007+ } ) ;
1008+
1009+ let allowed_request = Request :: builder ( )
1010+ . method ( Method :: POST )
1011+ . header ( "Accept" , "application/json, text/event-stream" )
1012+ . header ( CONTENT_TYPE , "application/json" )
1013+ . header ( "Host" , "localhost:8080" )
1014+ . body ( Full :: new ( Bytes :: from ( init_body. to_string ( ) ) ) )
1015+ . unwrap ( ) ;
1016+
1017+ let response = service. handle ( allowed_request) . await ;
1018+ assert_eq ! ( response. status( ) , http:: StatusCode :: OK ) ;
1019+
1020+ let wrong_port_request = Request :: builder ( )
1021+ . method ( Method :: POST )
1022+ . header ( "Accept" , "application/json, text/event-stream" )
1023+ . header ( CONTENT_TYPE , "application/json" )
1024+ . header ( "Host" , "localhost:3000" )
1025+ . body ( Full :: new ( Bytes :: from ( init_body. to_string ( ) ) ) )
1026+ . unwrap ( ) ;
1027+
1028+ let response = service. handle ( wrong_port_request) . await ;
1029+ assert_eq ! ( response. status( ) , http:: StatusCode :: FORBIDDEN ) ;
1030+ }
0 commit comments