Skip to content

Commit 22bc617

Browse files
committed
Validate constraints
1 parent a7687fa commit 22bc617

3 files changed

Lines changed: 125 additions & 10 deletions

File tree

Sources/Compiler/Sema/StmtTypeChecker.swift

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,12 @@ extension StmtTypeChecker {
890890
}
891891
}
892892

893+
private mutating func insertColumnsIntoEnv(columns: borrowing Columns) {
894+
for column in columns {
895+
env.insert(column.key, ty: column.value)
896+
}
897+
}
898+
893899
mutating func typeCheck(createTable: CreateTableStmtSyntax) {
894900
if pragmas.contains(.requireStrictTables)
895901
&& !createTable.options.kind.contains(.strict) {
@@ -910,11 +916,16 @@ extension StmtTypeChecker {
910916
primaryKey: primaryKey(of: createTable, columns: columns),
911917
kind: .normal
912918
)
913-
case let .columns(columns):
914-
let columns: Columns = columns.reduce(into: [:]) {
915-
$0[$1.value.name.value] = typeFor(column: $1.value)
919+
case let .columns(columnsDefs):
920+
var columns: Columns = [:]
921+
for (name, def) in columnsDefs {
922+
columns[name.value] = typeFor(
923+
column: def,
924+
tableColumns: columns
925+
)
916926
}
917927

928+
validateTableConstraints(of: createTable, columns: columns)
918929
schema[createTable.name.value] = Table(
919930
name: createTable.name.value,
920931
columns: columns,
@@ -950,7 +961,7 @@ extension StmtTypeChecker {
950961
newColumns[column.key == oldName.value ? newName.value : column.key] = column.value
951962
}
952963
case let .addColumn(column):
953-
table.columns[column.name.value] = typeFor(column: column)
964+
table.columns[column.name.value] = typeFor(column: column, tableColumns: table.columns)
954965
case let .dropColumn(column):
955966
table.columns[column.value] = nil
956967
}
@@ -987,11 +998,39 @@ extension StmtTypeChecker {
987998
}
988999

9891000
/// Will figure out the final SQL column type from the syntax
990-
private mutating func typeFor(column: borrowing ColumnDefSyntax) -> Type {
991-
// Technically you can have a NULL primary key but I don't
992-
// think people actually do that...
993-
let isNotNullable = column.constraints
994-
.contains { $0.isPkConstraint || $0.isNotNullConstraint }
1001+
private mutating func typeFor(
1002+
column: borrowing ColumnDefSyntax,
1003+
tableColumns: borrowing Columns
1004+
) -> Type {
1005+
var isNotNullable = false
1006+
for constraint in column.constraints {
1007+
switch constraint.kind {
1008+
case .primaryKey, .notNull:
1009+
// Technically you can have a NULL primary key but I don't
1010+
// think people actually do that...
1011+
isNotNullable = true
1012+
case .check(let expr):
1013+
inNewEnvironment { typeChecker in
1014+
typeChecker.insertColumnsIntoEnv(columns: tableColumns)
1015+
_ = typeChecker.typeCheck(expr)
1016+
}
1017+
case .default(let expr):
1018+
inNewEnvironment { typeChecker in
1019+
_ = typeChecker.typeCheck(expr)
1020+
}
1021+
case .foreignKey(let fk):
1022+
if schema[fk.foreignTable.value] == nil {
1023+
diagnostics.add(.tableDoesNotExist(fk.foreignTable))
1024+
}
1025+
case .generated(let expr, _):
1026+
inNewEnvironment { typeChecker in
1027+
typeChecker.insertColumnsIntoEnv(columns: tableColumns)
1028+
_ = typeChecker.typeCheck(expr)
1029+
}
1030+
case .unique, .collate:
1031+
break
1032+
}
1033+
}
9951034

9961035
// Validate it is an actual SQLite type since SQlite doesnt care.
9971036
if !Type.validTypeNames.contains(column.type.name.value) {
@@ -1076,6 +1115,41 @@ extension StmtTypeChecker {
10761115
}
10771116
}
10781117

1118+
private mutating func validateTableConstraints(
1119+
of stmt: CreateTableStmtSyntax,
1120+
columns: Columns
1121+
) {
1122+
for constraint in stmt.constraints {
1123+
switch constraint.kind {
1124+
case .check(let expr):
1125+
inNewEnvironment { typeChecker in
1126+
typeChecker.insertColumnsIntoEnv(columns: columns)
1127+
_ = typeChecker.typeCheck(expr)
1128+
}
1129+
case .foreignKey(let fkColumns, let fkClause):
1130+
// Make sure listed columns exist
1131+
for column in fkColumns {
1132+
guard columns[column.value] == nil else { continue }
1133+
diagnostics.add(.columnDoesNotExist(column))
1134+
}
1135+
1136+
// Make sure referenced table exists
1137+
guard let foreignTable = schema[fkClause.foreignTable.value] else {
1138+
diagnostics.add(.tableDoesNotExist(fkClause.foreignTable))
1139+
return
1140+
}
1141+
1142+
// Make sure referenced columns exist
1143+
for column in fkClause.foreignColumns {
1144+
guard foreignTable.columns[column.value] == nil else { continue }
1145+
diagnostics.add(.columnDoesNotExist(column))
1146+
}
1147+
case .primaryKey, .unique:
1148+
break
1149+
}
1150+
}
1151+
}
1152+
10791153
mutating func typeCheck(fts5Table: borrowing CreateVirtualTableStmtSyntax) {
10801154
var columns: Columns = [:]
10811155

Tests/CompilerTests/Compiler/CompileCreateTable.sql

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,44 @@ CREATE TABLE allValidTypes (
9898
blob BLOB,
9999
any ANY
100100
) STRICT;
101+
102+
-- CHECK: TABLE
103+
-- CHECK: NAME hasGenerated
104+
-- CHECK: COLUMNS
105+
-- CHECK: KEY foo
106+
-- CHECK: VALUE INTEGER?
107+
-- CHECK: KEY bar
108+
-- CHECK: VALUE INTEGER?
109+
-- CHECK: KEY baz
110+
-- CHECK: VALUE INTEGER?
111+
-- CHECK: KEY ref
112+
-- CHECK: VALUE INTEGER?
113+
-- CHECK: KIND normal
114+
-- CHECK-ERROR: Column 'qux' does not exist
115+
-- CHECK-ERROR: Table 'dne' does not exist
116+
CREATE TABLE hasGenerated (
117+
foo INTEGER,
118+
bar INTEGER GENERATED ALWAYS AS (foo + 1),
119+
baz INTEGER GENERATED ALWAYS AS (qux + 1),
120+
ref INTEGER REFERENCES dne (value)
121+
) STRICT;
122+
123+
-- CHECK: TABLE
124+
-- CHECK: NAME hasTableCheck
125+
-- CHECK: COLUMNS
126+
-- CHECK: KEY foo
127+
-- CHECK: VALUE INTEGER?
128+
-- CHECK: KEY bar
129+
-- CHECK: VALUE INTEGER?
130+
-- CHECK: KIND normal
131+
-- CHECK-ERROR: Column 'foooooo' does not exist
132+
-- CHECK-ERROR: Column 'typo' does not exist
133+
-- CHECK-ERROR: Table 'doesNotExist' does not exist
134+
CREATE TABLE hasTableCheck (
135+
foo INTEGER,
136+
bar INTEGER,
137+
CHECK (foo + bar > 1),
138+
CHECK (foooooo + bar > 1),
139+
FOREIGN KEY (typo) REFERENCES doesNotExist (meh),
140+
FOREIGN KEY (foo) REFERENCES hasGenerated (foo)
141+
) STRICT;

Tests/CompilerTests/CompilerTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import XCTest
1111

1212
class CompilerTests: XCTestCase {
1313
func testCheckSimpleSelects() throws {
14-
try checkQueries(compile: "CompileSimpleSelects", dump: true)
14+
try checkQueries(compile: "CompileSimpleSelects")
1515
}
1616

1717
func testSelectWithJoins() throws {

0 commit comments

Comments
 (0)