Skip to content

Commit 411cf06

Browse files
committed
Support normal command mode after pub/sub unsubscribe
When all subscriptions are removed, the detached connection transitions to normal command mode instead of being stuck in pub/sub-only mode. Clients can then execute regular Redis commands (GET, SET, transactions, etc.) or re-enter pub/sub mode with a new SUBSCRIBE/PSUBSCRIBE. Key changes: - Add respWriter interface so writeResponse works with both Conn and DetachedConn - Refactor pubsubSession with commandLoop that handles both pub/sub and normal modes - Support transactions (MULTI/EXEC/DISCARD) in normal mode on detached connections - Track forwardMessages goroutine lifecycle for clean pub/sub mode transitions - Extract command name constants (MULTI, EXEC, DISCARD, PING, QUIT)
1 parent 4bc099d commit 411cf06

2 files changed

Lines changed: 341 additions & 51 deletions

File tree

proxy/proxy.go

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@ import (
1515
// txnCommandsOverhead is the number of extra commands (MULTI + EXEC) wrapping queued commands.
1616
const txnCommandsOverhead = 2
1717

18+
// respWriter is the subset of redcon.Conn and redcon.DetachedConn used for writing RESP responses.
19+
// Both connection types satisfy this interface, enabling shared response-writing logic
20+
// between the main event loop and detached pub/sub sessions.
21+
type respWriter interface {
22+
WriteError(msg string)
23+
WriteString(msg string)
24+
WriteBulk(b []byte)
25+
WriteBulkString(msg string)
26+
WriteInt64(num int64)
27+
WriteArray(count int)
28+
WriteNull()
29+
}
30+
1831
// proxyConnState tracks per-connection state (transactions, PubSub).
1932
type proxyConnState struct {
2033
inTxn bool
@@ -143,11 +156,11 @@ func (p *ProxyServer) dispatchCommand(conn redcon.Conn, state *proxyConnState, n
143156

144157
func (p *ProxyServer) handleQueuedCommand(conn redcon.Conn, state *proxyConnState, name string, args [][]byte) {
145158
switch name {
146-
case "EXEC":
159+
case cmdExec:
147160
p.execTxn(conn, state)
148-
case "DISCARD":
161+
case cmdDiscard:
149162
p.discardTxn(conn, state)
150-
case "MULTI":
163+
case cmdMulti:
151164
conn.WriteError("ERR MULTI calls can not be nested")
152165
default:
153166
state.txnQueue = append(state.txnQueue, args)
@@ -224,6 +237,7 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args
224237
session := &pubsubSession{
225238
dconn: dconn,
226239
upstream: upstream,
240+
proxy: p,
227241
logger: p.logger,
228242
}
229243

@@ -253,7 +267,7 @@ func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) {
253267
name := strings.ToUpper(string(args[0]))
254268

255269
// Handle PING locally for speed.
256-
if name == "PING" {
270+
if name == cmdPing {
257271
if len(args) > 1 {
258272
conn.WriteBulk(args[1])
259273
} else {
@@ -263,7 +277,7 @@ func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) {
263277
}
264278

265279
// Handle QUIT locally.
266-
if name == "QUIT" {
280+
if name == cmdQuit {
267281
conn.WriteString("OK")
268282
conn.Close()
269283
return
@@ -282,21 +296,21 @@ func (p *ProxyServer) handleScript(conn redcon.Conn, args [][]byte) {
282296

283297
func (p *ProxyServer) handleTxnCommand(conn redcon.Conn, state *proxyConnState, name string) {
284298
switch name {
285-
case "MULTI":
299+
case cmdMulti:
286300
if state.inTxn {
287301
conn.WriteError("ERR MULTI calls can not be nested")
288302
return
289303
}
290304
state.inTxn = true
291305
state.txnQueue = nil
292306
conn.WriteString("OK")
293-
case "EXEC":
307+
case cmdExec:
294308
if !state.inTxn {
295309
conn.WriteError("ERR EXEC without MULTI")
296310
return
297311
}
298312
p.execTxn(conn, state)
299-
case "DISCARD":
313+
case cmdDiscard:
300314
if !state.inTxn {
301315
conn.WriteError("ERR DISCARD without MULTI")
302316
return
@@ -350,57 +364,54 @@ func (p *ProxyServer) discardTxn(conn redcon.Conn, state *proxyConnState) {
350364
conn.WriteString("OK")
351365
}
352366

353-
// writeResponse handles the common pattern of writing a go-redis response
354-
// to a redcon connection, correctly handling redis.Nil and upstream errors.
355-
func writeResponse(conn redcon.Conn, resp interface{}, err error) {
367+
// writeResponse handles the common pattern of writing a go-redis response,
368+
// correctly handling redis.Nil and upstream errors.
369+
func writeResponse(w respWriter, resp interface{}, err error) {
356370
if err != nil {
357371
if errors.Is(err, redis.Nil) {
358-
conn.WriteNull()
372+
w.WriteNull()
359373
return
360374
}
361-
writeRedisError(conn, err)
375+
writeRedisError(w, err)
362376
return
363377
}
364-
writeRedisValue(conn, resp)
378+
writeRedisValue(w, resp)
365379
}
366380

367381
// writeRedisError writes an upstream error without double-prefixing.
368382
// Redis errors already contain their prefix (e.g. "ERR ...", "WRONGTYPE ...").
369-
func writeRedisError(conn redcon.Conn, err error) {
383+
func writeRedisError(w respWriter, err error) {
370384
msg := err.Error()
371385
// go-redis errors are already formatted with prefix; pass through as-is.
372-
conn.WriteError(msg)
386+
w.WriteError(msg)
373387
}
374388

375-
// writeRedisValue writes a go-redis response value to a redcon connection.
376-
func writeRedisValue(conn redcon.Conn, val interface{}) {
389+
// writeRedisValue writes a go-redis response value to a respWriter.
390+
func writeRedisValue(w respWriter, val interface{}) {
377391
if val == nil {
378-
conn.WriteNull()
392+
w.WriteNull()
379393
return
380394
}
381395
switch v := val.(type) {
382396
case string:
383-
// go-redis flattens Status and Bulk strings into Go strings.
384-
// Use WriteString (status reply) for known status responses,
385-
// WriteBulkString (bulk reply) for data values.
386397
if isStatusResponse(v) {
387-
conn.WriteString(v)
398+
w.WriteString(v)
388399
} else {
389-
conn.WriteBulkString(v)
400+
w.WriteBulkString(v)
390401
}
391402
case int64:
392-
conn.WriteInt64(v)
403+
w.WriteInt64(v)
393404
case []interface{}:
394-
conn.WriteArray(len(v))
405+
w.WriteArray(len(v))
395406
for _, item := range v {
396-
writeRedisValue(conn, item)
407+
writeRedisValue(w, item)
397408
}
398409
case []byte:
399-
conn.WriteBulk(v)
410+
w.WriteBulk(v)
400411
case redis.Error:
401-
conn.WriteError(v.Error())
412+
w.WriteError(v.Error())
402413
default:
403-
conn.WriteBulkString(fmt.Sprintf("%v", v))
414+
w.WriteBulkString(fmt.Sprintf("%v", v))
404415
}
405416
}
406417

0 commit comments

Comments
 (0)