Skip to content

Commit 7f3f917

Browse files
committed
proxy: reduce secondary replay mismatches
1 parent b751fd1 commit 7f3f917

4 files changed

Lines changed: 177 additions & 3 deletions

File tree

proxy/backend.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const (
1414
defaultDialTimeout = 5 * time.Second
1515
defaultReadTimeout = 3 * time.Second
1616
defaultWriteTimeout = 3 * time.Second
17+
respProtocolV2 = 2
1718
)
1819

1920
// Backend abstracts a Redis-protocol endpoint (real Redis or ElasticKV).
@@ -72,6 +73,7 @@ func NewRedisBackendWithOptions(addr string, name string, opts BackendOptions) *
7273
Addr: addr,
7374
DB: opts.DB,
7475
Password: opts.Password,
76+
Protocol: respProtocolV2,
7577
PoolSize: opts.PoolSize,
7678
DialTimeout: opts.DialTimeout,
7779
ReadTimeout: opts.ReadTimeout,

proxy/dualwrite.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ type DualWriter struct {
3232
writeSem chan struct{} // bounds concurrent secondary write goroutines
3333
shadowSem chan struct{} // bounds concurrent shadow read goroutines
3434

35-
wg sync.WaitGroup
36-
mu sync.Mutex // protects closed; held briefly to make wg.Add atomic with close check
37-
closed bool
35+
wg sync.WaitGroup
36+
mu sync.Mutex // protects closed; held briefly to make wg.Add atomic with close check
37+
closed bool
38+
scriptMu sync.RWMutex
39+
scripts map[string]string
3840
}
3941

4042
// NewDualWriter creates a DualWriter with the given backends.
@@ -48,6 +50,7 @@ func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMe
4850
logger: logger,
4951
writeSem: make(chan struct{}, maxWriteGoroutines),
5052
shadowSem: make(chan struct{}, maxShadowGoroutines),
53+
scripts: make(map[string]string),
5154
}
5255

5356
if cfg.Mode == ModeDualWriteShadow || cfg.Mode == ModeElasticKVPrimary {
@@ -190,6 +193,7 @@ func (d *DualWriter) Script(ctx context.Context, cmd string, args [][]byte) (any
190193
result := d.primary.Do(ctx, iArgs...)
191194
resp, err := result.Result()
192195
d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds())
196+
d.rememberScript(cmd, args)
193197

194198
if err != nil && !errors.Is(err, redis.Nil) {
195199
d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc()
@@ -211,6 +215,12 @@ func (d *DualWriter) writeSecondary(cmd string, iArgs []any) {
211215
start := time.Now()
212216
result := d.secondary.Do(sCtx, iArgs...)
213217
_, sErr := result.Result()
218+
if isNoScriptError(sErr) {
219+
if fallbackArgs, ok := d.evalFallbackArgs(cmd, iArgs); ok {
220+
result = d.secondary.Do(sCtx, fallbackArgs...)
221+
_, sErr = result.Result()
222+
}
223+
}
214224
d.metrics.CommandDuration.WithLabelValues(cmd, d.secondary.Name()).Observe(time.Since(start).Seconds())
215225

216226
if sErr != nil && !errors.Is(sErr, redis.Nil) {

proxy/proxy_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,15 @@ func TestDefaultBackendOptions(t *testing.T) {
768768
assert.Equal(t, 5*time.Second, opts.DialTimeout)
769769
}
770770

771+
func TestNewRedisBackend_UsesRESP2(t *testing.T) {
772+
backend := NewRedisBackend("127.0.0.1:6379", "test")
773+
t.Cleanup(func() {
774+
assert.NoError(t, backend.Close())
775+
})
776+
777+
assert.Equal(t, respProtocolV2, backend.client.Options().Protocol)
778+
}
779+
771780
// ========== Pipeline error handling tests ==========
772781

773782
func TestPipeline_TransportError(t *testing.T) {
@@ -781,6 +790,45 @@ func TestPipeline_TransportError(t *testing.T) {
781790
assert.Len(t, results, 3)
782791
}
783792

793+
func TestDualWriter_Script_CachesEvalForEvalSHAFallback(t *testing.T) {
794+
primary := newMockBackend("primary")
795+
primary.doFunc = makeCmd("OK", nil)
796+
797+
secondary := newMockBackend("secondary")
798+
script := "return ARGV[1]"
799+
sha := scriptSHA(script)
800+
var calls int
801+
secondary.doFunc = func(ctx context.Context, args ...any) *redis.Cmd {
802+
calls++
803+
cmd := redis.NewCmd(ctx, args...)
804+
switch calls {
805+
case 1:
806+
assert.Equal(t, []byte("EVALSHA"), args[0])
807+
assert.Equal(t, []byte(sha), args[1])
808+
cmd.SetErr(testRedisErr("NOSCRIPT No matching script. Please use EVAL."))
809+
case 2:
810+
assert.Equal(t, []byte("EVAL"), args[0])
811+
assert.Equal(t, []byte(script), args[1])
812+
cmd.SetVal("OK")
813+
default:
814+
t.Fatalf("unexpected secondary call %d", calls)
815+
}
816+
return cmd
817+
}
818+
819+
metrics := newTestMetrics()
820+
d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger)
821+
822+
_, err := d.Script(context.Background(), "EVAL", [][]byte{[]byte("EVAL"), []byte(script), []byte("0"), []byte("value")})
823+
assert.NoError(t, err)
824+
825+
d.cfg.Mode = ModeDualWrite
826+
d.writeSecondary("EVALSHA", []any{[]byte("EVALSHA"), []byte(sha), []byte("0"), []byte("value")})
827+
828+
assert.Equal(t, 2, calls)
829+
assert.InDelta(t, 0, testutil.ToFloat64(metrics.SecondaryWriteErrors), 0.001)
830+
}
831+
784832
// ========== writeRedisValue tests ==========
785833

786834
// testRedisErr satisfies the redis.Error interface for testing.

proxy/script_cache.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package proxy
2+
3+
import (
4+
"crypto/sha1" // #nosec G505 -- Redis EVALSHA specifies SHA1 script digests.
5+
"encoding/hex"
6+
"errors"
7+
"strings"
8+
9+
"github.com/redis/go-redis/v9"
10+
)
11+
12+
const (
13+
cmdEval = "EVAL"
14+
cmdEvalSHA = "EVALSHA"
15+
cmdScript = "SCRIPT"
16+
17+
minScriptSubcommandArgs = 2
18+
scriptLoadArgIndex = 2
19+
minEvalSHAArgs = 2
20+
)
21+
22+
func (d *DualWriter) rememberScript(cmd string, args [][]byte) {
23+
upper := strings.ToUpper(cmd)
24+
25+
switch upper {
26+
case cmdEval, "EVAL_RO":
27+
if len(args) > 1 {
28+
d.storeScript(string(args[1]))
29+
}
30+
case cmdScript:
31+
if len(args) < minScriptSubcommandArgs {
32+
return
33+
}
34+
switch strings.ToUpper(string(args[1])) {
35+
case "LOAD":
36+
if len(args) > scriptLoadArgIndex {
37+
d.storeScript(string(args[scriptLoadArgIndex]))
38+
}
39+
case "FLUSH":
40+
d.clearScripts()
41+
}
42+
}
43+
}
44+
45+
func (d *DualWriter) storeScript(script string) {
46+
sha := scriptSHA(script)
47+
48+
d.scriptMu.Lock()
49+
defer d.scriptMu.Unlock()
50+
d.scripts[sha] = script
51+
}
52+
53+
func (d *DualWriter) clearScripts() {
54+
d.scriptMu.Lock()
55+
defer d.scriptMu.Unlock()
56+
clear(d.scripts)
57+
}
58+
59+
func (d *DualWriter) lookupScript(sha string) (string, bool) {
60+
d.scriptMu.RLock()
61+
defer d.scriptMu.RUnlock()
62+
script, ok := d.scripts[strings.ToLower(sha)]
63+
return script, ok
64+
}
65+
66+
func (d *DualWriter) evalFallbackArgs(cmd string, iArgs []any) ([]any, bool) {
67+
upper := strings.ToUpper(cmd)
68+
if upper != cmdEvalSHA && upper != "EVALSHA_RO" {
69+
return nil, false
70+
}
71+
if len(iArgs) < minEvalSHAArgs {
72+
return nil, false
73+
}
74+
75+
sha := stringArg(iArgs[1])
76+
script, ok := d.lookupScript(sha)
77+
if !ok {
78+
return nil, false
79+
}
80+
81+
fallback := make([]any, len(iArgs))
82+
fallback[0] = []byte(cmdEval)
83+
fallback[1] = []byte(script)
84+
copy(fallback[2:], iArgs[2:])
85+
return fallback, true
86+
}
87+
88+
func isNoScriptError(err error) bool {
89+
if err == nil {
90+
return false
91+
}
92+
var redisErr redis.Error
93+
if errors.As(err, &redisErr) {
94+
return strings.HasPrefix(redisErr.Error(), "NOSCRIPT ")
95+
}
96+
return strings.HasPrefix(err.Error(), "NOSCRIPT ")
97+
}
98+
99+
func scriptSHA(script string) string {
100+
// #nosec G401 -- Redis EVALSHA uses SHA1 digests by protocol.
101+
sum := sha1.Sum([]byte(script))
102+
return hex.EncodeToString(sum[:])
103+
}
104+
105+
func stringArg(arg any) string {
106+
switch v := arg.(type) {
107+
case []byte:
108+
return strings.ToLower(string(v))
109+
case string:
110+
return strings.ToLower(v)
111+
default:
112+
return strings.ToLower(string(argsToBytes([]any{arg})[0]))
113+
}
114+
}

0 commit comments

Comments
 (0)