Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions lib/mcp/server/transports/streamable_http_transport.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class StreamableHTTPTransport < Transport

def initialize(server, stateless: false, session_idle_timeout: nil)
super(server)
# Maps `session_id` to `{ stream: stream_object, server_session: ServerSession, last_active_at: float_from_monotonic_clock }`.
# Maps `session_id` to `{ get_sse_stream: stream_object, server_session: ServerSession, last_active_at: float_from_monotonic_clock }`.
@sessions = {}
@mutex = Mutex.new

Expand Down Expand Up @@ -61,7 +61,7 @@ def close
end

removed_sessions.each do |session|
close_stream_safely(session[:stream])
close_stream_safely(session[:get_sse_stream])
close_post_request_streams(session)
end
end
Expand Down Expand Up @@ -113,7 +113,7 @@ def send_notification(method, params = nil, session_id: nil, related_request_id:
failed_sessions = []

@sessions.each do |sid, session|
next unless (stream = session[:stream])
next unless (stream = session[:get_sse_stream])

if session_expired?(session)
failed_sessions << sid
Expand Down Expand Up @@ -247,7 +247,7 @@ def reap_expired_sessions
end

removed_sessions.each do |session|
close_stream_safely(session[:stream])
close_stream_safely(session[:get_sse_stream])
close_post_request_streams(session)
end
end
Expand Down Expand Up @@ -334,7 +334,7 @@ def cleanup_session(session_id)
end

if session
close_stream_safely(session[:stream])
close_stream_safely(session[:get_sse_stream])
close_post_request_streams(session)
end
end
Expand All @@ -358,7 +358,7 @@ def cleanup_session_unsafe(session_id)
def cleanup_and_collect_stream(session_id, streams_to_close)
return unless (removed = cleanup_session_unsafe(session_id))

streams_to_close << removed[:stream]
streams_to_close << removed[:get_sse_stream]
removed[:post_request_streams]&.each_value { |stream| streams_to_close << stream }
end

Expand Down Expand Up @@ -449,7 +449,7 @@ def handle_initialization(body_string, body)

@mutex.synchronize do
@sessions[session_id] = {
stream: nil,
get_sse_stream: nil,
server_session: server_session,
last_active_at: Process.clock_gettime(Process::CLOCK_MONOTONIC),
}
Expand Down Expand Up @@ -543,7 +543,7 @@ def active_stream(session, related_request_id: nil)
if related_request_id
session.dig(:post_request_streams, related_request_id)
else
session[:stream]
session[:get_sse_stream]
end
end

Expand Down Expand Up @@ -572,7 +572,7 @@ def validate_and_touch_session(session_id)
end

if removed
close_stream_safely(removed[:stream])
close_stream_safely(removed[:get_sse_stream])

removed[:post_request_streams]&.each_value do |stream|
close_stream_safely(stream)
Expand All @@ -583,7 +583,7 @@ def validate_and_touch_session(session_id)
end

def get_session_stream(session_id)
@mutex.synchronize { @sessions[session_id]&.fetch(:stream, nil) }
@mutex.synchronize { @sessions[session_id]&.fetch(:get_sse_stream, nil) }
end

def session_exists?(session_id)
Expand Down Expand Up @@ -626,8 +626,8 @@ def create_sse_body(session_id)
def store_stream_for_session(session_id, stream)
@mutex.synchronize do
session = @sessions[session_id]
if session && !session[:stream]
session[:stream] = stream
if session && !session[:get_sse_stream]
session[:get_sse_stream] = stream
else
# Either session was removed, or another request already established a stream.
stream.close
Expand All @@ -652,13 +652,13 @@ def start_keepalive_thread(session_id)
end

def session_active_with_stream?(session_id)
@mutex.synchronize { @sessions.key?(session_id) && @sessions[session_id][:stream] }
@mutex.synchronize { @sessions.key?(session_id) && @sessions[session_id][:get_sse_stream] }
end

def send_keepalive_ping(session_id)
@mutex.synchronize do
if @sessions[session_id] && @sessions[session_id][:stream]
send_ping_to_stream(@sessions[session_id][:stream])
if @sessions[session_id] && @sessions[session_id][:get_sse_stream]
send_ping_to_stream(@sessions[session_id][:get_sse_stream])
end
end
rescue *STREAM_WRITE_ERRORS => e
Expand Down
14 changes: 7 additions & 7 deletions test/mcp/server/transports/streamable_http_transport_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def string

# Simulate an active SSE stream by storing a stream object in the session
mock_stream = StringIO.new
@transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream
@transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = mock_stream

# Attempt a second GET request for the same session
get_request = create_rack_request(
Expand Down Expand Up @@ -377,14 +377,14 @@ def string
# Establish stream A
stream_a = StringIO.new
@transport.send(:store_stream_for_session, session_id, stream_a)
assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:stream]
assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream]

# Attempt to store stream B (simulating a racing request)
stream_b = StringIO.new
@transport.send(:store_stream_for_session, session_id, stream_b)

# Stream A should still be the active stream
assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:stream]
assert_equal stream_a, @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream]

# Stream B should have been closed
assert stream_b.closed?
Expand Down Expand Up @@ -928,7 +928,7 @@ def string
end
end

@transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream
@transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = mock_stream

result = @transport.send_notification("test", { message: "test" }, session_id: session_id)

Expand Down Expand Up @@ -973,7 +973,7 @@ def string
# The broken request_stream should be removed.
refute @transport.instance_variable_get(:@sessions)[session_id][:post_request_streams].key?(related_id)
# GET SSE stream should still be intact.
assert @transport.instance_variable_get(:@sessions)[session_id][:stream]
assert @transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream]
end

test "active_stream does not fall back to GET SSE when related_request_id is given but request_stream is missing" do
Expand Down Expand Up @@ -2378,7 +2378,7 @@ def string
mutex.unlock
end
end
transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream
transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = mock_stream

sleep(0.02) # Wait for session to expire.

Expand Down Expand Up @@ -2495,7 +2495,7 @@ def string

# Attach a mock stream to the session
stream = StringIO.new
transport.instance_variable_get(:@sessions)[session_id][:stream] = stream
transport.instance_variable_get(:@sessions)[session_id][:get_sse_stream] = stream

# Wait for the session to exceed the idle timeout (0.01s)
sleep(0.02)
Expand Down