|
1 | 1 | using System.Reflection; |
2 | 2 |
|
3 | | -namespace EntityFrameworkCore.Projectables.Extensions |
| 3 | +namespace EntityFrameworkCore.Projectables.Extensions; |
| 4 | + |
| 5 | +public static class TypeExtensions |
4 | 6 | { |
5 | | - public static class TypeExtensions |
| 7 | + public static Type[] GetNestedTypePath(this Type type) |
6 | 8 | { |
7 | | - public static string GetSimplifiedTypeName(this Type type) |
| 9 | + // First pass: count the nesting depth so we can size the array exactly. |
| 10 | + var depth = 0; |
| 11 | + var current = type; |
| 12 | + while (true) |
8 | 13 | { |
9 | | - var name = type.Name; |
10 | | - |
11 | | - var backtickIndex = name.IndexOf("`"); |
12 | | - if (backtickIndex != -1) |
| 14 | + depth++; |
| 15 | + if (!current.IsNested || current.DeclaringType is null) |
13 | 16 | { |
14 | | - name = name.Substring(0, backtickIndex); |
| 17 | + break; |
15 | 18 | } |
16 | 19 |
|
17 | | - return name; |
| 20 | + current = current.DeclaringType; |
18 | 21 | } |
19 | 22 |
|
20 | | - public static IEnumerable<Type> GetNestedTypePath(this Type type) |
| 23 | + // Second pass: fill the array outermost-first by walking back from the leaf. |
| 24 | + var path = new Type[depth]; |
| 25 | + current = type; |
| 26 | + for (var i = depth - 1; i >= 0; i--) |
21 | 27 | { |
22 | | - if (type.IsNested && type.DeclaringType is not null) |
23 | | - { |
24 | | - foreach (var containingType in type.DeclaringType.GetNestedTypePath()) |
25 | | - { |
26 | | - yield return containingType; |
27 | | - } |
28 | | - } |
29 | | - |
30 | | - yield return type; |
| 28 | + path[i] = current; |
| 29 | + current = current.DeclaringType!; |
31 | 30 | } |
32 | 31 |
|
33 | | - private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo) |
34 | | - { |
35 | | - // We only need to search for virtual instance methods who are not declared on the derivedType |
36 | | - if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual) |
37 | | - { |
38 | | - return false; |
39 | | - } |
| 32 | + return path; |
| 33 | + } |
40 | 34 |
|
41 | | - if (!derivedType.IsAssignableTo(methodInfo.DeclaringType)) |
42 | | - { |
43 | | - throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); |
44 | | - } |
| 35 | + private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo) |
| 36 | + { |
| 37 | + // We only need to search for virtual instance methods who are not declared on the derivedType |
| 38 | + if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual) |
| 39 | + { |
| 40 | + return false; |
| 41 | + } |
45 | 42 |
|
46 | | - return true; |
| 43 | + if (!derivedType.IsAssignableTo(methodInfo.DeclaringType)) |
| 44 | + { |
| 45 | + throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); |
47 | 46 | } |
48 | 47 |
|
49 | | - private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition) |
50 | | - => methodInfo.GetBaseDefinition() == baseDefinition; |
| 48 | + return true; |
| 49 | + } |
| 50 | + |
| 51 | + private static bool IsOverridingMethodOf(this MethodInfo methodInfo, MethodInfo baseDefinition) |
| 52 | + => methodInfo.GetBaseDefinition() == baseDefinition; |
51 | 53 |
|
52 | | - public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) |
| 54 | + public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo) |
| 55 | + { |
| 56 | + if (!derivedType.CanHaveOverridingMethod(methodInfo)) |
53 | 57 | { |
54 | | - if (!derivedType.CanHaveOverridingMethod(methodInfo)) |
55 | | - { |
56 | | - return methodInfo; |
57 | | - } |
| 58 | + return methodInfo; |
| 59 | + } |
58 | 60 |
|
59 | | - var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); |
| 61 | + var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); |
60 | 62 |
|
61 | | - MethodInfo? overridingMethod = null; |
62 | | - if (derivedMethods is { Length: > 0 }) |
63 | | - { |
64 | | - var baseDefinition = methodInfo.GetBaseDefinition(); |
65 | | - overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo |
66 | | - => derivedMethodInfo.IsOverridingMethodOf(baseDefinition)); |
67 | | - } |
68 | | - |
69 | | - return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo |
| 63 | + MethodInfo? overridingMethod = null; |
| 64 | + if (derivedMethods is { Length: > 0 }) |
| 65 | + { |
| 66 | + var baseDefinition = methodInfo.GetBaseDefinition(); |
| 67 | + overridingMethod = derivedMethods.FirstOrDefault(derivedMethodInfo |
| 68 | + => derivedMethodInfo.IsOverridingMethodOf(baseDefinition)); |
70 | 69 | } |
71 | 70 |
|
72 | | - public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo) |
| 71 | + return overridingMethod ?? methodInfo; // If no derived methods were found, return the original methodInfo |
| 72 | + } |
| 73 | + |
| 74 | + private static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo) |
| 75 | + { |
| 76 | + var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod); |
| 77 | + if (accessor is null) |
73 | 78 | { |
74 | | - var accessor = propertyInfo.GetAccessors(true).FirstOrDefault(derivedType.CanHaveOverridingMethod); |
75 | | - if (accessor is null) |
76 | | - { |
77 | | - return propertyInfo; |
78 | | - } |
| 79 | + return propertyInfo; |
| 80 | + } |
79 | 81 |
|
80 | | - var isGetAccessor = propertyInfo.GetMethod == accessor; |
| 82 | + var isGetAccessor = propertyInfo.GetMethod == accessor; |
81 | 83 |
|
82 | | - var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); |
83 | | - |
84 | | - PropertyInfo? overridingProperty = null; |
85 | | - if (derivedProperties is { Length: > 0 }) |
86 | | - { |
87 | | - var baseDefinition = accessor.GetBaseDefinition(); |
88 | | - overridingProperty = derivedProperties.FirstOrDefault(p |
89 | | - => (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true); |
90 | | - } |
91 | | - |
92 | | - return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo |
93 | | - } |
| 84 | + var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); |
94 | 85 |
|
95 | | - public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo) |
| 86 | + PropertyInfo? overridingProperty = null; |
| 87 | + if (derivedProperties is { Length: > 0 }) |
96 | 88 | { |
97 | | - var interfaceType = methodInfo.DeclaringType; |
98 | | - // We only need to search for interface methods |
99 | | - if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual) |
100 | | - { |
101 | | - return methodInfo; |
102 | | - } |
103 | | - |
104 | | - if (!derivedType.IsAssignableTo(interfaceType)) |
105 | | - { |
106 | | - throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); |
107 | | - } |
| 89 | + var baseDefinition = accessor.GetBaseDefinition(); |
| 90 | + overridingProperty = derivedProperties.FirstOrDefault(p |
| 91 | + => (isGetAccessor ? p.GetMethod : p.SetMethod)?.IsOverridingMethodOf(baseDefinition) == true); |
| 92 | + } |
108 | 93 |
|
109 | | - var interfaceMap = derivedType.GetInterfaceMap(interfaceType); |
110 | | - for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++) |
111 | | - { |
112 | | - if (interfaceMap.InterfaceMethods[i] == methodInfo) |
113 | | - { |
114 | | - return interfaceMap.TargetMethods[i]; |
115 | | - } |
116 | | - } |
| 94 | + return overridingProperty ?? propertyInfo; // If no derived methods were found, return the original methodInfo |
| 95 | + } |
117 | 96 |
|
118 | | - throw new ApplicationException( |
119 | | - $"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!"); |
| 97 | + private static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo) |
| 98 | + { |
| 99 | + var interfaceType = methodInfo.DeclaringType; |
| 100 | + // We only need to search for interface methods |
| 101 | + if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual) |
| 102 | + { |
| 103 | + return methodInfo; |
120 | 104 | } |
121 | 105 |
|
122 | | - public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo) |
| 106 | + if (!derivedType.IsAssignableTo(interfaceType)) |
123 | 107 | { |
124 | | - var accessor = propertyInfo.GetAccessors()[0]; |
| 108 | + throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo)); |
| 109 | + } |
125 | 110 |
|
126 | | - var implementingAccessor = derivedType.GetImplementingMethod(accessor); |
127 | | - if (implementingAccessor == accessor) |
| 111 | + var interfaceMap = derivedType.GetInterfaceMap(interfaceType); |
| 112 | + for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++) |
| 113 | + { |
| 114 | + if (interfaceMap.InterfaceMethods[i] == methodInfo) |
128 | 115 | { |
129 | | - return propertyInfo; |
| 116 | + return interfaceMap.TargetMethods[i]; |
130 | 117 | } |
| 118 | + } |
131 | 119 |
|
132 | | - var implementingType = implementingAccessor.DeclaringType |
133 | | - // This should only be null if it is a property accessor on the global module, |
134 | | - // which should never happen since we found it from derivedType |
135 | | - ?? throw new ApplicationException("The property accessor has no declaring type!"); |
| 120 | + throw new ApplicationException( |
| 121 | + $"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!"); |
| 122 | + } |
136 | 123 |
|
137 | | - var derivedProperties = implementingType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); |
| 124 | + public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo) |
| 125 | + { |
| 126 | + var accessor = propertyInfo.GetAccessors()[0]; |
138 | 127 |
|
139 | | - return derivedProperties.FirstOrDefault(propertyInfo.GetMethod == accessor |
140 | | - ? p => MethodInfosEqual(p.GetMethod, implementingAccessor) |
141 | | - : p => MethodInfosEqual(p.SetMethod, implementingAccessor)) ?? propertyInfo; |
| 128 | + var implementingAccessor = derivedType.GetImplementingMethod(accessor); |
| 129 | + if (implementingAccessor == accessor) |
| 130 | + { |
| 131 | + return propertyInfo; |
142 | 132 | } |
143 | 133 |
|
144 | | - /// <summary> |
145 | | - /// The built-in <see cref="MethodInfo.op_Equality(System.Reflection.MethodInfo?,System.Reflection.MethodInfo?)"/> |
146 | | - /// does not work if the <see cref="MemberInfo.ReflectedType"/>s don't agree. |
147 | | - /// </summary> |
148 | | - private static bool MethodInfosEqual(MethodInfo? first, MethodInfo second) |
149 | | - => first?.ReflectedType == second.ReflectedType |
150 | | - ? first == second |
151 | | - : first is not null |
152 | | - && first.DeclaringType == second.DeclaringType |
153 | | - && first.Name == second.Name |
154 | | - && first.GetParameters().Select(p => p.ParameterType) |
155 | | - .SequenceEqual(second.GetParameters().Select(p => p.ParameterType)) |
156 | | - && first.GetGenericArguments().SequenceEqual(second.GetGenericArguments()); |
157 | | - |
158 | | - public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo) |
159 | | - => methodInfo.DeclaringType?.IsInterface == true |
160 | | - ? derivedType.GetImplementingMethod(methodInfo) |
161 | | - : derivedType.GetOverridingMethod(methodInfo); |
162 | | - |
163 | | - public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo) |
164 | | - => propertyInfo.DeclaringType?.IsInterface == true |
165 | | - ? derivedType.GetImplementingProperty(propertyInfo) |
166 | | - : derivedType.GetOverridingProperty(propertyInfo); |
| 134 | + var implementingType = implementingAccessor.DeclaringType |
| 135 | + // This should only be null if it is a property accessor on the global module, |
| 136 | + // which should never happen since we found it from derivedType |
| 137 | + ?? throw new ApplicationException("The property accessor has no declaring type!"); |
| 138 | + |
| 139 | + var derivedProperties = implementingType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); |
| 140 | + |
| 141 | + return derivedProperties.FirstOrDefault(propertyInfo.GetMethod == accessor |
| 142 | + ? p => MethodInfosEqual(p.GetMethod, implementingAccessor) |
| 143 | + : p => MethodInfosEqual(p.SetMethod, implementingAccessor)) ?? propertyInfo; |
167 | 144 | } |
| 145 | + |
| 146 | + /// <summary> |
| 147 | + /// The built-in <see cref="MethodInfo.op_Equality(System.Reflection.MethodInfo?,System.Reflection.MethodInfo?)"/> |
| 148 | + /// does not work if the <see cref="MemberInfo.ReflectedType"/>s don't agree. |
| 149 | + /// </summary> |
| 150 | + private static bool MethodInfosEqual(MethodInfo? first, MethodInfo second) |
| 151 | + => first?.ReflectedType == second.ReflectedType |
| 152 | + ? first == second |
| 153 | + : first is not null |
| 154 | + && first.DeclaringType == second.DeclaringType |
| 155 | + && first.Name == second.Name |
| 156 | + && first.GetParameters().Select(p => p.ParameterType) |
| 157 | + .SequenceEqual(second.GetParameters().Select(p => p.ParameterType)) |
| 158 | + && first.GetGenericArguments().SequenceEqual(second.GetGenericArguments()); |
| 159 | + |
| 160 | + public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo) |
| 161 | + => methodInfo.DeclaringType?.IsInterface == true |
| 162 | + ? derivedType.GetImplementingMethod(methodInfo) |
| 163 | + : derivedType.GetOverridingMethod(methodInfo); |
| 164 | + |
| 165 | + public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo) |
| 166 | + => propertyInfo.DeclaringType?.IsInterface == true |
| 167 | + ? derivedType.GetImplementingProperty(propertyInfo) |
| 168 | + : derivedType.GetOverridingProperty(propertyInfo); |
168 | 169 | } |
0 commit comments