Skip to content

Commit a993c12

Browse files
committed
Outlandish CTEs
1 parent 84ce190 commit a993c12

7 files changed

Lines changed: 136 additions & 49 deletions

File tree

Sources/Compiler/Sema/Builtins.swift

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ enum Builtins {
7575
"instr": Function(.text, .text, returning: .integer),
7676
"last_insert_rowid": Function(returning: .integer),
7777
"length": Function(.text, returning: .integer),
78-
"like": Function(.text, .text, returning: .integer),
78+
"like": Function(
79+
.text, .text,
80+
returning: .integer,
81+
overloads: [Function.Overload(.text, .text, .text, returning: .integer)]
82+
),
7983
"likelihood": Function(.var(0), .real, returning: .var(0)),
8084
"likely": Function(.var(0), returning: .var(0)),
8185
"lower": Function(.text, returning: .text),
@@ -91,19 +95,39 @@ enum Builtins {
9195
"random": Function(returning: .integer),
9296
"randomblob": Function(.integer, returning: .blob),
9397
"replace": Function(.text, .text, .text, returning: .text),
94-
"round": Function(.real, .integer, returning: .real),
98+
"round": Function(
99+
.real,
100+
returning: .real,
101+
overloads: [Function.Overload(.real, .integer, returning: .real)]
102+
),
95103
"rtrim": Function(
96104
.text,
97105
returning: .text,
98106
overloads: [Function.Overload(.text, .text, returning: .text)]
99107
),
100108
"sign": Function(.var(.integer(0)), returning: .integer),
101109
"soundex": Function(.text, returning: .text),
102-
"substr": Function(.text, .integer, .integer, returning: .text),
103-
"substring": Function(.text, .integer, .integer, returning: .text),
104-
"trim": Function(.text, .text, returning: .text),
110+
"substr": Function(
111+
.text, .integer,
112+
returning: .text,
113+
overloads: [Function.Overload(.text, .integer, .integer, returning: .text)]
114+
),
115+
"substring": Function(
116+
.text, .integer,
117+
returning: .text,
118+
overloads: [Function.Overload(.text, .integer, .integer, returning: .text)]
119+
),
120+
"trim": Function(
121+
.text,
122+
returning: .text,
123+
overloads: [Function.Overload(.text, .text, returning: .text)]
124+
),
105125
"typeof": Function(.var(0), returning: .text),
106-
"unhex": Function(.text, returning: .blob),
126+
"unhex": Function(
127+
.text,
128+
returning: .blob,
129+
overloads: [Function.Overload(.text, .text, returning: .blob)]
130+
),
107131
"unicode": Function(.text, returning: .integer),
108132
"unlikely": Function(.var(0), returning: .var(0)),
109133
"upper": Function(.text, returning: .text),

Sources/Compiler/Sema/ExprTypeChecker.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ struct ExprTypeChecker {
1616
private(set) var env: Environment
1717
/// The entire database schema
1818
private let schema: Schema
19+
/// Any CTEs available to the expression
20+
private let ctes: [Substring: Table]
1921
/// Any diagnostics that are emitted during compilation
2022
private(set) var diagnostics = Diagnostics()
2123
/// Any table that is used
@@ -27,11 +29,13 @@ struct ExprTypeChecker {
2729
inferenceState: InferenceState,
2830
env: Environment,
2931
schema: Schema,
32+
ctes: [Substring: Table],
3033
pragmas: FeatherPragmas
3134
) {
3235
self.inferenceState = inferenceState
3336
self.env = env
3437
self.schema = schema
38+
self.ctes = ctes
3539
self.pragmas = pragmas
3640
}
3741

@@ -54,6 +58,7 @@ struct ExprTypeChecker {
5458
var typeChecker = StmtTypeChecker(
5559
env: Environment(parent: env),
5660
schema: schema,
61+
ctes: ctes,
5762
inferenceState: inferenceState,
5863
pragmas: pragmas
5964
)

Sources/Compiler/Sema/StmtTypeChecker.swift

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct StmtTypeChecker {
2020
private(set) var schema: Schema
2121
/// Any CTE that was declared with the statement.
2222
/// Keeping these separate from the schema so they don't get passed to the next statement
23-
private(set) var ctes: [Substring: Table] = [:]
23+
private(set) var ctes: [Substring: Table]
2424
/// Any diagnostics that are emitted during compilation
2525
private(set) var diagnostics = Diagnostics()
2626
/// Inferrer for any bind parameter names
@@ -41,11 +41,13 @@ struct StmtTypeChecker {
4141
init(
4242
env: Environment = Environment(),
4343
schema: Schema,
44+
ctes: [Substring: Table] = [:],
4445
inferenceState: InferenceState = InferenceState(),
4546
pragmas: FeatherPragmas
4647
) {
4748
self.env = env
4849
self.schema = schema
50+
self.ctes = ctes
4951
self.inferenceState = inferenceState
5052
self.pragmas = pragmas
5153
}
@@ -67,6 +69,7 @@ struct StmtTypeChecker {
6769
inferenceState: inferenceState,
6870
env: env,
6971
schema: schema,
72+
ctes: ctes,
7073
pragmas: pragmas
7174
)
7275
let type = exprTypeChecker.typeCheck(expr)
@@ -516,6 +519,7 @@ extension StmtTypeChecker {
516519
var index = 0
517520
let secondColumns = secondResult.allColumns.values
518521
return firstResult.mapTypes { type in
522+
let type = inferenceState.solution(for: type)
519523
inferenceState.unify(
520524
type,
521525
with: inferenceState.solution(for: secondColumns[index]),
@@ -690,6 +694,9 @@ extension StmtTypeChecker {
690694
return ResultColumns(columns: resultColumns, table: nil)
691695
}
692696

697+
/// Type checks the beginning of 1 or more CTE declarations.
698+
/// Optional to help with the ease of the API since any time its
699+
/// used its optionally at the beginning of some statements.
693700
private mutating func typeCheck(with: WithSyntax?) {
694701
guard let with else { return }
695702

@@ -702,40 +709,50 @@ extension StmtTypeChecker {
702709
}
703710
}
704711

712+
/// Type checks the CTE expression. Will return the resultant table
713+
/// representing the CTE.
705714
private mutating func typeCheck(
706715
cte: CommonTableExpressionSyntax,
707716
recursive: Bool
708717
) -> Table {
709718
let cteName = QualifiedName(name: cte.table.value, schema: nil)
710719

711-
if cte.columns.isEmpty {
712-
let resultColumns = typeCheck(select: cte.select)
713-
return Table(name: cteName, columns: resultColumns.allColumns, kind: .cte)
720+
// Recursive CTE's can reference themselves so we need to create a table to
721+
// represent this CTE with all columns as type variables.
722+
let recursiveCte: Table?
723+
if recursive {
724+
recursiveCte = Table(
725+
name: cteName,
726+
columns: cte.columns.reduce(into: [:]) { columns, name in
727+
columns.append(inferenceState.freshTyVar(for: name), for: name.value)
728+
},
729+
kind: .cte
730+
)
731+
732+
ctes[cteName.name] = recursiveCte
714733
} else {
715-
// CTE's can reference themselves so we need to create a table to
716-
// represent this CTE with all columns as type variables.
717-
let thisCte = Table(
718-
name: cteName,
719-
columns: cte.columns.reduce(into: [:]) { columns, name in
720-
columns.append(inferenceState.freshTyVar(for: name), for: name.value)
721-
},
722-
kind: .cte
723-
)
724-
725-
ctes[thisCte.name.name] = thisCte
726-
727-
let resultColumns = typeCheck(select: cte.select)
728-
let columnTypes = resultColumns.allColumns.values
729-
if columnTypes.count != cte.columns.count {
734+
recursiveCte = nil
735+
}
736+
737+
let resultColumns = typeCheck(select: cte.select)
738+
739+
// If the CTE defined columns make sure that they have the same amount of columns.
740+
if !cte.columns.isEmpty {
741+
if resultColumns.count != cte.columns.count {
730742
diagnostics.add(.init(
731-
"CTE expected \(cte.columns.count) columns, but got \(columnTypes.count)",
743+
"CTE expected \(cte.columns.count) columns, but got \(resultColumns.count)",
732744
at: cte.location
733745
))
734746
}
735-
747+
}
748+
749+
if let recursiveCte {
736750
// Simply return the table but getting the solution types so the substitution
737751
// map retains it's integrity.
738-
return thisCte.mapTypes { inferenceState.solution(for: $0) }
752+
return recursiveCte.mapTypes { inferenceState.solution(for: $0) }
753+
} else {
754+
// No recursive CTE, create a new table from the result columns.
755+
return Table(name: cteName, columns: resultColumns.allColumns, kind: .cte)
739756
}
740757
}
741758

Tests/CompilerTests/Compiler/CompileCTE.sql

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
This contains some examples from SQLite's own documentation on
3+
https://sqlite.org/lang_with.html
4+
5+
They are labeled as "Outlandish" in the docs title and figured would
6+
be some great tests to really test out the behviour of CTEs to
7+
make sure they match SQLite even in the most extreme use cases.
8+
*/
9+
10+
-- CHECK: SIGNATURE
11+
-- CHECK: ...
12+
WITH RECURSIVE
13+
xaxis(x) AS (VALUES(-2.0) UNION ALL SELECT x+0.05 FROM xaxis WHERE x<1.2),
14+
yaxis(y) AS (VALUES(-1.0) UNION ALL SELECT y+0.1 FROM yaxis WHERE y<1.0),
15+
m(iter, cx, cy, x, y) AS (
16+
SELECT 0, x, y, 0.0, 0.0 FROM xaxis, yaxis
17+
UNION ALL
18+
SELECT iter+1, cx, cy, x*x-y*y + cx, 2.0*x*y + cy FROM m
19+
WHERE (x*x + y*y) < 4.0 AND iter<28
20+
),
21+
m2(iter, cx, cy) AS (
22+
SELECT max(iter), cx, cy FROM m GROUP BY cx, cy
23+
),
24+
a(t) AS (
25+
SELECT group_concat( substr(' .+*#', 1+min(iter/7,4), 1), '')
26+
FROM m2 GROUP BY cy
27+
)
28+
SELECT group_concat(rtrim(t),'0a') FROM a;
29+
30+
-- CHECK: SIGNATURE
31+
-- CHECK: ...
32+
WITH RECURSIVE
33+
input(sud) AS (
34+
VALUES('53..7....6..195....98....6.8...6...34..8.3..17...2...6.6....28....419..5....8..79')
35+
),
36+
digits(z, lp) AS (
37+
VALUES('1', 1)
38+
UNION ALL SELECT
39+
CAST(lp+1 AS TEXT), lp+1 FROM digits WHERE lp<9
40+
),
41+
x(s, ind) AS (
42+
SELECT sud, instr(sud, '.') FROM input
43+
UNION ALL
44+
SELECT
45+
substr(s, 1, ind-1) || z || substr(s, ind+1),
46+
instr( substr(s, 1, ind-1) || z || substr(s, ind+1), '.' )
47+
FROM x, digits AS z
48+
WHERE ind>0
49+
AND NOT EXISTS (
50+
SELECT 1
51+
FROM digits AS lp
52+
WHERE z.z = substr(s, ((ind-1)/9)*9 + lp, 1)
53+
OR z.z = substr(s, ((ind-1)%9) + (lp-1)*9 + 1, 1)
54+
OR z.z = substr(s, (((ind-1)/3) % 3) * 3
55+
+ ((ind-1)/27) * 27 + lp
56+
+ ((lp-1) / 3) * 6, 1)
57+
)
58+
)
59+
SELECT s FROM x WHERE ind=0;

Tests/CompilerTests/CompilerTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ class CompilerTests: XCTestCase {
8181
try checkSchema(compile: "CompileIndex")
8282
}
8383

84-
func testCte() throws {
85-
try checkQueries(compile: "CompileCTE", dump: true)
84+
func testOutlandishCte() throws {
85+
try checkQueries(compile: "CompileOutlandishCTE", dump: true)
8686
}
8787
}
8888

Tests/CompilerTests/TypeCheckerTests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ class TypeCheckerTests: XCTestCase {
240240
inferenceState: InferenceState(),
241241
env: scope,
242242
schema: Schema(),
243+
ctes: [:],
243244
pragmas: []
244245
)
245246
var nameInferrer = NameInferrer()

0 commit comments

Comments
 (0)