Skip to content

Commit a6be63d

Browse files
authored
sqlreplay: fix replaying with readonly mode still prepares DML statements (#998)
1 parent a15c5e0 commit a6be63d

7 files changed

Lines changed: 145 additions & 87 deletions

File tree

pkg/sqlreplay/conn/conn.go

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,9 @@ func (c *conn) Run(ctx context.Context) {
164164
if command == nil {
165165
break
166166
}
167-
if c.readonly {
168-
if !c.isReadOnly(command.Value) {
169-
c.replayStats.FilteredCmds.Add(1)
170-
continue
171-
}
167+
if c.readonly && !c.isReadOnly(command.Value) {
168+
c.replayStats.FilteredCmds.Add(1)
169+
continue
172170
}
173171
// Quit the connection in the next round no matter what exception happens (like disconnection).
174172
if command.Value.Type == pnet.ComQuit {
@@ -234,16 +232,14 @@ func (c *conn) Run(ctx context.Context) {
234232

235233
func (c *conn) isReadOnly(command *cmd.Command) bool {
236234
switch command.Type {
237-
case pnet.ComQuery:
235+
case pnet.ComQuery, pnet.ComStmtPrepare:
236+
// If the statement is not readonly, it won't be prepared.
238237
return lex.IsReadOnly(hack.String(command.Payload[1:]))
239-
case pnet.ComStmtExecute, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch:
240-
stmtID := binary.LittleEndian.Uint32(command.Payload[1:5])
241-
ps := c.preparedStmts[stmtID]
242-
if len(ps.text) == 0 {
243-
// Maybe the connection is reconnected after disconnection and the prepared statements are lost.
244-
return false
245-
}
246-
return lex.IsReadOnly(ps.text)
238+
case pnet.ComStmtExecute, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch, pnet.ComStmtClose:
239+
// If the statement is prepared successfully, then it's readonly.
240+
captureStmtID := binary.LittleEndian.Uint32(command.Payload[1:5])
241+
_, ok := c.psIDMapping[captureStmtID]
242+
return ok
247243
case pnet.ComCreateDB, pnet.ComDropDB, pnet.ComDelayedInsert:
248244
return false
249245
}
@@ -254,9 +250,7 @@ func (c *conn) isReadOnly(command *cmd.Command) bool {
254250
return true
255251
}
256252

257-
// maintain prepared statement info so that we can find its info when:
258-
// - Judge whether an EXECUTE command is readonly
259-
// - Get the error message when an EXECUTE command fails
253+
// Maintain prepared statement info so that we can find its info when getting the failed statement and params.
260254
func (c *conn) updatePreparedStmts(capturedPsID uint32, request []byte, resp ExecuteResp) {
261255
switch request[0] {
262256
case pnet.ComStmtPrepare.Byte():

pkg/sqlreplay/conn/conn_test.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -223,27 +223,31 @@ func TestSkipReadOnly(t *testing.T) {
223223
readonly: false,
224224
},
225225
{
226-
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("select ?")...)},
226+
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, CapturedPsID: 1, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("select ?")...)},
227227
readonly: true,
228228
},
229229
{
230-
cmd: &cmd.Command{Type: pnet.ComStmtExecute, Payload: []byte{pnet.ComStmtExecute.Byte(), 1, 0, 0, 0, 0, 0, 0, 0}},
230+
cmd: &cmd.Command{Type: pnet.ComStmtExecute, CapturedPsID: 1, Payload: []byte{pnet.ComStmtExecute.Byte(), 1, 0, 0, 0, 0, 0, 0, 0}},
231231
readonly: true,
232232
},
233233
{
234-
cmd: &cmd.Command{Type: pnet.ComStmtFetch, Payload: []byte{pnet.ComStmtFetch.Byte(), 1, 0, 0, 0}},
234+
cmd: &cmd.Command{Type: pnet.ComStmtFetch, CapturedPsID: 1, Payload: []byte{pnet.ComStmtFetch.Byte(), 1, 0, 0, 0}},
235235
readonly: true,
236236
},
237237
{
238-
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("insert into t value(?)")...)},
239-
readonly: true,
238+
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, CapturedPsID: 2, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("insert into t value(?)")...)},
239+
readonly: false,
240+
},
241+
{
242+
cmd: &cmd.Command{Type: pnet.ComStmtExecute, CapturedPsID: 2, Payload: []byte{pnet.ComStmtExecute.Byte(), 2, 0, 0, 0}},
243+
readonly: false,
240244
},
241245
{
242-
cmd: &cmd.Command{Type: pnet.ComStmtExecute, Payload: []byte{pnet.ComStmtExecute.Byte(), 2, 0, 0, 0}},
246+
cmd: &cmd.Command{Type: pnet.ComStmtSendLongData, CapturedPsID: 2, Payload: []byte{pnet.ComStmtFetch.Byte(), 2, 0, 0, 0, 0, 0, 0, 0}},
243247
readonly: false,
244248
},
245249
{
246-
cmd: &cmd.Command{Type: pnet.ComStmtSendLongData, Payload: []byte{pnet.ComStmtFetch.Byte(), 2, 0, 0, 0, 0, 0, 0, 0}},
250+
cmd: &cmd.Command{Type: pnet.ComStmtClose, CapturedPsID: 2, Payload: []byte{pnet.ComStmtClose.Byte(), 2, 0, 0, 0}},
247251
readonly: false,
248252
},
249253
{
@@ -315,7 +319,7 @@ func TestReadOnly(t *testing.T) {
315319
{
316320
cmd: pnet.ComStmtPrepare,
317321
stmt: "insert into t value(?)",
318-
readOnly: true,
322+
readOnly: false,
319323
},
320324
{
321325
cmd: pnet.ComStmtExecute,
@@ -327,10 +331,15 @@ func TestReadOnly(t *testing.T) {
327331
stmt: "insert into t value(?)",
328332
readOnly: false,
329333
},
334+
{
335+
cmd: pnet.ComStmtExecute,
336+
stmt: "",
337+
readOnly: false,
338+
},
330339
{
331340
cmd: pnet.ComStmtClose,
332341
stmt: "insert into t value(?)",
333-
readOnly: true,
342+
readOnly: false,
334343
},
335344
{
336345
cmd: pnet.ComQuit,
@@ -346,13 +355,19 @@ func TestReadOnly(t *testing.T) {
346355
backendConn := newMockBackendConn()
347356
conn.backendConn = backendConn
348357
for i, test := range tests {
358+
clear(conn.psIDMapping)
349359
var payload []byte
350360
switch test.cmd {
351-
case pnet.ComQuery:
361+
case pnet.ComQuery, pnet.ComStmtPrepare:
352362
payload = append([]byte{test.cmd.Byte()}, []byte(test.stmt)...)
353-
default:
354-
conn.preparedStmts[1] = preparedStmt{text: test.stmt}
363+
case pnet.ComStmtExecute, pnet.ComStmtClose, pnet.ComStmtFetch, pnet.ComStmtReset, pnet.ComStmtSendLongData:
364+
prepare := cmd.NewCommand(append([]byte{pnet.ComStmtPrepare.Byte()}, []byte(test.stmt)...), time.Time{}, 100)
365+
if conn.isReadOnly(prepare) {
366+
conn.psIDMapping[1] = 1
367+
}
355368
payload = []byte{test.cmd.Byte(), 1, 0, 0, 0}
369+
default:
370+
payload = []byte{test.cmd.Byte()}
356371
}
357372
command := cmd.NewCommand(payload, time.Time{}, 100)
358373
require.Equal(t, test.readOnly, conn.isReadOnly(command), "case %d", i)

pkg/sqlreplay/replay/dry_run.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/pingcap/tiproxy/pkg/sqlreplay/cmd"
1010
"github.com/pingcap/tiproxy/pkg/sqlreplay/conn"
1111
"github.com/pingcap/tiproxy/pkg/sqlreplay/report"
12+
"github.com/pingcap/tiproxy/pkg/util/waitgroup"
1213
)
1314

1415
type nopConn struct {
@@ -32,24 +33,41 @@ func (c *nopConn) Stop() {
3233
c.closeCh <- c.connID
3334
}
3435

35-
var _ report.Report = (*mockReport)(nil)
36+
var _ report.Report = (*nopReport)(nil)
3637

37-
type mockReport struct {
38+
type nopReport struct {
3839
exceptionCh chan conn.Exception
40+
wg waitgroup.WaitGroup
41+
cancel context.CancelFunc
3942
}
4043

41-
func newMockReport(exceptionCh chan conn.Exception) *mockReport {
42-
return &mockReport{
44+
func newMockReport(exceptionCh chan conn.Exception) *nopReport {
45+
return &nopReport{
4346
exceptionCh: exceptionCh,
4447
}
4548
}
4649

47-
func (mr *mockReport) Start(ctx context.Context, cfg report.ReportConfig) error {
50+
func (mr *nopReport) Start(ctx context.Context, cfg report.ReportConfig) error {
51+
childCtx, cancel := context.WithCancel(ctx)
52+
mr.cancel = cancel
53+
mr.wg.RunWithRecover(func() { mr.loop(childCtx) }, nil, nil)
4854
return nil
4955
}
5056

51-
func (mr *mockReport) Stop(err error) {
57+
func (mr *nopReport) loop(ctx context.Context) {
58+
for ctx.Err() == nil {
59+
select {
60+
case <-ctx.Done():
61+
return
62+
case <-mr.exceptionCh:
63+
}
64+
}
5265
}
5366

54-
func (mr *mockReport) Close() {
67+
func (mr *nopReport) Close() {
68+
if mr.cancel != nil {
69+
mr.cancel()
70+
mr.cancel = nil
71+
}
72+
mr.wg.Wait()
5573
}

pkg/sqlreplay/replay/replay.go

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ func (r *replay) Start(cfg ReplayConfig, backendTLSConfig *tls.Config, hsHandler
360360
}
361361
r.report = cfg.report
362362
if r.report == nil {
363-
if cfg.DryRun {
364-
r.report = &mockReport{exceptionCh: r.exceptionCh}
363+
if cfg.DryRun || cfg.ReadOnly {
364+
r.report = &nopReport{exceptionCh: r.exceptionCh}
365365
} else {
366366
backendConnCreator := func() conn.BackendConn {
367367
return conn.NewBackendConn(r.lg.Named("be"), r.idMgr.NewID(), hsHandler, bcConfig, backendTLSConfig, r.cfg.Username, r.cfg.Password)
@@ -513,8 +513,9 @@ func (r *replay) readCommands(ctx context.Context) {
513513
zap.Int("alive_conns", connCount),
514514
zap.Time("last_cmd_start_ts", time.Unix(0, r.replayStats.CurCmdTs.Load())),
515515
zap.Time("last_cmd_end_ts", time.Unix(0, r.replayStats.CurCmdEndTs.Load())),
516-
zap.NamedError("ctx_err", ctx.Err()),
517-
zap.Bool("graceful_stop", r.gracefulStop.Load()))
516+
zap.Bool("graceful_stop", r.gracefulStop.Load()),
517+
zap.Error(err),
518+
zap.NamedError("ctx_err", ctx.Err()))
518519

519520
// Notify the connections that the commands are finished.
520521
for _, conn := range conns {
@@ -816,24 +817,22 @@ func (r *replay) saveCheckpointLoop(ctx context.Context) {
816817
}
817818
defer file.Close()
818819

819-
for {
820-
// Add an interval here to avoid printing too many logs when error occurs.
821-
if err != nil {
822-
time.Sleep(stateSaveRetryInterval)
823-
}
824-
820+
for ctx.Err() == nil {
825821
select {
826822
case <-ctx.Done():
827-
return
823+
break
828824
case <-ticker.C:
829825
err = r.saveCheckpointToFile(file)
830826
if err != nil {
831827
r.lg.Error("save current checkpoint failed", zap.Error(err))
828+
// Add an interval here to avoid printing too many logs when error occurs.
832829
time.Sleep(stateSaveRetryInterval)
833-
continue
834830
}
835831
}
836832
}
833+
if err = r.saveCheckpointToFile(file); err != nil {
834+
r.lg.Error("save current state failed on close", zap.Error(err))
835+
}
837836
}
838837

839838
func (r *replay) saveCheckpointToFile(file *os.File) error {
@@ -855,16 +854,6 @@ func (r *replay) saveCheckpointToFile(file *os.File) error {
855854
return nil
856855
}
857856

858-
func (r *replay) saveCurrentStateToFilePath(filePath string) error {
859-
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0644)
860-
if err != nil {
861-
return errors.Wrapf(err, "open state file %s", filePath)
862-
}
863-
defer file.Close()
864-
865-
return r.saveCheckpointToFile(file)
866-
}
867-
868857
func (r *replay) fetchCurrentCheckpoint() replayCheckpoint {
869858
return replayCheckpoint{
870859
CurCmdTs: r.replayStats.CurCmdTs.Load(),
@@ -923,6 +912,10 @@ func (r *replay) stop(err error) {
923912
r.cancel = nil
924913
}
925914
close(r.execInfoCh)
915+
if r.report != nil {
916+
r.report.Close()
917+
r.report = nil
918+
}
926919
r.endTime = time.Now()
927920
// decodedCmds - pendingCmds may be greater than replayedCmds because if a connection is closed unexpectedly,
928921
// the pending commands of that connection are discarded. We calculate the progress based on decodedCmds - pendingCmds.
@@ -1006,17 +999,6 @@ func (r *replay) Stop(err error, graceful bool) {
1006999

10071000
func (r *replay) Close() {
10081001
r.Stop(errors.New("shutting down"), false)
1009-
if r.report != nil {
1010-
r.report.Close()
1011-
}
1012-
// at this time, the save checkpoint loop and replay loop have exited. It's safe to update the latest
1013-
// checkpoint file.
1014-
if len(r.cfg.CheckPointFilePath) > 0 {
1015-
err := r.saveCurrentStateToFilePath(r.cfg.CheckPointFilePath)
1016-
if err != nil {
1017-
r.lg.Error("save current state failed on close", zap.Error(err))
1018-
}
1019-
}
10201002
}
10211003

10221004
func getDirForInput(input string) (string, error) {

pkg/util/lex/filter.go

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
package lex
55

6+
import (
7+
"strings"
8+
9+
"github.com/pingcap/tidb/pkg/parser"
10+
)
11+
612
func startsWithKeyword(sql string, keywords [][]string) bool {
713
lexer := NewLexer(sql)
814
tokens := make([]string, 0, 2)
@@ -47,24 +53,51 @@ func IsSensitiveSQL(sql string) bool {
4753
// include SELECT FOR UPDATE because it doesn't require write privilege
4854
// include SET because SET SESSION_STATES and SET session variables should be executed
4955
// include BEGIN / COMMIT in case the user sets autocommit to false, either in SET SESSION_STATES or SET @@autocommit
50-
var readOnlyKeywords = [][]string{
51-
{"SELECT"},
52-
{"SHOW"},
53-
{"WITH"},
54-
{"SET"},
55-
{"USE"},
56-
{"DESC"},
57-
{"DESCRIBE"},
58-
{"TABLE"},
59-
{"DO"},
60-
{"BEGIN"},
61-
{"COMMIT"},
62-
{"ROLLBACK"},
63-
{"START", "TRANSACTION"},
64-
}
65-
6656
func IsReadOnly(sql string) bool {
67-
return startsWithKeyword(sql, readOnlyKeywords)
57+
lexer := NewLexer(sql)
58+
switch lexer.NextToken() {
59+
case "SELECT":
60+
for {
61+
token := lexer.NextToken()
62+
if token == "" {
63+
break
64+
}
65+
if token == "FOR" && lexer.NextToken() == "UPDATE" {
66+
return false
67+
}
68+
}
69+
return true
70+
case "SHOW", "WITH", "USE", "DESC", "DESCRIBE", "TABLE", "DO", "BEGIN", "COMMIT", "ROLLBACK":
71+
return true
72+
case "START":
73+
return lexer.NextToken() == "TRANSACTION"
74+
case "SET":
75+
// Filter `set global`, `set @@global.`, `set password`, and other unknown statements.
76+
normalized := parser.Normalize(sql, "ON")
77+
switch {
78+
case strings.HasPrefix(normalized, "set session_states "):
79+
return true
80+
case strings.HasPrefix(normalized, "set session "):
81+
return true
82+
case strings.HasPrefix(normalized, "set names "):
83+
return true
84+
case strings.HasPrefix(normalized, "set char "):
85+
return true
86+
case strings.HasPrefix(normalized, "set charset "):
87+
return true
88+
case strings.HasPrefix(normalized, "set character "):
89+
return true
90+
case strings.HasPrefix(normalized, "set transaction "):
91+
return true
92+
case strings.HasPrefix(normalized, "set @@global."):
93+
return false
94+
case strings.HasPrefix(normalized, "set @"):
95+
return true
96+
}
97+
return false
98+
99+
}
100+
return false
68101
}
69102

70103
var startTxnKeywords = [][]string{

pkg/util/lex/filter_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,17 @@ func TestReadOnlySQL(t *testing.T) {
3737
{`SELECT ? FROM table_name`, true},
3838
{`(select * from t1) union (select * from t2)`, true},
3939
{`WITH cte AS (SELECT 1, 2) SELECT * FROM cte t1, cte t2`, true},
40+
{`SELECT ? FROM table_name for update`, false},
41+
{`SELECT "for update"`, true},
4042
{`SET session_States ''`, true},
4143
{`SET @@session_variable=true`, true},
42-
{`set GLOBAL variable=false`, true},
44+
{`SET @@global.variable=true`, false},
45+
{`set GLOBAL variable=false`, false},
46+
{`set password = 'hello'`, false},
47+
{`set NAMES utf8`, true},
48+
{`set character utf8`, true},
49+
{`set transaction isolation_level = 'read committed`, true},
50+
{`SET @variable=true`, true},
4351
{`insert into table t value(1)`, false},
4452
{`desc table t`, true},
4553
{`describe select * from t`, true},

0 commit comments

Comments
 (0)