Skip to content

Commit 84ce190

Browse files
committed
wip
1 parent affde2e commit 84ce190

7 files changed

Lines changed: 83 additions & 25 deletions

File tree

Sources/Compiler/Parse/Parsers.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ enum Parsers {
10881088
location: state.location(from: start)
10891089
)
10901090
case let .identifier(table) where state.peek.kind == .dot && state.peek2.kind == .star:
1091-
let table = IdentifierSyntax(value: table, location: state.current.location)
1091+
let table = IdentifierSyntax(id: state.nextId(), value: table, location: state.current.location)
10921092
state.skip()
10931093
state.consume(.dot)
10941094
state.consume(.star)
@@ -2443,10 +2443,10 @@ enum Parsers {
24432443

24442444
guard case let .identifier(ident) = token.kind else {
24452445
state.diagnostics.add(.init("Expected identifier", at: token.location))
2446-
return IdentifierSyntax(value: "<<error>>", location: token.location)
2446+
return IdentifierSyntax(id: state.nextId(), value: "<<error>>", location: token.location)
24472447
}
24482448

2449-
return IdentifierSyntax(value: ident, location: token.location)
2449+
return IdentifierSyntax(id: state.nextId(), value: ident, location: token.location)
24502450
}
24512451

24522452
/// So this is to handle a weird edge case. SQLite apparently allows keywords
@@ -2457,7 +2457,7 @@ enum Parsers {
24572457
let token = state.take()
24582458

24592459
if case let .identifier(ident) = token.kind {
2460-
return IdentifierSyntax(value: ident, location: token.location)
2460+
return IdentifierSyntax(id: state.nextId(), value: ident, location: token.location)
24612461
}
24622462

24632463
// Since this is kind of edge casey instead of making all
@@ -2471,10 +2471,10 @@ enum Parsers {
24712471
// should not allow
24722472
guard isKeyword else {
24732473
state.diagnostics.add(.init("Expected identifier", at: token.location))
2474-
return IdentifierSyntax(value: "<<error>>", location: token.location)
2474+
return IdentifierSyntax(id: state.nextId(), value: "<<error>>", location: token.location)
24752475
}
24762476

2477-
return IdentifierSyntax(value: rawValue, location: token.location)
2477+
return IdentifierSyntax(id: state.nextId(), value: rawValue, location: token.location)
24782478
}
24792479

24802480
/// https://www.sqlite.org/syntax/numeric-literal.html

Sources/Compiler/Sema/Builtins.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ enum Builtins {
7979
"likelihood": Function(.var(0), .real, returning: .var(0)),
8080
"likely": Function(.var(0), returning: .var(0)),
8181
"lower": Function(.text, returning: .text),
82-
"ltrim": Function(.text, .text, returning: .text),
82+
"ltrim": Function(
83+
.text,
84+
returning: .text,
85+
overloads: [Function.Overload(.text, .text, returning: .text)]
86+
),
8387
"max": Function(.var(0), returning: .var(0), variadic: true),
8488
"min": Function(.var(0), returning: .var(0), variadic: true),
8589
"nullif": Function(.var(0), .var(0), returning: .optional(.var(0))),
@@ -88,7 +92,11 @@ enum Builtins {
8892
"randomblob": Function(.integer, returning: .blob),
8993
"replace": Function(.text, .text, .text, returning: .text),
9094
"round": Function(.real, .integer, returning: .real),
91-
"rtrim": Function(.text, .text, returning: .text),
95+
"rtrim": Function(
96+
.text,
97+
returning: .text,
98+
overloads: [Function.Overload(.text, .text, returning: .text)]
99+
),
92100
"sign": Function(.var(.integer(0)), returning: .integer),
93101
"soundex": Function(.text, returning: .text),
94102
"substr": Function(.text, .integer, .integer, returning: .text),

Sources/Compiler/Sema/StmtTypeChecker.swift

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,11 @@ extension StmtTypeChecker {
516516
var index = 0
517517
let secondColumns = secondResult.allColumns.values
518518
return firstResult.mapTypes { type in
519-
inferenceState.unify(type, with: secondColumns[index], at: location)
519+
inferenceState.unify(
520+
type,
521+
with: inferenceState.solution(for: secondColumns[index]),
522+
at: location
523+
)
520524
index += 1
521525
return type
522526
}
@@ -691,20 +695,36 @@ extension StmtTypeChecker {
691695

692696
for cte in with.ctes {
693697
let table = inNewEnvironment { typeChecker in
694-
typeChecker.typeCheck(cte: cte)
698+
typeChecker.typeCheck(cte: cte, recursive: with.recursive)
695699
}
696700

697701
ctes[cte.table.value] = table
698702
}
699703
}
700704

701-
private mutating func typeCheck(cte: CommonTableExpressionSyntax) -> Table {
702-
let resultColumns = typeCheck(select: cte.select)
703-
704-
let columns: Columns
705+
private mutating func typeCheck(
706+
cte: CommonTableExpressionSyntax,
707+
recursive: Bool
708+
) -> Table {
709+
let cteName = QualifiedName(name: cte.table.value, schema: nil)
710+
705711
if cte.columns.isEmpty {
706-
columns = resultColumns.allColumns
712+
let resultColumns = typeCheck(select: cte.select)
713+
return Table(name: cteName, columns: resultColumns.allColumns, kind: .cte)
707714
} else {
715+
// CTE's can reference themselves so we need to create a table to
716+
// represent this CTE with all columns as type variables.
717+
let thisCte = Table(
718+
name: cteName,
719+
columns: cte.columns.reduce(into: [:]) { columns, name in
720+
columns.append(inferenceState.freshTyVar(for: name), for: name.value)
721+
},
722+
kind: .cte
723+
)
724+
725+
ctes[thisCte.name.name] = thisCte
726+
727+
let resultColumns = typeCheck(select: cte.select)
708728
let columnTypes = resultColumns.allColumns.values
709729
if columnTypes.count != cte.columns.count {
710730
diagnostics.add(.init(
@@ -713,16 +733,10 @@ extension StmtTypeChecker {
713733
))
714734
}
715735

716-
columns = (0 ..< min(columnTypes.count, cte.columns.count))
717-
.reduce(into: [:]) { $0.append(columnTypes[$1], for: cte.columns[$1].value) }
736+
// Simply return the table but getting the solution types so the substitution
737+
// map retains it's integrity.
738+
return thisCte.mapTypes { inferenceState.solution(for: $0) }
718739
}
719-
720-
return Table(
721-
name: QualifiedName(name: cte.table.value, schema: nil),
722-
columns: columns,
723-
primaryKey: [],
724-
kind: .cte
725-
)
726740
}
727741

728742
/// Will infer the core part of the select.

Sources/Compiler/Syntax/IdentifierSyntax.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
// Created by Wes Wickwire on 10/19/24.
66
//
77

8-
struct IdentifierSyntax: Sendable {
8+
struct IdentifierSyntax: Sendable, Syntax {
9+
let id: SyntaxId
910
private(set) var value: Substring
1011
private(set) var location: SourceLocation
1112
}

Sources/Compiler/Table.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ public struct Table: Sendable, Equatable {
2424
case subquery
2525
}
2626

27+
init(
28+
name: QualifiedName,
29+
columns: Columns,
30+
primaryKey: [Substring] = [],
31+
kind: Kind
32+
) {
33+
self.name = name
34+
self.columns = columns
35+
self.primaryKey = primaryKey
36+
self.kind = kind
37+
}
38+
2739
var type: Type {
2840
return .row(.fixed(columns.map(\.value)))
2941
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- CHECK: SIGNATURE
2+
-- CHECK: ...
3+
WITH RECURSIVE
4+
xaxis(x) AS (VALUES(-2.0) UNION ALL SELECT x+0.05 FROM xaxis WHERE x<1.2),
5+
yaxis(y) AS (VALUES(-1.0) UNION ALL SELECT y+0.1 FROM yaxis WHERE y<1.0),
6+
m(iter, cx, cy, x, y) AS (
7+
SELECT 0, x, y, 0.0, 0.0 FROM xaxis, yaxis
8+
UNION ALL
9+
SELECT iter+1, cx, cy, x*x-y*y + cx, 2.0*x*y + cy FROM m
10+
WHERE (x*x + y*y) < 4.0 AND iter<28
11+
),
12+
m2(iter, cx, cy) AS (
13+
SELECT max(iter), cx, cy FROM m GROUP BY cx, cy
14+
),
15+
a(t) AS (
16+
SELECT group_concat( substr(' .+*#', 1+min(iter/7,4), 1), '')
17+
FROM m2 GROUP BY cy
18+
)
19+
SELECT group_concat(rtrim(t),'0a') FROM a;

Tests/CompilerTests/CompilerTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class CompilerTests: XCTestCase {
8080
func testIndex() throws {
8181
try checkSchema(compile: "CompileIndex")
8282
}
83+
84+
func testCte() throws {
85+
try checkQueries(compile: "CompileCTE", dump: true)
86+
}
8387
}
8488

8589
struct CheckSignature: Checkable {

0 commit comments

Comments
 (0)