Skip to content

Commit 8de5788

Browse files
committed
feat: add tool timeouts, context transport, ping middleware
1 parent e590e3d commit 8de5788

17 files changed

Lines changed: 501 additions & 48 deletions

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ if(FASTMCPP_BUILD_TESTS)
226226
target_link_libraries(fastmcpp_tools_manager PRIVATE fastmcpp_core)
227227
add_test(NAME fastmcpp_tools_manager COMMAND fastmcpp_tools_manager)
228228

229+
add_executable(fastmcpp_tools_timeout tests/tools/test_tool_timeout.cpp)
230+
target_link_libraries(fastmcpp_tools_timeout PRIVATE fastmcpp_core)
231+
add_test(NAME fastmcpp_tools_timeout COMMAND fastmcpp_tools_timeout)
232+
229233
add_executable(fastmcpp_integration tests/integration.cpp)
230234
target_link_libraries(fastmcpp_integration PRIVATE fastmcpp_core)
231235
add_test(NAME fastmcpp_integration COMMAND fastmcpp_integration)

include/fastmcpp/app.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "fastmcpp/server/server.hpp"
88
#include "fastmcpp/tools/manager.hpp"
99

10+
#include <chrono>
1011
#include <memory>
1112
#include <optional>
1213
#include <string>
@@ -65,6 +66,7 @@ class FastMCP
6566
std::vector<std::string> exclude_args;
6667
TaskSupport task_support{TaskSupport::Forbidden};
6768
Json output_schema{Json::object()};
69+
std::optional<std::chrono::milliseconds> timeout;
6870
};
6971

7072
struct PromptOptions
@@ -250,7 +252,7 @@ class FastMCP
250252
// =========================================================================
251253

252254
/// Invoke a tool by name (handles prefixed routing)
253-
Json invoke_tool(const std::string& name, const Json& args) const;
255+
Json invoke_tool(const std::string& name, const Json& args, bool enforce_timeout = true) const;
254256

255257
/// Read a resource by URI (handles prefixed routing)
256258
resources::ResourceContent read_resource(const std::string& uri,

include/fastmcpp/exceptions.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ struct ValidationError : public Error
2020
using Error::Error;
2121
};
2222

23+
struct ToolTimeoutError : public Error
24+
{
25+
using Error::Error;
26+
};
27+
2328
struct TransportError : public Error
2429
{
2530
using Error::Error;

include/fastmcpp/proxy.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ class ProxyApp
107107

108108
/// Invoke a tool by name
109109
/// Tries local tools first, falls back to remote
110-
client::CallToolResult invoke_tool(const std::string& name, const Json& args) const;
110+
client::CallToolResult invoke_tool(const std::string& name, const Json& args,
111+
bool enforce_timeout = true) const;
111112

112113
/// Read a resource by URI
113114
/// Tries local resources first, falls back to remote

include/fastmcpp/server/context.hpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ enum class LogLevel
3636
Error
3737
};
3838

39+
enum class TransportType
40+
{
41+
Stdio,
42+
Sse,
43+
StreamableHttp
44+
};
45+
3946
// ============================================================================
4047
// Sampling types (for Context.sample())
4148
// ============================================================================
@@ -146,6 +153,21 @@ inline std::string to_string(LogLevel level)
146153
}
147154
}
148155

156+
inline std::string to_string(TransportType transport)
157+
{
158+
switch (transport)
159+
{
160+
case TransportType::Stdio:
161+
return "stdio";
162+
case TransportType::Sse:
163+
return "sse";
164+
case TransportType::StreamableHttp:
165+
return "streamable-http";
166+
default:
167+
return "unknown";
168+
}
169+
}
170+
149171
using LogCallback = std::function<void(LogLevel, const std::string&, const std::string&)>;
150172
using ProgressCallback =
151173
std::function<void(const std::string&, double, double, const std::string&)>;
@@ -158,7 +180,8 @@ class Context
158180
Context(const resources::ResourceManager& rm, const prompts::PromptManager& pm,
159181
std::optional<fastmcpp::Json> request_meta,
160182
std::optional<std::string> request_id = std::nullopt,
161-
std::optional<std::string> session_id = std::nullopt);
183+
std::optional<std::string> session_id = std::nullopt,
184+
std::optional<TransportType> transport = std::nullopt);
162185

163186
std::vector<resources::Resource> list_resources() const;
164187
std::vector<prompts::Prompt> list_prompts() const;
@@ -177,6 +200,16 @@ class Context
177200
{
178201
return session_id_;
179202
}
203+
std::optional<std::string> transport() const
204+
{
205+
if (!transport_.has_value())
206+
return std::nullopt;
207+
return to_string(*transport_);
208+
}
209+
std::optional<TransportType> transport_type() const
210+
{
211+
return transport_;
212+
}
180213

181214
std::optional<std::string> client_id() const
182215
{
@@ -398,6 +431,7 @@ class Context
398431
std::optional<fastmcpp::Json> request_meta_;
399432
std::optional<std::string> request_id_;
400433
std::optional<std::string> session_id_;
434+
std::optional<TransportType> transport_;
401435
mutable std::unordered_map<std::string, std::any> state_;
402436
LogCallback log_callback_;
403437
ProgressCallback progress_callback_;

include/fastmcpp/server/middleware_pipeline.hpp

Lines changed: 148 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,24 @@
77
/// - Middleware base class with virtual hooks
88
/// - Built-in implementations: Logging, Timing, Caching, RateLimiting, ErrorHandling
99

10-
#include "fastmcpp/types.hpp"
11-
12-
#include <chrono>
13-
#include <functional>
14-
#include <iostream>
15-
#include <memory>
16-
#include <mutex>
17-
#include <optional>
18-
#include <string>
19-
#include <unordered_map>
20-
#include <vector>
10+
#include "fastmcpp/server/session.hpp"
11+
#include "fastmcpp/types.hpp"
12+
13+
#include <atomic>
14+
#include <chrono>
15+
#include <condition_variable>
16+
#include <functional>
17+
#include <cstdint>
18+
#include <iostream>
19+
#include <memory>
20+
#include <mutex>
21+
#include <optional>
22+
#include <stdexcept>
23+
#include <string>
24+
#include <thread>
25+
#include <unordered_map>
26+
#include <unordered_set>
27+
#include <vector>
2128

2229
namespace fastmcpp::server
2330
{
@@ -32,11 +39,12 @@ struct MiddlewareContext
3239
std::string method; ///< MCP method name (e.g., "tools/call")
3340
std::string source{"client"}; ///< Origin: "client" or "server"
3441
std::string type{"request"}; ///< Message type: "request" or "notification"
35-
std::chrono::steady_clock::time_point timestamp; ///< Request timestamp
36-
std::optional<std::string> request_id; ///< Request ID if available
37-
std::optional<std::string> tool_name; ///< Tool name for tools/call
38-
std::optional<std::string> resource_uri; ///< Resource URI for resources/read
39-
std::optional<std::string> prompt_name; ///< Prompt name for prompts/get
42+
std::chrono::steady_clock::time_point timestamp; ///< Request timestamp
43+
std::optional<std::string> request_id; ///< Request ID if available
44+
std::shared_ptr<ServerSession> session; ///< ServerSession for this request (optional)
45+
std::optional<std::string> tool_name; ///< Tool name for tools/call
46+
std::optional<std::string> resource_uri; ///< Resource URI for resources/read
47+
std::optional<std::string> prompt_name; ///< Prompt name for prompts/get
4048

4149
/// Create a copy with modified fields
4250
MiddlewareContext copy() const
@@ -502,8 +510,8 @@ class RateLimitingMiddleware : public Middleware
502510
};
503511

504512
/// Error handling middleware - catches exceptions and converts to MCP errors
505-
class ErrorHandlingMiddleware : public Middleware
506-
{
513+
class ErrorHandlingMiddleware : public Middleware
514+
{
507515
public:
508516
using ErrorCallback = std::function<void(const std::string& method, const std::exception& e)>;
509517

@@ -570,7 +578,125 @@ class ErrorHandlingMiddleware : public Middleware
570578
ErrorCallback callback_;
571579
bool include_trace_;
572580
mutable std::mutex mutex_;
573-
std::unordered_map<std::string, size_t> error_counts_;
574-
};
575-
576-
} // namespace fastmcpp::server
581+
std::unordered_map<std::string, size_t> error_counts_;
582+
};
583+
584+
/// Ping middleware - sends periodic pings to keep client connections alive
585+
class PingMiddleware : public Middleware
586+
{
587+
public:
588+
explicit PingMiddleware(std::chrono::milliseconds interval = std::chrono::milliseconds(30000))
589+
: interval_(interval)
590+
{
591+
if (interval_.count() <= 0)
592+
throw std::invalid_argument("interval must be positive");
593+
}
594+
595+
explicit PingMiddleware(int interval_ms)
596+
: PingMiddleware(std::chrono::milliseconds(interval_ms))
597+
{
598+
}
599+
600+
~PingMiddleware() override
601+
{
602+
stop();
603+
}
604+
605+
Json operator()(const MiddlewareContext& ctx, CallNext call_next) override
606+
{
607+
if (ctx.session)
608+
ensure_session(ctx.session);
609+
return call_next(ctx);
610+
}
611+
612+
private:
613+
void ensure_session(const std::shared_ptr<ServerSession>& session)
614+
{
615+
const std::string key = session_key(session);
616+
if (key.empty())
617+
return;
618+
619+
bool should_start = false;
620+
{
621+
std::lock_guard<std::mutex> lock(mutex_);
622+
if (active_sessions_.insert(key).second)
623+
should_start = true;
624+
}
625+
626+
if (!should_start)
627+
return;
628+
629+
std::weak_ptr<ServerSession> weak_session = session;
630+
std::thread worker([this, weak_session, key]() { ping_loop(weak_session, key); });
631+
{
632+
std::lock_guard<std::mutex> lock(mutex_);
633+
threads_.push_back(std::move(worker));
634+
}
635+
}
636+
637+
void ping_loop(std::weak_ptr<ServerSession> weak_session, const std::string& key)
638+
{
639+
while (true)
640+
{
641+
{
642+
std::unique_lock<std::mutex> lock(mutex_);
643+
if (cv_.wait_for(lock, interval_, [this]() { return stop_.load(); }))
644+
break;
645+
}
646+
647+
if (stop_.load())
648+
break;
649+
650+
auto session = weak_session.lock();
651+
if (!session)
652+
break;
653+
654+
try
655+
{
656+
session->send_ping(interval_);
657+
}
658+
catch (const std::exception&)
659+
{
660+
break;
661+
}
662+
}
663+
664+
std::lock_guard<std::mutex> lock(mutex_);
665+
active_sessions_.erase(key);
666+
}
667+
668+
void stop()
669+
{
670+
stop_.store(true);
671+
cv_.notify_all();
672+
673+
std::vector<std::thread> threads;
674+
{
675+
std::lock_guard<std::mutex> lock(mutex_);
676+
threads.swap(threads_);
677+
}
678+
679+
for (auto& t : threads)
680+
if (t.joinable())
681+
t.join();
682+
}
683+
684+
static std::string session_key(const std::shared_ptr<ServerSession>& session)
685+
{
686+
if (!session)
687+
return {};
688+
auto key = session->session_id();
689+
if (!key.empty())
690+
return key;
691+
return "session@" + std::to_string(reinterpret_cast<std::uintptr_t>(session.get()));
692+
}
693+
694+
std::chrono::milliseconds interval_;
695+
std::mutex mutex_;
696+
std::condition_variable cv_;
697+
std::unordered_set<std::string> active_sessions_;
698+
std::vector<std::thread> threads_;
699+
std::atomic<bool> stop_{false};
700+
};
701+
702+
} // namespace fastmcpp::server

include/fastmcpp/server/session.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,18 @@ class ServerSession
291291
send_callback_(notification);
292292
}
293293

294+
/**
295+
* Send a ping request to the client and wait for a response.
296+
*
297+
* @param timeout How long to wait for response
298+
* @throws RequestTimeoutError if timeout exceeded
299+
* @throws ClientError if client returns an error
300+
*/
301+
void send_ping(std::chrono::milliseconds timeout = DEFAULT_TIMEOUT)
302+
{
303+
(void)send_request("ping", Json::object(), timeout);
304+
}
305+
294306
/**
295307
* Send a progress notification to the client.
296308
*

include/fastmcpp/tools/manager.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ class ToolManager
1919
{
2020
return tools_.at(name);
2121
}
22-
fastmcpp::Json invoke(const std::string& name, const fastmcpp::Json& input) const
22+
fastmcpp::Json invoke(const std::string& name, const fastmcpp::Json& input,
23+
bool enforce_timeout = true) const
2324
{
2425
auto it = tools_.find(name);
2526
if (it == tools_.end())
2627
throw fastmcpp::NotFoundError("tool not found: " + name);
27-
return it->second.invoke(input);
28+
return it->second.invoke(input, enforce_timeout);
2829
}
2930

3031
std::vector<std::string> list_names() const

0 commit comments

Comments
 (0)