Skip to content

Commit 75eb978

Browse files
fix: add defensive reflection checks with ReflectionMethodCache (#57)
- Centralize reflection method lookups in ReflectionMethodCache for DRY - Cache method lookups for performance - Add descriptive error messages when lookups fail (includes issue URL) - Add 12 regression tests
1 parent 96e5012 commit 75eb978

7 files changed

Lines changed: 402 additions & 107 deletions

File tree

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
using JsonApiToolkit.Helpers;
2+
3+
namespace JsonApiToolkit.Tests.Helpers;
4+
5+
public class ReflectionMethodCacheTests
6+
{
7+
[Fact]
8+
public void GetEnumerableAnyWithPredicate_ReturnsCorrectMethod()
9+
{
10+
// Act
11+
var method = ReflectionMethodCache.GetEnumerableAnyWithPredicate(typeof(int));
12+
13+
// Assert
14+
Assert.NotNull(method);
15+
Assert.Equal("Any", method.Name);
16+
Assert.Equal(2, method.GetParameters().Length);
17+
Assert.True(method.IsGenericMethod);
18+
}
19+
20+
[Fact]
21+
public void GetEnumerableAnyWithPredicate_IsCached()
22+
{
23+
// Act
24+
var method1 = ReflectionMethodCache.GetEnumerableAnyWithPredicate(typeof(int));
25+
var method2 = ReflectionMethodCache.GetEnumerableAnyWithPredicate(typeof(int));
26+
27+
// Assert - both should be the same instance (cached base method, different generic instantiation)
28+
Assert.Equal(method1, method2);
29+
}
30+
31+
[Fact]
32+
public void GetEnumerableContains_ReturnsCorrectMethod()
33+
{
34+
// Act
35+
var method = ReflectionMethodCache.GetEnumerableContains(typeof(string));
36+
37+
// Assert
38+
Assert.NotNull(method);
39+
Assert.Equal("Contains", method.Name);
40+
Assert.Equal(2, method.GetParameters().Length);
41+
Assert.True(method.IsGenericMethod);
42+
}
43+
44+
[Fact]
45+
public void GetEnumerableWhere_ReturnsCorrectMethod()
46+
{
47+
// Act
48+
var method = ReflectionMethodCache.GetEnumerableWhere(typeof(int));
49+
50+
// Assert
51+
Assert.NotNull(method);
52+
Assert.Equal("Where", method.Name);
53+
Assert.Equal(2, method.GetParameters().Length);
54+
Assert.True(method.IsGenericMethod);
55+
}
56+
57+
[Theory]
58+
[InlineData("OrderBy")]
59+
[InlineData("OrderByDescending")]
60+
[InlineData("ThenBy")]
61+
[InlineData("ThenByDescending")]
62+
public void GetQueryableOrderingMethod_ReturnsCorrectMethod(string methodName)
63+
{
64+
// Act
65+
var method = ReflectionMethodCache.GetQueryableOrderingMethod(
66+
methodName,
67+
typeof(TestEntity),
68+
typeof(string)
69+
);
70+
71+
// Assert
72+
Assert.NotNull(method);
73+
Assert.Equal(methodName, method.Name);
74+
Assert.Equal(2, method.GetParameters().Length);
75+
Assert.True(method.IsGenericMethod);
76+
}
77+
78+
[Fact]
79+
public void GetQueryableOrderingMethod_WithInvalidMethodName_ThrowsInvalidOperationException()
80+
{
81+
// Act & Assert
82+
var ex = Assert.Throws<InvalidOperationException>(() =>
83+
ReflectionMethodCache.GetQueryableOrderingMethod(
84+
"NonExistentMethod",
85+
typeof(TestEntity),
86+
typeof(string)
87+
)
88+
);
89+
90+
Assert.Contains("Could not find Queryable.NonExistentMethod", ex.Message);
91+
Assert.Contains("report this issue", ex.Message);
92+
}
93+
94+
[Fact]
95+
public void GetEfCoreIncludeMethod_ReturnsCorrectMethod()
96+
{
97+
// Act
98+
var method = ReflectionMethodCache.GetEfCoreIncludeMethod(
99+
typeof(TestEntity),
100+
typeof(string)
101+
);
102+
103+
// Assert
104+
Assert.NotNull(method);
105+
Assert.Equal("Include", method.Name);
106+
Assert.Equal(2, method.GetParameters().Length);
107+
Assert.True(method.IsGenericMethod);
108+
}
109+
110+
[Fact]
111+
public void GetEfCoreThenIncludeMethod_ForCollectionNavigation_ReturnsCorrectMethod()
112+
{
113+
// Act
114+
var method = ReflectionMethodCache.GetEfCoreThenIncludeMethod(
115+
isPreviousCollection: true,
116+
entityType: typeof(TestEntity),
117+
previousPropertyType: typeof(TestRelated),
118+
newPropertyType: typeof(string)
119+
);
120+
121+
// Assert
122+
Assert.NotNull(method);
123+
Assert.Equal("ThenInclude", method.Name);
124+
Assert.True(method.IsGenericMethod);
125+
}
126+
127+
[Fact]
128+
public void GetEfCoreThenIncludeMethod_ForReferenceNavigation_ReturnsCorrectMethod()
129+
{
130+
// Act
131+
var method = ReflectionMethodCache.GetEfCoreThenIncludeMethod(
132+
isPreviousCollection: false,
133+
entityType: typeof(TestEntity),
134+
previousPropertyType: typeof(TestRelated),
135+
newPropertyType: typeof(string)
136+
);
137+
138+
// Assert
139+
Assert.NotNull(method);
140+
Assert.Equal("ThenInclude", method.Name);
141+
Assert.True(method.IsGenericMethod);
142+
}
143+
144+
private class TestEntity
145+
{
146+
public int Id { get; set; }
147+
public string Name { get; set; } = string.Empty;
148+
}
149+
150+
private class TestRelated
151+
{
152+
public int Id { get; set; }
153+
}
154+
}

JsonApiToolkit/Extensions/Querying/Filtering/NestedPropertyNavigator.cs

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Linq.Expressions;
22
using System.Reflection;
33
using System.Text.RegularExpressions;
4+
using JsonApiToolkit.Helpers;
45
using JsonApiToolkit.Models.Querying.Filtering;
56
using Microsoft.Extensions.Logging;
67

@@ -219,14 +220,7 @@ private static string SanitizeForLog(string? value)
219220
LambdaExpression predicate = Expression.Lambda(innerExpression, itemParam);
220221

221222
// Get the Enumerable.Any<T>(IEnumerable<T>, Func<T, bool>) method
222-
MethodInfo anyMethod = typeof(Enumerable)
223-
.GetMethods()
224-
.First(m =>
225-
m.Name == "Any"
226-
&& m.GetParameters().Length == 2
227-
&& m.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Func<,>)
228-
)
229-
.MakeGenericMethod(elementType);
223+
MethodInfo anyMethod = ReflectionMethodCache.GetEnumerableAnyWithPredicate(elementType);
230224

231225
// Build: collection.Any(item => predicate)
232226
return Expression.Call(anyMethod, collectionAccess, predicate);
@@ -274,15 +268,9 @@ private static string SanitizeForLog(string? value)
274268

275269
LambdaExpression predicate = Expression.Lambda(containsCall, itemParam);
276270

277-
MethodInfo anyMethod = typeof(Enumerable)
278-
.GetMethods()
279-
.First(m =>
280-
m.Name == "Any"
281-
&& m.GetParameters().Length == 2
282-
&& m.GetParameters()[1].ParameterType.GetGenericTypeDefinition()
283-
== typeof(Func<,>)
284-
)
285-
.MakeGenericMethod(elementType);
271+
MethodInfo anyMethod = ReflectionMethodCache.GetEnumerableAnyWithPredicate(
272+
elementType
273+
);
286274

287275
return Expression.Call(anyMethod, collectionAccess, predicate);
288276
}
@@ -300,10 +288,9 @@ private static string SanitizeForLog(string? value)
300288
}
301289

302290
// Get Contains method on IEnumerable<T> (via Enumerable.Contains)
303-
MethodInfo containsMethodInfo = typeof(Enumerable)
304-
.GetMethods()
305-
.First(m => m.Name == "Contains" && m.GetParameters().Length == 2)
306-
.MakeGenericMethod(elementType);
291+
MethodInfo containsMethodInfo = ReflectionMethodCache.GetEnumerableContains(
292+
elementType
293+
);
307294

308295
return Expression.Call(
309296
containsMethodInfo,
@@ -326,10 +313,9 @@ private static string SanitizeForLog(string? value)
326313
return null;
327314
}
328315

329-
MethodInfo containsMethodInfo = typeof(Enumerable)
330-
.GetMethods()
331-
.First(m => m.Name == "Contains" && m.GetParameters().Length == 2)
332-
.MakeGenericMethod(elementType);
316+
MethodInfo containsMethodInfo = ReflectionMethodCache.GetEnumerableContains(
317+
elementType
318+
);
333319

334320
return Expression.Not(
335321
Expression.Call(

JsonApiToolkit/Extensions/Querying/Handlers/SortingHandler.cs

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Linq.Expressions;
22
using System.Reflection;
3+
using JsonApiToolkit.Helpers;
34
using JsonApiToolkit.Models.Querying;
45
using Microsoft.Extensions.Logging;
56

@@ -46,31 +47,23 @@ public static IQueryable<T> ApplySorting<T>(
4647
if (i == 0)
4748
{
4849
methodName = sortParam.IsDescending ? "OrderByDescending" : "OrderBy";
49-
orderedQuery = (IOrderedQueryable<T>?)
50-
typeof(Queryable)
51-
.GetMethods()
52-
.Single(method =>
53-
method.Name == methodName
54-
&& method.IsGenericMethodDefinition
55-
&& method.GetParameters().Length == 2
56-
)
57-
.MakeGenericMethod(entityType, property.PropertyType)
58-
.Invoke(null, [query, lambda]);
50+
var orderMethod = ReflectionMethodCache.GetQueryableOrderingMethod(
51+
methodName,
52+
entityType,
53+
property.PropertyType
54+
);
55+
orderedQuery = (IOrderedQueryable<T>?)orderMethod.Invoke(null, [query, lambda]);
5956
}
6057
else
6158
{
6259
methodName = sortParam.IsDescending ? "ThenByDescending" : "ThenBy";
63-
60+
var thenMethod = ReflectionMethodCache.GetQueryableOrderingMethod(
61+
methodName,
62+
entityType,
63+
property.PropertyType
64+
);
6465
orderedQuery = (IOrderedQueryable<T>?)
65-
typeof(Queryable)
66-
.GetMethods()
67-
.Single(method =>
68-
method.Name == methodName
69-
&& method.IsGenericMethodDefinition
70-
&& method.GetParameters().Length == 2
71-
)
72-
.MakeGenericMethod(entityType, property.PropertyType)
73-
.Invoke(null, [orderedQuery, lambda]);
66+
thenMethod.Invoke(null, [orderedQuery, lambda]);
7467
}
7568
}
7669

JsonApiToolkit/Extensions/Querying/Includes/EfCoreIncludeExpressions.cs

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Linq.Expressions;
22
using System.Reflection;
3+
using JsonApiToolkit.Helpers;
34
using Microsoft.EntityFrameworkCore;
45

56
namespace JsonApiToolkit.Extensions.Querying;
@@ -13,42 +14,12 @@ internal static MethodInfo GetThenIncludeMethod(
1314
Type newPropertyType
1415
)
1516
{
16-
var thenIncludeMethods = typeof(EntityFrameworkQueryableExtensions)
17-
.GetMethods()
18-
.Where(m => m.Name == "ThenInclude" && m.GetGenericArguments().Length == 3)
19-
.ToList();
20-
21-
foreach (var method in thenIncludeMethods)
22-
{
23-
var parameters = method.GetParameters();
24-
if (parameters.Length != 2)
25-
continue;
26-
27-
var firstParamType = parameters[0].ParameterType;
28-
if (
29-
!firstParamType.IsGenericType
30-
|| firstParamType.GetGenericTypeDefinition().Name != "IIncludableQueryable`2"
31-
)
32-
continue;
33-
34-
var genericArgs = firstParamType.GetGenericArguments();
35-
if (genericArgs.Length != 2)
36-
continue;
37-
38-
var secondGenericArg = genericArgs[1];
39-
40-
bool isCollectionOverload =
41-
secondGenericArg.IsGenericType
42-
&& secondGenericArg.GetGenericTypeDefinition() == typeof(IEnumerable<>);
43-
44-
if (isCollectionOverload == isPreviousCollection)
45-
return method.MakeGenericMethod(entityType, previousPropertyType, newPropertyType);
46-
}
47-
48-
return typeof(EntityFrameworkQueryableExtensions)
49-
.GetMethods()
50-
.First(m => m.Name == "ThenInclude" && m.GetGenericArguments().Length == 3)
51-
.MakeGenericMethod(entityType, previousPropertyType, newPropertyType);
17+
return ReflectionMethodCache.GetEfCoreThenIncludeMethod(
18+
isPreviousCollection,
19+
entityType,
20+
previousPropertyType,
21+
newPropertyType
22+
);
5223
}
5324

5425
internal static IQueryable<T> ApplyIncludeExpression<T>(
@@ -65,15 +36,7 @@ internal static IQueryable<T> ApplyIncludeExpression<T>(
6536
{
6637
var returnType = lambdaType.GetGenericArguments()[1];
6738

68-
var includeMethod = typeof(EntityFrameworkQueryableExtensions)
69-
.GetMethods()
70-
.First(m =>
71-
m.Name == "Include"
72-
&& m.GetParameters().Length == 2
73-
&& m.GetParameters()[1].ParameterType.GetGenericTypeDefinition()
74-
== typeof(Expression<>)
75-
)
76-
.MakeGenericMethod(typeof(T), returnType);
39+
var includeMethod = ReflectionMethodCache.GetEfCoreIncludeMethod(typeof(T), returnType);
7740

7841
return (IQueryable<T>)
7942
includeMethod.Invoke(null, new object[] { query, includeExpression })!;

0 commit comments

Comments
 (0)