diff --git a/core/functional_tests/websocket_client/service.cpp b/core/functional_tests/websocket_client/service.cpp index 1a6fadcbe209..d75f8ecc0c10 100644 --- a/core/functional_tests/websocket_client/service.cpp +++ b/core/functional_tests/websocket_client/service.cpp @@ -51,6 +51,28 @@ class WebSocketUnauth final : public server::handlers::WebsocketHandlerBase { void Handle(websocket::WebSocketConnection&, server::request::RequestContext&) const override {} }; +class WebSocketLargeMessage final : public server::handlers::WebsocketHandlerBase { +public: + static constexpr std::string_view kName = "websocket-large-message-handler"; + + using WebsocketHandlerBase::WebsocketHandlerBase; + + void Handle(websocket::WebSocketConnection& ws, server::request::RequestContext&) const override { + constexpr std::size_t kMessageSize = 70000; + + ws.SendText(std::string(kMessageSize, 'X')); + + websocket::Message message; + while (!engine::current_task::ShouldCancel()) { + ws.Recv(message); + if (message.close_status) { + ws.Close(*message.close_status); + break; + } + } + } +}; + // HTTP handler for testing C++ WebSocket client class TestClientHandler final : public server::handlers::HttpHandlerBase { public: @@ -84,6 +106,10 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase { return TestUnauth(request); } else if (test_name == "connection_already_extracted") { return TestConnectionAlreadyExtracted(request); + } else if (test_name == "large_server_message_default_limit") { + return TestLargeServerMessageDefaultLimit(request); + } else if (test_name == "large_server_message_custom_limit") { + return TestLargeServerMessageCustomLimit(request); } return "Unknown test"; } catch (const std::exception& e) { @@ -247,6 +273,35 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase { return "OK"; } + std::string TestLargeServerMessageDefaultLimit(const server::http::HttpRequest& request) const { + auto conn = PerformWebSocket(request, "/large-message").MakeWebSocketConnection(); + + websocket::Message msg; + conn->Recv(msg); + + if (!msg.close_status || *msg.close_status != websocket::CloseStatus::kTooBigData) { + return fmt::format( + "FAIL: close status is {}", + msg.close_status ? static_cast(*msg.close_status) : -1 + ); + } + + return "OK"; + } + + std::string TestLargeServerMessageCustomLimit(const server::http::HttpRequest& request) const { + auto ws_response = PerformWebSocket(request, "/large-message"); + websocket::Config config; + config.max_remote_payload = 80000; + auto conn = ws_response.MakeWebSocketConnectionWithConfig(config); + + websocket::Message msg; + conn->Recv(msg); + conn->Close(websocket::CloseStatus::kNormal); + + return (msg.is_text && msg.data == std::string(70000, 'X')) ? "OK" : "FAIL"; + } + clients::http::Client& client_; }; @@ -257,6 +312,7 @@ int main(int argc, char* argv[]) { .Append() .Append() .Append() + .Append() .Append() .Append(); return utils::DaemonMain(argc, argv, component_list); diff --git a/core/functional_tests/websocket_client/static_config.yaml b/core/functional_tests/websocket_client/static_config.yaml index 7b9a90cb6b64..d89502b22468 100644 --- a/core/functional_tests/websocket_client/static_config.yaml +++ b/core/functional_tests/websocket_client/static_config.yaml @@ -36,6 +36,13 @@ components_manager: max-remote-payload: 100000 fragment-size: 65536 + websocket-large-message-handler: + path: /large-message + method: GET + task_processor: main-task-processor + max-remote-payload: 100000 + fragment-size: 65536 + # HTTP handler for C++ client tests test-client: path: /test-client diff --git a/core/functional_tests/websocket_client/tests/test_websocket_client.py b/core/functional_tests/websocket_client/tests/test_websocket_client.py index 167e26379b85..2a62058d8a92 100644 --- a/core/functional_tests/websocket_client/tests/test_websocket_client.py +++ b/core/functional_tests/websocket_client/tests/test_websocket_client.py @@ -13,6 +13,8 @@ 'nonblocking_write', 'unauth', 'connection_already_extracted', + 'large_server_message_default_limit', + 'large_server_message_custom_limit', ], ) async def test_client(service_client, service_port, test_name): diff --git a/core/include/userver/clients/http/websocket_response.hpp b/core/include/userver/clients/http/websocket_response.hpp index 9c40a86ad787..8c1d3fc901fb 100644 --- a/core/include/userver/clients/http/websocket_response.hpp +++ b/core/include/userver/clients/http/websocket_response.hpp @@ -10,6 +10,7 @@ USERVER_NAMESPACE_BEGIN namespace websocket { class WebSocketConnection; +struct Config; } // namespace websocket namespace clients::http { @@ -43,6 +44,11 @@ class WebSocketResponse final { /// @brief Create a WebSocket connection from this response std::shared_ptr MakeWebSocketConnection(); + /// @brief Create a WebSocket connection from this response with custom websocket config + std::shared_ptr MakeWebSocketConnectionWithConfig( + const websocket::Config& config + ); + private: std::shared_ptr handshake_response_; fs::blocking::FileDescriptor socket_; diff --git a/core/src/clients/http/websocket_response.cpp b/core/src/clients/http/websocket_response.cpp index 587822924f91..664b59efcfff 100644 --- a/core/src/clients/http/websocket_response.cpp +++ b/core/src/clients/http/websocket_response.cpp @@ -19,6 +19,12 @@ bool WebSocketResponse::IsProtocolUpgraded() const { } std::shared_ptr WebSocketResponse::MakeWebSocketConnection() { + return MakeWebSocketConnectionWithConfig(websocket::Config{}); +} + +std::shared_ptr WebSocketResponse::MakeWebSocketConnectionWithConfig( + const websocket::Config& config +) { if (!IsProtocolUpgraded()) { throw std::runtime_error("Protocol is not upgraded to WebSocket"); } @@ -29,7 +35,6 @@ std::shared_ptr WebSocketResponse::MakeWebSocket auto socket = std::make_unique(socket_.GetNative()); auto addr = socket->Getsockname(); - auto config = websocket::Config{}; std::move(socket_).Release();