diff --git a/ast/alter_server_configuration_statement.go b/ast/alter_server_configuration_statement.go index 4d6bfa82..385819e5 100644 --- a/ast/alter_server_configuration_statement.go +++ b/ast/alter_server_configuration_statement.go @@ -1,5 +1,22 @@ package ast +// AlterServerConfigurationStatement represents ALTER SERVER CONFIGURATION SET PROCESS AFFINITY statement +type AlterServerConfigurationStatement struct { + ProcessAffinity string // "CpuAuto", "Cpu", "NumaNode" + ProcessAffinityRanges []*ProcessAffinityRange // for Cpu or NumaNode +} + +func (a *AlterServerConfigurationStatement) node() {} +func (a *AlterServerConfigurationStatement) statement() {} + +// ProcessAffinityRange represents a CPU or NUMA node range +type ProcessAffinityRange struct { + From ScalarExpression // IntegerLiteral + To ScalarExpression // IntegerLiteral (optional) +} + +func (p *ProcessAffinityRange) node() {} + // AlterServerConfigurationSetSoftNumaStatement represents ALTER SERVER CONFIGURATION SET SOFTNUMA statement type AlterServerConfigurationSetSoftNumaStatement struct { Options []*AlterServerConfigurationSoftNumaOption diff --git a/ast/create_table_statement.go b/ast/create_table_statement.go index 1a7b1a06..7801d56a 100644 --- a/ast/create_table_statement.go +++ b/ast/create_table_statement.go @@ -83,11 +83,12 @@ type TableConstraint interface { // IndexDefinition represents an index definition within CREATE TABLE type IndexDefinition struct { - Name *Identifier - Columns []*ColumnWithSortOrder - Unique bool - IndexType *IndexType - IndexOptions []*IndexExpressionOption + Name *Identifier + Columns []*ColumnWithSortOrder + Unique bool + IndexType *IndexType + IndexOptions []*IndexExpressionOption + IncludeColumns []*ColumnReferenceExpression } func (i *IndexDefinition) node() {} @@ -108,3 +109,41 @@ const ( SortOrderAscending SortOrderDescending ) + +// CheckConstraintDefinition represents a CHECK constraint +type CheckConstraintDefinition struct { + ConstraintIdentifier *Identifier + CheckCondition BooleanExpression + NotForReplication bool +} + +func (c *CheckConstraintDefinition) node() {} +func (c *CheckConstraintDefinition) tableConstraint() {} +func (c *CheckConstraintDefinition) constraintDefinition() {} + +// UniqueConstraintDefinition represents a UNIQUE or PRIMARY KEY constraint +type UniqueConstraintDefinition struct { + ConstraintIdentifier *Identifier + Clustered bool + IsPrimaryKey bool + Columns []*ColumnWithSortOrder + IndexType *IndexType +} + +func (u *UniqueConstraintDefinition) node() {} +func (u *UniqueConstraintDefinition) tableConstraint() {} +func (u *UniqueConstraintDefinition) constraintDefinition() {} + +// ForeignKeyConstraintDefinition represents a FOREIGN KEY constraint +type ForeignKeyConstraintDefinition struct { + ConstraintIdentifier *Identifier + Columns []*Identifier + ReferenceTableName *SchemaObjectName + ReferencedColumns []*Identifier + DeleteAction string + UpdateAction string + NotForReplication bool +} + +func (f *ForeignKeyConstraintDefinition) node() {} +func (f *ForeignKeyConstraintDefinition) tableConstraint() {} diff --git a/ast/declare_variable_statement.go b/ast/declare_variable_statement.go index f0cb3242..b13a4830 100644 --- a/ast/declare_variable_statement.go +++ b/ast/declare_variable_statement.go @@ -16,6 +16,23 @@ type DeclareVariableElement struct { Nullable *NullableConstraintDefinition `json:"Nullable,omitempty"` } +// DeclareTableVariableStatement represents a DECLARE @var TABLE statement. +type DeclareTableVariableStatement struct { + Body *DeclareTableVariableBody `json:"Body,omitempty"` +} + +func (d *DeclareTableVariableStatement) node() {} +func (d *DeclareTableVariableStatement) statement() {} + +// DeclareTableVariableBody represents the body of a table variable declaration. +type DeclareTableVariableBody struct { + VariableName *Identifier `json:"VariableName,omitempty"` + AsDefined bool `json:"AsDefined,omitempty"` + Definition *TableDefinition `json:"Definition,omitempty"` +} + +func (d *DeclareTableVariableBody) node() {} + // SqlDataTypeReference represents a SQL data type. type SqlDataTypeReference struct { SqlDataTypeOption string `json:"SqlDataTypeOption,omitempty"` diff --git a/ast/drop_statements.go b/ast/drop_statements.go index 5a8cd00e..7c9781c5 100644 --- a/ast/drop_statements.go +++ b/ast/drop_statements.go @@ -152,6 +152,15 @@ type DropExternalResourcePoolStatement struct { func (s *DropExternalResourcePoolStatement) statement() {} func (s *DropExternalResourcePoolStatement) node() {} +// DropExternalModelStatement represents a DROP EXTERNAL MODEL statement +type DropExternalModelStatement struct { + IsIfExists bool + Name *SchemaObjectName +} + +func (s *DropExternalModelStatement) statement() {} +func (s *DropExternalModelStatement) node() {} + // DropWorkloadGroupStatement represents a DROP WORKLOAD GROUP statement type DropWorkloadGroupStatement struct { IsIfExists bool diff --git a/ast/set_variable_statement.go b/ast/set_variable_statement.go index f4935341..996c8f57 100644 --- a/ast/set_variable_statement.go +++ b/ast/set_variable_statement.go @@ -2,12 +2,14 @@ package ast // SetVariableStatement represents a SET @var = value statement. type SetVariableStatement struct { - Variable *VariableReference `json:"Variable,omitempty"` - Expression ScalarExpression `json:"Expression,omitempty"` - CursorDefinition *CursorDefinition `json:"CursorDefinition,omitempty"` - AssignmentKind string `json:"AssignmentKind,omitempty"` - SeparatorType string `json:"SeparatorType,omitempty"` + Variable *VariableReference `json:"Variable,omitempty"` + Expression ScalarExpression `json:"Expression,omitempty"` + CursorDefinition *CursorDefinition `json:"CursorDefinition,omitempty"` + AssignmentKind string `json:"AssignmentKind,omitempty"` + SeparatorType string `json:"SeparatorType,omitempty"` + Identifier *Identifier `json:"Identifier,omitempty"` FunctionCallExists bool `json:"FunctionCallExists,omitempty"` + Parameters []ScalarExpression `json:"Parameters,omitempty"` } func (s *SetVariableStatement) node() {} diff --git a/parser/marshal.go b/parser/marshal.go index 1c301d21..3e610c22 100644 --- a/parser/marshal.go +++ b/parser/marshal.go @@ -58,6 +58,8 @@ func statementToJSON(stmt ast.Statement) jsonNode { return deleteStatementToJSON(s) case *ast.DeclareVariableStatement: return declareVariableStatementToJSON(s) + case *ast.DeclareTableVariableStatement: + return declareTableVariableStatementToJSON(s) case *ast.SetVariableStatement: return setVariableStatementToJSON(s) case *ast.IfStatement: @@ -176,6 +178,8 @@ func statementToJSON(stmt ast.Statement) jsonNode { return dropExternalTableStatementToJSON(s) case *ast.DropExternalResourcePoolStatement: return dropExternalResourcePoolStatementToJSON(s) + case *ast.DropExternalModelStatement: + return dropExternalModelStatementToJSON(s) case *ast.DropWorkloadGroupStatement: return dropWorkloadGroupStatementToJSON(s) case *ast.DropWorkloadClassifierStatement: @@ -270,6 +274,8 @@ func statementToJSON(stmt ast.Statement) jsonNode { return alterXmlSchemaCollectionStatementToJSON(s) case *ast.AlterServerConfigurationSetSoftNumaStatement: return alterServerConfigurationSetSoftNumaStatementToJSON(s) + case *ast.AlterServerConfigurationStatement: + return alterServerConfigurationStatementToJSON(s) case *ast.AlterLoginAddDropCredentialStatement: return alterLoginAddDropCredentialStatementToJSON(s) case *ast.TryCatchStatement: @@ -801,6 +807,13 @@ func indexDefinitionToJSON(idx *ast.IndexDefinition) jsonNode { } node["Columns"] = cols } + if len(idx.IncludeColumns) > 0 { + cols := make([]jsonNode, len(idx.IncludeColumns)) + for i, c := range idx.IncludeColumns { + cols[i] = scalarExpressionToJSON(c) + } + node["IncludeColumns"] = cols + } return node } @@ -1969,6 +1982,30 @@ func declareVariableElementToJSON(elem *ast.DeclareVariableElement) jsonNode { return node } +func declareTableVariableStatementToJSON(s *ast.DeclareTableVariableStatement) jsonNode { + node := jsonNode{ + "$type": "DeclareTableVariableStatement", + } + if s.Body != nil { + node["Body"] = declareTableVariableBodyToJSON(s.Body) + } + return node +} + +func declareTableVariableBodyToJSON(body *ast.DeclareTableVariableBody) jsonNode { + node := jsonNode{ + "$type": "DeclareTableVariableBody", + } + if body.VariableName != nil { + node["VariableName"] = identifierToJSON(body.VariableName) + } + node["AsDefined"] = body.AsDefined + if body.Definition != nil { + node["Definition"] = tableDefinitionToJSON(body.Definition) + } + return node +} + func sqlDataTypeReferenceToJSON(dt *ast.SqlDataTypeReference) jsonNode { node := jsonNode{ "$type": "SqlDataTypeReference", @@ -1998,10 +2035,18 @@ func setVariableStatementToJSON(s *ast.SetVariableStatement) jsonNode { } if s.SeparatorType != "" { node["SeparatorType"] = s.SeparatorType - } else { - node["SeparatorType"] = "NotSpecified" + } + if s.Identifier != nil { + node["Identifier"] = identifierToJSON(s.Identifier) } node["FunctionCallExists"] = s.FunctionCallExists + if len(s.Parameters) > 0 { + params := make([]jsonNode, len(s.Parameters)) + for i, p := range s.Parameters { + params[i] = scalarExpressionToJSON(p) + } + node["Parameters"] = params + } if s.Expression != nil { node["Expression"] = scalarExpressionToJSON(s.Expression) } @@ -2294,17 +2339,90 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { col.IdentityOptions = identityOpts } - // Parse optional NULL/NOT NULL constraint - if p.curTok.Type == TokenNot { - p.nextToken() // consume NOT - if p.curTok.Type != TokenNull { - return nil, fmt.Errorf("expected NULL after NOT, got %s", p.curTok.Literal) + // Parse column constraints (NULL, NOT NULL, UNIQUE, PRIMARY KEY, DEFAULT, CHECK, CONSTRAINT) + for { + upperLit := strings.ToUpper(p.curTok.Literal) + + if p.curTok.Type == TokenNot { + p.nextToken() // consume NOT + if p.curTok.Type == TokenNull { + p.nextToken() // consume NULL + col.Constraints = append(col.Constraints, &ast.NullableConstraintDefinition{Nullable: false}) + } + } else if p.curTok.Type == TokenNull { + p.nextToken() // consume NULL + col.Constraints = append(col.Constraints, &ast.NullableConstraintDefinition{Nullable: true}) + } else if upperLit == "UNIQUE" { + p.nextToken() // consume UNIQUE + constraint := &ast.UniqueConstraintDefinition{ + IsPrimaryKey: false, + } + // Parse optional CLUSTERED/NONCLUSTERED + if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" { + constraint.Clustered = true + constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"} + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" { + constraint.Clustered = false + constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} + p.nextToken() + } + col.Constraints = append(col.Constraints, constraint) + } else if upperLit == "PRIMARY" { + p.nextToken() // consume PRIMARY + if p.curTok.Type == TokenKey { + p.nextToken() // consume KEY + } + constraint := &ast.UniqueConstraintDefinition{ + IsPrimaryKey: true, + } + // Parse optional CLUSTERED/NONCLUSTERED + if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" { + constraint.Clustered = true + constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"} + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" { + constraint.Clustered = false + constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} + p.nextToken() + } + col.Constraints = append(col.Constraints, constraint) + } else if p.curTok.Type == TokenDefault { + p.nextToken() // consume DEFAULT + defaultConstraint := &ast.DefaultConstraintDefinition{} + + // Parse the default expression + expr, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + defaultConstraint.Expression = expr + col.DefaultConstraint = defaultConstraint + } else if upperLit == "CHECK" { + p.nextToken() // consume CHECK + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + cond, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + col.Constraints = append(col.Constraints, &ast.CheckConstraintDefinition{ + CheckCondition: cond, + }) + } + } else if upperLit == "CONSTRAINT" { + p.nextToken() // skip CONSTRAINT + if p.curTok.Type == TokenIdent { + p.nextToken() // skip constraint name + } + // Continue to parse actual constraint in next iteration + continue + } else { + break } - p.nextToken() // consume NULL - col.Constraints = append(col.Constraints, &ast.NullableConstraintDefinition{Nullable: false}) - } else if p.curTok.Type == TokenNull { - p.nextToken() // consume NULL - col.Constraints = append(col.Constraints, &ast.NullableConstraintDefinition{Nullable: true}) } return col, nil @@ -2405,6 +2523,13 @@ func tableDefinitionToJSON(t *ast.TableDefinition) jsonNode { } node["ColumnDefinitions"] = cols } + if len(t.TableConstraints) > 0 { + constraints := make([]jsonNode, len(t.TableConstraints)) + for i, constraint := range t.TableConstraints { + constraints[i] = tableConstraintToJSON(constraint) + } + node["TableConstraints"] = constraints + } if len(t.Indexes) > 0 { indexes := make([]jsonNode, len(t.Indexes)) for i, idx := range t.Indexes { @@ -2415,6 +2540,53 @@ func tableDefinitionToJSON(t *ast.TableDefinition) jsonNode { return node } +func tableConstraintToJSON(c ast.TableConstraint) jsonNode { + switch constraint := c.(type) { + case *ast.UniqueConstraintDefinition: + return uniqueConstraintToJSON(constraint) + case *ast.CheckConstraintDefinition: + return checkConstraintToJSON(constraint) + case *ast.ForeignKeyConstraintDefinition: + return foreignKeyConstraintToJSON(constraint) + default: + return jsonNode{"$type": "UnknownTableConstraint"} + } +} + +func foreignKeyConstraintToJSON(c *ast.ForeignKeyConstraintDefinition) jsonNode { + node := jsonNode{ + "$type": "ForeignKeyConstraintDefinition", + "NotForReplication": c.NotForReplication, + } + if c.ConstraintIdentifier != nil { + node["ConstraintIdentifier"] = identifierToJSON(c.ConstraintIdentifier) + } + if c.ReferenceTableName != nil { + node["ReferenceTableName"] = schemaObjectNameToJSON(c.ReferenceTableName) + } + if len(c.Columns) > 0 { + cols := make([]jsonNode, len(c.Columns)) + for i, col := range c.Columns { + cols[i] = identifierToJSON(col) + } + node["Columns"] = cols + } + if len(c.ReferencedColumns) > 0 { + cols := make([]jsonNode, len(c.ReferencedColumns)) + for i, col := range c.ReferencedColumns { + cols[i] = identifierToJSON(col) + } + node["ReferencedColumns"] = cols + } + if c.DeleteAction != "" { + node["DeleteAction"] = c.DeleteAction + } + if c.UpdateAction != "" { + node["UpdateAction"] = c.UpdateAction + } + return node +} + func columnDefinitionToJSON(c *ast.ColumnDefinition) jsonNode { node := jsonNode{ "$type": "ColumnDefinition", @@ -2427,6 +2599,9 @@ func columnDefinitionToJSON(c *ast.ColumnDefinition) jsonNode { if c.IdentityOptions != nil { node["IdentityOptions"] = identityOptionsToJSON(c.IdentityOptions) } + if c.DefaultConstraint != nil { + node["DefaultConstraint"] = defaultConstraintToJSON(c.DefaultConstraint) + } if len(c.Constraints) > 0 { constraints := make([]jsonNode, len(c.Constraints)) for i, constraint := range c.Constraints { @@ -2440,6 +2615,20 @@ func columnDefinitionToJSON(c *ast.ColumnDefinition) jsonNode { return node } +func defaultConstraintToJSON(d *ast.DefaultConstraintDefinition) jsonNode { + node := jsonNode{ + "$type": "DefaultConstraintDefinition", + "WithValues": false, + } + if d.ConstraintIdentifier != nil { + node["ConstraintIdentifier"] = identifierToJSON(d.ConstraintIdentifier) + } + if d.Expression != nil { + node["Expression"] = scalarExpressionToJSON(d.Expression) + } + return node +} + func identityOptionsToJSON(i *ast.IdentityOptions) jsonNode { node := jsonNode{ "$type": "IdentityOptions", @@ -2461,11 +2650,51 @@ func constraintDefinitionToJSON(c ast.ConstraintDefinition) jsonNode { "$type": "NullableConstraintDefinition", "Nullable": constraint.Nullable, } + case *ast.UniqueConstraintDefinition: + return uniqueConstraintToJSON(constraint) + case *ast.CheckConstraintDefinition: + return checkConstraintToJSON(constraint) default: return jsonNode{"$type": "UnknownConstraint"} } } +func uniqueConstraintToJSON(c *ast.UniqueConstraintDefinition) jsonNode { + node := jsonNode{ + "$type": "UniqueConstraintDefinition", + "Clustered": c.Clustered, + "IsPrimaryKey": c.IsPrimaryKey, + } + if c.ConstraintIdentifier != nil { + node["ConstraintIdentifier"] = identifierToJSON(c.ConstraintIdentifier) + } + if c.IndexType != nil { + node["IndexType"] = indexTypeToJSON(c.IndexType) + } + if len(c.Columns) > 0 { + cols := make([]jsonNode, len(c.Columns)) + for i, col := range c.Columns { + cols[i] = columnWithSortOrderToJSON(col) + } + node["Columns"] = cols + } + return node +} + +func checkConstraintToJSON(c *ast.CheckConstraintDefinition) jsonNode { + node := jsonNode{ + "$type": "CheckConstraintDefinition", + "NotForReplication": c.NotForReplication, + } + if c.ConstraintIdentifier != nil { + node["ConstraintIdentifier"] = identifierToJSON(c.ConstraintIdentifier) + } + if c.CheckCondition != nil { + node["CheckCondition"] = booleanExpressionToJSON(c.CheckCondition) + } + return node +} + func dataTypeReferenceToJSON(d ast.DataTypeReference) jsonNode { switch dt := d.(type) { case *ast.SqlDataTypeReference: @@ -3231,6 +3460,36 @@ func onOffOptionValueToJSON(o *ast.OnOffOptionValue) jsonNode { } } +func alterServerConfigurationStatementToJSON(s *ast.AlterServerConfigurationStatement) jsonNode { + node := jsonNode{ + "$type": "AlterServerConfigurationStatement", + } + if s.ProcessAffinity != "" { + node["ProcessAffinity"] = s.ProcessAffinity + } + if len(s.ProcessAffinityRanges) > 0 { + ranges := make([]jsonNode, len(s.ProcessAffinityRanges)) + for i, r := range s.ProcessAffinityRanges { + ranges[i] = processAffinityRangeToJSON(r) + } + node["ProcessAffinityRanges"] = ranges + } + return node +} + +func processAffinityRangeToJSON(r *ast.ProcessAffinityRange) jsonNode { + node := jsonNode{ + "$type": "ProcessAffinityRange", + } + if r.From != nil { + node["From"] = scalarExpressionToJSON(r.From) + } + if r.To != nil { + node["To"] = scalarExpressionToJSON(r.To) + } + return node +} + func alterLoginAddDropCredentialStatementToJSON(s *ast.AlterLoginAddDropCredentialStatement) jsonNode { node := jsonNode{ "$type": "AlterLoginAddDropCredentialStatement", @@ -5173,6 +5432,17 @@ func dropExternalResourcePoolStatementToJSON(s *ast.DropExternalResourcePoolStat return node } +func dropExternalModelStatementToJSON(s *ast.DropExternalModelStatement) jsonNode { + node := jsonNode{ + "$type": "DropExternalModelStatement", + "IsIfExists": s.IsIfExists, + } + if s.Name != nil { + node["Name"] = schemaObjectNameToJSON(s.Name) + } + return node +} + func dropWorkloadGroupStatementToJSON(s *ast.DropWorkloadGroupStatement) jsonNode { node := jsonNode{ "$type": "DropWorkloadGroupStatement", diff --git a/parser/parse_ddl.go b/parser/parse_ddl.go index 9c26674f..fe0a4a34 100644 --- a/parser/parse_ddl.go +++ b/parser/parse_ddl.go @@ -134,11 +134,30 @@ func (p *Parser) parseDropExternalStatement() (ast.Statement, error) { return p.parseDropExternalTableStatement() case "RESOURCE": return p.parseDropExternalResourcePoolStatement() + case "MODEL": + return p.parseDropExternalModelStatement() } return nil, fmt.Errorf("unexpected token after EXTERNAL: %s", p.curTok.Literal) } +func (p *Parser) parseDropExternalModelStatement() (*ast.DropExternalModelStatement, error) { + // Consume MODEL + p.nextToken() + + stmt := &ast.DropExternalModelStatement{} + + // Parse model name + stmt.Name, _ = p.parseSchemaObjectName() + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + func (p *Parser) parseDropExternalLanguageStatement() (*ast.DropExternalLanguageStatement, error) { // Consume LANGUAGE p.nextToken() @@ -1741,11 +1760,14 @@ func (p *Parser) parseAlterServerConfigurationStatement() (ast.Statement, error) p.nextToken() // Check what type of SET it is - if strings.ToUpper(p.curTok.Literal) == "SOFTNUMA" { + switch strings.ToUpper(p.curTok.Literal) { + case "SOFTNUMA": return p.parseAlterServerConfigurationSetSoftNumaStatement() + case "PROCESS": + return p.parseAlterServerConfigurationSetProcessAffinityStatement() + default: + return nil, fmt.Errorf("unexpected token after SET: %s", p.curTok.Literal) } - - return nil, fmt.Errorf("unexpected token after SET: %s", p.curTok.Literal) } func (p *Parser) parseAlterServerConfigurationSetSoftNumaStatement() (*ast.AlterServerConfigurationSetSoftNumaStatement, error) { @@ -1777,6 +1799,97 @@ func (p *Parser) parseAlterServerConfigurationSetSoftNumaStatement() (*ast.Alter return stmt, nil } +func (p *Parser) parseAlterServerConfigurationSetProcessAffinityStatement() (*ast.AlterServerConfigurationStatement, error) { + // Consume PROCESS + p.nextToken() + + // Expect AFFINITY + if strings.ToUpper(p.curTok.Literal) != "AFFINITY" { + return nil, fmt.Errorf("expected AFFINITY after PROCESS, got %s", p.curTok.Literal) + } + p.nextToken() + + stmt := &ast.AlterServerConfigurationStatement{} + + // Parse CPU or NUMANODE + affinityType := strings.ToUpper(p.curTok.Literal) + switch affinityType { + case "CPU": + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + // Check for AUTO + if strings.ToUpper(p.curTok.Literal) == "AUTO" { + stmt.ProcessAffinity = "CpuAuto" + p.nextToken() + } else { + // Parse ranges + stmt.ProcessAffinity = "Cpu" + ranges, err := p.parseProcessAffinityRanges() + if err != nil { + return nil, err + } + stmt.ProcessAffinityRanges = ranges + } + } + case "NUMANODE": + p.nextToken() + if p.curTok.Type == TokenEquals { + p.nextToken() + stmt.ProcessAffinity = "NumaNode" + ranges, err := p.parseProcessAffinityRanges() + if err != nil { + return nil, err + } + stmt.ProcessAffinityRanges = ranges + } + default: + return nil, fmt.Errorf("expected CPU or NUMANODE after AFFINITY, got %s", p.curTok.Literal) + } + + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + + return stmt, nil +} + +func (p *Parser) parseProcessAffinityRanges() ([]*ast.ProcessAffinityRange, error) { + var ranges []*ast.ProcessAffinityRange + + for { + r := &ast.ProcessAffinityRange{} + + // Parse From value + if p.curTok.Type != TokenNumber { + return nil, fmt.Errorf("expected number in process affinity range, got %s", p.curTok.Literal) + } + r.From = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + + // Check for TO + if strings.ToUpper(p.curTok.Literal) == "TO" { + p.nextToken() + if p.curTok.Type != TokenNumber { + return nil, fmt.Errorf("expected number after TO, got %s", p.curTok.Literal) + } + r.To = &ast.IntegerLiteral{LiteralType: "Integer", Value: p.curTok.Literal} + p.nextToken() + } + + ranges = append(ranges, r) + + // Check for comma + if p.curTok.Type != TokenComma { + break + } + p.nextToken() + } + + return ranges, nil +} + func capitalizeFirst(s string) string { if len(s) == 0 { return s diff --git a/parser/parse_statements.go b/parser/parse_statements.go index bf6d9f0d..344c1569 100644 --- a/parser/parse_statements.go +++ b/parser/parse_statements.go @@ -9,25 +9,114 @@ import ( "github.com/kyleconroy/teesql/ast" ) -func (p *Parser) parseDeclareVariableStatement() (*ast.DeclareVariableStatement, error) { +func (p *Parser) parseDeclareVariableStatement() (ast.Statement, error) { // Consume DECLARE p.nextToken() + // Parse variable name + if p.curTok.Type != TokenIdent || !strings.HasPrefix(p.curTok.Literal, "@") { + return nil, fmt.Errorf("expected variable name, got %s", p.curTok.Literal) + } + varName := &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"} + p.nextToken() + + // Skip optional AS + asDefined := false + if p.curTok.Type == TokenAs { + asDefined = true + p.nextToken() + } + + // Check if this is a TABLE variable + if p.curTok.Type == TokenTable { + return p.parseDeclareTableVariableStatement(varName, asDefined) + } + + // Regular variable declaration stmt := &ast.DeclareVariableStatement{} + elem := &ast.DeclareVariableElement{ + VariableName: varName, + } - for { + // Parse data type + dataType, err := p.parseDataType() + if err != nil { + return nil, err + } + elem.DataType = dataType + + // Check for NULL / NOT NULL + if p.curTok.Type == TokenNull { + elem.Nullable = &ast.NullableConstraintDefinition{Nullable: true} + p.nextToken() + } else if p.curTok.Type == TokenNot { + p.nextToken() + if p.curTok.Type == TokenNull { + elem.Nullable = &ast.NullableConstraintDefinition{Nullable: false} + p.nextToken() + } + } + + // Check for = initial value + if p.curTok.Type == TokenEquals { + p.nextToken() + val, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + elem.Value = val + } + + stmt.Declarations = append(stmt.Declarations, elem) + + // Handle additional declarations separated by comma + for p.curTok.Type == TokenComma { + p.nextToken() decl, err := p.parseDeclareVariableElement() if err != nil { return nil, err } stmt.Declarations = append(stmt.Declarations, decl) + } - if p.curTok.Type != TokenComma { - break - } + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { p.nextToken() } + return stmt, nil +} + +func (p *Parser) parseDeclareTableVariableStatement(varName *ast.Identifier, asDefined bool) (*ast.DeclareTableVariableStatement, error) { + // Consume TABLE + p.nextToken() + + stmt := &ast.DeclareTableVariableStatement{ + Body: &ast.DeclareTableVariableBody{ + VariableName: varName, + AsDefined: asDefined, + }, + } + + // Expect ( + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after TABLE, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse table definition + tableDef, err := p.parseTableDefinitionBody() + if err != nil { + return nil, err + } + stmt.Body.Definition = tableDef + + // Expect ) + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after table definition, got %s", p.curTok.Literal) + } + p.nextToken() + // Skip optional semicolon if p.curTok.Type == TokenSemicolon { p.nextToken() @@ -36,6 +125,288 @@ func (p *Parser) parseDeclareVariableStatement() (*ast.DeclareVariableStatement, return stmt, nil } +// parseTableDefinitionBody parses the body of a table definition (column definitions, constraints, indexes) +// between parentheses. The opening parenthesis should already be consumed. +func (p *Parser) parseTableDefinitionBody() (*ast.TableDefinition, error) { + tableDef := &ast.TableDefinition{} + + // Parse column definitions, table constraints, and indexes + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + // Check for table constraints (CHECK, CONSTRAINT, PRIMARY KEY, UNIQUE, FOREIGN KEY, INDEX) + upperLit := strings.ToUpper(p.curTok.Literal) + + if upperLit == "CHECK" { + constraint, err := p.parseCheckConstraintInTable() + if err != nil { + return nil, err + } + tableDef.TableConstraints = append(tableDef.TableConstraints, constraint) + } else if upperLit == "CONSTRAINT" { + p.nextToken() // skip CONSTRAINT + p.nextToken() // skip constraint name + // Parse actual constraint + continue + } else if upperLit == "PRIMARY" || upperLit == "UNIQUE" || upperLit == "FOREIGN" { + constraint, err := p.parseTableConstraint() + if err != nil { + return nil, err + } + if constraint != nil { + tableDef.TableConstraints = append(tableDef.TableConstraints, constraint) + } + } else if upperLit == "INDEX" { + indexDef, err := p.parseInlineIndexDefinition() + if err != nil { + return nil, err + } + tableDef.Indexes = append(tableDef.Indexes, indexDef) + } else { + // Column definition + colDef, err := p.parseColumnDefinition() + if err != nil { + return nil, err + } + tableDef.ColumnDefinitions = append(tableDef.ColumnDefinitions, colDef) + } + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + + return tableDef, nil +} + +// parseCheckConstraintInTable parses a CHECK constraint in a table definition +func (p *Parser) parseCheckConstraintInTable() (*ast.CheckConstraintDefinition, error) { + // Consume CHECK + p.nextToken() + + constraint := &ast.CheckConstraintDefinition{} + + // Expect ( + if p.curTok.Type != TokenLParen { + return nil, fmt.Errorf("expected ( after CHECK, got %s", p.curTok.Literal) + } + p.nextToken() + + // Parse the check condition + cond, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + constraint.CheckCondition = cond + + // Expect ) + if p.curTok.Type != TokenRParen { + return nil, fmt.Errorf("expected ) after check condition, got %s", p.curTok.Literal) + } + p.nextToken() + + return constraint, nil +} + +// parseTableConstraint parses PRIMARY KEY, UNIQUE, or FOREIGN KEY constraints +func (p *Parser) parseTableConstraint() (ast.TableConstraint, error) { + upperLit := strings.ToUpper(p.curTok.Literal) + + if upperLit == "PRIMARY" { + p.nextToken() // consume PRIMARY + if p.curTok.Type == TokenKey { + p.nextToken() // consume KEY + } + constraint := &ast.UniqueConstraintDefinition{ + IsPrimaryKey: true, + } + // Parse optional CLUSTERED/NONCLUSTERED + if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" { + constraint.Clustered = true + constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"} + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" { + constraint.Clustered = false + constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} + p.nextToken() + } + // Skip the column list + if p.curTok.Type == TokenLParen { + p.skipParenthesizedContent() + } + return constraint, nil + } else if upperLit == "UNIQUE" { + p.nextToken() // consume UNIQUE + constraint := &ast.UniqueConstraintDefinition{ + IsPrimaryKey: false, + } + // Parse optional CLUSTERED/NONCLUSTERED + if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" { + constraint.Clustered = true + constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"} + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" { + constraint.Clustered = false + constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} + p.nextToken() + } + // Skip the column list + if p.curTok.Type == TokenLParen { + p.skipParenthesizedContent() + } + return constraint, nil + } else if upperLit == "FOREIGN" { + p.nextToken() // consume FOREIGN + if p.curTok.Type == TokenKey { + p.nextToken() // consume KEY + } + // Skip the constraint body for now + if p.curTok.Type == TokenLParen { + p.skipParenthesizedContent() + } + // Skip REFERENCES + if strings.ToUpper(p.curTok.Literal) == "REFERENCES" { + p.skipToEndOfStatement() + } + return &ast.ForeignKeyConstraintDefinition{}, nil + } + + return nil, nil +} + +// parseInlineIndexDefinition parses an inline INDEX definition in a table variable +func (p *Parser) parseInlineIndexDefinition() (*ast.IndexDefinition, error) { + // Consume INDEX + p.nextToken() + + indexDef := &ast.IndexDefinition{} + + // Parse index name + if p.curTok.Type == TokenIdent { + quoteType := "NotQuoted" + if strings.HasPrefix(p.curTok.Literal, "[") && strings.HasSuffix(p.curTok.Literal, "]") { + quoteType = "SquareBracket" + } + indexDef.Name = &ast.Identifier{ + Value: p.curTok.Literal, + QuoteType: quoteType, + } + p.nextToken() + } + + // Parse optional UNIQUE + if strings.ToUpper(p.curTok.Literal) == "UNIQUE" { + indexDef.Unique = true + p.nextToken() + } + + // Parse optional CLUSTERED/NONCLUSTERED + if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" { + indexDef.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"} + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" { + indexDef.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} + p.nextToken() + } + + // Parse column list + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + quoteType := "NotQuoted" + if strings.HasPrefix(p.curTok.Literal, "[") && strings.HasSuffix(p.curTok.Literal, "]") { + quoteType = "SquareBracket" + } + col := &ast.ColumnWithSortOrder{ + Column: &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Count: 1, + Identifiers: []*ast.Identifier{ + {Value: p.curTok.Literal, QuoteType: quoteType}, + }, + }, + }, + SortOrder: ast.SortOrderNotSpecified, + } + p.nextToken() + + // Parse optional ASC/DESC + if strings.ToUpper(p.curTok.Literal) == "ASC" { + col.SortOrder = ast.SortOrderAscending + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "DESC" { + col.SortOrder = ast.SortOrderDescending + p.nextToken() + } + + indexDef.Columns = append(indexDef.Columns, col) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + + // Parse optional INCLUDE + if strings.ToUpper(p.curTok.Literal) == "INCLUDE" { + p.nextToken() // consume INCLUDE + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + quoteType := "NotQuoted" + if strings.HasPrefix(p.curTok.Literal, "[") && strings.HasSuffix(p.curTok.Literal, "]") { + quoteType = "SquareBracket" + } + includeCol := &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Count: 1, + Identifiers: []*ast.Identifier{ + {Value: p.curTok.Literal, QuoteType: quoteType}, + }, + }, + } + indexDef.IncludeColumns = append(indexDef.IncludeColumns, includeCol) + p.nextToken() + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() + } + } + } + + return indexDef, nil +} + +// skipParenthesizedContent skips content within parentheses, handling nested parens +func (p *Parser) skipParenthesizedContent() { + if p.curTok.Type != TokenLParen { + return + } + p.nextToken() // consume ( + depth := 1 + for depth > 0 && p.curTok.Type != TokenEOF { + if p.curTok.Type == TokenLParen { + depth++ + } else if p.curTok.Type == TokenRParen { + depth-- + } + p.nextToken() + } +} + func (p *Parser) parseDeclareVariableElement() (*ast.DeclareVariableElement, error) { elem := &ast.DeclareVariableElement{} @@ -356,6 +727,48 @@ func (p *Parser) parseSetVariableStatement() (ast.Statement, error) { stmt.Variable = &ast.VariableReference{Name: p.curTok.Literal} p.nextToken() + // Check for dot or double-colon separator (SET @a.b = ... or SET @a::b ...) + if p.curTok.Type == TokenDot { + stmt.SeparatorType = "Dot" + p.nextToken() + if p.curTok.Type == TokenIdent { + stmt.Identifier = &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"} + p.nextToken() + } + } else if p.curTok.Type == TokenColonColon { + stmt.SeparatorType = "DoubleColon" + p.nextToken() // consume :: + if p.curTok.Type == TokenIdent { + stmt.Identifier = &ast.Identifier{Value: p.curTok.Literal, QuoteType: "NotQuoted"} + p.nextToken() + } + } + + // Check for function call: SET @a.b () or SET @a.b (params) + if p.curTok.Type == TokenLParen { + stmt.FunctionCallExists = true + p.nextToken() // consume ( + // Parse parameters + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + param, err := p.parseScalarExpression() + if err != nil { + return nil, err + } + stmt.Parameters = append(stmt.Parameters, param) + if p.curTok.Type == TokenComma { + p.nextToken() + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + // Skip optional semicolon + if p.curTok.Type == TokenSemicolon { + p.nextToken() + } + return stmt, nil + } + // Expect = if p.curTok.Type != TokenEquals { return nil, fmt.Errorf("expected =, got %s", p.curTok.Literal) @@ -1894,7 +2307,17 @@ func (p *Parser) parseSaveTransactionStatement() (*ast.SaveTransactionStatement, } // Optional transaction name or variable - if p.curTok.Type == TokenIdent && !isKeyword(p.curTok.Literal) && p.curTok.Literal[0] != '@' { + if p.curTok.Type == TokenIdent && p.curTok.Literal[0] == '@' { + // Variable reference + stmt.Name = &ast.IdentifierOrValueExpression{ + Value: p.curTok.Literal, + ValueExpression: &ast.VariableReference{ + Name: p.curTok.Literal, + }, + } + p.nextToken() + } else if p.curTok.Type == TokenIdent && !isKeyword(p.curTok.Literal) { + // Simple identifier stmt.Name = &ast.IdentifierOrValueExpression{ Value: p.curTok.Literal, Identifier: &ast.Identifier{ @@ -1903,14 +2326,16 @@ func (p *Parser) parseSaveTransactionStatement() (*ast.SaveTransactionStatement, }, } p.nextToken() - } else if p.curTok.Type == TokenIdent && p.curTok.Literal[0] == '@' { + } else if p.curTok.Type == TokenNumber || p.curTok.Type == TokenMinus { + // Legacy name format: [-]number:dotted.identifier + name := p.parseLegacyTransactionName() stmt.Name = &ast.IdentifierOrValueExpression{ - Value: p.curTok.Literal, - ValueExpression: &ast.VariableReference{ - Name: p.curTok.Literal, + Value: name, + Identifier: &ast.Identifier{ + Value: name, + QuoteType: "NotQuoted", }, } - p.nextToken() } // Skip optional semicolon @@ -1921,6 +2346,54 @@ func (p *Parser) parseSaveTransactionStatement() (*ast.SaveTransactionStatement, return stmt, nil } +// parseLegacyTransactionName parses legacy transaction names like "5:a.b" or "-100:[a].[b]" +func (p *Parser) parseLegacyTransactionName() string { + var parts []string + + // Optional minus sign + if p.curTok.Type == TokenMinus { + parts = append(parts, "-") + p.nextToken() + } + + // Number part + if p.curTok.Type == TokenNumber { + parts = append(parts, p.curTok.Literal) + p.nextToken() + } + + // Colon + if p.curTok.Type == TokenColon { + parts = append(parts, ":") + p.nextToken() + } + + // Dotted identifier part (e.g., "a.b" or "[a].[b]") + for { + if p.curTok.Type == TokenIdent { + // Check if it's a bracketed identifier + if strings.HasPrefix(p.curTok.Literal, "[") && strings.HasSuffix(p.curTok.Literal, "]") { + parts = append(parts, p.curTok.Literal) + } else { + parts = append(parts, p.curTok.Literal) + } + p.nextToken() + } else { + break + } + + // Check for dot continuation + if p.curTok.Type == TokenDot { + parts = append(parts, ".") + p.nextToken() + } else { + break + } + } + + return strings.Join(parts, "") +} + func (p *Parser) parseWaitForStatement() (*ast.WaitForStatement, error) { // Consume WAITFOR p.nextToken() diff --git a/parser/testdata/AlterServerConfigurationStatementTests/metadata.json b/parser/testdata/AlterServerConfigurationStatementTests/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/AlterServerConfigurationStatementTests/metadata.json +++ b/parser/testdata/AlterServerConfigurationStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false} diff --git a/parser/testdata/Baselines100_AlterServerConfigurationStatementTests/metadata.json b/parser/testdata/Baselines100_AlterServerConfigurationStatementTests/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/Baselines100_AlterServerConfigurationStatementTests/metadata.json +++ b/parser/testdata/Baselines100_AlterServerConfigurationStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false} diff --git a/parser/testdata/Baselines150_DeclareTableVariableTests150/metadata.json b/parser/testdata/Baselines150_DeclareTableVariableTests150/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/Baselines150_DeclareTableVariableTests150/metadata.json +++ b/parser/testdata/Baselines150_DeclareTableVariableTests150/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false} diff --git a/parser/testdata/Baselines90_SetVariableStatementTests90/metadata.json b/parser/testdata/Baselines90_SetVariableStatementTests90/metadata.json index 92f70877..49e9182b 100644 --- a/parser/testdata/Baselines90_SetVariableStatementTests90/metadata.json +++ b/parser/testdata/Baselines90_SetVariableStatementTests90/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": true} diff --git a/parser/testdata/BaselinesCommon_BigIntRowCountPageCountTests/metadata.json b/parser/testdata/BaselinesCommon_BigIntRowCountPageCountTests/metadata.json index 92f70877..49e9182b 100644 --- a/parser/testdata/BaselinesCommon_BigIntRowCountPageCountTests/metadata.json +++ b/parser/testdata/BaselinesCommon_BigIntRowCountPageCountTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": true} diff --git a/parser/testdata/BaselinesCommon_DeclareTableStatementTests/metadata.json b/parser/testdata/BaselinesCommon_DeclareTableStatementTests/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/BaselinesCommon_DeclareTableStatementTests/metadata.json +++ b/parser/testdata/BaselinesCommon_DeclareTableStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false} diff --git a/parser/testdata/BaselinesCommon_SaveTransactionStatementTests/metadata.json b/parser/testdata/BaselinesCommon_SaveTransactionStatementTests/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/BaselinesCommon_SaveTransactionStatementTests/metadata.json +++ b/parser/testdata/BaselinesCommon_SaveTransactionStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false} diff --git a/parser/testdata/BigIntRowCountPageCountTests/metadata.json b/parser/testdata/BigIntRowCountPageCountTests/metadata.json index 92f70877..49e9182b 100644 --- a/parser/testdata/BigIntRowCountPageCountTests/metadata.json +++ b/parser/testdata/BigIntRowCountPageCountTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": true} diff --git a/parser/testdata/CreateIndexStatementTests90/metadata.json b/parser/testdata/CreateIndexStatementTests90/metadata.json index 92f70877..49e9182b 100644 --- a/parser/testdata/CreateIndexStatementTests90/metadata.json +++ b/parser/testdata/CreateIndexStatementTests90/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": true} diff --git a/parser/testdata/DeclareTableStatementTests/metadata.json b/parser/testdata/DeclareTableStatementTests/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/DeclareTableStatementTests/metadata.json +++ b/parser/testdata/DeclareTableStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false} diff --git a/parser/testdata/DeclareTableVariableTests150/metadata.json b/parser/testdata/DeclareTableVariableTests150/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/DeclareTableVariableTests150/metadata.json +++ b/parser/testdata/DeclareTableVariableTests150/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false} diff --git a/parser/testdata/SaveTransactionStatementTests/metadata.json b/parser/testdata/SaveTransactionStatementTests/metadata.json index 92f70877..e27d63a6 100644 --- a/parser/testdata/SaveTransactionStatementTests/metadata.json +++ b/parser/testdata/SaveTransactionStatementTests/metadata.json @@ -1 +1 @@ -{"skip": true} \ No newline at end of file +{"skip": false}