Skip to content

Commit 845ca69

Browse files
authored
Inline parser driver into ast package (#13)
1 parent 7914c6b commit 845ca69

17 files changed

Lines changed: 83 additions & 147 deletions

BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ go_test(
6060
"//pkg/parser/mysql",
6161
"//pkg/parser/opcode",
6262
"//pkg/parser/terror",
63-
"//pkg/parser/test_driver",
6463
"@com_github_pingcap_errors//:errors",
6564
"@com_github_stretchr_testify//require",
6665
"@org_uber_go_goleak//:goleak",

ast/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@ go_library(
55
srcs = [
66
"ast.go",
77
"base.go",
8+
"datum.go",
9+
"datum_helper.go",
810
"ddl.go",
911
"dml.go",
1012
"expressions.go",
1113
"flag.go",
1214
"functions.go",
1315
"misc.go",
1416
"model.go",
17+
"mydecimal.go",
1518
"procedure.go",
1619
"sem.go",
1720
"stats.go",
1821
"util.go",
22+
"value_expr.go",
1923
],
2024
importpath = "github.com/sqlc-dev/marino/ast",
2125
visibility = ["//visibility:public"],
@@ -60,7 +64,6 @@ go_test(
6064
"//pkg/parser/charset",
6165
"//pkg/parser/format",
6266
"//pkg/parser/mysql",
63-
"//pkg/parser/test_driver",
6467
"@com_github_stretchr_testify//require",
6568
],
6669
)

ast/base.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,6 @@ type exprNode struct {
410410
flag uint64
411411
}
412412

413-
// TexprNode is exported for parser driver.
414-
type TexprNode = exprNode
415-
416413
// SetType implements ExprNode interface.
417414
func (en *exprNode) SetType(tp *types.FieldType) {
418415
en.Type = *tp
Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
//go:build !codes
1515

16-
package test_driver
16+
package ast
1717

1818
import (
1919
"bytes"
@@ -155,12 +155,12 @@ func (d *Datum) SetNull() {
155155
}
156156

157157
// GetBinaryLiteral gets Bit value
158-
func (d *Datum) GetBinaryLiteral() BinaryLiteral {
158+
func (d *Datum) GetBinaryLiteral() BinaryLit {
159159
return d.b
160160
}
161161

162162
// SetBinaryLiteral sets Bit value
163-
func (d *Datum) SetBinaryLiteral(b BinaryLiteral) {
163+
func (d *Datum) SetBinaryLiteral(b BinaryLit) {
164164
d.k = KindBinaryLiteral
165165
d.b = b
166166
}
@@ -227,12 +227,12 @@ func (d *Datum) SetValue(val any) {
227227
d.SetBytes(x)
228228
case *MyDecimal:
229229
d.SetMysqlDecimal(x)
230-
case BinaryLiteral:
230+
case BinaryLit:
231231
d.SetBinaryLiteral(x)
232-
case BitLiteral: // Store as BinaryLiteral for Bit and Hex literals
233-
d.SetBinaryLiteral(BinaryLiteral(x))
232+
case BitLiteral: // Store as BinaryLit for Bit and Hex literals
233+
d.SetBinaryLiteral(BinaryLit(x))
234234
case HexLiteral:
235-
d.SetBinaryLiteral(BinaryLiteral(x))
235+
d.SetBinaryLiteral(BinaryLit(x))
236236
default:
237237
d.SetInterface(x)
238238
}
@@ -270,33 +270,33 @@ func MakeDatums(args ...any) []Datum {
270270
return datums
271271
}
272272

273-
// BinaryLiteral is the internal type for storing bit / hex literal type.
274-
type BinaryLiteral []byte
273+
// BinaryLit is the internal type for storing bit / hex literal type.
274+
type BinaryLit []byte
275275

276276
// BitLiteral is the bit literal type.
277-
type BitLiteral BinaryLiteral
277+
type BitLiteral BinaryLit
278278

279279
// HexLiteral is the hex literal type.
280-
type HexLiteral BinaryLiteral
280+
type HexLiteral BinaryLit
281281

282-
// ZeroBinaryLiteral is a BinaryLiteral literal with zero value.
283-
var ZeroBinaryLiteral = BinaryLiteral{}
282+
// ZeroBinaryLit is a BinaryLit literal with zero value.
283+
var ZeroBinaryLit = BinaryLit{}
284284

285285
// String implements fmt.Stringer interface.
286-
func (b BinaryLiteral) String() string {
286+
func (b BinaryLit) String() string {
287287
if len(b) == 0 {
288288
return ""
289289
}
290290
return "0x" + hex.EncodeToString(b)
291291
}
292292

293293
// ToString returns the string representation for the literal.
294-
func (b BinaryLiteral) ToString() string {
294+
func (b BinaryLit) ToString() string {
295295
return string(b)
296296
}
297297

298298
// ToBitLiteralString returns the bit literal representation for the literal.
299-
func (b BinaryLiteral) ToBitLiteralString(trimLeadingZero bool) string {
299+
func (b BinaryLit) ToBitLiteralString(trimLeadingZero bool) string {
300300
if len(b) == 0 {
301301
return "b''"
302302
}
@@ -317,7 +317,7 @@ func (b BinaryLiteral) ToBitLiteralString(trimLeadingZero bool) string {
317317
// ParseBitStr parses bit string.
318318
// The string format can be b'val', B'val' or 0bval, val must be 0 or 1.
319319
// See https://dev.mysql.com/doc/refman/5.7/en/bit-value-literals.html
320-
func ParseBitStr(s string) (BinaryLiteral, error) {
320+
func ParseBitStr(s string) (BinaryLit, error) {
321321
if len(s) == 0 {
322322
return nil, fmt.Errorf("invalid empty string for parsing bit type")
323323
}
@@ -333,7 +333,7 @@ func ParseBitStr(s string) (BinaryLiteral, error) {
333333
}
334334

335335
if len(s) == 0 {
336-
return ZeroBinaryLiteral, nil
336+
return ZeroBinaryLit, nil
337337
}
338338

339339
alignedLength := (len(s) + 7) &^ 7
@@ -362,14 +362,14 @@ func NewBitLiteral(s string) (BitLiteral, error) {
362362
return BitLiteral(b), nil
363363
}
364364

365-
// ToString implement ast.BinaryLiteral interface
365+
// ToString implement BinaryLiteral interface
366366
func (b BitLiteral) ToString() string {
367-
return BinaryLiteral(b).ToString()
367+
return BinaryLit(b).ToString()
368368
}
369369

370370
// ParseHexStr parses hexadecimal string literal.
371371
// See https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html
372-
func ParseHexStr(s string) (BinaryLiteral, error) {
372+
func ParseHexStr(s string) (BinaryLit, error) {
373373
if len(s) == 0 {
374374
return nil, fmt.Errorf("invalid empty string for parsing hexadecimal literal")
375375
}
@@ -388,7 +388,7 @@ func ParseHexStr(s string) (BinaryLiteral, error) {
388388
}
389389

390390
if len(s) == 0 {
391-
return ZeroBinaryLiteral, nil
391+
return ZeroBinaryLit, nil
392392
}
393393

394394
if len(s)%2 != 0 {
@@ -410,9 +410,9 @@ func NewHexLiteral(s string) (HexLiteral, error) {
410410
return HexLiteral(h), nil
411411
}
412412

413-
// ToString implement ast.BinaryLiteral interface
413+
// ToString implement BinaryLiteral interface
414414
func (b HexLiteral) ToString() string {
415-
return BinaryLiteral(b).ToString()
415+
return BinaryLit(b).ToString()
416416
}
417417

418418
// SetBinChsClnFlag sets charset, collation as 'binary' and adds binaryFlag to FieldType.
@@ -491,7 +491,7 @@ func DefaultTypeForValue(value any, tp *types.FieldType, charset string, collate
491491
tp.SetDecimal(0)
492492
tp.AddFlag(mysql.UnsignedFlag)
493493
SetBinChsClnFlag(tp)
494-
case BinaryLiteral:
494+
case BinaryLit:
495495
tp.SetType(mysql.TypeBit)
496496
tp.SetFlen(len(x) * 8)
497497
tp.SetDecimal(0)
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
//go:build !codes
1515

16-
package test_driver
16+
package ast
1717

1818
import (
1919
"math"
@@ -38,7 +38,7 @@ func pow10(x int) int32 {
3838
return int32(math.Pow10(x))
3939
}
4040

41-
func Abs(n int64) int64 {
41+
func absInt64(n int64) int64 {
4242
y := n >> 63
4343
return (n ^ y) - y
4444
}
@@ -68,5 +68,5 @@ func StrLenOfInt64Fast(x int64) int {
6868
if x < 0 {
6969
size = 1 // add "-" sign on the length count
7070
}
71-
return size + StrLenOfUint64Fast(uint64(Abs(x)))
71+
return size + StrLenOfUint64Fast(uint64(absInt64(x)))
7272
}

ast/expressions.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,6 @@ type ValueExpr interface {
6464
SetProjectionOffset(offset int)
6565
}
6666

67-
// NewValueExpr creates a ValueExpr with value, and sets default field type.
68-
var NewValueExpr func(value any, charset string, collate string) ValueExpr
69-
70-
// NewParamMarkerExpr creates a ParamMarkerExpr.
71-
var NewParamMarkerExpr func(offset int) ParamMarkerExpr
72-
7367
// BetweenExpr is for "between and" or "not between and" expression.
7468
type BetweenExpr struct {
7569
exprNode

ast/functions_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"github.com/sqlc-dev/marino/format"
2222
"github.com/sqlc-dev/marino/mysql"
2323
"github.com/sqlc-dev/marino/parser"
24-
"github.com/sqlc-dev/marino/test_driver"
2524

2625
"reflect"
2726
)
@@ -182,7 +181,7 @@ func TestConvert(t *testing.T) {
182181

183182
st := stmt.(*SelectStmt)
184183
expr := st.Fields.Fields[0].Expr.(*FuncCallExpr)
185-
charsetArg := expr.Args[1].(*test_driver.ValueExpr)
184+
charsetArg := expr.Args[1].(*ValueExprBase)
186185
if !reflect.DeepEqual(testCase.CharsetName, charsetArg.GetString()) {
187186
t.Fatalf("got %v, want %v", charsetArg.GetString(), testCase.CharsetName)
188187
}
@@ -217,7 +216,7 @@ func TestChar(t *testing.T) {
217216

218217
st := stmt.(*SelectStmt)
219218
expr := st.Fields.Fields[0].Expr.(*FuncCallExpr)
220-
charsetArg := expr.Args[1].(*test_driver.ValueExpr)
219+
charsetArg := expr.Args[1].(*ValueExprBase)
221220
if !reflect.DeepEqual(testCase.CharsetName, charsetArg.GetString()) {
222221
t.Fatalf("got %v, want %v", charsetArg.GetString(), testCase.CharsetName)
223222
}

ast/misc.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4226,19 +4226,12 @@ type TextString struct {
42264226
IsBinaryLiteral bool
42274227
}
42284228

4229+
// BinaryLiteral abstracts over the concrete bit/hex literal types so the
4230+
// parser can stringify any of them without a type switch.
42294231
type BinaryLiteral interface {
42304232
ToString() string
42314233
}
42324234

4233-
// NewDecimal creates a types.Decimal value, it's provided by parser driver.
4234-
var NewDecimal func(string) (any, error)
4235-
4236-
// NewHexLiteral creates a types.HexLiteral value, it's provided by parser driver.
4237-
var NewHexLiteral func(string) (any, error)
4238-
4239-
// NewBitLiteral creates a types.BitLiteral value, it's provided by parser driver.
4240-
var NewBitLiteral func(string) (any, error)
4241-
42424235
// SetResourceGroupStmt is a statement to set the resource group name for current session.
42434236
type SetResourceGroupStmt struct {
42444237
stmtNode
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
//go:build !codes
1515

16-
package test_driver
16+
package ast
1717

1818
const panicInfo = "This branch is not implemented. " +
1919
"This is because you are trying to test something specific to TiDB's MyDecimal implementation. " +

ast/util_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
. "github.com/sqlc-dev/marino/format"
2323
"github.com/sqlc-dev/marino/mysql"
2424
"github.com/sqlc-dev/marino/parser"
25-
"github.com/sqlc-dev/marino/test_driver"
2625

2726
"reflect"
2827
)
@@ -194,7 +193,7 @@ func (checker *nodeTextCleaner) Enter(in Node) (out Node, skipChildren bool) {
194193
node.FnName.O = strings.ToLower(node.FnName.O)
195194
switch node.FnName.L {
196195
case "convert":
197-
node.Args[1].(*test_driver.ValueExpr).Datum.SetBytes(nil)
196+
node.Args[1].(*ValueExprBase).Datum.SetBytes(nil)
198197
}
199198
case *AggregateFuncExpr:
200199
node.F = strings.ToLower(node.F)

0 commit comments

Comments
 (0)