Skip to content

Commit 0aa84fc

Browse files
authored
Merge pull request #180 from koenbeuk/feature/optimize-closures
Feature/optimize closures
2 parents eb618fe + 9a2ea46 commit 0aa84fc

24 files changed

Lines changed: 451 additions & 41 deletions

File tree

src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,22 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
2222
private IEntityType? _entityType;
2323

2424
// Extract MethodInfo via expression trees (trim-safe; computed once per AppDomain)
25-
private static readonly MethodInfo _select =
25+
private readonly static MethodInfo _select =
2626
((MethodCallExpression)((Expression<Func<IQueryable<object>, IQueryable<object>>>)
2727
(q => q.Select(x => x))).Body).Method.GetGenericMethodDefinition();
2828

29-
private static readonly MethodInfo _where =
29+
private readonly static MethodInfo _where =
3030
((MethodCallExpression)((Expression<Func<IQueryable<object>, IQueryable<object>>>)
3131
(q => q.Where(x => true))).Body).Method.GetGenericMethodDefinition();
3232

33+
// Static caches — keyed by CLR type, shared across all instances for the AppDomain lifetime.
34+
// ConditionalWeakTable uses "ephemeron" semantics: the Type key is not kept alive by the
35+
// cache entry, so types from collectible AssemblyLoadContexts can still be unloaded.
36+
private readonly static ConditionalWeakTable<Type, StrongBox<bool>> _compilerGeneratedClosureCache = new();
37+
private readonly static ConditionalWeakTable<Type, PropertyInfo[]> _projectablePropertiesCache = new();
38+
private readonly static ConditionalWeakTable<Type, MethodInfo> _closedSelectCache = new();
39+
private readonly static ConditionalWeakTable<Type, MethodInfo> _closedWhereCache = new();
40+
3341
public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false)
3442
{
3543
_trackingByDefault = trackByDefault;
@@ -84,7 +92,6 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
8492
// // case of a first()
8593
// return obj.MyMap(x => new Obj {});
8694
// }
87-
8895

8996
if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable)))
9097
{
@@ -101,7 +108,8 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
101108
// before the query become executed by EF (before the .First()), we rewrite the .First(where)
102109
// as .Where(where).Select(x => ...).First()
103110

104-
var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments);
111+
var whereMethod = _closedWhereCache.GetValue(_entityType.ClrType, t => _where.MakeGenericMethod(t));
112+
var where = Expression.Call(null, whereMethod, call.Arguments);
105113
// The call instance is based on the wrong polymorphied method.
106114
var first = call.Method.DeclaringType?.GetMethods()
107115
.FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1);
@@ -138,18 +146,27 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
138146
protected override Expression VisitMethodCall(MethodCallExpression node)
139147
{
140148
// Replace MethodGroup arguments with their reflected expressions.
141-
// Note that MethodCallExpression.Update returns the original Expression if argument values have not changed.
142-
node = node.Update(node.Object, node.Arguments.Select(arg => arg switch {
143-
UnaryExpression {
144-
NodeType: ExpressionType.Convert,
145-
Operand: MethodCallExpression {
146-
NodeType: ExpressionType.Call,
147-
Method: { Name: nameof(MethodInfo.CreateDelegate), DeclaringType.Name: nameof(MethodInfo) },
148-
Object: ConstantExpression { Value: MethodInfo methodInfo }
149-
}
150-
} => TryGetReflectedExpression(methodInfo, out var expressionArg) ? expressionArg : arg,
151-
_ => arg
152-
}));
149+
// No-alloc fast-path: scan args without allocating; only copy the array and call
150+
// Update() when a replacement is actually found (method-group arguments are rare).
151+
Expression[]? updatedArgs = null;
152+
for (var i = 0; i < node.Arguments.Count; i++)
153+
{
154+
if (node.Arguments[i] is UnaryExpression {
155+
NodeType: ExpressionType.Convert,
156+
Operand: MethodCallExpression {
157+
NodeType: ExpressionType.Call,
158+
Method: { Name: nameof(MethodInfo.CreateDelegate), DeclaringType.Name: nameof(MethodInfo) },
159+
Object: ConstantExpression { Value: MethodInfo capturedMethodInfo }
160+
}
161+
} && TryGetReflectedExpression(capturedMethodInfo, out var expressionArg))
162+
{
163+
(updatedArgs ??= [.. node.Arguments])[i] = expressionArg;
164+
}
165+
}
166+
if (updatedArgs is not null)
167+
{
168+
node = node.Update(node.Object, updatedArgs);
169+
}
153170

154171
// Get the overriding methodInfo based on te type of the received of this expression
155172
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
@@ -172,7 +189,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
172189
{
173190
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
174191
{
175-
var parameterExpession = reflectedExpression.Parameters[parameterIndex];
192+
var parameterExpression = reflectedExpression.Parameters[parameterIndex];
176193
var mappedArgumentExpression = (parameterIndex, node.Object) switch {
177194
(0, not null) => node.Object,
178195
(_, not null) => node.Arguments[parameterIndex - 1],
@@ -181,7 +198,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
181198

182199
if (mappedArgumentExpression is not null)
183200
{
184-
_expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpession, mappedArgumentExpression);
201+
_expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpression, mappedArgumentExpression);
185202
}
186203
}
187204

@@ -232,19 +249,35 @@ protected override Expression VisitMember(MemberExpression node)
232249
{
233250
// Evaluate captured variables in closures that contain EF queries to inline them into the main query
234251
if (node.Expression is ConstantExpression constant &&
235-
constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) &&
236-
Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true))
252+
IsCompilerGeneratedClosure(constant.Type))
237253
{
238254
try
239255
{
240-
var value = Expression
241-
.Lambda<Func<object>>(Expression.Convert(node, typeof(object)))
242-
.Compile()
243-
.Invoke();
256+
// Cheap type check first: only call GetValue() when the declared type
257+
// could possibly hold an IQueryable at runtime. We use IEnumerable as
258+
// the gate (rather than IQueryable) because a variable legitimately
259+
// declared as IEnumerable<T> may hold an EF Core IQueryable<T> at
260+
// runtime — both interfaces share the same assignability chain.
261+
// FieldType / PropertyType are free property reads on already-
262+
// materialised MemberInfo objects, so this check is cheap.
263+
var memberType = node.Member switch {
264+
FieldInfo field => field.FieldType,
265+
PropertyInfo prop => prop.PropertyType,
266+
_ => null
267+
};
244268

245-
if (value is IQueryable queryable && ReferenceEquals(queryable.Provider, _currentQueryProvider))
269+
if (memberType is not null && typeof(IEnumerable).IsAssignableFrom(memberType))
246270
{
247-
return Visit(queryable.Expression);
271+
var value = node.Member switch {
272+
FieldInfo field => field.GetValue(constant.Value),
273+
PropertyInfo prop => prop.GetValue(constant.Value),
274+
_ => null
275+
};
276+
277+
if (value is IQueryable queryable && ReferenceEquals(queryable.Provider, _currentQueryProvider))
278+
{
279+
return Visit(queryable.Expression);
280+
}
248281
}
249282
}
250283
catch
@@ -275,16 +308,10 @@ PropertyInfo property when nodeExpression is not null
275308
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
276309
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
277310

278-
return base.Visit(
279-
updatedBody
280-
);
281-
}
282-
else
283-
{
284-
return base.Visit(
285-
reflectedExpression.Body
286-
);
311+
return base.Visit(updatedBody);
287312
}
313+
314+
return base.Visit(reflectedExpression.Body);
288315
}
289316

290317
return base.VisitMember(node);
@@ -303,12 +330,13 @@ protected override Expression VisitExtension(Expression node)
303330

304331
private Expression _AddProjectableSelect(Expression node, IEntityType entityType)
305332
{
306-
var projectableProperties = entityType.ClrType.GetProperties()
307-
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false))
308-
.Where(x => x.CanWrite)
309-
.ToList();
333+
var projectableProperties = _projectablePropertiesCache.GetValue(
334+
entityType.ClrType,
335+
static t => t.GetProperties()
336+
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false) && x.CanWrite)
337+
.ToArray());
310338

311-
if (!projectableProperties.Any())
339+
if (projectableProperties.Length == 0)
312340
{
313341
return node;
314342
}
@@ -327,7 +355,7 @@ private Expression _AddProjectableSelect(Expression node, IEntityType entityType
327355
.Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField"));
328356

329357
// Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
330-
var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
358+
var select = _closedSelectCache.GetValue(entityType.ClrType, t => _select.MakeGenericMethod(t, t));
331359
var xParam = Expression.Parameter(entityType.ClrType);
332360
return Expression.Call(
333361
null,
@@ -354,5 +382,12 @@ private Expression _GetAccessor(PropertyInfo property, ParameterExpression para)
354382
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
355383
return base.Visit(updatedBody);
356384
}
385+
386+
private static bool IsCompilerGeneratedClosure(Type type) =>
387+
// TypeAttributes.NestedPrivate is a cheap flag check that rules out most types before
388+
// touching the attribute cache.
389+
type.Attributes.HasFlag(TypeAttributes.NestedPrivate) &&
390+
_compilerGeneratedClosureCache.GetValue(type, static t =>
391+
new StrongBox<bool>(Attribute.IsDefined(t, typeof(CompilerGeneratedAttribute), inherit: true))).Value;
357392
}
358393
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT [e].[Id], [e].[Name]
2+
FROM [Entity] AS [e]
3+
WHERE EXISTS (
4+
SELECT 1
5+
FROM [Entity] AS [e0]
6+
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT [e].[Id], [e].[Name]
2+
FROM [Entity] AS [e]
3+
WHERE EXISTS (
4+
SELECT 1
5+
FROM [Entity] AS [e0]
6+
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT [e].[Id], [e].[Name]
2+
FROM [Entity] AS [e]
3+
WHERE EXISTS (
4+
SELECT 1
5+
FROM [Entity] AS [e0]
6+
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
SELECT [e].[Id], (
2+
SELECT COUNT(*)
3+
FROM [Entity] AS [e0]
4+
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
5+
FROM [Entity] AS [e]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
SELECT [e].[Id], (
2+
SELECT COUNT(*)
3+
FROM [Entity] AS [e0]
4+
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
5+
FROM [Entity] AS [e]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
SELECT [e].[Id], (
2+
SELECT COUNT(*)
3+
FROM [Entity] AS [e0]
4+
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
5+
FROM [Entity] AS [e]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT [e].[Id], [e].[Name]
2+
FROM [Entity] AS [e]
3+
WHERE EXISTS (
4+
SELECT 1
5+
FROM [Entity] AS [e0]
6+
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT [e].[Id], [e].[Name]
2+
FROM [Entity] AS [e]
3+
WHERE EXISTS (
4+
SELECT 1
5+
FROM [Entity] AS [e0]
6+
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
SELECT [e].[Id], [e].[Name]
2+
FROM [Entity] AS [e]
3+
WHERE EXISTS (
4+
SELECT 1
5+
FROM [Entity] AS [e0]
6+
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])

0 commit comments

Comments
 (0)