Skip to content

Commit 5f4f57c

Browse files
committed
Refactor Jet join SQL generation for correct grouping
Refactored JetQuerySqlGenerator to properly group and parenthesize tables and joins, ensuring Jet/Access-compliant SQL when mixing cross joins, inner joins, and subqueries. Added helper methods for join grouping and alias extraction. Updated Northwind test assertions to match new Jet-style SQL output, improving compatibility and correctness. Minor refactoring of nullable numeric and type conversion handling.
1 parent de25c73 commit 5f4f57c

4 files changed

Lines changed: 455 additions & 335 deletions

File tree

src/EFCore.Jet/Query/Sql/Internal/JetQuerySqlGenerator.cs

Lines changed: 207 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -206,108 +206,199 @@ protected override Expression VisitSelect(SelectExpression selectExpression)
206206
return selectExpression;
207207
}
208208

209-
private void VisitJetTables(IReadOnlyList<TableExpressionBase> Tables, bool addfromsql, out List<ColumnExpression> colexp)
209+
private void VisitJetTables(IReadOnlyList<TableExpressionBase> tables, bool addFromSql, out List<ColumnExpression> colexp)
210210
{
211211
colexp = [];
212-
// Implement Jet's non-standard JOIN syntax and DUAL table workaround.
213-
// TODO: This does not properly handle all cases (especially when cross joins are involved).
214-
if (Tables.Any())
212+
if (!tables.Any())
215213
{
216-
if (addfromsql)
217-
{
218-
Sql.AppendLine().Append("FROM ");
219-
}
214+
GeneratePseudoFromClause();
215+
return;
216+
}
220217

221-
const int maxTablesWithoutBrackets = 2;
218+
if (addFromSql)
219+
Sql.AppendLine().Append("FROM ");
222220

223-
var nonCrossTableCount = Tables.Count(t => t is not CrossJoinExpression and not CrossApplyExpression);
221+
if (tables.Any(t => t is CrossJoinExpression or CrossApplyExpression))
222+
VisitJetTablesGrouped(tables, colexp);
223+
else
224+
VisitJetTablesLinear(tables, colexp);
225+
}
224226

225-
Sql.Append(
226-
new string(
227-
'(',
228-
Math.Max(0, nonCrossTableCount - maxTablesWithoutBrackets)));
227+
// No cross joins: existing linear parenthesization algorithm.
228+
private void VisitJetTablesLinear(IReadOnlyList<TableExpressionBase> tables, List<ColumnExpression> colexp)
229+
{
230+
const int maxTablesWithoutBrackets = 2;
231+
var nonCrossTableCount = tables.Count(t => t is not CrossJoinExpression and not CrossApplyExpression);
232+
Sql.Append(new string('(', Math.Max(0, nonCrossTableCount - maxTablesWithoutBrackets)));
229233

230-
var nonCrossTablesSeen = 0;
234+
var nonCrossTablesSeen = 0;
235+
for (var index = 0; index < tables.Count; index++)
236+
{
237+
var tableExpression = tables[index];
238+
var isCrossExpression = tableExpression is CrossJoinExpression or CrossApplyExpression;
231239

232-
for (var index = 0; index < Tables.Count; index++)
240+
if (tableExpression is CrossApplyExpression or OuterApplyExpression)
241+
throw new UnreachableException();
242+
243+
if (!isCrossExpression)
244+
nonCrossTablesSeen++;
245+
246+
if (index > 0)
233247
{
234-
var tableExpression = Tables[index];
248+
if (isCrossExpression)
249+
Sql.Append(",");
250+
else if (nonCrossTablesSeen > maxTablesWithoutBrackets)
251+
Sql.Append(")");
235252

236-
var isApplyExpression = tableExpression is CrossApplyExpression or OuterApplyExpression;
237-
var isCrossExpression = tableExpression is CrossJoinExpression or CrossApplyExpression;
238-
var isNonCrossExpression = !isCrossExpression;
253+
Sql.AppendLine();
254+
}
239255

240-
if (isApplyExpression)
256+
if (tableExpression is InnerJoinExpression innerJoin)
257+
{
258+
var tempcolexp = innerJoin.JoinPredicate switch
241259
{
242-
throw new UnreachableException();
243-
}
260+
SqlBinaryExpression bin => ExtractColumnExpressions(bin),
261+
SqlUnaryExpression unary => ExtractColumnExpressions(unary),
262+
_ => (List<ColumnExpression>)[]
263+
};
244264

245-
if (isNonCrossExpression)
265+
if (tempcolexp.Any(col => col.TableAlias == tables[0].Alias))
246266
{
247-
nonCrossTablesSeen++;
267+
Visit(tableExpression);
248268
}
249-
250-
if (index > 0)
269+
else
251270
{
252-
if (isCrossExpression)
253-
{
254-
Sql.Append(",");
255-
}
256-
else if (nonCrossTablesSeen > maxTablesWithoutBrackets)
257-
{
258-
Sql.Append(")");
259-
}
260-
261-
Sql.AppendLine();
271+
colexp.AddRange(tempcolexp);
272+
Sql.Append("LEFT JOIN ");
273+
Visit(innerJoin.Table);
274+
Sql.Append(" ON ");
275+
Visit(innerJoin.JoinPredicate);
262276
}
277+
}
278+
else
279+
{
280+
Visit(tableExpression);
281+
}
282+
}
283+
}
263284

264-
List<ColumnExpression> tempcolexp;
265-
if (tableExpression is InnerJoinExpression expression)
266-
{
267-
if (expression.JoinPredicate is SqlBinaryExpression binaryJoin)
268-
{
269-
tempcolexp = ExtractColumnExpressions(binaryJoin);
270-
}
271-
else if (expression.JoinPredicate is SqlUnaryExpression unaryJoin)
285+
// Cross joins present: group each primary table with its associated joins so that
286+
// each group is parenthesized independently before being comma-cross-joined.
287+
// Jet rejects mixing comma (cross-join) and explicit JOIN at the same paren level.
288+
private void VisitJetTablesGrouped(IReadOnlyList<TableExpressionBase> tables, List<ColumnExpression> colexp)
289+
{
290+
var groups = new List<(TableExpressionBase Primary, List<PredicateJoinExpressionBase> Joins, HashSet<string> OwnedAliases)>();
291+
292+
foreach (var table in tables)
293+
{
294+
switch (table)
295+
{
296+
case CrossApplyExpression or OuterApplyExpression:
297+
throw new UnreachableException();
298+
case CrossJoinExpression cj:
299+
groups.Add((cj.Table, [], new HashSet<string> { cj.Table.Alias! }));
300+
break;
301+
case PredicateJoinExpressionBase join:
302+
var predicateAliases = ExtractTableAliases(join.JoinPredicate);
303+
var owner = groups.FirstOrDefault(g => g.OwnedAliases.Overlaps(predicateAliases));
304+
if (owner.Primary != null)
272305
{
273-
tempcolexp = ExtractColumnExpressions(unaryJoin);
306+
owner.Joins.Add(join);
307+
owner.OwnedAliases.Add(join.Table.Alias!);
274308
}
275309
else
276310
{
277-
tempcolexp = [];
311+
var last = groups[^1];
312+
last.Joins.Add(join);
313+
last.OwnedAliases.Add(join.Table.Alias!);
278314
}
315+
break;
316+
default:
317+
groups.Add((table, [], new HashSet<string> { table.Alias! }));
318+
break;
319+
}
320+
}
279321

280-
var referencesFirstTable = false;
281-
foreach (var col in tempcolexp)
282-
{
283-
if (col.TableAlias == Tables[0].Alias)
284-
{
285-
referencesFirstTable = true;
286-
break;
287-
}
288-
}
322+
for (var i = 0; i < groups.Count; i++)
323+
{
324+
if (i > 0)
325+
Sql.Append(",").AppendLine();
289326

290-
if (referencesFirstTable)
291-
{
292-
Visit(tableExpression);
293-
continue;
294-
}
327+
var (primary, joins, _) = groups[i];
328+
var groupTableCount = 1 + joins.Count;
295329

296-
colexp.AddRange(tempcolexp);
297-
Sql.Append("LEFT JOIN ");
298-
Visit(expression.Table);
299-
Sql.Append(" ON ");
300-
Visit(expression.JoinPredicate);
301-
}
302-
else
303-
{
304-
Visit(tableExpression);
305-
}
330+
if (joins.Count > 0)
331+
Sql.Append("(");
332+
333+
Sql.Append(new string('(', Math.Max(0, groupTableCount - 2)));
334+
335+
Visit(primary);
336+
337+
var nonSeen = 1;
338+
foreach (var join in joins)
339+
{
340+
nonSeen++;
341+
if (nonSeen > 2)
342+
Sql.Append(")");
343+
344+
Sql.AppendLine();
345+
EmitJoinInGroup(join, primary, colexp);
346+
}
347+
348+
if (joins.Count > 0)
349+
Sql.Append(")");
350+
}
351+
}
352+
353+
private void EmitJoinInGroup(PredicateJoinExpressionBase join, TableExpressionBase groupPrimary, List<ColumnExpression> colexp)
354+
{
355+
if (join is InnerJoinExpression innerJoin)
356+
{
357+
var tempcolexp = innerJoin.JoinPredicate switch
358+
{
359+
SqlBinaryExpression bin => ExtractColumnExpressions(bin),
360+
SqlUnaryExpression unary => ExtractColumnExpressions(unary),
361+
_ => (List<ColumnExpression>)[]
362+
};
363+
364+
if (tempcolexp.Any(col => col.TableAlias == groupPrimary.Alias))
365+
Visit(join);
366+
else
367+
{
368+
colexp.AddRange(tempcolexp);
369+
Sql.Append("LEFT JOIN ");
370+
Visit(innerJoin.Table);
371+
Sql.Append(" ON ");
372+
Visit(innerJoin.JoinPredicate);
306373
}
307374
}
308375
else
309376
{
310-
GeneratePseudoFromClause();
377+
Visit(join);
378+
}
379+
}
380+
381+
private HashSet<string> ExtractTableAliases(SqlExpression expression)
382+
{
383+
var result = new HashSet<string>();
384+
CollectTableAliases(expression, result);
385+
return result;
386+
}
387+
388+
private static void CollectTableAliases(SqlExpression expression, HashSet<string> result)
389+
{
390+
switch (expression)
391+
{
392+
case ColumnExpression col:
393+
result.Add(col.TableAlias);
394+
break;
395+
case SqlBinaryExpression bin:
396+
CollectTableAliases(bin.Left, result);
397+
CollectTableAliases(bin.Right, result);
398+
break;
399+
case SqlUnaryExpression unary:
400+
CollectTableAliases(unary.Operand, result);
401+
break;
311402
}
312403
}
313404

@@ -369,15 +460,25 @@ protected override Expression VisitColumn(ColumnExpression columnExpression)
369460
if (columnExpression.IsNullable && _nullNumerics.Contains(columnExpression.Name) && _convertMappings.TryGetValue(columnExpression.Type.Name, out var function))
370461
{
371462

463+
bool useValCStrCol = false;//columnExpression.Type.Name is nameof(Decimal) or nameof(Int64);
372464
if (parent.TryPeek(out var exp) && exp is SqlBinaryExpression)
373465
{
374466
Sql.Append("IIF(");
375467
base.VisitColumn(columnExpression);
376468
Sql.Append(" IS NULL, NULL, ");
377-
Sql.Append(function);
378-
Sql.Append("(");
379-
base.VisitColumn(columnExpression);
380-
Sql.Append(")");
469+
if (useValCStrCol)
470+
{
471+
Sql.Append("Val(CStr(");
472+
base.VisitColumn(columnExpression);
473+
Sql.Append("))");
474+
}
475+
else
476+
{
477+
Sql.Append(function);
478+
Sql.Append("(");
479+
base.VisitColumn(columnExpression);
480+
Sql.Append(")");
481+
}
381482
Sql.Append(")");
382483
return columnExpression;
383484
}
@@ -501,9 +602,18 @@ protected override Expression VisitSqlParameter(SqlParameterExpression sqlParame
501602
{
502603
if (_convertMappings.TryGetValue(sqlParameterExpression.Type.Name, out var conv))
503604
{
504-
Sql.Append($"{conv}(");
505-
base.VisitSqlParameter(sqlParameterExpression);
506-
Sql.Append(")");
605+
/*if (sqlParameterExpression.Type.Name is nameof(Decimal) or nameof(Int64))
606+
{
607+
Sql.Append("Val(CStr(");
608+
base.VisitSqlParameter(sqlParameterExpression);
609+
Sql.Append("))");
610+
}
611+
else*/
612+
{
613+
Sql.Append($"{conv}(");
614+
base.VisitSqlParameter(sqlParameterExpression);
615+
Sql.Append(")");
616+
}
507617
return sqlParameterExpression;
508618
}
509619
}
@@ -625,6 +735,21 @@ private Expression VisitJetConvertExpression(SqlUnaryExpression convertExpressio
625735
SqlExpression checksqlexp = convertExpression.Operand;
626736
SqlExpression? notnullsqlexp = null;
627737

738+
bool useValCStr = false;//convertExpression.Type.Name is nameof(Decimal) or nameof(Int64);
739+
740+
/*SqlFunctionExpression WrapConvert(SqlExpression inner) =>
741+
useValCStr
742+
? new SqlFunctionExpression("Val",
743+
[new SqlFunctionExpression("CStr", [inner], false, [false], typeof(string), null)],
744+
false, [false], typeMapping.ClrType, null)
745+
: new SqlFunctionExpression(function, [inner], false, [false], typeMapping.ClrType, null);*/
746+
747+
SqlFunctionExpression WrapConvert(SqlExpression inner) =>
748+
useValCStr
749+
? new SqlFunctionExpression("CVar", [inner],
750+
false, [false], typeMapping.ClrType, null)
751+
: new SqlFunctionExpression(function, [inner], false, [false], typeMapping.ClrType, null);
752+
628753
if (convertExpression.TypeMapping is ByteArrayTypeMapping)
629754
{
630755
notnullsqlexp = checksqlexp;
@@ -636,8 +761,7 @@ private Expression VisitJetConvertExpression(SqlUnaryExpression convertExpressio
636761
if (convertExpression.Type == typeof(bool))
637762
{
638763
// bool?bool: no flip needed, CBOOL(x) correctly returns true for any non-zero
639-
notnullsqlexp = new SqlFunctionExpression(function, [convertExpression.Operand],
640-
false, [false], typeMapping.ClrType, null);
764+
notnullsqlexp = WrapConvert(convertExpression.Operand);
641765
}
642766
else
643767
{
@@ -650,14 +774,12 @@ private Expression VisitJetConvertExpression(SqlUnaryExpression convertExpressio
650774
convertExpression.Operand.Type,
651775
convertExpression.Operand.TypeMapping);
652776

653-
notnullsqlexp = new SqlFunctionExpression(function, [flippedOperand],
654-
false, [false], typeMapping.ClrType, null);
777+
notnullsqlexp = WrapConvert(flippedOperand);
655778
}
656779
}
657780
else
658781
{
659-
notnullsqlexp = new SqlFunctionExpression(function, [convertExpression.Operand],
660-
false, [false], typeMapping.ClrType, null);
782+
notnullsqlexp = WrapConvert(convertExpression.Operand);
661783
}
662784
}
663785

test/EFCore.Jet.FunctionalTests/Query/NorthwindJoinQueryJetTest.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,14 +1039,14 @@ public override async Task GroupJoin_subquery_projection_outer_mixed(bool async)
10391039

10401040
AssertSql(
10411041
"""
1042-
SELECT [c].[CustomerID] AS [A], [t].[CustomerID] AS [B], [o0].[CustomerID] AS [C]
1043-
FROM [Customers] AS [c]
1044-
CROSS JOIN (
1045-
SELECT TOP(1) [o].[CustomerID]
1046-
FROM [Orders] AS [o]
1047-
ORDER BY [o].[OrderID]
1048-
) AS [t]
1049-
INNER JOIN [Orders] AS [o0] ON [c].[CustomerID] = [o0].[CustomerID]
1042+
SELECT `c`.`CustomerID` AS `A`, `o0`.`CustomerID` AS `B`, `o1`.`CustomerID` AS `C`
1043+
FROM (`Customers` AS `c`
1044+
INNER JOIN `Orders` AS `o1` ON `c`.`CustomerID` = `o1`.`CustomerID`),
1045+
(
1046+
SELECT TOP 1 `o`.`CustomerID`
1047+
FROM `Orders` AS `o`
1048+
ORDER BY `o`.`OrderID`
1049+
) AS `o0`
10501050
""");
10511051
}
10521052

0 commit comments

Comments
 (0)