77// / - Middleware base class with virtual hooks
88// / - Built-in implementations: Logging, Timing, Caching, RateLimiting, ErrorHandling
99
10- #include " fastmcpp/types.hpp"
11-
12- #include < chrono>
13- #include < functional>
14- #include < iostream>
15- #include < memory>
16- #include < mutex>
17- #include < optional>
18- #include < string>
19- #include < unordered_map>
20- #include < vector>
10+ #include " fastmcpp/server/session.hpp"
11+ #include " fastmcpp/types.hpp"
12+
13+ #include < atomic>
14+ #include < chrono>
15+ #include < condition_variable>
16+ #include < functional>
17+ #include < cstdint>
18+ #include < iostream>
19+ #include < memory>
20+ #include < mutex>
21+ #include < optional>
22+ #include < stdexcept>
23+ #include < string>
24+ #include < thread>
25+ #include < unordered_map>
26+ #include < unordered_set>
27+ #include < vector>
2128
2229namespace fastmcpp ::server
2330{
@@ -32,11 +39,12 @@ struct MiddlewareContext
3239 std::string method; // /< MCP method name (e.g., "tools/call")
3340 std::string source{" client" }; // /< Origin: "client" or "server"
3441 std::string type{" request" }; // /< Message type: "request" or "notification"
35- std::chrono::steady_clock::time_point timestamp; // /< Request timestamp
36- std::optional<std::string> request_id; // /< Request ID if available
37- std::optional<std::string> tool_name; // /< Tool name for tools/call
38- std::optional<std::string> resource_uri; // /< Resource URI for resources/read
39- std::optional<std::string> prompt_name; // /< Prompt name for prompts/get
42+ std::chrono::steady_clock::time_point timestamp; // /< Request timestamp
43+ std::optional<std::string> request_id; // /< Request ID if available
44+ std::shared_ptr<ServerSession> session; // /< ServerSession for this request (optional)
45+ std::optional<std::string> tool_name; // /< Tool name for tools/call
46+ std::optional<std::string> resource_uri; // /< Resource URI for resources/read
47+ std::optional<std::string> prompt_name; // /< Prompt name for prompts/get
4048
4149 // / Create a copy with modified fields
4250 MiddlewareContext copy () const
@@ -502,8 +510,8 @@ class RateLimitingMiddleware : public Middleware
502510};
503511
504512// / Error handling middleware - catches exceptions and converts to MCP errors
505- class ErrorHandlingMiddleware : public Middleware
506- {
513+ class ErrorHandlingMiddleware : public Middleware
514+ {
507515 public:
508516 using ErrorCallback = std::function<void (const std::string& method, const std::exception& e)>;
509517
@@ -570,7 +578,125 @@ class ErrorHandlingMiddleware : public Middleware
570578 ErrorCallback callback_;
571579 bool include_trace_;
572580 mutable std::mutex mutex_;
573- std::unordered_map<std::string, size_t > error_counts_;
574- };
575-
576- } // namespace fastmcpp::server
581+ std::unordered_map<std::string, size_t > error_counts_;
582+ };
583+
584+ // / Ping middleware - sends periodic pings to keep client connections alive
585+ class PingMiddleware : public Middleware
586+ {
587+ public:
588+ explicit PingMiddleware (std::chrono::milliseconds interval = std::chrono::milliseconds(30000 ))
589+ : interval_(interval)
590+ {
591+ if (interval_.count () <= 0 )
592+ throw std::invalid_argument (" interval must be positive" );
593+ }
594+
595+ explicit PingMiddleware (int interval_ms)
596+ : PingMiddleware(std::chrono::milliseconds(interval_ms))
597+ {
598+ }
599+
600+ ~PingMiddleware () override
601+ {
602+ stop ();
603+ }
604+
605+ Json operator ()(const MiddlewareContext& ctx, CallNext call_next) override
606+ {
607+ if (ctx.session )
608+ ensure_session (ctx.session );
609+ return call_next (ctx);
610+ }
611+
612+ private:
613+ void ensure_session (const std::shared_ptr<ServerSession>& session)
614+ {
615+ const std::string key = session_key (session);
616+ if (key.empty ())
617+ return ;
618+
619+ bool should_start = false ;
620+ {
621+ std::lock_guard<std::mutex> lock (mutex_);
622+ if (active_sessions_.insert (key).second )
623+ should_start = true ;
624+ }
625+
626+ if (!should_start)
627+ return ;
628+
629+ std::weak_ptr<ServerSession> weak_session = session;
630+ std::thread worker ([this , weak_session, key]() { ping_loop (weak_session, key); });
631+ {
632+ std::lock_guard<std::mutex> lock (mutex_);
633+ threads_.push_back (std::move (worker));
634+ }
635+ }
636+
637+ void ping_loop (std::weak_ptr<ServerSession> weak_session, const std::string& key)
638+ {
639+ while (true )
640+ {
641+ {
642+ std::unique_lock<std::mutex> lock (mutex_);
643+ if (cv_.wait_for (lock, interval_, [this ]() { return stop_.load (); }))
644+ break ;
645+ }
646+
647+ if (stop_.load ())
648+ break ;
649+
650+ auto session = weak_session.lock ();
651+ if (!session)
652+ break ;
653+
654+ try
655+ {
656+ session->send_ping (interval_);
657+ }
658+ catch (const std::exception&)
659+ {
660+ break ;
661+ }
662+ }
663+
664+ std::lock_guard<std::mutex> lock (mutex_);
665+ active_sessions_.erase (key);
666+ }
667+
668+ void stop ()
669+ {
670+ stop_.store (true );
671+ cv_.notify_all ();
672+
673+ std::vector<std::thread> threads;
674+ {
675+ std::lock_guard<std::mutex> lock (mutex_);
676+ threads.swap (threads_);
677+ }
678+
679+ for (auto & t : threads)
680+ if (t.joinable ())
681+ t.join ();
682+ }
683+
684+ static std::string session_key (const std::shared_ptr<ServerSession>& session)
685+ {
686+ if (!session)
687+ return {};
688+ auto key = session->session_id ();
689+ if (!key.empty ())
690+ return key;
691+ return " session@" + std::to_string (reinterpret_cast <std::uintptr_t >(session.get ()));
692+ }
693+
694+ std::chrono::milliseconds interval_;
695+ std::mutex mutex_;
696+ std::condition_variable cv_;
697+ std::unordered_set<std::string> active_sessions_;
698+ std::vector<std::thread> threads_;
699+ std::atomic<bool > stop_{false };
700+ };
701+
702+ } // namespace fastmcpp::server
0 commit comments