@@ -961,7 +961,8 @@ extension StmtTypeChecker {
961961 for (name, def) in columnsDefs {
962962 columns [ name. value] = typeFor (
963963 column: def,
964- tableColumns: columns
964+ tableColumns: columns,
965+ tableName: createTable. name. value
965966 )
966967 }
967968
@@ -1001,7 +1002,11 @@ extension StmtTypeChecker {
10011002 newColumns [ column. key == oldName. value ? newName. value : column. key] = column. value
10021003 }
10031004 case let . addColumn( column) :
1004- table. columns [ column. name. value] = typeFor ( column: column, tableColumns: table. columns)
1005+ table. columns [ column. name. value] = typeFor (
1006+ column: column,
1007+ tableColumns: table. columns,
1008+ tableName: table. name
1009+ )
10051010 case let . dropColumn( column) :
10061011 table. columns [ column. value] = nil
10071012 }
@@ -1040,7 +1045,8 @@ extension StmtTypeChecker {
10401045 /// Will figure out the final SQL column type from the syntax
10411046 private mutating func typeFor(
10421047 column: borrowing ColumnDefSyntax ,
1043- tableColumns: borrowing Columns
1048+ tableColumns: borrowing Columns ,
1049+ tableName: Substring
10441050 ) -> Type {
10451051 var isNotNullable = false
10461052 for constraint in column. constraints {
@@ -1059,7 +1065,23 @@ extension StmtTypeChecker {
10591065 _ = typeChecker. typeCheck ( expr)
10601066 }
10611067 case . foreignKey( let fk) :
1062- if schema [ fk. foreignTable. value] == nil {
1068+ if fk. foreignTable. value == tableName {
1069+ for foreignColumn in fk. foreignColumns {
1070+ // Column constraints can reference the column they are
1071+ // declared for so if its this table and this column then ignore it.
1072+ guard column. name. value != foreignColumn. value else { continue }
1073+
1074+ if tableColumns [ foreignColumn. value] == nil {
1075+ diagnostics. add ( . columnDoesNotExist( foreignColumn) )
1076+ }
1077+ }
1078+ } else if let table = schema [ fk. foreignTable. value] {
1079+ for foreignColumn in fk. foreignColumns {
1080+ if table. columns [ foreignColumn. value] == nil {
1081+ diagnostics. add ( . columnDoesNotExist( foreignColumn) )
1082+ }
1083+ }
1084+ } else {
10631085 diagnostics. add ( . tableDoesNotExist( fk. foreignTable) )
10641086 }
10651087 case . generated( let expr, _) :
@@ -1095,6 +1117,21 @@ extension StmtTypeChecker {
10951117 }
10961118 }
10971119
1120+ private mutating func typeCheck(
1121+ fk: ForeignKeyClauseSyntax ,
1122+ column: ColumnDefSyntax ,
1123+ tableColumns: Columns
1124+ ) {
1125+ for foreignColumn in fk. foreignColumns {
1126+ // Column constraints can oddly reference themselves
1127+ guard column. name. value != foreignColumn. value else { continue }
1128+
1129+ if tableColumns [ foreignColumn. value] == nil {
1130+ diagnostics. add ( . columnDoesNotExist( foreignColumn) )
1131+ }
1132+ }
1133+ }
1134+
10981135 /// Gets the column names of the primary key and validates them
10991136 private mutating func primaryKey(
11001137 of stmt: CreateTableStmtSyntax ,
0 commit comments