Skip to content

Commit a3dbf52

Browse files
committed
Add IDENTITY function call and IDENTITYCOL/ROWGUIDCOL support
- Add IdentityFunctionCall AST type for IDENTITY(data_type [, seed, increment]) - Add parsing for IDENTITY function calls - Add handling for IDENTITYCOL and ROWGUIDCOL column types in expressions - Handle multi-part identifiers with empty parts (e.g., master..t1.IDENTITYCOL) - Fix national string handling in column aliases (AS N'alias') - Add JSON marshaling for IdentityFunctionCall This enables the BaselinesCommon_SelectExpressionTests test.
1 parent d12e922 commit a3dbf52

4 files changed

Lines changed: 243 additions & 10 deletions

File tree

ast/function_call.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,13 @@ type TryConvertCall struct {
9797

9898
func (*TryConvertCall) node() {}
9999
func (*TryConvertCall) scalarExpression() {}
100+
101+
// IdentityFunctionCall represents an IDENTITY function call: IDENTITY(data_type [, seed, increment])
102+
type IdentityFunctionCall struct {
103+
DataType DataTypeReference `json:"DataType,omitempty"`
104+
Seed ScalarExpression `json:"Seed,omitempty"`
105+
Increment ScalarExpression `json:"Increment,omitempty"`
106+
}
107+
108+
func (*IdentityFunctionCall) node() {}
109+
func (*IdentityFunctionCall) scalarExpression() {}

parser/marshal.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,20 @@ func scalarExpressionToJSON(expr ast.ScalarExpression) jsonNode {
14881488
node["Collation"] = identifierToJSON(e.Collation)
14891489
}
14901490
return node
1491+
case *ast.IdentityFunctionCall:
1492+
node := jsonNode{
1493+
"$type": "IdentityFunctionCall",
1494+
}
1495+
if e.DataType != nil {
1496+
node["DataType"] = dataTypeReferenceToJSON(e.DataType)
1497+
}
1498+
if e.Seed != nil {
1499+
node["Seed"] = scalarExpressionToJSON(e.Seed)
1500+
}
1501+
if e.Increment != nil {
1502+
node["Increment"] = scalarExpressionToJSON(e.Increment)
1503+
}
1504+
return node
14911505
case *ast.BinaryExpression:
14921506
node := jsonNode{
14931507
"$type": "BinaryExpression",

parser/parse_select.go

Lines changed: 218 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,42 @@ func (p *Parser) parseSelectElement() (ast.SelectElement, error) {
418418
}
419419

420420
// Not an assignment, treat as regular scalar expression starting with variable
421-
// We need to "un-consume" the variable and let parseScalarExpression handle it
422-
// Create the variable reference and use it as the expression
423421
varRef := &ast.VariableReference{Name: varName}
424-
sse := &ast.SelectScalarExpression{Expression: varRef}
422+
423+
// Check if next token is a binary operator - if so, continue parsing the expression
424+
var expr ast.ScalarExpression = varRef
425+
for p.curTok.Type == TokenPlus || p.curTok.Type == TokenMinus ||
426+
p.curTok.Type == TokenStar || p.curTok.Type == TokenSlash ||
427+
p.curTok.Type == TokenPercent || p.curTok.Type == TokenDoublePipe {
428+
// We have a variable followed by a binary operator, continue parsing
429+
var opType string
430+
switch p.curTok.Type {
431+
case TokenPlus:
432+
opType = "Add"
433+
case TokenMinus:
434+
opType = "Subtract"
435+
case TokenStar:
436+
opType = "Multiply"
437+
case TokenSlash:
438+
opType = "Divide"
439+
case TokenPercent:
440+
opType = "Modulo"
441+
case TokenDoublePipe:
442+
opType = "Add" // String concatenation
443+
}
444+
p.nextToken() // consume operator
445+
right, err := p.parsePrimaryExpression()
446+
if err != nil {
447+
return nil, err
448+
}
449+
expr = &ast.BinaryExpression{
450+
FirstExpression: expr,
451+
SecondExpression: right,
452+
BinaryExpressionType: opType,
453+
}
454+
}
455+
456+
sse := &ast.SelectScalarExpression{Expression: expr}
425457

426458
// Check for column alias
427459
if p.curTok.Type == TokenIdent && p.curTok.Literal[0] == '[' {
@@ -512,6 +544,13 @@ func (p *Parser) parseSelectElement() (ast.SelectElement, error) {
512544
Value: str.Value,
513545
ValueExpression: str,
514546
}
547+
} else if p.curTok.Type == TokenNationalString {
548+
// National string literal alias: AS N'alias'
549+
str, _ := p.parseNationalStringFromToken()
550+
sse.ColumnName = &ast.IdentifierOrValueExpression{
551+
Value: str.Value,
552+
ValueExpression: str,
553+
}
515554
} else {
516555
alias := p.parseIdentifier()
517556
sse.ColumnName = &ast.IdentifierOrValueExpression{
@@ -735,6 +774,17 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) {
735774
if upper == "TRY_CONVERT" && p.peekTok.Type == TokenLParen {
736775
return p.parseTryConvertCall()
737776
}
777+
if upper == "IDENTITY" && p.peekTok.Type == TokenLParen {
778+
return p.parseIdentityFunctionCall()
779+
}
780+
if upper == "IDENTITYCOL" {
781+
p.nextToken()
782+
return &ast.ColumnReferenceExpression{ColumnType: "IdentityCol"}, nil
783+
}
784+
if upper == "ROWGUIDCOL" {
785+
p.nextToken()
786+
return &ast.ColumnReferenceExpression{ColumnType: "RowGuidCol"}, nil
787+
}
738788
return p.parseColumnReferenceOrFunctionCall()
739789
case TokenNumber:
740790
val := p.curTok.Literal
@@ -785,6 +835,13 @@ func (p *Parser) parsePrimaryExpression() (ast.ScalarExpression, error) {
785835
// Wildcard column reference (e.g., * in count(*))
786836
p.nextToken()
787837
return &ast.ColumnReferenceExpression{ColumnType: "Wildcard"}, nil
838+
case TokenDot:
839+
// Multi-part identifier starting with empty parts (e.g., ..t1.c1)
840+
return p.parseColumnReferenceWithLeadingDots()
841+
case TokenMaster, TokenDatabase, TokenKey, TokenTable, TokenIndex,
842+
TokenSchema, TokenUser, TokenView:
843+
// Keywords that can be used as identifiers in column/table references
844+
return p.parseColumnReferenceOrFunctionCall()
788845
default:
789846
return nil, fmt.Errorf("unexpected token in expression: %s", p.curTok.Literal)
790847
}
@@ -1059,20 +1116,43 @@ func (p *Parser) parseNationalStringFromToken() (*ast.StringLiteral, error) {
10591116
}, nil
10601117
}
10611118

1119+
func (p *Parser) isIdentifierToken() bool {
1120+
switch p.curTok.Type {
1121+
case TokenIdent, TokenMaster, TokenDatabase, TokenKey, TokenTable, TokenIndex,
1122+
TokenSchema, TokenUser, TokenView, TokenDefault:
1123+
return true
1124+
default:
1125+
return false
1126+
}
1127+
}
1128+
10621129
func (p *Parser) parseColumnReferenceOrFunctionCall() (ast.ScalarExpression, error) {
10631130
var identifiers []*ast.Identifier
1131+
colType := "Regular"
10641132

10651133
for {
1066-
if p.curTok.Type != TokenIdent {
1134+
if !p.isIdentifierToken() {
10671135
break
10681136
}
10691137

10701138
quoteType := "NotQuoted"
10711139
literal := p.curTok.Literal
1140+
upper := strings.ToUpper(literal)
1141+
10721142
// Handle bracketed identifiers
10731143
if len(literal) >= 2 && literal[0] == '[' && literal[len(literal)-1] == ']' {
10741144
quoteType = "SquareBracket"
10751145
literal = literal[1 : len(literal)-1]
1146+
} else if upper == "IDENTITYCOL" || upper == "ROWGUIDCOL" {
1147+
// IDENTITYCOL/ROWGUIDCOL at end of multi-part identifier sets column type
1148+
// and is not included in the identifier list
1149+
if upper == "IDENTITYCOL" {
1150+
colType = "IdentityCol"
1151+
} else {
1152+
colType = "RowGuidCol"
1153+
}
1154+
p.nextToken()
1155+
break
10761156
}
10771157

10781158
id := &ast.Identifier{
@@ -1091,6 +1171,12 @@ func (p *Parser) parseColumnReferenceOrFunctionCall() (ast.ScalarExpression, err
10911171
break
10921172
}
10931173
p.nextToken() // consume dot
1174+
1175+
// Handle consecutive dots (empty parts in multi-part identifier)
1176+
for p.curTok.Type == TokenDot {
1177+
identifiers = append(identifiers, &ast.Identifier{Value: "", QuoteType: "NotQuoted"})
1178+
p.nextToken() // consume dot
1179+
}
10941180
}
10951181

10961182
// Check for :: (user-defined type method call or property access): a.b::func() or a::prop
@@ -1169,12 +1255,21 @@ func (p *Parser) parseColumnReferenceOrFunctionCall() (ast.ScalarExpression, err
11691255
return p.parseFunctionCallFromIdentifiers(identifiers)
11701256
}
11711257

1258+
// If we have identifiers, build a column reference with them
1259+
if len(identifiers) > 0 {
1260+
return &ast.ColumnReferenceExpression{
1261+
ColumnType: colType,
1262+
MultiPartIdentifier: &ast.MultiPartIdentifier{
1263+
Count: len(identifiers),
1264+
Identifiers: identifiers,
1265+
},
1266+
}, nil
1267+
}
1268+
1269+
// No identifiers means just IDENTITYCOL or ROWGUIDCOL (already handled in parsePrimaryExpression)
1270+
// but handle the case anyway
11721271
return &ast.ColumnReferenceExpression{
1173-
ColumnType: "Regular",
1174-
MultiPartIdentifier: &ast.MultiPartIdentifier{
1175-
Count: len(identifiers),
1176-
Identifiers: identifiers,
1177-
},
1272+
ColumnType: colType,
11781273
}, nil
11791274
}
11801275

@@ -1190,6 +1285,71 @@ func (p *Parser) parseColumnReference() (*ast.ColumnReferenceExpression, error)
11901285
return nil, fmt.Errorf("expected column reference, got function call")
11911286
}
11921287

1288+
func (p *Parser) parseColumnReferenceWithLeadingDots() (ast.ScalarExpression, error) {
1289+
// Handle multi-part identifiers starting with dots like ..t1.c1 or .db..t1.c1
1290+
var identifiers []*ast.Identifier
1291+
1292+
// Add empty identifiers for leading dots
1293+
for p.curTok.Type == TokenDot {
1294+
identifiers = append(identifiers, &ast.Identifier{Value: "", QuoteType: "NotQuoted"})
1295+
p.nextToken() // consume dot
1296+
}
1297+
1298+
// Now parse the remaining identifiers
1299+
for p.isIdentifierToken() {
1300+
quoteType := "NotQuoted"
1301+
literal := p.curTok.Literal
1302+
// Handle special column types
1303+
upper := strings.ToUpper(literal)
1304+
if upper == "IDENTITYCOL" || upper == "ROWGUIDCOL" {
1305+
// Return with the proper column type
1306+
colType := "IdentityCol"
1307+
if upper == "ROWGUIDCOL" {
1308+
colType = "RowGuidCol"
1309+
}
1310+
p.nextToken()
1311+
return &ast.ColumnReferenceExpression{
1312+
ColumnType: colType,
1313+
MultiPartIdentifier: &ast.MultiPartIdentifier{
1314+
Count: len(identifiers),
1315+
Identifiers: identifiers,
1316+
},
1317+
}, nil
1318+
}
1319+
// Handle bracketed identifiers
1320+
if len(literal) >= 2 && literal[0] == '[' && literal[len(literal)-1] == ']' {
1321+
quoteType = "SquareBracket"
1322+
literal = literal[1 : len(literal)-1]
1323+
}
1324+
1325+
id := &ast.Identifier{
1326+
Value: literal,
1327+
QuoteType: quoteType,
1328+
}
1329+
identifiers = append(identifiers, id)
1330+
p.nextToken()
1331+
1332+
if p.curTok.Type != TokenDot {
1333+
break
1334+
}
1335+
// Check for qualified star
1336+
if p.peekTok.Type == TokenStar {
1337+
break
1338+
}
1339+
p.nextToken() // consume dot
1340+
}
1341+
1342+
// Don't consume .* here - let the caller (parseSelectElement) handle qualified stars
1343+
1344+
return &ast.ColumnReferenceExpression{
1345+
ColumnType: "Regular",
1346+
MultiPartIdentifier: &ast.MultiPartIdentifier{
1347+
Count: len(identifiers),
1348+
Identifiers: identifiers,
1349+
},
1350+
}, nil
1351+
}
1352+
11931353
func (p *Parser) parseFunctionCallFromIdentifiers(identifiers []*ast.Identifier) (ast.ScalarExpression, error) {
11941354
fc := &ast.FunctionCall{
11951355
UniqueRowFilter: "NotSpecified",
@@ -3098,6 +3258,55 @@ func (p *Parser) parseTryConvertCall() (ast.ScalarExpression, error) {
30983258
return convert, nil
30993259
}
31003260

3261+
// parseIdentityFunctionCall parses an IDENTITY function call: IDENTITY(data_type [, seed, increment])
3262+
func (p *Parser) parseIdentityFunctionCall() (ast.ScalarExpression, error) {
3263+
p.nextToken() // consume IDENTITY
3264+
if p.curTok.Type != TokenLParen {
3265+
return nil, fmt.Errorf("expected ( after IDENTITY, got %s", p.curTok.Literal)
3266+
}
3267+
p.nextToken() // consume (
3268+
3269+
// Parse the data type
3270+
dt, err := p.parseDataTypeReference()
3271+
if err != nil {
3272+
return nil, err
3273+
}
3274+
3275+
identity := &ast.IdentityFunctionCall{
3276+
DataType: dt,
3277+
}
3278+
3279+
// Check for optional seed and increment
3280+
if p.curTok.Type == TokenComma {
3281+
p.nextToken() // consume ,
3282+
seed, err := p.parseScalarExpression()
3283+
if err != nil {
3284+
return nil, err
3285+
}
3286+
identity.Seed = seed
3287+
3288+
// Expect comma before increment
3289+
if p.curTok.Type != TokenComma {
3290+
return nil, fmt.Errorf("expected , before increment in IDENTITY, got %s", p.curTok.Literal)
3291+
}
3292+
p.nextToken() // consume ,
3293+
3294+
increment, err := p.parseScalarExpression()
3295+
if err != nil {
3296+
return nil, err
3297+
}
3298+
identity.Increment = increment
3299+
}
3300+
3301+
// Expect )
3302+
if p.curTok.Type != TokenRParen {
3303+
return nil, fmt.Errorf("expected ) in IDENTITY, got %s", p.curTok.Literal)
3304+
}
3305+
p.nextToken() // consume )
3306+
3307+
return identity, nil
3308+
}
3309+
31013310
// parsePredictTableReference parses PREDICT(...) in FROM clause
31023311
// PREDICT(MODEL = expression, DATA = table AS alias, RUNTIME=ident) WITH (columns) AS alias
31033312
func (p *Parser) parsePredictTableReference() (*ast.PredictTableReference, error) {
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"todo": true}
1+
{}

0 commit comments

Comments
 (0)