Skip to content

Commit 4d471b6

Browse files
committed
adapter: fix pubsub data races and nil relay guard
- Add nil guard for relay in RelayPublish to prevent panic - Fix sc.chans data race: capture channel count under ps.mu before releasing the lock, pass count to writeSubscribeReply - Handle Publish flush errors: mark connection as closed and clean up instead of silently ignoring write failures
1 parent 0dbfa69 commit 4d471b6

2 files changed

Lines changed: 19 additions & 10 deletions

File tree

adapter/internal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (i *Internal) Forward(_ context.Context, req *pb.ForwardRequest) (*pb.Forwa
6161
}
6262

6363
func (i *Internal) RelayPublish(_ context.Context, req *pb.RelayPublishRequest) (*pb.RelayPublishResponse, error) {
64-
if req == nil {
64+
if req == nil || i.relay == nil {
6565
return &pb.RelayPublishResponse{}, nil
6666
}
6767
return &pb.RelayPublishResponse{

adapter/redis_pubsub.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ func (ps *redisPubSub) Subscribe(conn redcon.Conn, channel string) {
4848
ps.mu.Lock()
4949
sc, isNew := ps.getOrCreate(conn)
5050
ps.addEntry(sc, channel)
51+
count := len(sc.chans)
5152
ps.mu.Unlock()
5253

53-
ps.writeSubscribeReply(sc, channel)
54+
ps.writeSubscribeReply(sc, channel, count)
5455

5556
if isNew {
5657
go sc.bgrunner(ps)
@@ -77,8 +78,12 @@ func (ps *redisPubSub) Publish(channel, message string) int {
7778
sc.dconn.WriteBulkString("message")
7879
sc.dconn.WriteBulkString(channel)
7980
sc.dconn.WriteBulkString(message)
80-
_ = sc.dconn.Flush()
81-
sent++
81+
if err := sc.dconn.Flush(); err != nil {
82+
sc.closed = true
83+
_ = sc.dconn.Close()
84+
} else {
85+
sent++
86+
}
8287
}
8388
sc.mu.Unlock()
8489
}
@@ -143,13 +148,13 @@ func (ps *redisPubSub) removeAll(sc *pubsubConn) {
143148
sc.chans = nil
144149
}
145150

146-
func (ps *redisPubSub) writeSubscribeReply(sc *pubsubConn, channel string) {
151+
func (ps *redisPubSub) writeSubscribeReply(sc *pubsubConn, channel string, count int) {
147152
sc.mu.Lock()
148153
defer sc.mu.Unlock()
149154
sc.dconn.WriteArray(respArrayMessage)
150155
sc.dconn.WriteBulkString("subscribe")
151156
sc.dconn.WriteBulkString(channel)
152-
sc.dconn.WriteInt(len(sc.chans))
157+
sc.dconn.WriteInt(count)
153158
_ = sc.dconn.Flush()
154159
}
155160

@@ -174,17 +179,21 @@ func (sc *pubsubConn) handleSubscribe(ps *redisPubSub, args [][]byte) {
174179
sc.mu.Unlock()
175180
return
176181
}
177-
channels := make([]string, 0, len(args)-1)
182+
type subInfo struct {
183+
channel string
184+
count int
185+
}
186+
subs := make([]subInfo, 0, len(args)-1)
178187
ps.mu.Lock()
179188
for i := 1; i < len(args); i++ {
180189
ch := string(args[i])
181190
ps.addEntry(sc, ch)
182-
channels = append(channels, ch)
191+
subs = append(subs, subInfo{channel: ch, count: len(sc.chans)})
183192
}
184193
ps.mu.Unlock()
185194

186-
for _, ch := range channels {
187-
ps.writeSubscribeReply(sc, ch)
195+
for _, s := range subs {
196+
ps.writeSubscribeReply(sc, s.channel, s.count)
188197
}
189198
}
190199

0 commit comments

Comments
 (0)