Skip to content

Commit fcc97b6

Browse files
committed
Add security middleware for logging, rate limiting, and concurrency control (issue #5)
Implements optional request logging, rate limiting, and concurrency limiting middleware to prevent abuse and DoS attacks on server routes. Features: - LoggingMiddleware: Optional audit trail for all route invocations - RateLimitMiddleware: Sliding window rate limiting per route - ConcurrencyLimitMiddleware: Limits parallel handler execution All middleware are optional and can be combined. Includes comprehensive tests covering logging, rate enforcement, window expiration, concurrency limiting, and middleware composition. Security audit issue: fastmcpp #5 (minimal logging/limits on server hooks)
1 parent 580728d commit fcc97b6

4 files changed

Lines changed: 634 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_library(fastmcpp_core
2525
src/server/server.cpp
2626
src/server/context.cpp
2727
src/server/middleware.cpp
28+
src/server/security_middleware.cpp
2829
src/server/http_server.cpp
2930
src/server/stdio_server.cpp
3031
src/server/sse_server.cpp
@@ -262,6 +263,10 @@ if(FASTMCPP_BUILD_TESTS)
262263
target_link_libraries(fastmcpp_server_auth_cors_security PRIVATE fastmcpp_core)
263264
add_test(NAME fastmcpp_server_auth_cors_security COMMAND fastmcpp_server_auth_cors_security)
264265

266+
add_executable(fastmcpp_server_security_middleware tests/server/security_middleware.cpp)
267+
target_link_libraries(fastmcpp_server_security_middleware PRIVATE fastmcpp_core)
268+
add_test(NAME fastmcpp_server_security_middleware COMMAND fastmcpp_server_security_middleware)
269+
265270
add_executable(fastmcpp_client_transports tests/client/transports.cpp)
266271
target_link_libraries(fastmcpp_client_transports PRIVATE fastmcpp_core)
267272
add_test(NAME fastmcpp_client_transports COMMAND fastmcpp_client_transports)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#pragma once
2+
#include "fastmcpp/server/middleware.hpp"
3+
#include "fastmcpp/types.hpp"
4+
5+
#include <atomic>
6+
#include <chrono>
7+
#include <deque>
8+
#include <functional>
9+
#include <mutex>
10+
#include <string>
11+
#include <unordered_map>
12+
13+
namespace fastmcpp::server
14+
{
15+
16+
/// Log entry for a request
17+
struct RequestLogEntry
18+
{
19+
std::chrono::system_clock::time_point timestamp;
20+
std::string route;
21+
size_t payload_size;
22+
bool success;
23+
std::string error_message; // Empty if success
24+
};
25+
26+
/// Logging callback function type
27+
using LogCallback = std::function<void(const RequestLogEntry&)>;
28+
29+
/// Logging middleware for audit trail (v2.13.0+)
30+
///
31+
/// Provides optional request logging to track all route/tool invocations.
32+
/// Can be used as both BeforeHook and AfterHook for comprehensive logging.
33+
///
34+
/// Usage:
35+
/// ```cpp
36+
/// auto logger = std::make_shared<LoggingMiddleware>(
37+
/// [](const RequestLogEntry& entry) {
38+
/// std::cout << entry.timestamp << " " << entry.route << std::endl;
39+
/// });
40+
/// srv.add_before(logger->create_before_hook());
41+
/// srv.add_after(logger->create_after_hook());
42+
/// ```
43+
class LoggingMiddleware
44+
{
45+
public:
46+
explicit LoggingMiddleware(LogCallback callback) : callback_(std::move(callback))
47+
{
48+
}
49+
50+
/// Create a BeforeHook that logs incoming requests
51+
BeforeHook create_before_hook();
52+
53+
/// Create an AfterHook that logs completed requests
54+
AfterHook create_after_hook();
55+
56+
private:
57+
LogCallback callback_;
58+
std::mutex mutex_;
59+
std::unordered_map<std::string, size_t> request_sizes_; // Track sizes for after hook
60+
};
61+
62+
/// Rate limiting middleware for DoS prevention (v2.13.0+)
63+
///
64+
/// Enforces per-route request limits using a sliding window algorithm.
65+
/// Rejects requests that exceed the configured rate.
66+
///
67+
/// Usage:
68+
/// ```cpp
69+
/// auto limiter = std::make_shared<RateLimitMiddleware>(
70+
/// 100, // max requests
71+
/// std::chrono::minutes(1) // per time window
72+
/// );
73+
/// srv.add_before(limiter->create_hook());
74+
/// ```
75+
class RateLimitMiddleware
76+
{
77+
public:
78+
/// Construct rate limiter
79+
/// @param max_requests Maximum requests allowed in time window
80+
/// @param window Time window for rate limiting
81+
RateLimitMiddleware(size_t max_requests,
82+
std::chrono::steady_clock::duration window = std::chrono::minutes(1))
83+
: max_requests_(max_requests), window_(window)
84+
{
85+
}
86+
87+
/// Create a BeforeHook that enforces rate limits
88+
BeforeHook create_hook();
89+
90+
/// Get current request count for a route
91+
size_t get_request_count(const std::string& route);
92+
93+
/// Reset rate limit counters (for testing)
94+
void reset();
95+
96+
private:
97+
size_t max_requests_;
98+
std::chrono::steady_clock::duration window_;
99+
std::mutex mutex_;
100+
101+
struct RouteStats
102+
{
103+
std::deque<std::chrono::steady_clock::time_point> timestamps;
104+
};
105+
106+
std::unordered_map<std::string, RouteStats> stats_;
107+
108+
void cleanup_old_entries(RouteStats& stats);
109+
};
110+
111+
/// Concurrency limiting middleware for resource control (v2.13.0+)
112+
///
113+
/// Limits the number of concurrent route handler executions.
114+
/// Uses atomic counters for thread-safe tracking.
115+
///
116+
/// Usage:
117+
/// ```cpp
118+
/// auto limiter = std::make_shared<ConcurrencyLimitMiddleware>(10); // Max 10 parallel
119+
/// srv.add_before(limiter->create_before_hook());
120+
/// srv.add_after(limiter->create_after_hook());
121+
/// ```
122+
class ConcurrencyLimitMiddleware
123+
{
124+
public:
125+
/// Construct concurrency limiter
126+
/// @param max_concurrent Maximum number of concurrent handler executions
127+
explicit ConcurrencyLimitMiddleware(size_t max_concurrent) : max_concurrent_(max_concurrent)
128+
{
129+
}
130+
131+
/// Create a BeforeHook that checks concurrency limit
132+
BeforeHook create_before_hook();
133+
134+
/// Create an AfterHook that releases concurrency slot
135+
AfterHook create_after_hook();
136+
137+
/// Get current concurrent request count
138+
size_t get_current_count() const
139+
{
140+
return current_count_.load();
141+
}
142+
143+
private:
144+
size_t max_concurrent_;
145+
std::atomic<size_t> current_count_{0};
146+
};
147+
148+
} // namespace fastmcpp::server

src/server/security_middleware.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#include "fastmcpp/server/security_middleware.hpp"
2+
3+
#include "fastmcpp/exceptions.hpp"
4+
5+
#include <algorithm>
6+
#include <sstream>
7+
8+
namespace fastmcpp::server
9+
{
10+
11+
// LoggingMiddleware implementation
12+
13+
BeforeHook LoggingMiddleware::create_before_hook()
14+
{
15+
return [this](const std::string& route, const Json& payload) -> std::optional<Json>
16+
{
17+
std::lock_guard<std::mutex> lock(mutex_);
18+
19+
// Store payload size for correlation with after hook
20+
size_t payload_size = payload.dump().size();
21+
request_sizes_[route] = payload_size;
22+
23+
// Log the incoming request
24+
RequestLogEntry entry;
25+
entry.timestamp = std::chrono::system_clock::now();
26+
entry.route = route;
27+
entry.payload_size = payload_size;
28+
entry.success = true; // Will be updated in after hook if there's an error
29+
entry.error_message = "";
30+
31+
if (callback_)
32+
callback_(entry);
33+
34+
return std::nullopt; // Continue to normal handler
35+
};
36+
}
37+
38+
AfterHook LoggingMiddleware::create_after_hook()
39+
{
40+
return [this](const std::string& route, const Json& /*payload*/, Json& response)
41+
{
42+
std::lock_guard<std::mutex> lock(mutex_);
43+
44+
// Log the completed request
45+
RequestLogEntry entry;
46+
entry.timestamp = std::chrono::system_clock::now();
47+
entry.route = route;
48+
entry.payload_size = request_sizes_[route]; // Get stored size
49+
entry.success = !response.contains("error");
50+
entry.error_message =
51+
response.contains("error") ? response["error"].dump() : std::string();
52+
53+
if (callback_)
54+
callback_(entry);
55+
56+
// Clean up stored size
57+
request_sizes_.erase(route);
58+
};
59+
}
60+
61+
// RateLimitMiddleware implementation
62+
63+
void RateLimitMiddleware::cleanup_old_entries(RouteStats& stats)
64+
{
65+
auto now = std::chrono::steady_clock::now();
66+
auto cutoff = now - window_;
67+
68+
// Remove timestamps older than the window
69+
while (!stats.timestamps.empty() && stats.timestamps.front() < cutoff)
70+
{
71+
stats.timestamps.pop_front();
72+
}
73+
}
74+
75+
BeforeHook RateLimitMiddleware::create_hook()
76+
{
77+
return [this](const std::string& route, const Json& /*payload*/) -> std::optional<Json>
78+
{
79+
std::lock_guard<std::mutex> lock(mutex_);
80+
81+
auto& stats = stats_[route];
82+
cleanup_old_entries(stats);
83+
84+
// Check if rate limit exceeded
85+
if (stats.timestamps.size() >= max_requests_)
86+
{
87+
// Return rate limit error
88+
return Json{{"error",
89+
Json{{"code", -32000}, // JSON-RPC server error
90+
{"message", "Rate limit exceeded for route: " + route},
91+
{"data",
92+
Json{{"route", route},
93+
{"limit", max_requests_},
94+
{"window_seconds",
95+
std::chrono::duration_cast<std::chrono::seconds>(window_).count()},
96+
{"current_count", stats.timestamps.size()}}}}}};
97+
}
98+
99+
// Record this request
100+
stats.timestamps.push_back(std::chrono::steady_clock::now());
101+
102+
return std::nullopt; // Continue to normal handler
103+
};
104+
}
105+
106+
size_t RateLimitMiddleware::get_request_count(const std::string& route)
107+
{
108+
std::lock_guard<std::mutex> lock(mutex_);
109+
auto it = stats_.find(route);
110+
if (it == stats_.end())
111+
return 0;
112+
113+
cleanup_old_entries(it->second);
114+
return it->second.timestamps.size();
115+
}
116+
117+
void RateLimitMiddleware::reset()
118+
{
119+
std::lock_guard<std::mutex> lock(mutex_);
120+
stats_.clear();
121+
}
122+
123+
// ConcurrencyLimitMiddleware implementation
124+
125+
BeforeHook ConcurrencyLimitMiddleware::create_before_hook()
126+
{
127+
return [this](const std::string& route, const Json& /*payload*/) -> std::optional<Json>
128+
{
129+
size_t current = current_count_.fetch_add(1);
130+
131+
// Check if we exceeded the limit
132+
if (current >= max_concurrent_)
133+
{
134+
// Rollback the increment
135+
current_count_.fetch_sub(1);
136+
137+
// Return concurrency limit error
138+
return Json{{"error",
139+
Json{{"code", -32000}, // JSON-RPC server error
140+
{"message", "Concurrency limit exceeded"},
141+
{"data",
142+
Json{{"route", route},
143+
{"limit", max_concurrent_},
144+
{"current", current}}}}}};
145+
}
146+
147+
return std::nullopt; // Continue to normal handler
148+
};
149+
}
150+
151+
AfterHook ConcurrencyLimitMiddleware::create_after_hook()
152+
{
153+
return [this](const std::string& /*route*/, const Json& /*payload*/, Json& /*response*/)
154+
{
155+
// Decrement the counter when handler completes
156+
current_count_.fetch_sub(1);
157+
};
158+
}
159+
160+
} // namespace fastmcpp::server

0 commit comments

Comments
 (0)