Skip to content

Commit 347acfc

Browse files
authored
Reuse a single-result subquery across its projected members (#38502)
Fixes #7776
1 parent f2f2673 commit 347acfc

5 files changed

Lines changed: 225 additions & 3 deletions

File tree

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
namespace Microsoft.EntityFrameworkCore.Query.Internal;
5+
6+
public partial class NavigationExpandingExpressionVisitor
7+
{
8+
private NavigationExpansionExpression LiftSingleResultSubqueries(NavigationExpansionExpression source)
9+
{
10+
var selectorBody = source.PendingSelector;
11+
12+
var collector = new SingleResultMemberAccessCollector();
13+
collector.Visit(selectorBody);
14+
15+
foreach (var subquery in collector.Liftable)
16+
{
17+
var collection = BuildSingleResultCollection(subquery);
18+
var innerParameter = Expression.Parameter(collection.Type.GetSequenceType(), "e");
19+
var rewrittenBody = new ReplacingExpressionVisitor([subquery], [innerParameter]).Visit(selectorBody);
20+
21+
// The collection already references the outer element via source.PendingSelector; this parameter is unused.
22+
source = ProcessSelectMany(
23+
source,
24+
Expression.Lambda(collection, Expression.Parameter(source.SourceElementType, "o")),
25+
Expression.Lambda(rewrittenBody, Expression.Parameter(source.SourceElementType, "o"), innerParameter));
26+
selectorBody = source.PendingSelector;
27+
}
28+
29+
return source;
30+
}
31+
32+
private static Expression BuildSingleResultCollection(MethodCallExpression subqueryMethod)
33+
{
34+
var method = subqueryMethod.Method.GetGenericMethodDefinition();
35+
var source = subqueryMethod.Arguments[0];
36+
var elementType = source.Type.GetSequenceType();
37+
38+
Expression Where(Expression c)
39+
=> Expression.Call(QueryableMethods.Where.MakeGenericMethod(elementType), c, subqueryMethod.Arguments[1]);
40+
Expression Skip(Expression c)
41+
=> Expression.Call(QueryableMethods.Skip.MakeGenericMethod(elementType), c, subqueryMethod.Arguments[1]);
42+
Expression Reverse(Expression c)
43+
=> Expression.Call(QueryableMethods.Reverse.MakeGenericMethod(elementType), c);
44+
45+
var oneRow = method switch
46+
{
47+
_ when method == QueryableMethods.FirstWithPredicate || method == QueryableMethods.FirstOrDefaultWithPredicate
48+
|| method == QueryableMethods.SingleWithPredicate || method == QueryableMethods.SingleOrDefaultWithPredicate
49+
=> Where(source),
50+
_ when method == QueryableMethods.LastWithPredicate || method == QueryableMethods.LastOrDefaultWithPredicate
51+
=> Reverse(Where(source)),
52+
_ when method == QueryableMethods.LastWithoutPredicate || method == QueryableMethods.LastOrDefaultWithoutPredicate
53+
=> Reverse(source),
54+
_ when method == QueryableMethods.ElementAt || method == QueryableMethods.ElementAtOrDefault
55+
=> Skip(source),
56+
_ => source
57+
};
58+
59+
var firstRow = Expression.Call(QueryableMethods.Take.MakeGenericMethod(elementType), oneRow, Expression.Constant(1));
60+
61+
return Expression.Call(QueryableMethods.DefaultIfEmptyWithoutArgument.MakeGenericMethod(elementType), firstRow);
62+
}
63+
64+
private sealed class SingleResultMemberAccessCollector : ExpressionVisitor
65+
{
66+
private static readonly HashSet<MethodInfo> SingleResultMethods =
67+
[
68+
QueryableMethods.FirstWithPredicate, QueryableMethods.FirstWithoutPredicate,
69+
QueryableMethods.FirstOrDefaultWithPredicate, QueryableMethods.FirstOrDefaultWithoutPredicate,
70+
QueryableMethods.SingleWithPredicate, QueryableMethods.SingleWithoutPredicate,
71+
QueryableMethods.SingleOrDefaultWithPredicate, QueryableMethods.SingleOrDefaultWithoutPredicate,
72+
QueryableMethods.LastWithPredicate, QueryableMethods.LastWithoutPredicate,
73+
QueryableMethods.LastOrDefaultWithPredicate, QueryableMethods.LastOrDefaultWithoutPredicate,
74+
QueryableMethods.ElementAt, QueryableMethods.ElementAtOrDefault
75+
];
76+
77+
private readonly Dictionary<MethodCallExpression, int> _memberAccessCount = new(ReferenceEqualityComparer.Instance);
78+
79+
public IEnumerable<MethodCallExpression> Liftable
80+
=> _memberAccessCount.Where(e => e.Value > 1).Select(e => e.Key);
81+
82+
protected override Expression VisitMember(MemberExpression memberExpression)
83+
{
84+
if (memberExpression.Expression is MethodCallExpression { Method.IsGenericMethod: true } subquery
85+
&& SingleResultMethods.Contains(subquery.Method.GetGenericMethodDefinition()))
86+
{
87+
if (_memberAccessCount.TryGetValue(subquery, out var count))
88+
{
89+
_memberAccessCount[subquery] = count + 1;
90+
}
91+
else
92+
{
93+
// First encounter: count it and visit its arguments once (a later access duplicates the count)
94+
_memberAccessCount[subquery] = 1;
95+
96+
foreach (var argument in subquery.Arguments)
97+
{
98+
Visit(argument);
99+
}
100+
}
101+
102+
return memberExpression;
103+
}
104+
105+
return base.VisitMember(memberExpression);
106+
}
107+
}
108+
}

src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -671,9 +671,10 @@ when QueryableMethods.IsSumWithSelector(method):
671671

672672
case nameof(Queryable.Select)
673673
when genericMethod == QueryableMethods.Select:
674-
return ProcessSelect(
675-
source,
676-
methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
674+
return LiftSingleResultSubqueries(
675+
ProcessSelect(
676+
source,
677+
methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()));
677678

678679
case nameof(Queryable.Where)
679680
when genericMethod == QueryableMethods.Where:

test/EFCore.Cosmos.FunctionalTests/Query/NorthwindSelectQueryCosmosTest.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,13 @@ public override async Task Projection_in_a_subquery_should_be_liftable(bool asyn
695695
public override Task Projection_containing_DateTime_subtraction(bool async)
696696
=> Assert.ThrowsAsync<InvalidOperationException>(() => base.Projection_containing_DateTime_subtraction(async));
697697

698+
public override async Task Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(bool async, string method)
699+
{
700+
await AssertTranslationFailed(() => base.Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(async, method));
701+
702+
AssertSql();
703+
}
704+
698705
public override async Task Project_single_element_from_collection_with_OrderBy_Take_and_FirstOrDefault(bool async)
699706
{
700707
// Cosmos client evaluation. Issue #17246.

test/EFCore.Specification.Tests/Query/NorthwindSelectQueryTestBase.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2303,4 +2303,55 @@ from o2 in ss.Set<Order>()
23032303
}).Take(5),
23042304
assertOrder: true,
23052305
elementAsserter: (e, a) => AssertCollection(e.OrderIds, a.OrderIds, elementSorter: ee => ee));
2306+
2307+
public static TheoryData<bool, string> SingleResultMethodData()
2308+
=> new(
2309+
from async in new[] { true, false }
2310+
from method in new[]
2311+
{
2312+
nameof(Queryable.First), nameof(Queryable.FirstOrDefault),
2313+
nameof(Queryable.Single), nameof(Queryable.SingleOrDefault),
2314+
nameof(Queryable.Last), nameof(Queryable.LastOrDefault),
2315+
nameof(Queryable.ElementAt), nameof(Queryable.ElementAtOrDefault)
2316+
}
2317+
select (async, method));
2318+
2319+
[Theory, MemberData(nameof(SingleResultMethodData))]
2320+
public virtual Task Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(bool async, string method)
2321+
=> AssertQuery(
2322+
async,
2323+
ss =>
2324+
{
2325+
var customers = ss.Set<Customer>();
2326+
var orders = ss.Set<Order>().Where(o => o.CustomerID != null);
2327+
return method switch
2328+
{
2329+
nameof(Queryable.First) => from o in orders
2330+
let c = customers.First(c => c.CustomerID == o.CustomerID)
2331+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2332+
nameof(Queryable.FirstOrDefault) => from o in orders
2333+
let c = customers.FirstOrDefault(c => c.CustomerID == o.CustomerID)
2334+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2335+
nameof(Queryable.Single) => from o in orders
2336+
let c = customers.Single(c => c.CustomerID == o.CustomerID)
2337+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2338+
nameof(Queryable.SingleOrDefault) => from o in orders
2339+
let c = customers.SingleOrDefault(c => c.CustomerID == o.CustomerID)
2340+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2341+
nameof(Queryable.Last) => from o in orders
2342+
let c = customers.OrderBy(c => c.CustomerID).Last(c => c.CustomerID == o.CustomerID)
2343+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2344+
nameof(Queryable.LastOrDefault) => from o in orders
2345+
let c = customers.OrderBy(c => c.CustomerID).LastOrDefault(c => c.CustomerID == o.CustomerID)
2346+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2347+
nameof(Queryable.ElementAt) => from o in orders
2348+
let c = customers.Where(c => c.CustomerID == o.CustomerID).OrderBy(c => c.CustomerID).ElementAt(0)
2349+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2350+
nameof(Queryable.ElementAtOrDefault) => from o in orders
2351+
let c = customers.Where(c => c.CustomerID == o.CustomerID).OrderBy(c => c.CustomerID).ElementAtOrDefault(0)
2352+
select new { o.OrderID, c.City, c.Country, c.ContactName },
2353+
_ => throw new InvalidOperationException(method)
2354+
};
2355+
},
2356+
elementSorter: e => e.OrderID);
23062357
}

test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2911,6 +2911,61 @@ ORDER BY [c0].[CustomerID]
29112911
""");
29122912
}
29132913

2914+
public override async Task Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(bool async, string method)
2915+
{
2916+
await base.Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(async, method);
2917+
2918+
AssertSql(
2919+
method switch
2920+
{
2921+
nameof(Queryable.First) or
2922+
nameof(Queryable.FirstOrDefault) or
2923+
nameof(Queryable.Single) or
2924+
nameof(Queryable.SingleOrDefault) => """
2925+
SELECT [o].[OrderID], [c1].[City], [c1].[Country], [c1].[ContactName]
2926+
FROM [Orders] AS [o]
2927+
LEFT JOIN (
2928+
SELECT [c0].[City], [c0].[ContactName], [c0].[Country], [c0].[CustomerID0]
2929+
FROM (
2930+
SELECT [c].[City], [c].[ContactName], [c].[Country], [c].[CustomerID] AS [CustomerID0], ROW_NUMBER() OVER(PARTITION BY [c].[CustomerID] ORDER BY [c].[CustomerID]) AS [row]
2931+
FROM [Customers] AS [c]
2932+
) AS [c0]
2933+
WHERE [c0].[row] <= 1
2934+
) AS [c1] ON [o].[CustomerID] = [c1].[CustomerID0]
2935+
WHERE [o].[CustomerID] IS NOT NULL
2936+
""",
2937+
nameof(Queryable.Last) or
2938+
nameof(Queryable.LastOrDefault) => """
2939+
SELECT [o].[OrderID], [c1].[City], [c1].[Country], [c1].[ContactName]
2940+
FROM [Orders] AS [o]
2941+
LEFT JOIN (
2942+
SELECT [c0].[City], [c0].[ContactName], [c0].[Country], [c0].[CustomerID0]
2943+
FROM (
2944+
SELECT [c].[City], [c].[ContactName], [c].[Country], [c].[CustomerID] AS [CustomerID0], ROW_NUMBER() OVER(PARTITION BY [c].[CustomerID] ORDER BY [c].[CustomerID] DESC) AS [row]
2945+
FROM [Customers] AS [c]
2946+
) AS [c0]
2947+
WHERE [c0].[row] <= 1
2948+
) AS [c1] ON [o].[CustomerID] = [c1].[CustomerID0]
2949+
WHERE [o].[CustomerID] IS NOT NULL
2950+
""",
2951+
nameof(Queryable.ElementAt) or
2952+
nameof(Queryable.ElementAtOrDefault) => """
2953+
SELECT [o].[OrderID], [c1].[City], [c1].[Country], [c1].[ContactName]
2954+
FROM [Orders] AS [o]
2955+
LEFT JOIN (
2956+
SELECT [c0].[City], [c0].[ContactName], [c0].[Country], [c0].[CustomerID0]
2957+
FROM (
2958+
SELECT [c].[City], [c].[ContactName], [c].[Country], [c].[CustomerID] AS [CustomerID0], ROW_NUMBER() OVER(PARTITION BY [c].[CustomerID] ORDER BY [c].[CustomerID]) AS [row]
2959+
FROM [Customers] AS [c]
2960+
) AS [c0]
2961+
WHERE 0 < [c0].[row] AND [c0].[row] <= 1
2962+
) AS [c1] ON [o].[CustomerID] = [c1].[CustomerID0]
2963+
WHERE [o].[CustomerID] IS NOT NULL
2964+
""",
2965+
_ => throw new InvalidOperationException(method)
2966+
});
2967+
}
2968+
29142969
private void AssertSql(params string[] expected)
29152970
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
29162971

0 commit comments

Comments
 (0)