Skip to content

Commit 534a39b

Browse files
committed
Fix race condition, error normalization, and security concerns
- Fix startForwarding race: capture upstream under lock before goroutine start - Reject SUBSCRIBE/PSUBSCRIBE during MULTI transaction to prevent state corruption - Normalize non-redis.Error to "ERR ..." prefix in writeRedisError for valid RESP - Truncate divergence values in Sentry reports to prevent data leakage
1 parent be85bca commit 534a39b

3 files changed

Lines changed: 37 additions & 10 deletions

File tree

proxy/proxy.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,16 @@ func writeResponse(w respWriter, resp any, err error) {
378378
writeRedisValue(w, resp)
379379
}
380380

381-
// writeRedisError writes an upstream error without double-prefixing.
382-
// Redis errors already contain their prefix (e.g. "ERR ...", "WRONGTYPE ...").
381+
// writeRedisError writes an upstream error to the client.
382+
// go-redis redis.Error values already carry the Redis prefix (e.g. "ERR ...", "WRONGTYPE ...").
383+
// Other errors (timeouts, dial failures) are normalized to "ERR ..." to produce valid RESP.
383384
func writeRedisError(w respWriter, err error) {
384-
msg := err.Error()
385-
// go-redis errors are already formatted with prefix; pass through as-is.
386-
w.WriteError(msg)
385+
var redisErr redis.Error
386+
if errors.As(err, &redisErr) {
387+
w.WriteError(redisErr.Error())
388+
return
389+
}
390+
w.WriteError("ERR " + err.Error())
387391
}
388392

389393
// writeRedisValue writes a go-redis response value to a respWriter.

proxy/pubsub.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,24 @@ func (s *pubsubSession) cleanup() {
7575
}
7676

7777
func (s *pubsubSession) startForwarding() {
78+
// Capture upstream under lock to avoid race with exitPubSubMode.
79+
s.mu.Lock()
80+
upstream := s.upstream
81+
s.mu.Unlock()
82+
if upstream == nil {
83+
return
84+
}
85+
ch := upstream.Channel()
7886
s.fwdDone = make(chan struct{})
7987
go func() {
8088
defer close(s.fwdDone)
81-
s.forwardMessages()
89+
s.forwardMessages(ch)
8290
}()
8391
}
8492

8593
// forwardMessages reads from the upstream go-redis PubSub channel and writes
8694
// messages to the detached client connection.
87-
func (s *pubsubSession) forwardMessages() {
88-
ch := s.upstream.Channel()
95+
func (s *pubsubSession) forwardMessages(ch <-chan *redis.Message) {
8996
for msg := range ch {
9097
s.mu.Lock()
9198
if s.closed {
@@ -198,6 +205,10 @@ func (s *pubsubSession) dispatchNormalCommand(name string, args [][]byte) bool {
198205
return true
199206
}
200207
if name == cmdSubscribe || name == cmdPSubscribe {
208+
if s.inTxn {
209+
s.writeError("ERR Command not allowed inside a transaction")
210+
return true
211+
}
201212
s.reenterPubSub(name, args)
202213
return true
203214
}

proxy/sentry.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import (
1111

1212
const (
1313
defaultReportCooldown = 60 * time.Second
14+
// maxSentryValueLen limits the length of values attached to Sentry events
15+
// to prevent data leakage and oversized events.
16+
maxSentryValueLen = 256
1417
// maxReportEntries caps the lastReport map to prevent unbounded growth.
1518
maxReportEntries = 10000
1619
)
@@ -80,8 +83,8 @@ func (r *SentryReporter) CaptureDivergence(div Divergence) {
8083
scope.SetTag("command", div.Command)
8184
scope.SetTag("key", div.Key)
8285
scope.SetTag("kind", div.Kind.String())
83-
scope.SetExtra("primary", fmt.Sprintf("%v", div.Primary))
84-
scope.SetExtra("secondary", fmt.Sprintf("%v", div.Secondary))
86+
scope.SetExtra("primary", truncateValue(div.Primary))
87+
scope.SetExtra("secondary", truncateValue(div.Secondary))
8588
scope.SetFingerprint([]string{"divergence", div.Kind.String(), div.Command})
8689
scope.SetLevel(sentry.LevelWarning)
8790
r.hub.CaptureMessage(fmt.Sprintf("data divergence: %s %s (%s)", div.Kind, div.Command, div.Key))
@@ -129,3 +132,12 @@ func cmdNameFromArgs(args [][]byte) string {
129132
}
130133
return unknownStr
131134
}
135+
136+
// truncateValue formats a value for Sentry, truncating to avoid data leakage and oversized events.
137+
func truncateValue(v any) string {
138+
s := fmt.Sprintf("%v", v)
139+
if len(s) > maxSentryValueLen {
140+
return s[:maxSentryValueLen] + "...(truncated)"
141+
}
142+
return s
143+
}

0 commit comments

Comments
 (0)