Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_library(fastmcpp_core
src/server/server.cpp
src/server/context.cpp
src/server/middleware.cpp
src/server/security_middleware.cpp
src/server/http_server.cpp
src/server/stdio_server.cpp
src/server/sse_server.cpp
Expand Down Expand Up @@ -250,10 +251,40 @@ if(FASTMCPP_BUILD_TESTS)
target_link_libraries(fastmcpp_server_context_meta PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_server_context_meta COMMAND fastmcpp_server_context_meta)

add_executable(fastmcpp_server_security_limits tests/server/security_limits.cpp)
target_link_libraries(fastmcpp_server_security_limits PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_server_security_limits COMMAND fastmcpp_server_security_limits)

add_executable(fastmcpp_server_sse_session_security tests/server/sse_session_security.cpp)
target_link_libraries(fastmcpp_server_sse_session_security PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_server_sse_session_security COMMAND fastmcpp_server_sse_session_security)

# SSE session security with fastmcpp::client::HttpTransport (not raw httplib)
add_executable(fastmcpp_client_sse_session_client tests/client/sse_session_client.cpp)
target_link_libraries(fastmcpp_client_sse_session_client PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_client_sse_session_client COMMAND fastmcpp_client_sse_session_client)

# SSE + HTTP integration (real network, not LoopbackTransport)
add_executable(fastmcpp_server_sse_http_integration tests/server/sse_http_integration.cpp)
target_link_libraries(fastmcpp_server_sse_http_integration PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_server_sse_http_integration COMMAND fastmcpp_server_sse_http_integration)

add_executable(fastmcpp_server_auth_cors_security tests/server/auth_cors_security.cpp)
target_link_libraries(fastmcpp_server_auth_cors_security PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_server_auth_cors_security COMMAND fastmcpp_server_auth_cors_security)

add_executable(fastmcpp_server_security_middleware tests/server/security_middleware.cpp)
target_link_libraries(fastmcpp_server_security_middleware PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_server_security_middleware COMMAND fastmcpp_server_security_middleware)

add_executable(fastmcpp_client_transports tests/client/transports.cpp)
target_link_libraries(fastmcpp_client_transports PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_client_transports COMMAND fastmcpp_client_transports)

add_executable(fastmcpp_client_http_client_security tests/client/http_client_security.cpp)
target_link_libraries(fastmcpp_client_http_client_security PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_client_http_client_security COMMAND fastmcpp_client_http_client_security)

add_executable(fastmcpp_client_api_basic tests/client/api_basic.cpp)
target_link_libraries(fastmcpp_client_api_basic PRIVATE fastmcpp_core)
add_test(NAME fastmcpp_client_api_basic COMMAND fastmcpp_client_api_basic)
Expand Down
50 changes: 49 additions & 1 deletion examples/streaming_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ int main()
std::vector<int> seen;
std::mutex m;
std::atomic<bool> sse_connected{false};
std::string session_id;

httplib::Client cli("127.0.0.1", port);
cli.set_connection_timeout(std::chrono::seconds(10));
Expand All @@ -53,6 +54,28 @@ int main()
{
sse_connected = true;
std::string chunk(data, len);

// Parse SSE endpoint event to extract session_id
if (chunk.find("event: endpoint") != std::string::npos)
{
size_t data_pos = chunk.find("data: ");
if (data_pos != std::string::npos)
{
size_t start = data_pos + 6;
size_t end = chunk.find_first_of("\n\r", start);
std::string endpoint_url = chunk.substr(start, end - start);

size_t sid_pos = endpoint_url.find("session_id=");
if (sid_pos != std::string::npos)
{
size_t sid_start = sid_pos + 11;
size_t sid_end = endpoint_url.find_first_of("&\n\r", sid_start);
std::lock_guard<std::mutex> lock(m);
session_id = endpoint_url.substr(sid_start, sid_end - sid_start);
}
}
}

if (chunk.find("data: ") == 0)
{
size_t start = 6;
Expand Down Expand Up @@ -102,11 +125,36 @@ int main()
return 1;
}

// Wait for session_id to be extracted
for (int i = 0; i < 100; ++i)
{
std::lock_guard<std::mutex> lock(m);
if (!session_id.empty())
break;
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}

std::string sid;
{
std::lock_guard<std::mutex> lock(m);
sid = session_id;
}

if (sid.empty())
{
server->stop();
if (sse_thread.joinable())
sse_thread.join();
std::cerr << "Failed to extract session_id" << std::endl;
return 1;
}

httplib::Client post("127.0.0.1", port);
for (int i = 1; i <= 3; ++i)
{
Json j = Json{{"n", i}};
auto res = post.Post("/messages", j.dump(), "application/json");
std::string post_url = "/messages?session_id=" + sid;
auto res = post.Post(post_url, j.dump(), "application/json");
if (!res || res->status != 200)
{
server->stop();
Expand Down
16 changes: 15 additions & 1 deletion include/fastmcpp/server/http_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@ namespace fastmcpp::server
class HttpServerWrapper
{
public:
/**
* Construct an HTTP server with a core Server instance.
*
* @param core Shared pointer to the core Server (routes handler)
* @param host Host address to bind to (default: "127.0.0.1" for localhost)
* @param port Port to listen on (default: 18080)
* @param auth_token Optional auth token for Bearer authentication (empty = no auth required)
* @param cors_origin Optional CORS origin to allow (empty = no CORS header, use "*" for
* wildcard)
*/
HttpServerWrapper(std::shared_ptr<Server> core, std::string host = "127.0.0.1",
int port = 18080);
int port = 18080, std::string auth_token = "", std::string cors_origin = "");
~HttpServerWrapper();

bool start();
Expand All @@ -37,9 +47,13 @@ class HttpServerWrapper
}

private:
bool check_auth(const std::string& auth_header) const;

std::shared_ptr<Server> core_;
std::string host_;
int port_;
std::string auth_token_; // Optional Bearer token for authentication
std::string cors_origin_; // Optional CORS origin (empty = no CORS)
std::unique_ptr<httplib::Server> svr_;
std::thread thread_;
std::atomic<bool> running_{false};
Expand Down
144 changes: 144 additions & 0 deletions include/fastmcpp/server/security_middleware.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#pragma once
#include "fastmcpp/server/middleware.hpp"
#include "fastmcpp/types.hpp"

#include <atomic>
#include <chrono>
#include <deque>
#include <functional>
#include <mutex>
#include <string>
#include <unordered_map>

namespace fastmcpp::server
{

/// Log entry for a request
struct RequestLogEntry
{
std::chrono::system_clock::time_point timestamp;
std::string route;
size_t payload_size;
bool success;
std::string error_message; // Empty if success
};

/// Logging callback function type
using LogCallback = std::function<void(const RequestLogEntry&)>;

/// Logging middleware for audit trail (v2.13.0+)
///
/// Provides optional request logging to track all route/tool invocations.
/// Can be used as both BeforeHook and AfterHook for comprehensive logging.
///
/// Usage:
/// ```cpp
/// auto logger = std::make_shared<LoggingMiddleware>(
/// [](const RequestLogEntry& entry) {
/// std::cout << entry.timestamp << " " << entry.route << std::endl;
/// });
/// srv.add_before(logger->create_before_hook());
/// srv.add_after(logger->create_after_hook());
/// ```
class LoggingMiddleware
{
public:
explicit LoggingMiddleware(LogCallback callback) : callback_(std::move(callback)) {}

/// Create a BeforeHook that logs incoming requests
BeforeHook create_before_hook();

/// Create an AfterHook that logs completed requests
AfterHook create_after_hook();

private:
LogCallback callback_;
std::mutex mutex_;
std::unordered_map<std::string, size_t> request_sizes_; // Track sizes for after hook
};

/// Rate limiting middleware for DoS prevention (v2.13.0+)
///
/// Enforces per-route request limits using a sliding window algorithm.
/// Rejects requests that exceed the configured rate.
///
/// Usage:
/// ```cpp
/// auto limiter = std::make_shared<RateLimitMiddleware>(
/// 100, // max requests
/// std::chrono::minutes(1) // per time window
/// );
/// srv.add_before(limiter->create_hook());
/// ```
class RateLimitMiddleware
{
public:
/// Construct rate limiter
/// @param max_requests Maximum requests allowed in time window
/// @param window Time window for rate limiting
RateLimitMiddleware(size_t max_requests,
std::chrono::steady_clock::duration window = std::chrono::minutes(1))
: max_requests_(max_requests), window_(window)
{
}

/// Create a BeforeHook that enforces rate limits
BeforeHook create_hook();

/// Get current request count for a route
size_t get_request_count(const std::string& route);

/// Reset rate limit counters (for testing)
void reset();

private:
size_t max_requests_;
std::chrono::steady_clock::duration window_;
std::mutex mutex_;

struct RouteStats
{
std::deque<std::chrono::steady_clock::time_point> timestamps;
};

std::unordered_map<std::string, RouteStats> stats_;

void cleanup_old_entries(RouteStats& stats);
};

/// Concurrency limiting middleware for resource control (v2.13.0+)
///
/// Limits the number of concurrent route handler executions.
/// Uses atomic counters for thread-safe tracking.
///
/// Usage:
/// ```cpp
/// auto limiter = std::make_shared<ConcurrencyLimitMiddleware>(10); // Max 10 parallel
/// srv.add_before(limiter->create_before_hook());
/// srv.add_after(limiter->create_after_hook());
/// ```
class ConcurrencyLimitMiddleware
{
public:
/// Construct concurrency limiter
/// @param max_concurrent Maximum number of concurrent handler executions
explicit ConcurrencyLimitMiddleware(size_t max_concurrent) : max_concurrent_(max_concurrent) {}

/// Create a BeforeHook that checks concurrency limit
BeforeHook create_before_hook();

/// Create an AfterHook that releases concurrency slot
AfterHook create_after_hook();

/// Get current concurrent request count
size_t get_current_count() const
{
return current_count_.load();
}

private:
size_t max_concurrent_;
std::atomic<size_t> current_count_{0};
};

} // namespace fastmcpp::server
25 changes: 20 additions & 5 deletions include/fastmcpp/server/sse_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <queue>
#include <string>
#include <thread>
#include <unordered_map>

namespace fastmcpp::server
{
Expand Down Expand Up @@ -50,10 +51,13 @@ class SseServerWrapper
* @param port Port to listen on (default: 18080)
* @param sse_path Path for SSE GET endpoint (default: "/sse")
* @param message_path Path for POST message endpoint (default: "/messages")
* @param auth_token Optional auth token for Bearer authentication (empty = no auth required)
* @param cors_origin Optional CORS origin to allow (empty = no CORS header, use "*" for
* wildcard)
*/
explicit SseServerWrapper(McpHandler handler, std::string host = "127.0.0.1", int port = 18080,
std::string sse_path = "/sse",
std::string message_path = "/messages");
std::string sse_path = "/sse", std::string message_path = "/messages",
std::string auth_token = "", std::string cors_origin = "");

~SseServerWrapper();

Expand Down Expand Up @@ -118,29 +122,40 @@ class SseServerWrapper
private:
void run_server();
void send_event_to_all_clients(const fastmcpp::Json& event);
void send_event_to_session(const std::string& session_id, const fastmcpp::Json& event);
std::string generate_session_id();
bool check_auth(const std::string& auth_header) const;

McpHandler handler_;
std::string host_;
int port_;
std::string sse_path_;
std::string message_path_;
std::string auth_token_; // Optional Bearer token for authentication
std::string cors_origin_; // Optional CORS origin (empty = no CORS)

std::unique_ptr<httplib::Server> svr_;
std::thread thread_;
std::atomic<bool> running_{false};

// Security limits
static constexpr size_t MAX_CONNECTIONS = 100;
static constexpr size_t MAX_QUEUE_SIZE = 1000;

struct ConnectionState
{
std::string session_id;
std::deque<fastmcpp::Json> queue;
std::mutex m;
std::condition_variable cv;
bool alive{true};
};

void handle_sse_connection(httplib::DataSink& sink, std::shared_ptr<ConnectionState> conn);
void handle_sse_connection(httplib::DataSink& sink, std::shared_ptr<ConnectionState> conn,
const std::string& session_id);

// Active SSE connections (per-connection queues)
std::vector<std::shared_ptr<ConnectionState>> connections_;
// Active SSE connections mapped by session ID
std::unordered_map<std::string, std::shared_ptr<ConnectionState>> connections_;
std::mutex conns_mutex_;
};

Expand Down
Loading