Skip to content

Commit 86e5107

Browse files
committed
Allow column constraints to reference themselves
1 parent 5d6d66f commit 86e5107

6 files changed

Lines changed: 66 additions & 37 deletions

File tree

Sources/Compiler/Compiler.swift

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,46 +7,32 @@
77

88
public struct Compiler {
99
public var schema = Schema()
10-
public private(set) var queries: [Statement] = []
11-
public private(set) var migrations: [Statement] = []
12-
public private(set) var diagnostics = Diagnostics()
13-
1410
private var pragmas = PragmaAnalyzer()
1511

1612
public init() {}
1713

18-
public var hasDiagnostics: Bool {
19-
return !diagnostics.isEmpty
20-
}
21-
22-
public mutating func compile(migration: String) -> Diagnostics {
23-
let (stmts, diagnostics) = compile(
14+
public mutating func compile(migration: String) -> ([Statement], Diagnostics) {
15+
compile(
2416
source: migration,
2517
validator: IsValidForMigrations(),
2618
context: "migrations"
2719
)
28-
self.migrations.append(contentsOf: stmts)
29-
self.diagnostics.merge(diagnostics)
30-
return diagnostics
3120
}
3221

33-
public mutating func compile(queries: String) -> Diagnostics {
34-
let (stmts, diagnostics) = compile(
22+
public mutating func compile(queries: String) -> ([Statement], Diagnostics) {
23+
compile(
3524
source: queries,
3625
validator: IsValidForQueries(),
3726
context: "queries"
3827
)
39-
self.queries.append(contentsOf: stmts)
40-
self.diagnostics.merge(diagnostics)
41-
return diagnostics
4228
}
4329

4430
public mutating func compile(
4531
query: String,
4632
named name: String,
4733
inputType: String?,
4834
outputType: String?
49-
) -> Diagnostics {
35+
) -> (Statement?, Diagnostics) {
5036
var (stmts, diagnostics) = compile(
5137
source: query,
5238
validator: IsValidForQueries(),
@@ -56,8 +42,7 @@ public struct Compiler {
5642
guard let stmt = stmts.first else {
5743
let loc = SourceLocation(range: query.startIndex..<query.endIndex, line: 0, column: 0)
5844
diagnostics.add(.init("Query has no statements", at: loc))
59-
self.diagnostics.merge(diagnostics)
60-
return diagnostics
45+
return (nil, diagnostics)
6146
}
6247

6348
let stmtWithDef = stmt.with(
@@ -68,9 +53,7 @@ public struct Compiler {
6853
)
6954
)
7055

71-
self.queries.append(stmtWithDef)
72-
self.diagnostics.merge(diagnostics)
73-
return diagnostics
56+
return (stmtWithDef, diagnostics)
7457
}
7558

7659
mutating func compile<Validator>(

Sources/Compiler/Driver.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public actor Driver {
118118
var compiler = Compiler()
119119
compiler.schema = currentSchema
120120

121-
let diagnostics = switch usage {
121+
let (statements, diagnostics) = switch usage {
122122
case .migration:
123123
compiler.compile(migration: fileContents)
124124
case .queries:
@@ -131,7 +131,7 @@ public actor Driver {
131131
fileName: file,
132132
usage: usage,
133133
diagnostics: diagnostics,
134-
statements: compiler.queries,
134+
statements: statements,
135135
schema: compiler.schema
136136
)
137137

Sources/Compiler/Sema/StmtTypeChecker.swift

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,8 @@ extension StmtTypeChecker {
961961
for (name, def) in columnsDefs {
962962
columns[name.value] = typeFor(
963963
column: def,
964-
tableColumns: columns
964+
tableColumns: columns,
965+
tableName: createTable.name.value
965966
)
966967
}
967968

@@ -1001,7 +1002,11 @@ extension StmtTypeChecker {
10011002
newColumns[column.key == oldName.value ? newName.value : column.key] = column.value
10021003
}
10031004
case let .addColumn(column):
1004-
table.columns[column.name.value] = typeFor(column: column, tableColumns: table.columns)
1005+
table.columns[column.name.value] = typeFor(
1006+
column: column,
1007+
tableColumns: table.columns,
1008+
tableName: table.name
1009+
)
10051010
case let .dropColumn(column):
10061011
table.columns[column.value] = nil
10071012
}
@@ -1040,7 +1045,8 @@ extension StmtTypeChecker {
10401045
/// Will figure out the final SQL column type from the syntax
10411046
private mutating func typeFor(
10421047
column: borrowing ColumnDefSyntax,
1043-
tableColumns: borrowing Columns
1048+
tableColumns: borrowing Columns,
1049+
tableName: Substring
10441050
) -> Type {
10451051
var isNotNullable = false
10461052
for constraint in column.constraints {
@@ -1059,7 +1065,23 @@ extension StmtTypeChecker {
10591065
_ = typeChecker.typeCheck(expr)
10601066
}
10611067
case .foreignKey(let fk):
1062-
if schema[fk.foreignTable.value] == nil {
1068+
if fk.foreignTable.value == tableName {
1069+
for foreignColumn in fk.foreignColumns {
1070+
// Column constraints can reference the column they are
1071+
// declared for so if its this table and this column then ignore it.
1072+
guard column.name.value != foreignColumn.value else { continue }
1073+
1074+
if tableColumns[foreignColumn.value] == nil {
1075+
diagnostics.add(.columnDoesNotExist(foreignColumn))
1076+
}
1077+
}
1078+
} else if let table = schema[fk.foreignTable.value] {
1079+
for foreignColumn in fk.foreignColumns {
1080+
if table.columns[foreignColumn.value] == nil {
1081+
diagnostics.add(.columnDoesNotExist(foreignColumn))
1082+
}
1083+
}
1084+
} else {
10631085
diagnostics.add(.tableDoesNotExist(fk.foreignTable))
10641086
}
10651087
case .generated(let expr, _):
@@ -1095,6 +1117,21 @@ extension StmtTypeChecker {
10951117
}
10961118
}
10971119

1120+
private mutating func typeCheck(
1121+
fk: ForeignKeyClauseSyntax,
1122+
column: ColumnDefSyntax,
1123+
tableColumns: Columns
1124+
) {
1125+
for foreignColumn in fk.foreignColumns {
1126+
// Column constraints can oddly reference themselves
1127+
guard column.name.value != foreignColumn.value else { continue }
1128+
1129+
if tableColumns[foreignColumn.value] == nil {
1130+
diagnostics.add(.columnDoesNotExist(foreignColumn))
1131+
}
1132+
}
1133+
}
1134+
10981135
/// Gets the column names of the primary key and validates them
10991136
private mutating func primaryKey(
11001137
of stmt: CreateTableStmtSyntax,

Sources/FeatherMacros/DatabaseMacro.swift

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,37 @@ extension DatabaseMacro: MemberMacro {
3434
}
3535

3636
var compiler = Compiler()
37+
var queries: [Statement] = []
3738

3839
for (migration, expr) in migrations {
39-
for diag in compiler.compile(migration: migration) {
40+
let (_, diagnostics) = compiler.compile(migration: migration)
41+
42+
for diag in diagnostics {
4043
context.addDiagnostics(from: diag, node: expr)
4144
}
4245
}
4346

4447
for (name, variable) in variables {
4548
guard let queryMacro = variable.queryMacroInputsIfIsQuery(in: context) else { continue }
4649

47-
for diag in compiler.compile(
50+
let (statement, diagnostics) = compiler.compile(
4851
query: queryMacro.source,
4952
named: name.removingQuerySuffix(),
5053
inputType: queryMacro.inputName,
5154
outputType: queryMacro.outputName
52-
) {
55+
)
56+
57+
for diag in diagnostics {
5358
context.addDiagnostics(from: diag, node: variable)
5459
}
60+
61+
if let statement {
62+
queries.append(statement)
63+
}
5564
}
5665

5766
let (generatedTables, generatedQueries) = try SwiftLanguage.assemble(
58-
queries: [(nil, compiler.queries)],
67+
queries: [(nil, queries)],
5968
schema: compiler.schema
6069
)
6170

Tests/CompilerTests/Compiler/CompileCreateTable.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ CREATE TABLE baz (
6060
-- CHECK: KIND normal
6161
-- CHECK-ERROR: Table 'qux' already has a primary key
6262
CREATE TABLE qux (
63-
foo TEXT PRIMARY KEY,
63+
foo TEXT PRIMARY KEY ON CONFLICT REPLACE REFERENCES qux (foo) ON DELETE CASCADE,
6464
bar INTEGER,
6565
PRIMARY KEY (bar)
6666
) STRICT;

Tests/CompilerTests/CompilerTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class CompilerTests: XCTestCase {
5757
sqlFile: "CompileIsSingleResult",
5858
parse: { contents in
5959
var compiler = Compiler()
60-
_ = compiler.compile(queries: contents)
61-
return compiler.queries
60+
let (statements, _) = compiler.compile(queries: contents)
61+
return statements
6262
.filter{ !($0.syntax is CreateTableStmtSyntax) }
6363
.map { $0.outputCardinality.rawValue.uppercased() }
6464
}

0 commit comments

Comments
 (0)