Skip to content

Commit 4545f7a

Browse files
authored
Fix Oracle SQL*Plus command splitting (#138)
fix oracle sqlplus command splitting
1 parent 757a020 commit 4545f7a

2 files changed

Lines changed: 267 additions & 18 deletions

File tree

oracle/parser/split.go

Lines changed: 118 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ func Split(sql string) []Segment {
6767

6868
if !state.inPLSQL {
6969
if cmd, ok := sqlPlusCommandAtLineStart(sql, tok); ok {
70+
prefixEmpty := onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc)
7071
if cmd.flush {
71-
if onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc) {
72+
if prefixEmpty {
7273
if tok.Type == '/' && len(segments) > 0 {
7374
stmtStart = lineEndBeforeBreak(sql, tok.End)
7475
} else {
@@ -78,20 +79,25 @@ func Split(sql string) []Segment {
7879
segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc))
7980
stmtStart = lineEndBeforeBreak(sql, tok.End)
8081
}
82+
lexer.pos = lineEndAfterBreak(sql, tok.End)
83+
state.reset()
84+
continue
8185
} else {
82-
lineEnd := lineEndBeforeBreak(sql, tok.End)
83-
nextStart := lineEndAfterBreak(sql, tok.End)
84-
commandStart := stmtStart
85-
if !onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc) {
86-
segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc))
87-
commandStart = lineStartOffset(sql, tok.Loc)
86+
if prefixEmpty || cmd.terminatesBufferedSQL {
87+
lineEnd := lineEndBeforeBreak(sql, tok.End)
88+
nextStart := lineEndAfterBreak(sql, tok.End)
89+
commandStart := stmtStart
90+
if !prefixEmpty {
91+
segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc))
92+
commandStart = lineStartOffset(sql, tok.Loc)
93+
}
94+
segments = appendSegmentWithKind(segments, sql, commandStart, lineEnd, SegmentSQLPlusCommand)
95+
stmtStart = nextStart
96+
lexer.pos = lineEndAfterBreak(sql, tok.End)
97+
state.reset()
98+
continue
8899
}
89-
segments = appendSegmentWithKind(segments, sql, commandStart, lineEnd, SegmentSQLPlusCommand)
90-
stmtStart = nextStart
91100
}
92-
lexer.pos = lineEndAfterBreak(sql, tok.End)
93-
state.reset()
94-
continue
95101
}
96102
}
97103

@@ -431,7 +437,8 @@ func (s *splitState) canStartNestedSubprogram(tok Token) bool {
431437
}
432438

433439
type sqlPlusCommand struct {
434-
flush bool
440+
flush bool
441+
terminatesBufferedSQL bool
435442
}
436443

437444
func sqlPlusCommandAtLineStart(sql string, tok Token) (sqlPlusCommand, bool) {
@@ -442,21 +449,24 @@ func sqlPlusCommandAtLineStart(sql string, tok Token) (sqlPlusCommand, bool) {
442449
}
443450

444451
if tok.Type == '/' && isSlashDelimiterLine(sql, tok.Loc, tok.End) {
445-
return sqlPlusCommand{flush: true}, true
452+
return sqlPlusCommand{flush: true, terminatesBufferedSQL: true}, true
446453
}
447454
if tok.Type == '@' || tok.Type == '!' {
448-
return sqlPlusCommand{}, true
455+
return sqlPlusCommand{terminatesBufferedSQL: true}, true
449456
}
450457

451458
word := splitTokenWord(tok)
452459
if word == "" {
453460
return sqlPlusCommand{}, false
454461
}
462+
if isOracleSetStatement(word, sql, tok.End) {
463+
return sqlPlusCommand{}, false
464+
}
455465
if isSQLPlusFlushCommand(word) {
456-
return sqlPlusCommand{flush: true}, true
466+
return sqlPlusCommand{flush: true, terminatesBufferedSQL: true}, true
457467
}
458468
if isSQLPlusLineCommand(word) {
459-
return sqlPlusCommand{}, true
469+
return sqlPlusCommand{terminatesBufferedSQL: isSQLPlusLineCommandThatTerminatesBufferedSQL(word, sql, tok.End)}, true
460470
}
461471
return sqlPlusCommand{}, false
462472
}
@@ -468,6 +478,44 @@ func splitTokenWord(tok Token) string {
468478
return ""
469479
}
470480

481+
func isOracleSetStatement(word, sql string, pos int) bool {
482+
if word != "SET" {
483+
return false
484+
}
485+
next := nextWordOnLine(sql, pos)
486+
switch next {
487+
case "TRANSACTION", "ROLE", "CONSTRAINT", "CONSTRAINTS":
488+
return true
489+
default:
490+
return false
491+
}
492+
}
493+
494+
func nextWordOnLine(sql string, pos int) string {
495+
pos = skipHorizontalSpace(sql, pos)
496+
start := pos
497+
for pos < len(sql) {
498+
c := sql[pos]
499+
if c == '\n' || c == '\r' || !(c == '_' || c >= '0' && c <= '9' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z') {
500+
break
501+
}
502+
pos++
503+
}
504+
if pos == start {
505+
return ""
506+
}
507+
508+
buf := make([]byte, pos-start)
509+
for i := start; i < pos; i++ {
510+
c := sql[i]
511+
if c >= 'a' && c <= 'z' {
512+
c -= 'a' - 'A'
513+
}
514+
buf[i-start] = c
515+
}
516+
return string(buf)
517+
}
518+
471519
func isSQLPlusFlushCommand(word string) bool {
472520
switch word {
473521
case "RUN", "R":
@@ -505,6 +553,59 @@ func isSQLPlusLineCommand(word string) bool {
505553
}
506554
}
507555

556+
func isSQLPlusLineCommandThatTerminatesBufferedSQL(word, sql string, pos int) bool {
557+
switch word {
558+
case "ACC", "ACCEPT",
559+
"BRE", "BREAK", "BTI", "BTITLE",
560+
"COL", "COLUMN", "COMP", "COMPUTE",
561+
"DEF", "DEFINE",
562+
"HO", "HOST",
563+
"PRI", "PRINT", "PRO", "PROMPT",
564+
"REM", "REMARK",
565+
"SHO", "SHOW", "SPO", "SPOOL",
566+
"TTI", "TTITLE",
567+
"UNDEF", "UNDEFINE",
568+
"VAR", "VARIABLE",
569+
"WHENEVER":
570+
return true
571+
case "CONN", "CONNECT":
572+
return isSQLPlusConnectCommandThatTerminatesBufferedSQL(sql, pos)
573+
case "STA", "START":
574+
return nextWordOnLine(sql, pos) != "WITH"
575+
case "SET":
576+
return isSQLPlusSetCommandThatTerminatesBufferedSQL(sql, pos)
577+
default:
578+
return false
579+
}
580+
}
581+
582+
func isSQLPlusConnectCommandThatTerminatesBufferedSQL(sql string, pos int) bool {
583+
switch nextWordOnLine(sql, pos) {
584+
case "BY", "TO":
585+
return false
586+
default:
587+
return true
588+
}
589+
}
590+
591+
func isSQLPlusSetCommandThatTerminatesBufferedSQL(sql string, pos int) bool {
592+
switch nextWordOnLine(sql, pos) {
593+
case "APPINFO", "ARRAYSIZE", "AUTOCOMMIT", "AUTOPRINT", "AUTORECOVERY", "AUTOTRACE",
594+
"BLOCKTERMINATOR", "CMDSEP", "COLINVISIBLE", "COLSEP", "CONCAT", "COPYCOMMIT",
595+
"COPYTYPECHECK", "DEF", "DEFINE", "DESCRIBE", "ECHO", "EDITFILE", "EMBEDDED", "ERRORLOGGING",
596+
"ESCAPE", "ESCCHAR", "EXITCOMMIT", "FEEDBACK", "FLAGGER", "FLUSH", "HEADING",
597+
"HEADSEP", "INSTANCE", "LINESIZE", "LOBOFFSET", "LOGSOURCE", "LONG", "LONGCHUNKSIZE",
598+
"MARKUP", "NEWPAGE", "NULL", "NUMFORMAT", "NUMWIDTH", "PAGESIZE", "PAUSE",
599+
"RECSEP", "RECSEPCHAR", "SCAN", "SERVEROUT", "SERVEROUTPUT", "SHIFTINOUT", "SHOWMODE", "SQLBLANKLINES",
600+
"SQLCASE", "SQLCONTINUE", "SQLNUMBER", "SQLPLUSCOMPATIBILITY", "SQLPREFIX",
601+
"SQLPROMPT", "SQLTERMINATOR", "SUFFIX", "TAB", "TERMOUT", "TIME", "TIMING",
602+
"TRIMOUT", "TRIMSPOOL", "UNDERLINE", "VERIFY", "WRAP":
603+
return true
604+
default:
605+
return false
606+
}
607+
}
608+
508609
func onlyIgnorableSQLPlusPrefix(sql string, start, end int) bool {
509610
seg := Segment{Text: sql[start:end], ByteStart: start, ByteEnd: end}
510611
return seg.Empty()

oracle/parser/split_test.go

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package parser
22

3-
import "testing"
3+
import (
4+
"os"
5+
"path/filepath"
6+
"testing"
7+
)
48

59
func TestSplitOrdinarySQL(t *testing.T) {
610
tests := []struct {
@@ -616,6 +620,150 @@ func TestSplitClassifiesSQLPlusCommands(t *testing.T) {
616620
}
617621
}
618622

623+
func TestSplitClassifiesOracleSetStatementsAsSQL(t *testing.T) {
624+
sql := "SET DEFINE OFF\n" +
625+
"SET TRANSACTION READ ONLY;\n" +
626+
"SET ROLE app_role;\n" +
627+
"SET CONSTRAINTS ALL IMMEDIATE;"
628+
got := Split(sql)
629+
wantTexts := []string{
630+
"SET DEFINE OFF",
631+
"SET TRANSACTION READ ONLY",
632+
"\nSET ROLE app_role",
633+
"\nSET CONSTRAINTS ALL IMMEDIATE",
634+
}
635+
wantKinds := []SegmentKind{
636+
SegmentSQLPlusCommand,
637+
SegmentSQL,
638+
SegmentSQL,
639+
SegmentSQL,
640+
}
641+
if len(got) != len(wantKinds) {
642+
t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantKinds))
643+
}
644+
for i := range wantKinds {
645+
if got[i].Text != wantTexts[i] {
646+
t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i])
647+
}
648+
if got[i].Kind != wantKinds[i] {
649+
t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, wantKinds[i])
650+
}
651+
}
652+
}
653+
654+
func TestSplitDoesNotClassifySQLContinuationLinesAsSQLPlus(t *testing.T) {
655+
sql := "SELECT employee_id\n" +
656+
"FROM employees\n" +
657+
"START WITH manager_id IS NULL\n" +
658+
"CONNECT BY PRIOR employee_id = manager_id;\n" +
659+
"CREATE DATABASE LINK remote_db\n" +
660+
"CONNECT TO remote_user IDENTIFIED BY remote_pass\n" +
661+
"USING 'remote_tns';\n" +
662+
"CREATE DATABASE mydb\n" +
663+
"SET DEFAULT BIGFILE TABLESPACE;"
664+
got := Split(sql)
665+
wantTexts := []string{
666+
"SELECT employee_id\nFROM employees\nSTART WITH manager_id IS NULL\nCONNECT BY PRIOR employee_id = manager_id",
667+
"\nCREATE DATABASE LINK remote_db\nCONNECT TO remote_user IDENTIFIED BY remote_pass\nUSING 'remote_tns'",
668+
"\nCREATE DATABASE mydb\nSET DEFAULT BIGFILE TABLESPACE",
669+
}
670+
if len(got) != len(wantTexts) {
671+
t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantTexts))
672+
}
673+
for i := range wantTexts {
674+
if got[i].Text != wantTexts[i] {
675+
t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i])
676+
}
677+
if got[i].Kind != SegmentSQL {
678+
t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, SegmentSQL)
679+
}
680+
}
681+
}
682+
683+
func TestSplitClassifiesUnambiguousSQLPlusCommandsAfterBufferedSQL(t *testing.T) {
684+
sql := "SELECT 1 FROM dual\n" +
685+
"PROMPT running next query\n" +
686+
"SPOOL install.log\n" +
687+
"SELECT 2 FROM dual\n" +
688+
"SET DEFINE OFF\n" +
689+
"SELECT 3 FROM dual\n" +
690+
"SET DEF OFF\n" +
691+
"SELECT 4 FROM dual\n" +
692+
"SET SERVEROUT ON\n" +
693+
"SELECT 3 FROM dual\n" +
694+
"CONNECT scott/tiger@db\n" +
695+
"SELECT 2 FROM dual;"
696+
got := Split(sql)
697+
wantTexts := []string{
698+
"SELECT 1 FROM dual",
699+
"PROMPT running next query",
700+
"SPOOL install.log",
701+
"SELECT 2 FROM dual",
702+
"SET DEFINE OFF",
703+
"SELECT 3 FROM dual",
704+
"SET DEF OFF",
705+
"SELECT 4 FROM dual",
706+
"SET SERVEROUT ON",
707+
"SELECT 3 FROM dual",
708+
"CONNECT scott/tiger@db",
709+
"SELECT 2 FROM dual",
710+
}
711+
wantKinds := []SegmentKind{
712+
SegmentSQL,
713+
SegmentSQLPlusCommand,
714+
SegmentSQLPlusCommand,
715+
SegmentSQL,
716+
SegmentSQLPlusCommand,
717+
SegmentSQL,
718+
SegmentSQLPlusCommand,
719+
SegmentSQL,
720+
SegmentSQLPlusCommand,
721+
SegmentSQL,
722+
SegmentSQLPlusCommand,
723+
SegmentSQL,
724+
}
725+
if len(got) != len(wantTexts) {
726+
t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantTexts))
727+
}
728+
for i := range wantTexts {
729+
if got[i].Text != wantTexts[i] {
730+
t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i])
731+
}
732+
if got[i].Kind != wantKinds[i] {
733+
t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, wantKinds[i])
734+
}
735+
}
736+
}
737+
738+
func TestSplitDoesNotClassifyValidCorpusStatementsAsSQLPlus(t *testing.T) {
739+
corpusDir := filepath.Join("..", "quality", "corpus")
740+
entries, err := os.ReadDir(corpusDir)
741+
if err != nil {
742+
corpusDir = filepath.Join("oracle", "quality", "corpus")
743+
entries, err = os.ReadDir(corpusDir)
744+
if err != nil {
745+
t.Fatalf("cannot read corpus directory: %v", err)
746+
}
747+
}
748+
749+
for _, entry := range entries {
750+
if entry.IsDir() || filepath.Ext(entry.Name()) != ".sql" {
751+
continue
752+
}
753+
path := filepath.Join(corpusDir, entry.Name())
754+
for _, stmt := range loadCorpusFile(t, path) {
755+
if stmt.valid != "true" {
756+
continue
757+
}
758+
for _, seg := range Split(stmt.sql) {
759+
if seg.Kind == SegmentSQLPlusCommand {
760+
t.Fatalf("%s/%s classified valid SQL as SQL*Plus command: %q", entry.Name(), stmt.name, seg.Text)
761+
}
762+
}
763+
}
764+
}
765+
}
766+
619767
func splitTexts(segs []Segment) []string {
620768
if len(segs) == 0 {
621769
return nil

0 commit comments

Comments
 (0)