Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.Query.Internal;

public partial class NavigationExpandingExpressionVisitor
{
private NavigationExpansionExpression LiftSingleResultSubqueries(NavigationExpansionExpression source)
{
var selectorBody = source.PendingSelector;

var collector = new SingleResultMemberAccessCollector();
collector.Visit(selectorBody);

foreach (var subquery in collector.Liftable)
{
var collection = BuildSingleResultCollection(subquery);
var innerParameter = Expression.Parameter(collection.Type.GetSequenceType(), "e");
var rewrittenBody = new ReplacingExpressionVisitor([subquery], [innerParameter]).Visit(selectorBody);

// The collection already references the outer element via source.PendingSelector; this parameter is unused.
source = ProcessSelectMany(
source,
Expression.Lambda(collection, Expression.Parameter(source.SourceElementType, "o")),
Expression.Lambda(rewrittenBody, Expression.Parameter(source.SourceElementType, "o"), innerParameter));
selectorBody = source.PendingSelector;
}

return source;
}

private static Expression BuildSingleResultCollection(MethodCallExpression subqueryMethod)
{
var method = subqueryMethod.Method.GetGenericMethodDefinition();
var source = subqueryMethod.Arguments[0];
var elementType = source.Type.GetSequenceType();

Expression Where(Expression c)
=> Expression.Call(QueryableMethods.Where.MakeGenericMethod(elementType), c, subqueryMethod.Arguments[1]);
Expression Skip(Expression c)
=> Expression.Call(QueryableMethods.Skip.MakeGenericMethod(elementType), c, subqueryMethod.Arguments[1]);
Expression Reverse(Expression c)
=> Expression.Call(QueryableMethods.Reverse.MakeGenericMethod(elementType), c);

var oneRow = method switch
{
_ when method == QueryableMethods.FirstWithPredicate || method == QueryableMethods.FirstOrDefaultWithPredicate
|| method == QueryableMethods.SingleWithPredicate || method == QueryableMethods.SingleOrDefaultWithPredicate
=> Where(source),
_ when method == QueryableMethods.LastWithPredicate || method == QueryableMethods.LastOrDefaultWithPredicate
=> Reverse(Where(source)),
_ when method == QueryableMethods.LastWithoutPredicate || method == QueryableMethods.LastOrDefaultWithoutPredicate
=> Reverse(source),
_ when method == QueryableMethods.ElementAt || method == QueryableMethods.ElementAtOrDefault
=> Skip(source),
_ => source
};

var firstRow = Expression.Call(QueryableMethods.Take.MakeGenericMethod(elementType), oneRow, Expression.Constant(1));

return Expression.Call(QueryableMethods.DefaultIfEmptyWithoutArgument.MakeGenericMethod(elementType), firstRow);
}

private sealed class SingleResultMemberAccessCollector : ExpressionVisitor
{
private static readonly HashSet<MethodInfo> SingleResultMethods =
[
QueryableMethods.FirstWithPredicate, QueryableMethods.FirstWithoutPredicate,
QueryableMethods.FirstOrDefaultWithPredicate, QueryableMethods.FirstOrDefaultWithoutPredicate,
QueryableMethods.SingleWithPredicate, QueryableMethods.SingleWithoutPredicate,
QueryableMethods.SingleOrDefaultWithPredicate, QueryableMethods.SingleOrDefaultWithoutPredicate,
QueryableMethods.LastWithPredicate, QueryableMethods.LastWithoutPredicate,
QueryableMethods.LastOrDefaultWithPredicate, QueryableMethods.LastOrDefaultWithoutPredicate,
QueryableMethods.ElementAt, QueryableMethods.ElementAtOrDefault
];

private readonly Dictionary<MethodCallExpression, int> _memberAccessCount = new(ReferenceEqualityComparer.Instance);

public IEnumerable<MethodCallExpression> Liftable
=> _memberAccessCount.Where(e => e.Value > 1).Select(e => e.Key);

protected override Expression VisitMember(MemberExpression memberExpression)
{
if (memberExpression.Expression is MethodCallExpression { Method.IsGenericMethod: true } subquery
&& SingleResultMethods.Contains(subquery.Method.GetGenericMethodDefinition()))
{
if (_memberAccessCount.TryGetValue(subquery, out var count))
{
_memberAccessCount[subquery] = count + 1;
}
else
{
// First encounter: count it and visit its arguments once (a later access duplicates the count)
_memberAccessCount[subquery] = 1;

foreach (var argument in subquery.Arguments)
{
Visit(argument);
}
}

return memberExpression;
}

return base.VisitMember(memberExpression);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,10 @@ when QueryableMethods.IsSumWithSelector(method):

case nameof(Queryable.Select)
when genericMethod == QueryableMethods.Select:
return ProcessSelect(
source,
methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
return LiftSingleResultSubqueries(
ProcessSelect(
source,
methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()));

case nameof(Queryable.Where)
when genericMethod == QueryableMethods.Where:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,13 @@ public override async Task Projection_in_a_subquery_should_be_liftable(bool asyn
public override Task Projection_containing_DateTime_subtraction(bool async)
=> Assert.ThrowsAsync<InvalidOperationException>(() => base.Projection_containing_DateTime_subtraction(async));

public override async Task Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(bool async, string method)
{
await AssertTranslationFailed(() => base.Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(async, method));

AssertSql();
}

public override async Task Project_single_element_from_collection_with_OrderBy_Take_and_FirstOrDefault(bool async)
{
// Cosmos client evaluation. Issue #17246.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2303,4 +2303,55 @@ from o2 in ss.Set<Order>()
}).Take(5),
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e.OrderIds, a.OrderIds, elementSorter: ee => ee));

public static TheoryData<bool, string> SingleResultMethodData()
=> new(
from async in new[] { true, false }
from method in new[]
{
nameof(Queryable.First), nameof(Queryable.FirstOrDefault),
nameof(Queryable.Single), nameof(Queryable.SingleOrDefault),
nameof(Queryable.Last), nameof(Queryable.LastOrDefault),
nameof(Queryable.ElementAt), nameof(Queryable.ElementAtOrDefault)
}
select (async, method));

[Theory, MemberData(nameof(SingleResultMethodData))]
public virtual Task Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(bool async, string method)
=> AssertQuery(
async,
ss =>
{
var customers = ss.Set<Customer>();
var orders = ss.Set<Order>().Where(o => o.CustomerID != null);
return method switch
{
nameof(Queryable.First) => from o in orders
let c = customers.First(c => c.CustomerID == o.CustomerID)
select new { o.OrderID, c.City, c.Country, c.ContactName },
nameof(Queryable.FirstOrDefault) => from o in orders
let c = customers.FirstOrDefault(c => c.CustomerID == o.CustomerID)
select new { o.OrderID, c.City, c.Country, c.ContactName },
nameof(Queryable.Single) => from o in orders
let c = customers.Single(c => c.CustomerID == o.CustomerID)
select new { o.OrderID, c.City, c.Country, c.ContactName },
nameof(Queryable.SingleOrDefault) => from o in orders
let c = customers.SingleOrDefault(c => c.CustomerID == o.CustomerID)
select new { o.OrderID, c.City, c.Country, c.ContactName },
nameof(Queryable.Last) => from o in orders
let c = customers.OrderBy(c => c.CustomerID).Last(c => c.CustomerID == o.CustomerID)
select new { o.OrderID, c.City, c.Country, c.ContactName },
nameof(Queryable.LastOrDefault) => from o in orders
let c = customers.OrderBy(c => c.CustomerID).LastOrDefault(c => c.CustomerID == o.CustomerID)
select new { o.OrderID, c.City, c.Country, c.ContactName },
nameof(Queryable.ElementAt) => from o in orders
let c = customers.Where(c => c.CustomerID == o.CustomerID).OrderBy(c => c.CustomerID).ElementAt(0)
select new { o.OrderID, c.City, c.Country, c.ContactName },
nameof(Queryable.ElementAtOrDefault) => from o in orders
let c = customers.Where(c => c.CustomerID == o.CustomerID).OrderBy(c => c.CustomerID).ElementAtOrDefault(0)
select new { o.OrderID, c.City, c.Country, c.ContactName },
_ => throw new InvalidOperationException(method)
};
},
elementSorter: e => e.OrderID);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2911,6 +2911,61 @@ ORDER BY [c0].[CustomerID]
""");
}

public override async Task Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(bool async, string method)
{
await base.Multiple_members_of_correlated_single_result_subquery_lift_to_single_join(async, method);

AssertSql(
method switch
{
nameof(Queryable.First) or
nameof(Queryable.FirstOrDefault) or
nameof(Queryable.Single) or
nameof(Queryable.SingleOrDefault) => """
SELECT [o].[OrderID], [c1].[City], [c1].[Country], [c1].[ContactName]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c0].[City], [c0].[ContactName], [c0].[Country], [c0].[CustomerID0]
FROM (
SELECT [c].[City], [c].[ContactName], [c].[Country], [c].[CustomerID] AS [CustomerID0], ROW_NUMBER() OVER(PARTITION BY [c].[CustomerID] ORDER BY [c].[CustomerID]) AS [row]
FROM [Customers] AS [c]
) AS [c0]
WHERE [c0].[row] <= 1
) AS [c1] ON [o].[CustomerID] = [c1].[CustomerID0]
WHERE [o].[CustomerID] IS NOT NULL
""",
nameof(Queryable.Last) or
nameof(Queryable.LastOrDefault) => """
SELECT [o].[OrderID], [c1].[City], [c1].[Country], [c1].[ContactName]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c0].[City], [c0].[ContactName], [c0].[Country], [c0].[CustomerID0]
FROM (
SELECT [c].[City], [c].[ContactName], [c].[Country], [c].[CustomerID] AS [CustomerID0], ROW_NUMBER() OVER(PARTITION BY [c].[CustomerID] ORDER BY [c].[CustomerID] DESC) AS [row]
FROM [Customers] AS [c]
) AS [c0]
WHERE [c0].[row] <= 1
) AS [c1] ON [o].[CustomerID] = [c1].[CustomerID0]
WHERE [o].[CustomerID] IS NOT NULL
""",
nameof(Queryable.ElementAt) or
nameof(Queryable.ElementAtOrDefault) => """
SELECT [o].[OrderID], [c1].[City], [c1].[Country], [c1].[ContactName]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c0].[City], [c0].[ContactName], [c0].[Country], [c0].[CustomerID0]
FROM (
SELECT [c].[City], [c].[ContactName], [c].[Country], [c].[CustomerID] AS [CustomerID0], ROW_NUMBER() OVER(PARTITION BY [c].[CustomerID] ORDER BY [c].[CustomerID]) AS [row]
FROM [Customers] AS [c]
) AS [c0]
WHERE 0 < [c0].[row] AND [c0].[row] <= 1
) AS [c1] ON [o].[CustomerID] = [c1].[CustomerID0]
WHERE [o].[CustomerID] IS NOT NULL
""",
_ => throw new InvalidOperationException(method)
});
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Loading