Skip to content

Commit eb66c0b

Browse files
committed
compound select
1 parent 89dcef7 commit eb66c0b

7 files changed

Lines changed: 258 additions & 60 deletions

File tree

Sources/Compiler/Parse/Parsers.swift

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -781,24 +781,56 @@ enum Parsers {
781781
cteRecursive: Bool,
782782
cte: CommonTableExpressionSyntax?
783783
) throws -> SelectStmtSyntax {
784-
let selects: [SelectCoreSyntax]? = state.current.kind == .select || state.current.kind == .values
785-
? try commaDelimited(state: &state, element: selectCore)
786-
: nil
787-
784+
let selects = try selects(state: &state)
788785
let orderBy = try orderingTerms(state: &state)
789786
let limit = try limit(state: &state)
790787

791788
return SelectStmtSyntax(
792789
id: state.nextId(),
793790
cte: cte.map(Indirect.init),
794791
cteRecursive: cteRecursive,
795-
selects: .init(.single(selects!.first!)), // TODO: Fix this and do it properly
792+
selects: .init(selects),
796793
orderBy: orderBy,
797794
limit: limit,
798795
location: state.location(from: start)
799796
)
800797
}
801798

799+
static func selects(state: inout ParserState) throws -> SelectStmtSyntax.Selects {
800+
let core = try selectCore(state: &state)
801+
802+
return if let op = compoundOperator(state: &state) {
803+
try .compound(core, op, selects(state: &state))
804+
} else {
805+
.single(core)
806+
}
807+
}
808+
809+
static func compoundOperator(state: inout ParserState) -> CompoundOperatorSyntax? {
810+
let start = state.location
811+
let kind: CompoundOperatorSyntax.Kind
812+
switch (state.current.kind, state.peek.kind) {
813+
case (.union, .all): kind = .unionAll
814+
case (.union, _): kind = .union
815+
case (.intersect, _): kind = .intersect
816+
case (.except, _): kind = .except
817+
default: return nil
818+
}
819+
820+
state.skip()
821+
if kind == .unionAll {
822+
state.skip()
823+
}
824+
825+
return CompoundOperatorSyntax(
826+
id: state.nextId(),
827+
kind: kind,
828+
location: state.location(
829+
from: start
830+
)
831+
)
832+
}
833+
802834
static func orderingTerms(state: inout ParserState) throws -> [OrderingTermSyntax] {
803835
guard state.take(if: .order) else { return [] }
804836
state.consume(.by)

Sources/Compiler/Sema/NameInferrer.swift

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -87,49 +87,7 @@ struct NameInferrer {
8787
infer(select: cte.select)
8888
}
8989

90-
switch select.selects.value {
91-
case .single(let s):
92-
switch s {
93-
case .select(let select):
94-
for column in select.columns {
95-
switch column.kind {
96-
case .expr(let e, _):
97-
_ = e.accept(visitor: &self)
98-
default:
99-
break
100-
}
101-
}
102-
103-
switch select.from {
104-
case .join(let join):
105-
_ = infer(tableOrSubquery: join.tableOrSubquery)
106-
case .tableOrSubqueries(let tableOrSubqueries):
107-
for tableOrSubquery in tableOrSubqueries {
108-
_ = infer(tableOrSubquery: tableOrSubquery)
109-
}
110-
case nil:
111-
break
112-
}
113-
114-
if let whereExpr = select.where {
115-
_ = whereExpr.accept(visitor: &self)
116-
}
117-
118-
if let groupBy = select.groupBy {
119-
for expr in groupBy.expressions {
120-
_ = expr.accept(visitor: &self)
121-
}
122-
}
123-
case .values(let groups):
124-
for group in groups {
125-
for value in group {
126-
_ = value.accept(visitor: &self)
127-
}
128-
}
129-
}
130-
case .compound:
131-
fatalError()
132-
}
90+
infer(selects: select.selects.value)
13391

13492
for orderBy in select.orderBy {
13593
_ = orderBy.expr.accept(visitor: &self)
@@ -140,6 +98,61 @@ struct NameInferrer {
14098
}
14199
}
142100

101+
private mutating func infer(selects: SelectStmtSyntax.Selects) {
102+
switch selects {
103+
case let .single(select):
104+
infer(select: select)
105+
case let .compound(first, _, second):
106+
infer(select: first)
107+
infer(selects: second)
108+
}
109+
}
110+
111+
private mutating func infer(select: SelectCoreSyntax) {
112+
switch select {
113+
case .select(let select):
114+
for column in select.columns {
115+
switch column.kind {
116+
case let .expr(e, alias):
117+
let eName = e.accept(visitor: &self)
118+
119+
if let alias {
120+
_ = unify(names: eName, with: .some(alias.identifier.value))
121+
}
122+
default:
123+
break
124+
}
125+
}
126+
127+
switch select.from {
128+
case .join(let join):
129+
_ = infer(tableOrSubquery: join.tableOrSubquery)
130+
case .tableOrSubqueries(let tableOrSubqueries):
131+
for tableOrSubquery in tableOrSubqueries {
132+
_ = infer(tableOrSubquery: tableOrSubquery)
133+
}
134+
case nil:
135+
break
136+
}
137+
138+
if let whereExpr = select.where {
139+
_ = whereExpr.accept(visitor: &self)
140+
}
141+
142+
if let groupBy = select.groupBy {
143+
for expr in groupBy.expressions {
144+
_ = expr.accept(visitor: &self)
145+
}
146+
}
147+
case .values(let groups):
148+
for group in groups {
149+
for value in group {
150+
_ = value.accept(visitor: &self)
151+
}
152+
}
153+
}
154+
}
155+
143156
private mutating func infer(tableOrSubquery: TableOrSubquerySyntax) -> Name {
144157
switch tableOrSubquery.kind {
145158
case .table:

Sources/Compiler/Sema/StmtTypeChecker.swift

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ struct StmtTypeChecker {
9898
let result = action(&inferrer)
9999
diagnostics = inferrer.diagnostics
100100
nameInferrer = inferrer.nameInferrer
101+
inferenceState = inferrer.inferenceState
101102
return result
102103
}
103104
}
@@ -239,16 +240,11 @@ extension StmtTypeChecker {
239240
typeCheck(cte: cte)
240241
}
241242

242-
let resultColumns = switch select.selects.value {
243-
case let .single(selectCore):
244-
typeCheck(
245-
select: selectCore,
246-
at: select.location,
247-
potentialNames: potentialNames
248-
)
249-
case .compound:
250-
fatalError()
251-
}
243+
let resultColumns = typeCheck(
244+
selects: select.selects.value,
245+
at: select.location,
246+
potentialNames: potentialNames
247+
)
252248

253249
for term in select.orderBy {
254250
_ = typeCheck(term.expr)
@@ -261,6 +257,54 @@ extension StmtTypeChecker {
261257
return resultColumns
262258
}
263259

260+
mutating func typeCheck(
261+
selects: SelectStmtSyntax.Selects,
262+
at location: SourceLocation,
263+
potentialNames: [IdentifierSyntax]? = nil
264+
) -> ResultColumns {
265+
switch selects {
266+
case let .single(selectCore):
267+
return typeCheck(
268+
select: selectCore,
269+
at: location,
270+
potentialNames: potentialNames
271+
)
272+
case let .compound(first, op, second):
273+
// SQLite:
274+
// * Does not care about types
275+
// * Uses names of first
276+
// * Cares about # of columns
277+
278+
let firstResult = inNewEnvironment { typeChecker in
279+
typeChecker.typeCheck(
280+
select: first,
281+
at: location,
282+
potentialNames: potentialNames
283+
)
284+
}
285+
286+
let secondResult = inNewEnvironment { typeChecker in
287+
typeChecker.typeCheck(selects: second, at: location)
288+
}
289+
290+
guard firstResult.count == secondResult.count else {
291+
diagnostics.add(.init(
292+
"SELECTs for \(op.kind) do not have the same number of columns (\(firstResult.count) and \(secondResult.count))",
293+
at: op.location
294+
))
295+
return firstResult
296+
}
297+
298+
var index = 0
299+
let secondColumns = secondResult.allColumns.values
300+
return firstResult.mapTypes { type in
301+
inferenceState.unify(type, with: secondColumns[index], at: location)
302+
index += 1
303+
return type
304+
}
305+
}
306+
}
307+
264308
mutating func typeCheck(insert: InsertStmtSyntax) -> ResultColumns {
265309
if let cte = insert.cte {
266310
typeCheck(cte: cte)
@@ -579,6 +623,7 @@ extension StmtTypeChecker {
579623

580624
if let name = alias?.identifier.value ?? names.proposedName {
581625
columns[name] = type
626+
nameInferrer.suggest(name: name, for: names)
582627
} else {
583628
diagnostics.add(.nameRequired(at: expr.location))
584629
}

Sources/Compiler/Syntax/CompoundOperatorSyntax.swift

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,19 @@ struct CompoundOperatorSyntax: Syntax {
1010
let kind: Kind
1111
let location: SourceLocation
1212

13-
enum Kind {
13+
enum Kind: CustomStringConvertible {
1414
case union
1515
case unionAll
1616
case intersect
1717
case except
18+
19+
var description: String {
20+
switch self {
21+
case .union: "UNION"
22+
case .unionAll: "UNION ALL"
23+
case .intersect: "INTERSECT"
24+
case .except: "EXCEPT"
25+
}
26+
}
1827
}
1928
}

Sources/Compiler/Syntax/Statements/SelectStmtSyntax.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ struct SelectStmtSyntax: StmtSyntax {
1616

1717
enum Selects {
1818
case single(SelectCoreSyntax)
19-
indirect case compound(Selects, CompoundOperatorSyntax, SelectCoreSyntax)
19+
indirect case compound(SelectCoreSyntax, CompoundOperatorSyntax, Selects)
2020
}
2121

2222
struct Limit {

Tests/CompilerTests/Compiler/CompileSimpleSelects.sql

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
CREATE TABLE foo (id INTEGER PRIMARY KEY, bar INTEGER AS Bool, baz TEXT NOT NULL);
2+
CREATE TABLE bar (id INTEGER PRIMARY KEY, qux INTEGER AS Bool);
23

34
-- CHECK: SIGNATURE
45
-- CHECK: PARAMETERS
@@ -87,3 +88,66 @@ SELECT foo.*, foo.baz || 'postfix' AS bazWithPostfix FROM foo;
8788
-- CHECK: baz TEXT
8889
-- CHECK: OUTPUT_TABLE foo
8990
SELECT foo.baz AS bazButOnItsOwn, foo.* FROM foo;
91+
92+
-- CHECK: SIGNATURE
93+
-- CHECK: OUTPUT_CHUNKS
94+
-- CHECK: CHUNK
95+
-- CHECK: OUTPUT
96+
-- CHECK: id INTEGER
97+
-- CHECK: bar (INTEGER AS Bool)?
98+
SELECT id, bar FROM foo
99+
UNION
100+
SELECT id, qux FROM bar;
101+
102+
-- CHECK: SIGNATURE
103+
-- CHECK: OUTPUT_CHUNKS
104+
-- CHECK: CHUNK
105+
-- CHECK: OUTPUT
106+
-- CHECK: id INTEGER
107+
-- CHECK: baz TEXT
108+
-- CHECK-ERROR: Unable to unify types 'TEXT' and '(INTEGER AS Bool)?'
109+
SELECT id, baz FROM foo
110+
UNION
111+
SELECT id, qux FROM bar;
112+
113+
-- CHECK: SIGNATURE
114+
-- CHECK: OUTPUT_CHUNKS
115+
-- CHECK: CHUNK
116+
-- CHECK: OUTPUT
117+
-- CHECK: id INTEGER
118+
-- CHECK: bar (INTEGER AS Bool)?
119+
-- CHECK: baz TEXT
120+
-- CHECK-ERROR: SELECTs for UNION do not have the same number of columns (3 and 2)
121+
SELECT id, bar, baz FROM foo
122+
UNION
123+
SELECT id, qux FROM bar;
124+
125+
-- CHECK: SIGNATURE
126+
-- CHECK: PARAMETERS
127+
-- CHECK: PARAMETER
128+
-- CHECK: TYPE (INTEGER AS Bool)?
129+
-- CHECK: INDEX 1
130+
-- CHECK: NAME param
131+
-- CHECK: OUTPUT_CHUNKS
132+
-- CHECK: CHUNK
133+
-- CHECK: OUTPUT
134+
-- CHECK: id INTEGER
135+
-- CHECK: param (INTEGER AS Bool)?
136+
SELECT id, ? AS param FROM foo
137+
UNION
138+
SELECT id, qux FROM bar;
139+
140+
-- CHECK: SIGNATURE
141+
-- CHECK: PARAMETERS
142+
-- CHECK: PARAMETER
143+
-- CHECK: TYPE (INTEGER AS Bool)?
144+
-- CHECK: INDEX 1
145+
-- CHECK: NAME value
146+
-- CHECK: OUTPUT_CHUNKS
147+
-- CHECK: CHUNK
148+
-- CHECK: OUTPUT
149+
-- CHECK: id INTEGER
150+
-- CHECK: bar (INTEGER AS Bool)?
151+
SELECT id, bar FROM foo
152+
UNION
153+
SELECT id, ? AS value FROM bar;

0 commit comments

Comments
 (0)