Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions lib/mcp/server/transports/streamable_http_transport.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ def close
@reaper_thread&.kill
@reaper_thread = nil

@mutex.synchronize do
@sessions.each_key { |session_id| cleanup_session_unsafe(session_id) }
removed_sessions = @mutex.synchronize do
@sessions.each_key.filter_map { |session_id| cleanup_session_unsafe(session_id) }
end

removed_sessions.each do |session|
close_stream_safely(session[:stream])
end
end

Expand All @@ -65,15 +69,17 @@ def send_notification(method, params = nil, session_id: nil)
}
notification[:params] = params if params

@mutex.synchronize do
streams_to_close = []

result = @mutex.synchronize do
if session_id
# Send to specific session
session = @sessions[session_id]
return false unless session && session[:stream]
next false unless session && session[:stream]

if session_expired?(session)
cleanup_session_unsafe(session_id)
return false
cleanup_and_collect_stream(session_id, streams_to_close)
next false
end

begin
Expand All @@ -84,7 +90,7 @@ def send_notification(method, params = nil, session_id: nil)
e,
{ session_id: session_id, error: "Failed to send notification" },
)
cleanup_session_unsafe(session_id)
cleanup_and_collect_stream(session_id, streams_to_close)
false
end
else
Expand Down Expand Up @@ -113,11 +119,17 @@ def send_notification(method, params = nil, session_id: nil)
end

# Clean up failed sessions
failed_sessions.each { |sid| cleanup_session_unsafe(sid) }
failed_sessions.each { |sid| cleanup_and_collect_stream(sid, streams_to_close) }

sent_count
end
end

streams_to_close.each do |stream|
close_stream_safely(stream)
end

result
end

private
Expand All @@ -136,22 +148,16 @@ def start_reaper_thread
def reap_expired_sessions
return unless @session_idle_timeout

expired_streams = @mutex.synchronize do
@sessions.each_with_object([]) do |(session_id, session), streams|
next unless session_expired?(session)
removed_sessions = @mutex.synchronize do
@sessions.each_key.filter_map do |session_id|
next unless session_expired?(@sessions[session_id])

streams << session[:stream] if session[:stream]
@sessions.delete(session_id)
cleanup_session_unsafe(session_id)
end
end

expired_streams.each do |stream|
# Closing outside the mutex is safe because expired sessions are already
# removed from `@sessions` above, so other threads will not find them
# and will not attempt to close the same stream.
stream.close
rescue StandardError
# Ignore close-related errors from already closed/broken streams.
removed_sessions.each do |session|
close_stream_safely(session[:stream])
end
end

Expand Down Expand Up @@ -228,23 +234,32 @@ def handle_delete(request)
end

def cleanup_session(session_id)
@mutex.synchronize do
session = @mutex.synchronize do
cleanup_session_unsafe(session_id)
end

close_stream_safely(session[:stream]) if session
end

# Removes a session from `@sessions` and returns it. Does not close the stream.
# Callers must close the stream outside the mutex to avoid holding the lock during
# potentially blocking I/O.
def cleanup_session_unsafe(session_id)
session = @sessions[session_id]
return unless session

begin
session[:stream]&.close
rescue StandardError
# Ignore close-related errors from already closed/broken streams.
end
@sessions.delete(session_id)
end

def cleanup_and_collect_stream(session_id, streams_to_close)
return unless (removed = cleanup_session_unsafe(session_id))

streams_to_close << removed[:stream]
end

def close_stream_safely(stream)
stream&.close
rescue StandardError
# Ignore close-related errors from already closed/broken streams.
end

def extract_session_id(request)
request.env["HTTP_MCP_SESSION_ID"]
end
Expand Down Expand Up @@ -357,19 +372,24 @@ def handle_regular_request(body_string, session_id)
end

def validate_and_touch_session(session_id)
@mutex.synchronize do
return session_not_found_response unless (session = @sessions[session_id])
return unless @session_idle_timeout
removed = nil

response = @mutex.synchronize do
next session_not_found_response unless (session = @sessions[session_id])
next unless @session_idle_timeout

if session_expired?(session)
cleanup_session_unsafe(session_id)
return session_not_found_response
removed = cleanup_session_unsafe(session_id)
next session_not_found_response
end

session[:last_active_at] = Process.clock_gettime(Process::CLOCK_MONOTONIC)
nil
end

nil
close_stream_safely(removed[:stream]) if removed

response
end

def get_session_stream(session_id)
Expand Down
66 changes: 66 additions & 0 deletions test/mcp/server/transports/streamable_http_transport_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,37 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
assert_not @transport.instance_variable_get(:@sessions).key?(session_id)
end

test "send_notification closes stream outside mutex on write error" do
init_request = create_rack_request(
"POST",
"/",
{ "CONTENT_TYPE" => "application/json" },
{ jsonrpc: "2.0", method: "initialize", id: "123" }.to_json,
)
init_response = @transport.handle_request(init_request)
session_id = init_response[1]["Mcp-Session-Id"]

# Use a mock stream that verifies mutex is NOT held during close.
mutex = @transport.instance_variable_get(:@mutex)
closed_outside_mutex = false
mock_stream = Object.new
mock_stream.define_singleton_method(:write) { |_data| raise Errno::EPIPE }
mock_stream.define_singleton_method(:close) do
if mutex.try_lock
closed_outside_mutex = true
mutex.unlock
end
end

@transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream

result = @transport.send_notification("test", { message: "test" }, session_id: session_id)

refute result
assert closed_outside_mutex, "Stream should be closed outside the mutex"
assert_not @transport.instance_variable_get(:@sessions).key?(session_id)
end

test "send_notification broadcast continues when one session raises Errno::ECONNRESET" do
# Create two sessions.
init_request1 = create_rack_request(
Expand Down Expand Up @@ -1613,6 +1644,41 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase
transport.close
end

test "reap_expired_sessions closes stream outside mutex" do
transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.01)

init_request = create_rack_request(
"POST",
"/",
{ "CONTENT_TYPE" => "application/json" },
{ jsonrpc: "2.0", method: "initialize", id: "init" }.to_json,
)
init_response = transport.handle_request(init_request)
session_id = init_response[1]["Mcp-Session-Id"]

# Replace the stream with one that verifies mutex is NOT held during close.
mutex = transport.instance_variable_get(:@mutex)
closed_outside_mutex = false
mock_stream = Object.new
mock_stream.define_singleton_method(:close) do
# If stream.close runs outside the mutex, try_lock succeeds.
if mutex.try_lock
closed_outside_mutex = true
mutex.unlock
end
end
transport.instance_variable_get(:@sessions)[session_id][:stream] = mock_stream

sleep(0.02) # Wait for session to expire.

transport.send(:reap_expired_sessions)

assert(closed_outside_mutex, "Stream should be closed outside the mutex")
assert_empty(transport.instance_variable_get(:@sessions))
ensure
transport.close
end

test "close stops the reaper thread" do
transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 3600)
reaper_thread = transport.instance_variable_get(:@reaper_thread)
Expand Down