Skip to content

Commit cfc9d20

Browse files
committed
fix(server): sync disconnect callbacks (coder#176)
Cherry-picked from upstream coder/claudecode.nvim with conflict resolution. Preserves on_disconnect_cleanup for expected disconnection errors while adopting centralized _disconnect_client pattern.
1 parent 5867bcd commit cfc9d20

7 files changed

Lines changed: 219 additions & 33 deletions

File tree

lua/claudecode/server/init.lua

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@ local M = {}
1313
---@field server table|nil The TCP server instance
1414
---@field port number|nil The port server is running on
1515
---@field auth_token string|nil The authentication token for validating connections
16-
---@field clients table<string, WebSocketClient> A list of connected clients
1716
---@field handlers table Message handlers by method name
1817
---@field ping_timer table|nil Timer for sending pings
1918
M.state = {
2019
server = nil,
2120
port = nil,
2221
auth_token = nil,
23-
clients = {},
2422
handlers = {},
2523
ping_timer = nil,
2624
}
@@ -54,8 +52,6 @@ function M.start(config, auth_token)
5452
M._handle_message(client, message)
5553
end,
5654
on_connect = function(client)
57-
M.state.clients[client.id] = client
58-
5955
-- Log connection with auth status
6056
if M.state.auth_token then
6157
logger.debug("server", "Authenticated WebSocket client connected:", client.id)
@@ -84,8 +80,6 @@ function M.start(config, auth_token)
8480
on_disconnect = function(client, code, reason)
8581
-- Unbind client from session before removing
8682
session_manager.unbind_client(client.id)
87-
88-
M.state.clients[client.id] = nil
8983
logger.debug(
9084
"server",
9185
"WebSocket client disconnected:",
@@ -142,8 +136,6 @@ function M.stop()
142136
M.state.server = nil
143137
M.state.port = nil
144138
M.state.auth_token = nil
145-
M.state.clients = {}
146-
147139
return true
148140
end
149141

@@ -231,8 +223,6 @@ end
231223
local module_instance_id = math.random(10000, 99999)
232224
logger.debug("server", "Server module loaded with instance ID:", module_instance_id)
233225

234-
-- Note: debug_deferred_table function removed as deferred_responses table is no longer used
235-
236226
function M._setup_deferred_response(deferred_info)
237227
local co = deferred_info.coroutine
238228

@@ -436,7 +426,7 @@ function M.send_to_session(session_id, method, params)
436426
return false
437427
end
438428

439-
local client = M.state.clients[session.client_id]
429+
local client = M.state.server and M.state.server.clients[session.client_id]
440430
if not client then
441431
logger.debug("server", "Cannot send to session", session_id, "- client not found")
442432
return false

lua/claudecode/server/mock.lua

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ local tools = require("claudecode.tools.init")
1212
M.state = {
1313
server = nil,
1414
port = nil,
15-
clients = {},
1615
handlers = {},
1716
messages = {}, -- Store messages for testing
1817
}
@@ -74,7 +73,6 @@ function M.stop()
7473
-- Reset state
7574
M.state.server = nil
7675
M.state.port = nil
77-
M.state.clients = {}
7876
M.state.messages = {}
7977

8078
return true
@@ -101,29 +99,36 @@ end
10199
---@param client_id string A unique client identifier
102100
---@return table client The client object
103101
function M.add_client(client_id)
102+
assert(type(client_id) == "string", "Expected client_id to be a string")
104103
if not M.state.server then
105104
error("Server not running")
106105
end
106+
assert(type(M.state.server.clients) == "table", "Expected mock server.clients to be a table")
107107

108108
local client = {
109109
id = client_id,
110110
connected = true,
111111
messages = {},
112112
}
113113

114-
M.state.clients[client_id] = client
114+
M.state.server.clients[client_id] = client
115115
return client
116116
end
117117

118118
---Remove a client from the server
119119
---@param client_id string The client identifier
120120
---@return boolean success Whether removal was successful
121121
function M.remove_client(client_id)
122-
if not M.state.server or not M.state.clients[client_id] then
122+
assert(type(client_id) == "string", "Expected client_id to be a string")
123+
if not M.state.server or type(M.state.server.clients) ~= "table" then
123124
return false
124125
end
125126

126-
M.state.clients[client_id] = nil
127+
if not M.state.server.clients[client_id] then
128+
return false
129+
end
130+
131+
M.state.server.clients[client_id] = nil
127132
return true
128133
end
129134

@@ -136,7 +141,10 @@ function M.send(client, method, params)
136141
local client_obj
137142

138143
if type(client) == "string" then
139-
client_obj = M.state.clients[client]
144+
if not M.state.server or type(M.state.server.clients) ~= "table" then
145+
return false
146+
end
147+
client_obj = M.state.server.clients[client]
140148
else
141149
client_obj = client
142150
end
@@ -172,7 +180,10 @@ function M.send_response(client, id, result, error)
172180
local client_obj
173181

174182
if type(client) == "string" then
175-
client_obj = M.state.clients[client]
183+
if not M.state.server or type(M.state.server.clients) ~= "table" then
184+
return false
185+
end
186+
client_obj = M.state.server.clients[client]
176187
else
177188
client_obj = client
178189
end
@@ -208,9 +219,13 @@ end
208219
---@param params table The parameters to send
209220
---@return boolean success Whether broadcasting was successful
210221
function M.broadcast(method, params)
222+
if not M.state.server or type(M.state.server.clients) ~= "table" then
223+
return false
224+
end
225+
211226
local success = true
212227

213-
for client_id, _ in pairs(M.state.clients) do
228+
for client_id, _ in pairs(M.state.server.clients) do
214229
local send_success = M.send(client_id, method, params)
215230
success = success and send_success
216231
end
@@ -223,7 +238,12 @@ end
223238
---@param message table The message to process
224239
---@return table|nil response The response if any
225240
function M.simulate_message(client_id, message)
226-
local client = M.state.clients[client_id]
241+
assert(type(client_id) == "string", "Expected client_id to be a string")
242+
if not M.state.server or type(M.state.server.clients) ~= "table" then
243+
return nil
244+
end
245+
246+
local client = M.state.server.clients[client_id]
227247

228248
if not client then
229249
return nil
@@ -255,7 +275,11 @@ end
255275
function M.clear_messages()
256276
M.state.messages = {}
257277

258-
for _, client in pairs(M.state.clients) do
278+
if not M.state.server or type(M.state.server.clients) ~= "table" then
279+
return
280+
end
281+
282+
for _, client in pairs(M.state.server.clients) do
259283
client.messages = {}
260284
end
261285
end

lua/claudecode/server/tcp.lua

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,36 +127,72 @@ function M._handle_new_connection(server)
127127
if err then
128128
-- ECONNRESET, EOF, EPIPE are expected when terminal closes - don't treat as errors
129129
if err:match("ECONNRESET") or err:match("EOF") or err:match("EPIPE") then
130-
server.on_disconnect_cleanup(client, err)
130+
if server.on_disconnect_cleanup then
131+
server.on_disconnect_cleanup(client, err)
132+
end
131133
else
132134
server.on_error("Client read error: " .. err)
133135
end
134-
M._remove_client(server, client)
136+
M._disconnect_client(server, client, 1006, "Client read error: " .. err)
135137
return
136138
end
137139

138140
if not data then
139141
-- EOF - client disconnected
140-
M._remove_client(server, client)
142+
M._disconnect_client(server, client, 1006, "EOF")
141143
return
142144
end
143145

144146
-- Process incoming data
145147
client_manager.process_data(client, data, function(cl, message)
146148
server.on_message(cl, message)
147149
end, function(cl, code, reason)
148-
server.on_disconnect(cl, code, reason)
149-
M._remove_client(server, cl)
150+
M._disconnect_client(server, cl, code, reason)
150151
end, function(cl, error_msg)
151152
server.on_error("Client " .. cl.id .. " error: " .. error_msg)
152-
M._remove_client(server, cl)
153+
M._disconnect_client(server, cl, 1006, "Client error: " .. error_msg)
153154
end, server.auth_token)
154155
end)
155156

156157
-- Notify about new connection
157158
server.on_connect(client)
158159
end
159160

161+
---Disconnect a client and remove it from the server.
162+
---This ensures `server.on_disconnect` is invoked for every disconnect path
163+
---(EOF, read errors, protocol errors, timeouts), and only once per client.
164+
---@param server TCPServer The server object
165+
---@param client WebSocketClient The client to disconnect
166+
---@param code number|nil WebSocket close code
167+
---@param reason string|nil WebSocket close reason
168+
function M._disconnect_client(server, client, code, reason)
169+
assert(type(server) == "table", "Expected server to be a table")
170+
local on_disconnect_type = type(server.on_disconnect)
171+
local on_disconnect_mt = on_disconnect_type == "table" and getmetatable(server.on_disconnect) or nil
172+
assert(
173+
on_disconnect_type == "function" or (on_disconnect_mt ~= nil and type(on_disconnect_mt.__call) == "function"),
174+
"Expected server.on_disconnect to be callable"
175+
)
176+
assert(type(server.clients) == "table", "Expected server.clients to be a table")
177+
assert(type(client) == "table", "Expected client to be a table")
178+
assert(type(client.id) == "string", "Expected client.id to be a string")
179+
if code ~= nil then
180+
assert(type(code) == "number", "Expected code to be a number")
181+
end
182+
if reason ~= nil then
183+
assert(type(reason) == "string", "Expected reason to be a string")
184+
end
185+
186+
-- Idempotency: a client can hit multiple disconnect paths (e.g. CLOSE frame
187+
-- followed by a TCP EOF). Only notify/remove once.
188+
if not server.clients[client.id] then
189+
return
190+
end
191+
192+
server.on_disconnect(client, code, reason)
193+
M._remove_client(server, client)
194+
end
195+
160196
---Remove a client from the server
161197
---@param server TCPServer The server object
162198
---@param client WebSocketClient The client to remove
@@ -299,7 +335,7 @@ function M.start_ping_timer(server, interval)
299335
string.format("Client %s keepalive timeout (%ds idle), closing connection", client.id, time_since_pong)
300336
)
301337
client_manager.close_client(client, 1006, "Connection timeout")
302-
M._remove_client(server, client)
338+
M._disconnect_client(server, client, 1006, "Connection timeout")
303339
end
304340
end
305341
end

tests/mocks/vim.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,7 @@ local vim = {
881881
return true
882882
end,
883883
read_start = function(self, callback)
884+
self._read_cb = callback
884885
return true
885886
end,
886887
write = function(self, data, callback)

tests/server_test.lua

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ describe("Server module", function()
226226
assert(type(server.state) == "table")
227227
assert(server.state.server == nil)
228228
assert(server.state.port == nil)
229-
assert(type(server.state.clients) == "table")
230229
assert(type(server.state.handlers) == "table")
231230
end)
232231

@@ -259,8 +258,11 @@ describe("Server module", function()
259258
assert(stop_success == true)
260259
assert(server.state.server == nil)
261260
assert(server.state.port == nil)
262-
assert(type(server.state.clients) == "table")
263-
assert(0 == #server.state.clients)
261+
262+
local status = server.get_status()
263+
assert(status.running == false)
264+
assert(status.port == nil)
265+
assert(status.client_count == 0)
264266
end)
265267

266268
it("should not stop the server if not running", function()

0 commit comments

Comments
 (0)