@@ -41,9 +41,10 @@ type pubsubSession struct {
4141 logger * slog.Logger
4242 closed bool
4343
44- // Track subscription counts for RESP replies.
45- channels int
46- patterns int
44+ // Track subscribed channels/patterns in sets for idempotent subscribe/unsubscribe
45+ // and correct subscription count tracking (matching Redis behavior).
46+ channelSet map [string ]struct {}
47+ patternSet map [string ]struct {}
4748
4849 // fwdDone is closed when the current forwardMessages goroutine exits.
4950 fwdDone chan struct {}
@@ -53,6 +54,11 @@ type pubsubSession struct {
5354 txnQueue [][][]byte
5455}
5556
57+ // subCount returns the total number of active subscriptions (channels + patterns).
58+ func (s * pubsubSession ) subCount () int {
59+ return len (s .channelSet ) + len (s .patternSet )
60+ }
61+
5662// run starts the session. It blocks until the client disconnects or sends QUIT.
5763func (s * pubsubSession ) run () {
5864 defer s .cleanup ()
@@ -135,7 +141,7 @@ func (s *pubsubSession) commandLoop() {
135141 name := strings .ToUpper (string (args [0 ]))
136142
137143 s .mu .Lock ()
138- inPubSub := s .channels > 0 || s . patterns > 0
144+ inPubSub := s .subCount () > 0
139145 s .mu .Unlock ()
140146
141147 if inPubSub {
@@ -154,7 +160,7 @@ func (s *pubsubSession) commandLoop() {
154160func (s * pubsubSession ) shouldExitPubSub () bool {
155161 s .mu .Lock ()
156162 defer s .mu .Unlock ()
157- return s .upstream != nil && s .channels == 0 && s . patterns == 0
163+ return s .upstream != nil && s .subCount () == 0
158164}
159165
160166func (s * pubsubSession ) exitPubSubMode () {
@@ -333,14 +339,14 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) {
333339 s .mu .Lock ()
334340 for _ , ch := range channels {
335341 if cmdName == cmdSubscribe {
336- s .channels ++
342+ s .channelSet [ ch ] = struct {}{}
337343 } else {
338- s .patterns ++
344+ s .patternSet [ ch ] = struct {}{}
339345 }
340346 s .dconn .WriteArray (pubsubArrayReply )
341347 s .dconn .WriteBulkString (kind )
342348 s .dconn .WriteBulkString (ch )
343- s .dconn .WriteInt (s .channels + s . patterns )
349+ s .dconn .WriteInt (s .subCount () )
344350 }
345351 _ = s .dconn .Flush ()
346352 s .mu .Unlock ()
@@ -399,11 +405,12 @@ func (s *pubsubSession) handleSubscribe(args [][]byte) {
399405 }
400406 s .mu .Lock ()
401407 for _ , ch := range channels {
402- s .channels ++
408+ // Idempotent: Redis treats re-subscribe as a no-op for counting.
409+ s .channelSet [ch ] = struct {}{}
403410 s .dconn .WriteArray (pubsubArrayReply )
404411 s .dconn .WriteBulkString ("subscribe" )
405412 s .dconn .WriteBulkString (ch )
406- s .dconn .WriteInt (s .channels + s . patterns )
413+ s .dconn .WriteInt (s .subCount () )
407414 }
408415 _ = s .dconn .Flush ()
409416 s .mu .Unlock ()
@@ -421,11 +428,12 @@ func (s *pubsubSession) handlePSubscribe(args [][]byte) {
421428 }
422429 s .mu .Lock ()
423430 for _ , p := range pats {
424- s .patterns ++
431+ // Idempotent: Redis treats re-subscribe as a no-op for counting.
432+ s .patternSet [p ] = struct {}{}
425433 s .dconn .WriteArray (pubsubArrayReply )
426434 s .dconn .WriteBulkString ("psubscribe" )
427435 s .dconn .WriteBulkString (p )
428- s .dconn .WriteInt (s .channels + s . patterns )
436+ s .dconn .WriteInt (s .subCount () )
429437 }
430438 _ = s .dconn .Flush ()
431439 s .mu .Unlock ()
@@ -442,21 +450,12 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) {
442450 }
443451
444452 if len (args ) < pubsubMinArgs {
445- // Unsubscribe all
453+ // Unsubscribe all: emit per-channel reply (matching Redis behavior).
446454 if err := unsubFn (context .Background ()); err != nil {
447455 s .logger .Warn ("upstream " + kind + " failed" , "err" , err )
448456 }
449457 s .mu .Lock ()
450- if isPattern {
451- s .patterns = 0
452- } else {
453- s .channels = 0
454- }
455- s .dconn .WriteArray (pubsubArrayReply )
456- s .dconn .WriteBulkString (kind )
457- s .dconn .WriteNull ()
458- s .dconn .WriteInt (s .channels + s .patterns )
459- _ = s .dconn .Flush ()
458+ s .writeUnsubAll (kind , isPattern )
460459 s .mu .Unlock ()
461460 return
462461 }
@@ -468,23 +467,57 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) {
468467 s .mu .Lock ()
469468 for _ , n := range names {
470469 if isPattern {
471- if s .patterns > 0 {
472- s .patterns --
473- }
470+ delete (s .patternSet , n )
474471 } else {
475- if s .channels > 0 {
476- s .channels --
477- }
472+ delete (s .channelSet , n )
478473 }
479474 s .dconn .WriteArray (pubsubArrayReply )
480475 s .dconn .WriteBulkString (kind )
481476 s .dconn .WriteBulkString (n )
482- s .dconn .WriteInt (s .channels + s . patterns )
477+ s .dconn .WriteInt (s .subCount () )
483478 }
484479 _ = s .dconn .Flush ()
485480 s .mu .Unlock ()
486481}
487482
483+ // writeUnsubAll emits per-channel/pattern unsubscribe replies and clears the set.
484+ // Must be called with s.mu held.
485+ func (s * pubsubSession ) writeUnsubAll (kind string , isPattern bool ) {
486+ set := s .channelSet
487+ if isPattern {
488+ set = s .patternSet
489+ }
490+
491+ if len (set ) == 0 {
492+ // No subscriptions: single reply with null channel (matching Redis).
493+ s .dconn .WriteArray (pubsubArrayReply )
494+ s .dconn .WriteBulkString (kind )
495+ s .dconn .WriteNull ()
496+ s .dconn .WriteInt (s .subCount ())
497+ _ = s .dconn .Flush ()
498+ return
499+ }
500+
501+ // Collect names before clearing so we can emit per-channel replies.
502+ names := make ([]string , 0 , len (set ))
503+ for n := range set {
504+ names = append (names , n )
505+ }
506+ if isPattern {
507+ s .patternSet = make (map [string ]struct {})
508+ } else {
509+ s .channelSet = make (map [string ]struct {})
510+ }
511+
512+ for _ , n := range names {
513+ s .dconn .WriteArray (pubsubArrayReply )
514+ s .dconn .WriteBulkString (kind )
515+ s .dconn .WriteBulkString (n )
516+ s .dconn .WriteInt (s .subCount ())
517+ }
518+ _ = s .dconn .Flush ()
519+ }
520+
488521// --- Ping handlers ---
489522
490523func (s * pubsubSession ) handlePubSubPing (args [][]byte ) {
0 commit comments