Skip to content

Commit 407efcc

Browse files
committed
perform checks for all exprs and check int division
1 parent bb8dcb6 commit 407efcc

8 files changed

Lines changed: 192 additions & 57 deletions

File tree

Sources/Compiler/Environment.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ struct Environment {
141141
func resolve(infix op: Operator) -> Function? {
142142
return switch op {
143143
case .in, .not(.in): Builtins.in
144-
case .plus, .minus, .multiply, .divide, .bitwuseOr,
144+
case .plus, .minus, .multiply, .bitwuseOr,
145145
.bitwiseAnd, .shl, .shr, .mod:
146146
Builtins.arithmetic
147+
case .divide: Builtins.divide
147148
case .eq, .eq2, .neq, .neq2, .lt, .gt, .lte, .gte, .is,
148149
.notNull, .notnull, .like, .isNot, .isDistinctFrom,
149150
.isNotDistinctFrom, .between, .and, .or, .isnull, .not:

Sources/Compiler/Function.swift

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,27 @@
55
// Created by Wes Wickwire on 6/7/25.
66
//
77

8+
/// A function that is callable from SQL
89
struct Function: Sendable {
10+
/// While SQL functions are not explicitly noted as generic we will treat them
11+
/// as such to retain the type as much as possible. Allows us to have one plus
12+
/// function that can do `INT + REAL = REAL` sort of things.
13+
///
14+
/// These values will be inferred upon initialization based off of
15+
/// the `params` and `result`
916
let genericTypes: [TypeVariable]
17+
/// The parameter types these take in
1018
let params: [Type]
19+
/// The return type
1120
let result: Type
21+
/// Any additional overloads for the function
1222
let overloads: [Overload]?
23+
/// Whether or not the function is variadic, meaning the last parameter type
24+
/// can be added on indefinitely.
1325
let variadic: Bool
14-
let check: (@Sendable ([ExpressionSyntax], SourceLocation, inout Diagnostics) -> Void)?
26+
/// A custom check to be performed during type checking. Allows us to put in
27+
/// custom error messages and linting if a function has odd usage.
28+
let check: (@Sendable ([Type], [ExpressionSyntax], SourceLocation, inout Diagnostics) -> Void)?
1529

1630
struct Overload: Sendable {
1731
let params: [Type]
@@ -28,7 +42,7 @@ struct Function: Sendable {
2842
returning result: Type,
2943
variadic: Bool = false,
3044
overloads: [Overload]? = nil,
31-
check: (@Sendable ([ExpressionSyntax], SourceLocation, inout Diagnostics) -> Void)? = nil
45+
check: (@Sendable ([Type], [ExpressionSyntax], SourceLocation, inout Diagnostics) -> Void)? = nil
3246
) {
3347
assert(!(variadic && (overloads?.count ?? 0) > 1), "Cannot have overloads and be variadic")
3448

@@ -37,8 +51,8 @@ struct Function: Sendable {
3751
genericTypes.append(result)
3852
}
3953

40-
// TODO: Remove ghetto distinct
41-
self.genericTypes = Array(Set(genericTypes))
54+
55+
self.genericTypes = genericTypes.distinct()
4256
self.params = params
4357
self.result = result
4458
self.overloads = overloads

Sources/Compiler/Sema/Builtins.swift

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,25 @@ enum Builtins {
1414
static let pos = Function(.var(0), returning: .var(0))
1515
static let between = Function(.var(0), .var(0), .var(0), returning: .integer)
1616
static let arithmetic = Function(.var(0), .var(0), returning: .var(0))
17-
static let divide = Function(.var(0), .var(0), returning: .var(0)) { _, _, _ in
18-
fatalError("check for integer division")
17+
static let divide = Function(.var(0), .var(0), returning: .var(0)) { types, exprs, location, diagnostics in
18+
func isInt(_ type: Type, expr: ExpressionSyntax) -> Bool {
19+
if type.root == .integer || type.root == .int { return true }
20+
if case let .numeric(_, isInt) = expr.literal?.kind { return isInt }
21+
return false
22+
}
23+
24+
// If both sides are integers than the output will always be an integer
25+
// and not be floating point so emit a warning.
26+
guard types.count == 2,
27+
exprs.count == 2,
28+
isInt(types[0], expr: exprs[0]),
29+
isInt(types[1], expr: exprs[1]) else { return }
30+
31+
diagnostics.add(.init(
32+
"Integer division, result will not be floating point. 'CAST' or add '.0'",
33+
level: .warning,
34+
at: location
35+
))
1936
}
2037
static let comparison = Function(.var(0), .var(0), returning: .integer)
2138
static let `in` = Function(.var(0), .row(.unknown(.var(0))), returning: .integer)
@@ -87,7 +104,23 @@ enum Builtins {
87104
// Datetime
88105
"unixepoch": Function(returning: .integer),
89106
"julianday": Function(returning: .real),
90-
"strftime": strftime,
107+
"strftime": Function(
108+
.text,
109+
returning: .text,
110+
variadic: true
111+
) { _, args, location, diagnostics in
112+
guard args.count == 2,
113+
case let .string(first) = args[0].literal?.kind,
114+
case let .string(second) = args[1].literal?.kind,
115+
first == "%s",
116+
second == "now" else { return }
117+
118+
diagnostics.add(.init(
119+
"Function returns the seconds as TEXT, not an INTEGER. Use unixepoch() instead",
120+
level: .warning,
121+
at: location
122+
))
123+
},
91124
"date": Function(.text, returning: .text, variadic: true),
92125
"time": Function(.text, returning: .text, variadic: true),
93126
"datetime": Function(.text, returning: .text, variadic: true),
@@ -96,35 +129,15 @@ enum Builtins {
96129
// Aggregate Functions
97130
"avg": Function(.var(.integer(0)), returning: .var(.integer(0))),
98131
"count": Function(.var(0), returning: .integer),
99-
"group_concat": groupConcat,
132+
"group_concat": Function(
133+
.text,
134+
returning: .text,
135+
overloads: [Function.Overload(.text, .text, returning: .text)]
136+
),
100137
"string_agg": Function(.text, .text, returning: .text),
101138
// 'max' and 'min' are added through the scalar functions and can be reused.
102139
// In the future we may need to separate these if we store them separately
103140
"sum": Function(.var(.integer(0)), returning: .var(.integer(0))),
104141
"total": Function(.var(.integer(0)), returning: .var(.integer(0))),
105142
]
106-
107-
static let groupConcat = Function(
108-
.text,
109-
returning: .text,
110-
overloads: [Function.Overload(.text, returning: .text)]
111-
)
112-
113-
static let strftime = Function(
114-
.text,
115-
returning: .text,
116-
variadic: true
117-
) { args, location, diagnostics in
118-
guard args.count == 2,
119-
case let .string(first) = args[0].literal?.kind,
120-
case let .string(second) = args[1].literal?.kind,
121-
first == "%s",
122-
second == "now" else { return }
123-
124-
diagnostics.add(.init(
125-
"Function returns the seconds as TEXT, not an INTEGER. Use unixepoch() instead",
126-
level: .warning,
127-
at: location
128-
))
129-
}
130143
}

Sources/Compiler/Sema/ExprTypeChecker.swift

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ struct ExprTypeChecker {
9191
return nil
9292
}
9393
}
94+
95+
/// Will instatiate the function's type scheme and perform any checks
96+
/// define by the function.
97+
private mutating func instanteAndCheck(
98+
fn: Function,
99+
argCount: Int,
100+
argTypes: @autoclosure () -> [Type],
101+
argExprs: @autoclosure () -> [ExpressionSyntax],
102+
location: SourceLocation
103+
) -> Type {
104+
let type = inferenceState.instantiate(fn, preferredArgCount: argCount)
105+
106+
guard let check = fn.check else { return type }
107+
check(argTypes(), argExprs(), location, &diagnostics)
108+
return type
109+
}
94110
}
95111

96112
extension ExprTypeChecker: ExprSyntaxVisitor {
@@ -178,12 +194,20 @@ extension ExprTypeChecker: ExprSyntaxVisitor {
178194
}
179195

180196
let tv = inferenceState.freshTyVar(for: expr)
181-
let fnType = inferenceState.instantiate(fn, preferredArgCount: 2)
197+
let fnType = instanteAndCheck(
198+
fn: fn,
199+
argCount: 2,
200+
argTypes: [lTy, rTy],
201+
argExprs: [expr.lhs, expr.rhs],
202+
location: expr.location
203+
)
204+
182205
inferenceState.unify(
183206
fnType,
184207
with: .fn(params: [inferenceState.solution(for: lTy), rTy], ret: tv),
185208
at: expr.location
186209
)
210+
187211
return inferenceState.solution(for: tv)
188212
}
189213

@@ -199,7 +223,13 @@ extension ExprTypeChecker: ExprSyntaxVisitor {
199223
}
200224

201225
let tv = inferenceState.freshTyVar(for: expr)
202-
let fnType = inferenceState.instantiate(fn, preferredArgCount: 1)
226+
let fnType = instanteAndCheck(
227+
fn: fn,
228+
argCount: 1,
229+
argTypes: [lhs],
230+
argExprs: [expr.lhs],
231+
location: expr.location
232+
)
203233
inferenceState.unify(fnType, with: .fn(params: [lhs], ret: tv), at: expr.location)
204234
return inferenceState.solution(for: tv)
205235
}
@@ -212,7 +242,13 @@ extension ExprTypeChecker: ExprSyntaxVisitor {
212242

213243
inferenceState.unify(all: allTypes, at: expr.location)
214244

215-
let between = inferenceState.instantiate(Builtins.between, preferredArgCount: 3)
245+
let between = instanteAndCheck(
246+
fn: Builtins.between,
247+
argCount: 3,
248+
argTypes: allTypes,
249+
argExprs: [expr.value, expr.lower, expr.upper],
250+
location: expr.location
251+
)
216252
inferenceState.unify(between, with: .fn(params: allTypes, ret: .integer), at: expr.location)
217253
return .integer
218254
}
@@ -226,13 +262,15 @@ extension ExprTypeChecker: ExprSyntaxVisitor {
226262
}
227263

228264
let tv = inferenceState.freshTyVar(for: expr)
229-
let fnType = inferenceState.instantiate(fn, preferredArgCount: argTys.count)
230-
inferenceState.unify(fnType, with: .fn(params: argTys, ret: tv), at: expr.location)
265+
let fnType = instanteAndCheck(
266+
fn: fn,
267+
argCount: argTys.count,
268+
argTypes: argTys,
269+
argExprs: expr.args,
270+
location: expr.location
271+
)
231272

232-
// If the function has any additional checks/validation to do perform it.
233-
if let check = fn.check {
234-
check(expr.args, expr.location, &diagnostics)
235-
}
273+
inferenceState.unify(fnType, with: .fn(params: argTys, ret: tv), at: expr.location)
236274

237275
return inferenceState.solution(for: tv)
238276
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//
2+
// Collection+Extensions.swift
3+
// Feather
4+
//
5+
// Created by Wes Wickwire on 6/8/25.
6+
//
7+
8+
extension RangeReplaceableCollection where Self: ExpressibleByArrayLiteral, Element: Hashable {
9+
/// Will return a new collection removing any duplicate items.
10+
/// While retaining the original order.
11+
func distinct() -> Self {
12+
// Cannot have duplicates if there is 1 or less elements.
13+
guard count > 1 else { return self }
14+
15+
// Technically we could skip the set creation for when the count
16+
// is 2 since there is only 1 possible value if there are dupelicates
17+
// but it is not worth it.
18+
var seen: Set<Element> = []
19+
var result: Self = []
20+
21+
// Given that there are probably duplicates, the require capacity
22+
// is probably lower so we can start out at half
23+
let expectedCapacity = count / 2
24+
25+
// Anything less than 2 isnt worth doing since the first append
26+
// will bump the capacity to 2.
27+
if expectedCapacity >= 2 {
28+
result.reserveCapacity(expectedCapacity)
29+
}
30+
31+
for element in self {
32+
guard !seen.contains(element) else { continue }
33+
seen.insert(element)
34+
result.append(element)
35+
}
36+
37+
return result
38+
}
39+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
CREATE TABLE foo (bar INTEGER, baz TEXT);
2+
3+
-- CHECK: SIGNATURE
4+
-- CHECK: OUTPUT_CHUNKS
5+
-- CHECK: CHUNK
6+
-- CHECK: OUTPUT
7+
-- CHECK: bar INTEGER?
8+
-- CHECK: baz TEXT?
9+
-- CHECK: OUTPUT_TABLE foo
10+
-- CHECK: TABLES
11+
-- CHECK: foo
12+
-- CHECK-ERROR: warn: Function returns the seconds as TEXT, not an INTEGER. Use unixepoch() instead
13+
-- CHECK-ERROR: Unable to unify types 'INTEGER?' and 'TEXT'
14+
SELECT * FROM foo WHERE bar = strftime('%s', 'now');
15+
16+
-- CHECK: SIGNATURE
17+
-- CHECK: OUTPUT_CHUNKS
18+
-- CHECK: CHUNK
19+
-- CHECK: OUTPUT
20+
-- CHECK: baz TEXT
21+
-- CHECK: baz TEXT
22+
-- CHECK: TABLES
23+
-- CHECK: foo
24+
SELECT GROUP_CONCAT(baz), GROUP_CONCAT(baz, ',') FROM foo;
25+
26+
-- CHECK: SIGNATURE
27+
-- CHECK: OUTPUT_CHUNKS
28+
-- CHECK: CHUNK
29+
-- CHECK: OUTPUT
30+
-- CHECK: column1 INTEGER
31+
-- CHECK: column2 REAL
32+
-- CHECK: column3 REAL
33+
-- CHECK: bar INTEGER?
34+
-- CHECK: TABLES
35+
-- CHECK: foo
36+
SELECT
37+
-- CHECK-ERROR: warn: Integer division, result will not be floating point. 'CAST' or add '.0'
38+
1 / 2,
39+
1.0 / 2,
40+
1 / 2.0,
41+
-- CHECK-ERROR: warn: Integer division, result will not be floating point. 'CAST' or add '.0'
42+
1 / bar
43+
FROM foo;

Tests/CompilerTests/Compiler/CompileLintChecks.sql

Lines changed: 0 additions & 13 deletions
This file was deleted.

Tests/CompilerTests/CompilerTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class CompilerTests: XCTestCase {
7373
try checkQueries(compile: "CompileTableOrSubqueries")
7474
}
7575

76-
func testLintChecks() throws {
77-
try checkQueries(compile: "CompileLintChecks", dump: true)
76+
func testFunctions() throws {
77+
try checkQueries(compile: "CompileFunctions")
7878
}
7979
}
8080

0 commit comments

Comments
 (0)