Skip to content

Commit e67d2eb

Browse files
committed
Add aggregate query support
1 parent e2e686e commit e67d2eb

6 files changed

Lines changed: 562 additions & 84 deletions

File tree

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ IoTSharp.Data.JsonDB is a lightweight in-memory relational database engine that
88
## Features
99

1010
- SQL `SELECT`, `INSERT`, `UPDATE`, `DELETE` over JSON arrays and objects
11-
- `WHERE`, `ORDER BY` with `asc`, `desc`, `ascnum`, `descnum`, and `LIMIT`
11+
- `WHERE`, `GROUP BY`, `HAVING`, `ORDER BY` with `asc`, `desc`, `ascnum`, `descnum`, and `LIMIT`
12+
- SQLite-style basic aggregate functions: `COUNT`, `SUM`, `TOTAL`, `AVG`, `MIN`, `MAX`, `GROUP_CONCAT`, `STRING_AGG`
1213
- Arithmetic and boolean expressions with custom function registration
1314
- Path-style field access such as `profile.name` and `metrics.score`
1415
- Standard ADO.NET types: `DbConnection`, `DbCommand`, `DbDataReader`, `DbParameter`, `DbDataAdapter`
@@ -70,6 +71,12 @@ WHERE status = "active"
7071
ORDER BY score DESCNUM
7172
LIMIT 0, 10
7273

74+
SELECT category, COUNT(*) AS count, AVG(score) AS averageScore
75+
FROM input
76+
GROUP BY category
77+
HAVING count > 1
78+
ORDER BY averageScore DESCNUM
79+
7380
INSERT INTO input SET id = 3, name = "Cora"
7481
UPDATE input SET score = score + 1 WHERE id = 3
7582
DELETE FROM input WHERE id = 3

src/IoTSharp.Data.JsonDB/Internal/JsonSqlQueryExecutor.cs

Lines changed: 281 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,56 @@ public static string ExecuteOnNode(
7373
};
7474
}
7575

76-
private static JsonArray ExecuteSelect(SqlSelectStatement statement, SqlExecutionContext context)
77-
{
78-
var rows = FilterAndSortRows(statement.Where, statement.OrderBy, statement.Limit, context);
79-
var result = new JsonArray();
80-
81-
foreach (var row in rows)
82-
{
83-
result.Add(ProjectRow(row, statement.Items, context));
76+
private static JsonArray ExecuteSelect(SqlSelectStatement statement, SqlExecutionContext context)
77+
{
78+
var result = new JsonArray();
79+
80+
if (statement.Items.Any(static item => item.AggregateFunction != SqlAggregateFunction.None) ||
81+
statement.GroupBy.Count > 0)
82+
{
83+
return ExecuteAggregateSelect(statement, context);
84+
}
85+
86+
var rows = FilterAndSortRows(statement.Where, statement.OrderBy, statement.Limit, context);
87+
foreach (var row in rows)
88+
{
89+
result.Add(ProjectRow(row, statement.Items, context));
8490
}
8591

86-
return result;
87-
}
92+
return result;
93+
}
94+
95+
private static JsonArray ExecuteAggregateSelect(SqlSelectStatement statement, SqlExecutionContext context)
96+
{
97+
var rows = FilterAndSortRows(statement.Where, Array.Empty<SqlOrderByItem>(), null, context);
98+
var groups = BuildGroups(rows, statement.GroupBy, context);
99+
var projected = new List<SqlRowContext>(groups.Count);
100+
101+
foreach (var group in groups)
102+
{
103+
var node = ProjectAggregateRow(group.Rows, group.FirstRow, statement.Items, context);
104+
var row = new SqlRowContext(node, -1);
105+
if (statement.Having is null || EvaluateBoolean(statement.Having, row, context))
106+
{
107+
projected.Add(row);
108+
}
109+
}
110+
111+
if (statement.OrderBy.Count > 0)
112+
{
113+
projected.Sort((left, right) => CompareRows(left, right, statement.OrderBy, context));
114+
}
115+
116+
projected = ApplyLimit(projected, statement.Limit);
117+
118+
var result = new JsonArray();
119+
foreach (var row in projected)
120+
{
121+
result.Add(row.Node.DeepClone());
122+
}
123+
124+
return result;
125+
}
88126

89127
private static JsonNode ExecuteUpdate(SqlUpdateStatement statement, SqlExecutionContext context)
90128
{
@@ -178,8 +216,8 @@ private static List<SqlRowContext> EnumerateRows(JsonNode root)
178216
return new List<SqlRowContext> { new(root, -1) };
179217
}
180218

181-
private static JsonNode ProjectRow(SqlRowContext row, IReadOnlyList<SqlSelectItem> items, SqlExecutionContext context)
182-
{
219+
private static JsonNode ProjectRow(SqlRowContext row, IReadOnlyList<SqlSelectItem> items, SqlExecutionContext context)
220+
{
183221
if (items.Count == 1 && items[0].IsWildcard)
184222
{
185223
return row.Node.DeepClone();
@@ -188,8 +226,8 @@ private static JsonNode ProjectRow(SqlRowContext row, IReadOnlyList<SqlSelectIte
188226
JsonObject result = new();
189227
foreach (var item in items)
190228
{
191-
if (item.IsWildcard)
192-
{
229+
if (item.IsWildcard)
230+
{
193231
if (row.Node is not JsonObject sourceObject)
194232
{
195233
continue;
@@ -200,15 +238,225 @@ private static JsonNode ProjectRow(SqlRowContext row, IReadOnlyList<SqlSelectIte
200238
result[property.Key] = property.Value?.DeepClone();
201239
}
202240

203-
continue;
204-
}
205-
206-
var value = SqlExpressionEvaluator.Evaluate(item.Expression!, row, context);
207-
result[item.Alias] = ConvertToJsonNode(value);
208-
}
209-
210-
return result;
211-
}
241+
continue;
242+
}
243+
244+
if (item.AggregateFunction != SqlAggregateFunction.None)
245+
{
246+
throw new InvalidOperationException("Aggregate functions cannot be mixed with row projection.");
247+
}
248+
249+
var value = SqlExpressionEvaluator.Evaluate(item.Expression!, row, context);
250+
result[item.Alias] = ConvertToJsonNode(value);
251+
}
252+
253+
return result;
254+
}
255+
256+
private static JsonNode ProjectAggregateRow(
257+
IReadOnlyList<SqlRowContext> rows,
258+
SqlRowContext? firstRow,
259+
IReadOnlyList<SqlSelectItem> items,
260+
SqlExecutionContext context)
261+
{
262+
JsonObject result = new();
263+
foreach (var item in items)
264+
{
265+
if (item.AggregateFunction == SqlAggregateFunction.None)
266+
{
267+
if (firstRow is null)
268+
{
269+
result[item.Alias] = null;
270+
continue;
271+
}
272+
273+
var groupValue = SqlExpressionEvaluator.Evaluate(item.Expression!, firstRow.Value, context);
274+
result[item.Alias] = ConvertToJsonNode(groupValue);
275+
continue;
276+
}
277+
278+
var value = EvaluateAggregate(rows, item, context);
279+
result[item.Alias] = ConvertToJsonNode(value);
280+
}
281+
282+
return result;
283+
}
284+
285+
private static object? EvaluateAggregate(
286+
IReadOnlyList<SqlRowContext> rows,
287+
SqlSelectItem item,
288+
SqlExecutionContext context)
289+
{
290+
return item.AggregateFunction switch
291+
{
292+
SqlAggregateFunction.Count => CountRows(rows, item.AggregateArguments.FirstOrDefault(), context),
293+
SqlAggregateFunction.Sum => SumRows(rows, item.AggregateArguments.FirstOrDefault(), context, nullWhenEmpty: true),
294+
SqlAggregateFunction.Total => SumRows(rows, item.AggregateArguments.FirstOrDefault(), context, nullWhenEmpty: false),
295+
SqlAggregateFunction.Avg => AverageRows(rows, item.AggregateArguments.FirstOrDefault(), context),
296+
SqlAggregateFunction.Min => MinMaxRows(rows, item.AggregateArguments.FirstOrDefault(), context, findMax: false),
297+
SqlAggregateFunction.Max => MinMaxRows(rows, item.AggregateArguments.FirstOrDefault(), context, findMax: true),
298+
SqlAggregateFunction.GroupConcat => ConcatRows(rows, item.AggregateArguments, context),
299+
SqlAggregateFunction.StringAgg => ConcatRows(rows, item.AggregateArguments, context),
300+
_ => throw new InvalidOperationException("Unsupported aggregate function.")
301+
};
302+
}
303+
304+
private static long CountRows(
305+
IReadOnlyList<SqlRowContext> rows,
306+
SqlExpression? expression,
307+
SqlExecutionContext context)
308+
{
309+
if (expression is null)
310+
{
311+
return rows.Count;
312+
}
313+
314+
return rows.LongCount(row => SqlExpressionEvaluator.Evaluate(expression, row, context) is not null);
315+
}
316+
317+
private static object? SumRows(
318+
IReadOnlyList<SqlRowContext> rows,
319+
SqlExpression? expression,
320+
SqlExecutionContext context,
321+
bool nullWhenEmpty)
322+
{
323+
decimal total = 0;
324+
var count = 0;
325+
foreach (var value in EvaluateAggregateValues(rows, expression, context))
326+
{
327+
if (TryConvertToDecimal(value, out var number))
328+
{
329+
total += number;
330+
count++;
331+
}
332+
}
333+
334+
return count == 0 && nullWhenEmpty ? null : total;
335+
}
336+
337+
private static object? AverageRows(
338+
IReadOnlyList<SqlRowContext> rows,
339+
SqlExpression? expression,
340+
SqlExecutionContext context)
341+
{
342+
decimal total = 0;
343+
var count = 0;
344+
foreach (var value in EvaluateAggregateValues(rows, expression, context))
345+
{
346+
if (TryConvertToDecimal(value, out var number))
347+
{
348+
total += number;
349+
count++;
350+
}
351+
}
352+
353+
return count == 0 ? null : total / count;
354+
}
355+
356+
private static object? MinMaxRows(
357+
IReadOnlyList<SqlRowContext> rows,
358+
SqlExpression? expression,
359+
SqlExecutionContext context,
360+
bool findMax)
361+
{
362+
object? best = null;
363+
var hasBest = false;
364+
foreach (var value in EvaluateAggregateValues(rows, expression, context))
365+
{
366+
if (value is null)
367+
{
368+
continue;
369+
}
370+
371+
if (!hasBest)
372+
{
373+
best = value;
374+
hasBest = true;
375+
continue;
376+
}
377+
378+
var comparison = CompareValues(value, best, numericOnly: false);
379+
if (findMax ? comparison > 0 : comparison < 0)
380+
{
381+
best = value;
382+
}
383+
}
384+
385+
return hasBest ? best : null;
386+
}
387+
388+
private static object? ConcatRows(
389+
IReadOnlyList<SqlRowContext> rows,
390+
IReadOnlyList<SqlExpression> arguments,
391+
SqlExecutionContext context)
392+
{
393+
if (arguments.Count == 0)
394+
{
395+
return null;
396+
}
397+
398+
var separator = arguments.Count > 1
399+
? Convert.ToString(SqlExpressionEvaluator.Evaluate(arguments[1], rows.FirstOrDefault(), context), CultureInfo.InvariantCulture) ?? ","
400+
: ",";
401+
var values = EvaluateAggregateValues(rows, arguments[0], context)
402+
.Where(value => value is not null)
403+
.Select(value => Convert.ToString(value, CultureInfo.InvariantCulture))
404+
.Where(value => !string.IsNullOrEmpty(value))
405+
.ToArray();
406+
407+
return values.Length == 0 ? null : string.Join(separator, values);
408+
}
409+
410+
private static IEnumerable<object?> EvaluateAggregateValues(
411+
IReadOnlyList<SqlRowContext> rows,
412+
SqlExpression? expression,
413+
SqlExecutionContext context)
414+
{
415+
if (expression is null)
416+
{
417+
return rows.Select(static row => (object?)row.Node);
418+
}
419+
420+
return rows.Select(row => SqlExpressionEvaluator.Evaluate(expression, row, context));
421+
}
422+
423+
private static List<SqlGroupContext> BuildGroups(
424+
IReadOnlyList<SqlRowContext> rows,
425+
IReadOnlyList<SqlExpression> groupBy,
426+
SqlExecutionContext context)
427+
{
428+
if (groupBy.Count == 0)
429+
{
430+
return [new SqlGroupContext(string.Empty, rows.ToList(), rows.Count == 0 ? null : rows[0])];
431+
}
432+
433+
var groups = new Dictionary<string, SqlGroupContext>(StringComparer.Ordinal);
434+
foreach (var row in rows)
435+
{
436+
var key = BuildGroupKey(row, groupBy, context);
437+
if (!groups.TryGetValue(key, out var group))
438+
{
439+
group = new SqlGroupContext(key, [], row);
440+
groups.Add(key, group);
441+
}
442+
443+
group.Rows.Add(row);
444+
}
445+
446+
return groups.Values.ToList();
447+
}
448+
449+
private static string BuildGroupKey(
450+
SqlRowContext row,
451+
IReadOnlyList<SqlExpression> groupBy,
452+
SqlExecutionContext context)
453+
{
454+
return string.Join('\u001f', groupBy.Select(expression =>
455+
{
456+
var value = SqlExpressionEvaluator.Evaluate(expression, row, context);
457+
return value is null ? "\u0000" : ToComparableString(value);
458+
}));
459+
}
212460

213461
private static List<SqlRowContext> ApplyLimit(List<SqlRowContext> rows, SqlLimit? limit)
214462
{
@@ -490,13 +738,18 @@ internal static IEnumerable<string> SplitPath(string path)
490738
};
491739
}
492740

493-
private static string ToComparableString(object value)
494-
{
741+
private static string ToComparableString(object value)
742+
{
495743
return value switch
496744
{
497745
JsonNode jsonNode => jsonNode.ToJsonString(),
498746
_ => Convert.ToString(value, CultureInfo.InvariantCulture) ?? string.Empty
499747
};
500-
}
501-
}
502-
}
748+
}
749+
750+
private sealed record SqlGroupContext(
751+
string Key,
752+
List<SqlRowContext> Rows,
753+
SqlRowContext? FirstRow);
754+
}
755+
}

0 commit comments

Comments
 (0)