forked from CommunityToolkit/Datasync
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathExpressionExtensions.cs
More file actions
237 lines (209 loc) · 10.1 KB
/
ExpressionExtensions.cs
File metadata and controls
237 lines (209 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
// Reflection and LINQ requires a lot of null manipulation, so we opt for
// a generalized "nullable" option here to allow us to do that.
#nullable disable
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Reflection;
namespace CommunityToolkit.Datasync.Client.Query.Linq;
/// <summary>
/// A set of extensions methods that help to deal with <see cref="Expression"/> values,
/// which are used extensively in LINQ parsing.
/// </summary>
internal static class ExpressionExtensions
{
private static readonly MethodInfo Contains;
private static readonly MethodInfo SequenceEqual;
static ExpressionExtensions()
{
Dictionary<string, List<MethodInfo>> queryableMethodGroups = typeof(Enumerable)
.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.GroupBy(mi => mi.Name)
.ToDictionary(e => e.Key, l => l.ToList());
MethodInfo GetMethod(string name, int genericParameterCount, Func<Type[], Type[]> parameterGenerator)
=> queryableMethodGroups[name].Single(mi => ((genericParameterCount == 0 && !mi.IsGenericMethod)
|| (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount))
&& mi.GetParameters().Select(e => e.ParameterType).SequenceEqual(
parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : [])));
Contains = GetMethod(
nameof(Enumerable.Contains), 1,
types => [typeof(IEnumerable<>).MakeGenericType(types[0]), types[0]]);
SequenceEqual = GetMethod(
nameof(Enumerable.SequenceEqual), 1,
types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0])]);
}
/// <summary>
/// Walk the expression and compute all the subtrees that are not dependent on any
/// of the expressions parameters.
/// </summary>
/// <param name="expression">The expression to analyze.</param>
/// <returns>A collection of all the expression subtrees that are independent from the expression parameters.</returns>
internal static List<Expression> FindIndependentSubtrees(this Expression expression)
{
List<Expression> subtrees = [];
// The dependent and isMemberInit flags are used to communicate between different layers
// of the recursive visitor.
bool dependent = false;
bool isMemberInit = false;
// Walk the tree, finding the independent subtrees
_ = VisitorHelper.VisitAll(expression, (Expression expr, Func<Expression, Expression> recurse) =>
{
if (expr != null)
{
bool parentIsDependent = dependent;
bool parentIsMemberInit = isMemberInit;
// Set flags
dependent = false;
isMemberInit = expr is MemberInitExpression;
// Recurse
_ = recurse(expr);
// If nothing in my subtree is dependent
if (!dependent)
{
// A NewExpression itself will appear to be independent, but if the parent is a MemberInitExpression,
// then the NewExpression can't be evaluated by itself. The MemberInitExpression will determine
// if the full expression is dependent or not, so don't check it here.
if (expr is NewExpression newExpression && parentIsMemberInit)
{
return expr;
}
// The current node is independent if it's not related to the parameter and it's not the constant query root.
ConstantExpression constant = expr as ConstantExpression;
if (expr.NodeType == ExpressionType.Parameter || (constant?.Value is IQueryable))
{
dependent = true;
}
else
{
subtrees.Add(expr);
}
}
dependent |= parentIsDependent;
}
return expr;
});
return subtrees;
}
/// <summary>
/// Returns the member expressions in the expression hierarchy of the <paramref name="expression"/>
/// </summary>
/// <param name="expression">The expression to search</param>
/// <returns>A collection of <see cref="MemberExpression"/> objects</returns>
internal static IEnumerable<MemberExpression> GetMemberExpressions(this Expression expression)
{
List<MemberExpression> members = [];
_ = VisitorHelper.VisitMembers(expression, (MemberExpression expr, Func<MemberExpression, Expression> recurse) =>
{
members.Add(expr);
return recurse(expr);
});
return members;
}
/// <summary>
/// Determines if the provided expression is a valid LambdaExpression.
/// </summary>
/// <param name="expression">The expression</param>
/// <param name="lambdaExpression">The lambda expression equivalent</param>
/// <returns>True if a lambda expression</returns>
internal static bool IsValidLambdaExpression(this MethodCallExpression expression, [NotNullWhen(true)] out LambdaExpression lambdaExpression)
{
if (expression != null)
{
if (expression.Arguments.Count >= 2)
{
if (expression.Arguments[1].StripQuote() is LambdaExpression lambda)
{
lambdaExpression = lambda;
return true;
}
}
}
lambdaExpression = null;
return false;
}
/// <summary>
/// Evaluate all subtrees of an expression that aren't dependent on parameters to
/// that expression and replace the subtree with a constant expression.
/// </summary>
/// <param name="expression">The expression to evaluate</param>
/// <returns>The partially evaluated expression</returns>
internal static Expression PartiallyEvaluate(this Expression expression)
{
expression = expression.RemoveSpanImplicitCast();
List<Expression> subtrees = expression.FindIndependentSubtrees();
return VisitorHelper.VisitAll(expression, (Expression expr, Func<Expression, Expression> recurse) =>
{
if (expr != null && subtrees.Contains(expr) && expr.NodeType != ExpressionType.Constant)
{
Delegate compiled = Expression.Lambda(expr).Compile();
object value = compiled.DynamicInvoke();
return Expression.Constant(value, expr.Type);
}
else
{
return recurse(expr);
}
});
}
internal static Expression RemoveSpanImplicitCast(this Expression expression)
{
return VisitorHelper.VisitAll(expression, (Expression expr, Func<Expression, Expression> recurse) =>
{
if (expr is MethodCallExpression methodCall)
{
MethodInfo method = methodCall.Method;
if (method.DeclaringType == typeof(MemoryExtensions))
{
switch (method.Name)
{
case nameof(MemoryExtensions.Contains)
when methodCall.Arguments is [Expression arg0, Expression arg1] && TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0):
{
Expression unwrappedExpr = Expression.Call(
Contains.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]),
unwrappedArg0, arg1);
return recurse(unwrappedExpr);
}
case nameof(MemoryExtensions.SequenceEqual)
when methodCall.Arguments is [Expression arg0, Expression arg1]
&& TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0)
&& TryUnwrapSpanImplicitCast(arg1, out Expression unwrappedArg1):
{
Expression unwrappedExpr = Expression.Call(
SequenceEqual.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]),
unwrappedArg0, unwrappedArg1);
return recurse(unwrappedExpr);
}
}
static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result)
{
if (expression is MethodCallExpression
{
Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType },
Arguments: [Expression unwrapped]
}
&& implicitCastDeclaringType.GetGenericTypeDefinition() is Type genericTypeDefinition
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)))
{
result = unwrapped;
return true;
}
result = null;
return false;
}
}
}
return recurse(expr);
});
}
/// <summary>
/// Remove the quote from quoted expressions.
/// </summary>
/// <param name="expression">The expression to check.</param>
/// <returns>An unquoted expression</returns>
internal static Expression StripQuote(this Expression expression)
=> expression.NodeType == ExpressionType.Quote ? ((UnaryExpression)expression).Operand : expression;
}