@@ -20,7 +20,7 @@ struct StmtTypeChecker {
2020 private( set) var schema : Schema
2121 /// Any CTE that was declared with the statement.
2222 /// Keeping these separate from the schema so they don't get passed to the next statement
23- private( set) var ctes : [ Substring : Table ] = [ : ]
23+ private( set) var ctes : [ Substring : Table ]
2424 /// Any diagnostics that are emitted during compilation
2525 private( set) var diagnostics = Diagnostics ( )
2626 /// Inferrer for any bind parameter names
@@ -41,11 +41,13 @@ struct StmtTypeChecker {
4141 init (
4242 env: Environment = Environment ( ) ,
4343 schema: Schema ,
44+ ctes: [ Substring : Table ] = [ : ] ,
4445 inferenceState: InferenceState = InferenceState ( ) ,
4546 pragmas: FeatherPragmas
4647 ) {
4748 self . env = env
4849 self . schema = schema
50+ self . ctes = ctes
4951 self . inferenceState = inferenceState
5052 self . pragmas = pragmas
5153 }
@@ -67,6 +69,7 @@ struct StmtTypeChecker {
6769 inferenceState: inferenceState,
6870 env: env,
6971 schema: schema,
72+ ctes: ctes,
7073 pragmas: pragmas
7174 )
7275 let type = exprTypeChecker. typeCheck ( expr)
@@ -516,6 +519,7 @@ extension StmtTypeChecker {
516519 var index = 0
517520 let secondColumns = secondResult. allColumns. values
518521 return firstResult. mapTypes { type in
522+ let type = inferenceState. solution ( for: type)
519523 inferenceState. unify (
520524 type,
521525 with: inferenceState. solution ( for: secondColumns [ index] ) ,
@@ -690,6 +694,9 @@ extension StmtTypeChecker {
690694 return ResultColumns ( columns: resultColumns, table: nil )
691695 }
692696
697+ /// Type checks the beginning of 1 or more CTE declarations.
698+ /// Optional to help with the ease of the API since any time its
699+ /// used its optionally at the beginning of some statements.
693700 private mutating func typeCheck( with: WithSyntax ? ) {
694701 guard let with else { return }
695702
@@ -702,40 +709,50 @@ extension StmtTypeChecker {
702709 }
703710 }
704711
712+ /// Type checks the CTE expression. Will return the resultant table
713+ /// representing the CTE.
705714 private mutating func typeCheck(
706715 cte: CommonTableExpressionSyntax ,
707716 recursive: Bool
708717 ) -> Table {
709718 let cteName = QualifiedName ( name: cte. table. value, schema: nil )
710719
711- if cte. columns. isEmpty {
712- let resultColumns = typeCheck ( select: cte. select)
713- return Table ( name: cteName, columns: resultColumns. allColumns, kind: . cte)
720+ // Recursive CTE's can reference themselves so we need to create a table to
721+ // represent this CTE with all columns as type variables.
722+ let recursiveCte : Table ?
723+ if recursive {
724+ recursiveCte = Table (
725+ name: cteName,
726+ columns: cte. columns. reduce ( into: [ : ] ) { columns, name in
727+ columns. append ( inferenceState. freshTyVar ( for: name) , for: name. value)
728+ } ,
729+ kind: . cte
730+ )
731+
732+ ctes [ cteName. name] = recursiveCte
714733 } else {
715- // CTE's can reference themselves so we need to create a table to
716- // represent this CTE with all columns as type variables.
717- let thisCte = Table (
718- name: cteName,
719- columns: cte. columns. reduce ( into: [ : ] ) { columns, name in
720- columns. append ( inferenceState. freshTyVar ( for: name) , for: name. value)
721- } ,
722- kind: . cte
723- )
724-
725- ctes [ thisCte. name. name] = thisCte
726-
727- let resultColumns = typeCheck ( select: cte. select)
728- let columnTypes = resultColumns. allColumns. values
729- if columnTypes. count != cte. columns. count {
734+ recursiveCte = nil
735+ }
736+
737+ let resultColumns = typeCheck ( select: cte. select)
738+
739+ // If the CTE defined columns make sure that they have the same amount of columns.
740+ if !cte. columns. isEmpty {
741+ if resultColumns. count != cte. columns. count {
730742 diagnostics. add ( . init(
731- " CTE expected \( cte. columns. count) columns, but got \( columnTypes . count) " ,
743+ " CTE expected \( cte. columns. count) columns, but got \( resultColumns . count) " ,
732744 at: cte. location
733745 ) )
734746 }
735-
747+ }
748+
749+ if let recursiveCte {
736750 // Simply return the table but getting the solution types so the substitution
737751 // map retains it's integrity.
738- return thisCte. mapTypes { inferenceState. solution ( for: $0) }
752+ return recursiveCte. mapTypes { inferenceState. solution ( for: $0) }
753+ } else {
754+ // No recursive CTE, create a new table from the result columns.
755+ return Table ( name: cteName, columns: resultColumns. allColumns, kind: . cte)
739756 }
740757 }
741758
0 commit comments