Skip to content

Commit 4308912

Browse files
committed
Enhance pubsub with idempotent subscription handling
1 parent c005cf0 commit 4308912

4 files changed

Lines changed: 636 additions & 39 deletions

File tree

proxy/proxy.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ func (p *ProxyServer) handleQueuedCommand(conn redcon.Conn, state *proxyConnStat
163163
case cmdMulti:
164164
conn.WriteError("ERR MULTI calls can not be nested")
165165
default:
166+
// NOTE: Commands are queued locally and always return QUEUED without
167+
// upstream validation. Real Redis validates queued commands immediately
168+
// (e.g., wrong arity returns an error). Full compatibility would require
169+
// pinning a dedicated upstream connection for the MULTI..EXEC lifetime.
166170
state.txnQueue = append(state.txnQueue, args)
167171
conn.WriteString("QUEUED")
168172
}
@@ -235,24 +239,26 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args
235239
dconn := conn.Detach()
236240

237241
session := &pubsubSession{
238-
dconn: dconn,
239-
upstream: upstream,
240-
proxy: p,
241-
logger: p.logger,
242+
dconn: dconn,
243+
upstream: upstream,
244+
proxy: p,
245+
logger: p.logger,
246+
channelSet: make(map[string]struct{}),
247+
patternSet: make(map[string]struct{}),
242248
}
243249

244250
// Write initial subscription confirmations.
245251
kind := strings.ToLower(cmdName)
246-
for i, ch := range channels {
252+
for _, ch := range channels {
247253
dconn.WriteArray(pubsubArrayReply)
248254
dconn.WriteBulkString(kind)
249255
dconn.WriteBulkString(ch)
250256
if cmdName == cmdSubscribe {
251-
session.channels = i + 1
257+
session.channelSet[ch] = struct{}{}
252258
} else {
253-
session.patterns = i + 1
259+
session.patternSet[ch] = struct{}{}
254260
}
255-
dconn.WriteInt(session.channels + session.patterns)
261+
dconn.WriteInt(session.subCount())
256262
}
257263
if err := dconn.Flush(); err != nil {
258264
dconn.Close()

proxy/pubsub.go

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ type pubsubSession struct {
4141
logger *slog.Logger
4242
closed bool
4343

44-
// Track subscription counts for RESP replies.
45-
channels int
46-
patterns int
44+
// Track subscribed channels/patterns in sets for idempotent subscribe/unsubscribe
45+
// and correct subscription count tracking (matching Redis behavior).
46+
channelSet map[string]struct{}
47+
patternSet map[string]struct{}
4748

4849
// fwdDone is closed when the current forwardMessages goroutine exits.
4950
fwdDone chan struct{}
@@ -53,6 +54,11 @@ type pubsubSession struct {
5354
txnQueue [][][]byte
5455
}
5556

57+
// subCount returns the total number of active subscriptions (channels + patterns).
58+
func (s *pubsubSession) subCount() int {
59+
return len(s.channelSet) + len(s.patternSet)
60+
}
61+
5662
// run starts the session. It blocks until the client disconnects or sends QUIT.
5763
func (s *pubsubSession) run() {
5864
defer s.cleanup()
@@ -135,7 +141,7 @@ func (s *pubsubSession) commandLoop() {
135141
name := strings.ToUpper(string(args[0]))
136142

137143
s.mu.Lock()
138-
inPubSub := s.channels > 0 || s.patterns > 0
144+
inPubSub := s.subCount() > 0
139145
s.mu.Unlock()
140146

141147
if inPubSub {
@@ -154,7 +160,7 @@ func (s *pubsubSession) commandLoop() {
154160
func (s *pubsubSession) shouldExitPubSub() bool {
155161
s.mu.Lock()
156162
defer s.mu.Unlock()
157-
return s.upstream != nil && s.channels == 0 && s.patterns == 0
163+
return s.upstream != nil && s.subCount() == 0
158164
}
159165

160166
func (s *pubsubSession) exitPubSubMode() {
@@ -333,14 +339,14 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) {
333339
s.mu.Lock()
334340
for _, ch := range channels {
335341
if cmdName == cmdSubscribe {
336-
s.channels++
342+
s.channelSet[ch] = struct{}{}
337343
} else {
338-
s.patterns++
344+
s.patternSet[ch] = struct{}{}
339345
}
340346
s.dconn.WriteArray(pubsubArrayReply)
341347
s.dconn.WriteBulkString(kind)
342348
s.dconn.WriteBulkString(ch)
343-
s.dconn.WriteInt(s.channels + s.patterns)
349+
s.dconn.WriteInt(s.subCount())
344350
}
345351
_ = s.dconn.Flush()
346352
s.mu.Unlock()
@@ -399,11 +405,12 @@ func (s *pubsubSession) handleSubscribe(args [][]byte) {
399405
}
400406
s.mu.Lock()
401407
for _, ch := range channels {
402-
s.channels++
408+
// Idempotent: Redis treats re-subscribe as a no-op for counting.
409+
s.channelSet[ch] = struct{}{}
403410
s.dconn.WriteArray(pubsubArrayReply)
404411
s.dconn.WriteBulkString("subscribe")
405412
s.dconn.WriteBulkString(ch)
406-
s.dconn.WriteInt(s.channels + s.patterns)
413+
s.dconn.WriteInt(s.subCount())
407414
}
408415
_ = s.dconn.Flush()
409416
s.mu.Unlock()
@@ -421,11 +428,12 @@ func (s *pubsubSession) handlePSubscribe(args [][]byte) {
421428
}
422429
s.mu.Lock()
423430
for _, p := range pats {
424-
s.patterns++
431+
// Idempotent: Redis treats re-subscribe as a no-op for counting.
432+
s.patternSet[p] = struct{}{}
425433
s.dconn.WriteArray(pubsubArrayReply)
426434
s.dconn.WriteBulkString("psubscribe")
427435
s.dconn.WriteBulkString(p)
428-
s.dconn.WriteInt(s.channels + s.patterns)
436+
s.dconn.WriteInt(s.subCount())
429437
}
430438
_ = s.dconn.Flush()
431439
s.mu.Unlock()
@@ -442,21 +450,12 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) {
442450
}
443451

444452
if len(args) < pubsubMinArgs {
445-
// Unsubscribe all
453+
// Unsubscribe all: emit per-channel reply (matching Redis behavior).
446454
if err := unsubFn(context.Background()); err != nil {
447455
s.logger.Warn("upstream "+kind+" failed", "err", err)
448456
}
449457
s.mu.Lock()
450-
if isPattern {
451-
s.patterns = 0
452-
} else {
453-
s.channels = 0
454-
}
455-
s.dconn.WriteArray(pubsubArrayReply)
456-
s.dconn.WriteBulkString(kind)
457-
s.dconn.WriteNull()
458-
s.dconn.WriteInt(s.channels + s.patterns)
459-
_ = s.dconn.Flush()
458+
s.writeUnsubAll(kind, isPattern)
460459
s.mu.Unlock()
461460
return
462461
}
@@ -468,23 +467,57 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) {
468467
s.mu.Lock()
469468
for _, n := range names {
470469
if isPattern {
471-
if s.patterns > 0 {
472-
s.patterns--
473-
}
470+
delete(s.patternSet, n)
474471
} else {
475-
if s.channels > 0 {
476-
s.channels--
477-
}
472+
delete(s.channelSet, n)
478473
}
479474
s.dconn.WriteArray(pubsubArrayReply)
480475
s.dconn.WriteBulkString(kind)
481476
s.dconn.WriteBulkString(n)
482-
s.dconn.WriteInt(s.channels + s.patterns)
477+
s.dconn.WriteInt(s.subCount())
483478
}
484479
_ = s.dconn.Flush()
485480
s.mu.Unlock()
486481
}
487482

483+
// writeUnsubAll emits per-channel/pattern unsubscribe replies and clears the set.
484+
// Must be called with s.mu held.
485+
func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) {
486+
set := s.channelSet
487+
if isPattern {
488+
set = s.patternSet
489+
}
490+
491+
if len(set) == 0 {
492+
// No subscriptions: single reply with null channel (matching Redis).
493+
s.dconn.WriteArray(pubsubArrayReply)
494+
s.dconn.WriteBulkString(kind)
495+
s.dconn.WriteNull()
496+
s.dconn.WriteInt(s.subCount())
497+
_ = s.dconn.Flush()
498+
return
499+
}
500+
501+
// Collect names before clearing so we can emit per-channel replies.
502+
names := make([]string, 0, len(set))
503+
for n := range set {
504+
names = append(names, n)
505+
}
506+
if isPattern {
507+
s.patternSet = make(map[string]struct{})
508+
} else {
509+
s.channelSet = make(map[string]struct{})
510+
}
511+
512+
for _, n := range names {
513+
s.dconn.WriteArray(pubsubArrayReply)
514+
s.dconn.WriteBulkString(kind)
515+
s.dconn.WriteBulkString(n)
516+
s.dconn.WriteInt(s.subCount())
517+
}
518+
_ = s.dconn.Flush()
519+
}
520+
488521
// --- Ping handlers ---
489522

490523
func (s *pubsubSession) handlePubSubPing(args [][]byte) {

0 commit comments

Comments
 (0)