diff --git a/lib/mcp/server/transports/streamable_http_transport.rb b/lib/mcp/server/transports/streamable_http_transport.rb index 688e38c..b03b6ce 100644 --- a/lib/mcp/server/transports/streamable_http_transport.rb +++ b/lib/mcp/server/transports/streamable_http_transport.rb @@ -15,7 +15,7 @@ class StreamableHTTPTransport < Transport def initialize(server, stateless: false, session_idle_timeout: nil) super(server) - # Maps `session_id` to `{ stream: stream_object, server_session: ServerSession, last_active_at: float_from_monotonic_clock }`. + # Maps `session_id` to `{ get_sse_stream: stream_object, server_session: ServerSession, last_active_at: float_from_monotonic_clock }`. @sessions = {} @mutex = Mutex.new @@ -61,7 +61,7 @@ def close end removed_sessions.each do |session| - close_stream_safely(session[:stream]) + close_stream_safely(session[:get_sse_stream]) close_post_request_streams(session) end end @@ -113,7 +113,7 @@ def send_notification(method, params = nil, session_id: nil, related_request_id: failed_sessions = [] @sessions.each do |sid, session| - next unless (stream = session[:stream]) + next unless (stream = session[:get_sse_stream]) if session_expired?(session) failed_sessions << sid @@ -247,7 +247,7 @@ def reap_expired_sessions end removed_sessions.each do |session| - close_stream_safely(session[:stream]) + close_stream_safely(session[:get_sse_stream]) close_post_request_streams(session) end end @@ -334,7 +334,7 @@ def cleanup_session(session_id) end if session - close_stream_safely(session[:stream]) + close_stream_safely(session[:get_sse_stream]) close_post_request_streams(session) end end @@ -358,7 +358,7 @@ def cleanup_session_unsafe(session_id) def cleanup_and_collect_stream(session_id, streams_to_close) return unless (removed = cleanup_session_unsafe(session_id)) - streams_to_close << removed[:stream] + streams_to_close << removed[:get_sse_stream] removed[:post_request_streams]&.each_value { |stream| streams_to_close << stream } end @@ -449,7 +449,7 @@ def handle_initialization(body_string, body) @mutex.synchronize do @sessions[session_id] = { - stream: nil, + get_sse_stream: nil, server_session: server_session, last_active_at: Process.clock_gettime(Process::CLOCK_MONOTONIC), } @@ -543,7 +543,7 @@ def active_stream(session, related_request_id: nil) if related_request_id session.dig(:post_request_streams, related_request_id) else - session[:stream] + session[:get_sse_stream] end end @@ -572,7 +572,7 @@ def validate_and_touch_session(session_id) end if removed - close_stream_safely(removed[:stream]) + close_stream_safely(removed[:get_sse_stream]) removed[:post_request_streams]&.each_value do |stream| close_stream_safely(stream) @@ -583,7 +583,7 @@ def validate_and_touch_session(session_id) end def get_session_stream(session_id) - @mutex.synchronize { @sessions[session_id]&.fetch(:stream, nil) } + @mutex.synchronize { @sessions[session_id]&.fetch(:get_sse_stream, nil) } end def session_exists?(session_id) @@ -626,8 +626,8 @@ def create_sse_body(session_id) def store_stream_for_session(session_id, stream) @mutex.synchronize do session = @sessions[session_id] - if session && !session[:stream] - session[:stream] = stream + if session && !session[:get_sse_stream] + session[:get_sse_stream] = stream else # Either session was removed, or another request already established a stream. stream.close @@ -652,13 +652,13 @@ def start_keepalive_thread(session_id) end def session_active_with_stream?(session_id) - @mutex.synchronize { @sessions.key?(session_id) && @sessions[session_id][:stream] } + @mutex.synchronize { @sessions.key?(session_id) && @sessions[session_id][:get_sse_stream] } end def send_keepalive_ping(session_id) @mutex.synchronize do - if @sessions[session_id] && @sessions[session_id][:stream] - send_ping_to_stream(@sessions[session_id][:stream]) + if @sessions[session_id] && @sessions[session_id][:get_sse_stream] + send_ping_to_stream(@sessions[session_id][:get_sse_stream]) end end rescue *STREAM_WRITE_ERRORS => e diff --git a/test/mcp/server/transports/streamable_http_transport_test.rb b/test/mcp/server/transports/streamable_http_transport_test.rb index 7435289..473d849 100644 --- a/test/mcp/server/transports/streamable_http_transport_test.rb +++ b/test/mcp/server/transports/streamable_http_transport_test.rb @@ -346,7 +346,7 @@ def string # Simulate an active SSE stream by storing a stream object in the session mock_stream = StringIO.new - @transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream + @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = mock_stream # Attempt a second GET request for the same session get_request = create_rack_request( @@ -377,14 +377,14 @@ def string # Establish stream A stream_a = StringIO.new @transport.send(:store_stream_for_session, session_id, stream_a) - assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:stream] + assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] # Attempt to store stream B (simulating a racing request) stream_b = StringIO.new @transport.send(:store_stream_for_session, session_id, stream_b) # Stream A should still be the active stream - assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:stream] + assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] # Stream B should have been closed assert stream_b.closed? @@ -928,7 +928,7 @@ def string end end - @transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream + @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = mock_stream result = @transport.send_notification("test", { message: "test" }, session_id: session_id) @@ -973,7 +973,7 @@ def string # The broken request_stream should be removed. refute @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams].key?(related_id) # GET SSE stream should still be intact. - assert @transport.instance_variable_get(:@sessions)[session_id][:stream] + assert @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] end test "active_stream does not fall back to GET SSE when related_request_id is given but request_stream is missing" do @@ -2378,7 +2378,7 @@ def string mutex.unlock end end - transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream + transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = mock_stream sleep(0.02) # Wait for session to expire. @@ -2495,7 +2495,7 @@ def string # Attach a mock stream to the session stream = StringIO.new - transport.instance_variable_get(:@sessions)[session_id][:stream] = stream + transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = stream # Wait for the session to exceed the idle timeout (0.01s) sleep(0.02)