Skip to content

Commit a20c5db

Browse files
committed
fix(analysis): improve enumerator detection and expand IL coverage in static-field tracing
- Unify call-site checks in CollectionElementLayer and EnumeratorLayer using Call|Callvirt + HasThis, fixing GetEnumerator/get_Current detection logic. - Make EnumeratorLayer.GetEnumeratorType and GetEnumerableType public for reuse across analysis modules. - Expand opcode handling in ParameterFlowAnalyzer and StaticFieldReferenceAnalyzer (Ldarga*, Ldloc*, Ldloca*, Stloc, Ldflda, Ldsflda). - Enhance StaticFieldTracingChain to infer element layers from arrays or IEnumerable<T> when encapsulation is empty or only has an enumerator layer. - Refine dedup behavior in MonoModExtensions.GetRuntimeMethods and add recursive GetAllInterfaces(TypeReference).
1 parent 4f8c770 commit a20c5db

7 files changed

Lines changed: 115 additions & 56 deletions

File tree

src/OTAPI.UnifiedServerProcess/Core/Analysis/DataModels/MemberAccess/CollectionElementLayer.cs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@ static void LazyInit(ModuleDefinition module) {
3636
}
3737
public static bool IsStoreElementsMethod(TypeInheritanceGraph graph, MethodReference caller, Instruction storeMethodCallInstruction) {
3838

39-
if (storeMethodCallInstruction.Operand is not MethodReference storeMethod) {
40-
return false;
41-
}
42-
43-
if (!storeMethod.HasThis) {
39+
if (storeMethodCallInstruction is not {
40+
OpCode.Code: Code.Call or Code.Callvirt,
41+
Operand: MethodReference { HasThis: true } storeMethod
42+
}) {
4443
return false;
4544
}
4645

@@ -55,13 +54,11 @@ public static bool IsStoreElementsMethod(TypeInheritanceGraph graph, MethodRefer
5554
}
5655

5756
public static bool IsStoreElementMethod(TypeInheritanceGraph graph, MethodReference caller, Instruction storeMethodCallInstruction, out int indexOfValueInParameters) {
58-
5957
indexOfValueInParameters = -1;
60-
if (storeMethodCallInstruction.Operand is not MethodReference storeMethod) {
61-
return false;
62-
}
6358

64-
if (!storeMethod.HasThis) {
59+
if (storeMethodCallInstruction is not {
60+
OpCode.Code: Code.Call or Code.Callvirt,
61+
Operand: MethodReference { HasThis: true } storeMethod }) {
6562
return false;
6663
}
6764

src/OTAPI.UnifiedServerProcess/Core/Analysis/DataModels/MemberAccess/EnumeratorLayer.cs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ public class EnumeratorLayer(TypeReference collectionType) : MemberAccessStep
1313
public override TypeReference MemberType => collectionType;
1414
static TypeReference? enumeratorType;
1515
static TypeReference? enumerableType;
16-
static TypeReference GetEnumeratorType(ModuleDefinition module) {
16+
public static TypeReference GetEnumeratorType(ModuleDefinition module) {
1717
return enumeratorType ??= module.ImportReference(typeof(IEnumerator<>));
1818
}
19-
static TypeReference GetEnumerableType(ModuleDefinition module) {
19+
public static TypeReference GetEnumerableType(ModuleDefinition module) {
2020
return enumerableType ??= module.ImportReference(typeof(IEnumerable<>));
2121
}
2222
public static bool IsEnumerator(TypeInheritanceGraph graph, TypeDefinition type) {
@@ -30,16 +30,16 @@ public static bool IsEnumerator(TypeInheritanceGraph graph, TypeDefinition type)
3030
return true;
3131
}
3232
public static bool IsGetEnumeratorMethod(TypeInheritanceGraph graph, MethodReference caller, Instruction getEnumeratorInstruction) {
33-
if (getEnumeratorInstruction.OpCode != OpCodes.Call || getEnumeratorInstruction.OpCode != OpCodes.Callvirt) {
33+
if (getEnumeratorInstruction is not {
34+
OpCode.Code: Code.Call or Code.Callvirt,
35+
Operand: MethodReference { HasThis: true } getEnumerator
36+
}) {
3437
return false;
3538
}
36-
if (getEnumeratorInstruction.Operand is not MethodReference methodRef) {
39+
if (getEnumerator.Name != "GetEnumerator") {
3740
return false;
3841
}
39-
if (methodRef.Name != "GetEnumerator") {
40-
return false;
41-
}
42-
var declaringType = methodRef.DeclaringType.TryResolve();
42+
var declaringType = getEnumerator.DeclaringType.TryResolve();
4343
if (declaringType is null) {
4444
return false;
4545
}
@@ -52,17 +52,17 @@ public static bool IsGetEnumeratorMethod(TypeInheritanceGraph graph, MethodRefer
5252
}
5353
return true;
5454
}
55-
public static bool IsGetCurrentMethod(TypeInheritanceGraph graph, MethodReference caller, Instruction getEnumeratorInstruction) {
56-
if (getEnumeratorInstruction.OpCode != OpCodes.Call && getEnumeratorInstruction.OpCode != OpCodes.Callvirt) {
57-
return false;
58-
}
59-
if (getEnumeratorInstruction.Operand is not MethodReference methodRef) {
55+
public static bool IsGetCurrentMethod(TypeInheritanceGraph graph, MethodReference caller, Instruction getCurrentInstruction) {
56+
if (getCurrentInstruction is not {
57+
OpCode.Code: Code.Call or Code.Callvirt,
58+
Operand: MethodReference { HasThis: true } getCurrent
59+
}) {
6060
return false;
6161
}
62-
if (methodRef.Name != "get_Current") {
62+
if (getCurrent.Name != "get_Current") {
6363
return false;
6464
}
65-
var declaringType = methodRef.DeclaringType.TryResolve();
65+
var declaringType = getCurrent.DeclaringType.TryResolve();
6666
if (declaringType is null) {
6767
return false;
6868
}

src/OTAPI.UnifiedServerProcess/Core/Analysis/MethodInheritanceGraph.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,6 @@ private static void ProcessInterfaces(
533533
Dictionary<string, Dictionary<string, MethodDefinition>> chains) {
534534
foreach (var interfaceImpl in type.Interfaces) {
535535
var interfaceDef = interfaceImpl.InterfaceType.Resolve();
536-
if (interfaceDef.Name.StartsWith("IEntryFilter")) {
537-
538-
}
539536

540537
foreach (var interfaceMethod in interfaceDef.Methods) {
541538
// Instantiate the interface method against the concrete interface type (handles generics).

src/OTAPI.UnifiedServerProcess/Core/Analysis/ParameterFlowAnalysis/ParameterFlowAnalyzer.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ private void ProcessMethod(
183183
case Code.Ldarg_3:
184184
case Code.Ldarg_S:
185185
case Code.Ldarg:
186+
case Code.Ldarga_S:
187+
case Code.Ldarga:
186188
HandleLoadArgument(instruction);
187189
break;
188190

@@ -191,6 +193,7 @@ private void ProcessMethod(
191193
case Code.Stloc_2:
192194
case Code.Stloc_3:
193195
case Code.Stloc_S:
196+
case Code.Stloc:
194197
HandleStoreLocal(instruction);
195198
break;
196199

@@ -199,10 +202,14 @@ private void ProcessMethod(
199202
case Code.Ldloc_2:
200203
case Code.Ldloc_3:
201204
case Code.Ldloc_S:
205+
case Code.Ldloc:
206+
case Code.Ldloca_S:
207+
case Code.Ldloca:
202208
HandleLoadLocal(instruction);
203209
break;
204210

205211
case Code.Ldfld:
212+
case Code.Ldflda:
206213
HandleLoadField(instruction);
207214
break;
208215

src/OTAPI.UnifiedServerProcess/Core/Analysis/StaticFieldReferenceAnalysis/StaticFieldReferenceAnalyzer.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ private void ProcessMethod(
176176
case Code.Stloc_2:
177177
case Code.Stloc_3:
178178
case Code.Stloc_S:
179+
case Code.Stloc:
179180
HandleStoreLocal(instruction);
180181
break;
181182

@@ -184,14 +185,19 @@ private void ProcessMethod(
184185
case Code.Ldloc_2:
185186
case Code.Ldloc_3:
186187
case Code.Ldloc_S:
188+
case Code.Ldloc:
189+
case Code.Ldloca_S:
190+
case Code.Ldloca:
187191
HandleLoadLocal(instruction);
188192
break;
189193

190194
case Code.Ldsfld:
195+
case Code.Ldsflda:
191196
HandleLoadStaticField(instruction);
192197
break;
193198

194199
case Code.Ldfld:
200+
case Code.Ldflda:
195201
HandleLoadField(instruction);
196202
break;
197203

src/OTAPI.UnifiedServerProcess/Core/Analysis/StaticFieldReferenceAnalysis/StaticFieldTracingChain.cs

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ private bool TryExtendComponentAccessPath(MemberAccessStep member, TypeFlowSccIn
229229

230230
var loopBaseType = loopState.ExactType ?? TracingStaticField.FieldType;
231231
if (!sccIndex.IsInSccIncludingBaseTypes(loopBaseType, loop.SccId)) {
232-
// A loop summary that doesn't apply to the current type would over-constrain the path; ignore it.
232+
// A loop summary that doesn't apply to the current typeRef would over-constrain the path; ignore it.
233233
result = this;
234234
return true;
235235
}
@@ -324,23 +324,47 @@ private static bool WouldRevisitTypeWithinScc(
324324
return sccIndex.IsInSccIncludingBaseTypes(nextType, sccId) && seen.Contains(nextType.FullName);
325325
}
326326
public bool TryTraceEnumeratorCurrent([NotNullWhen(true)] out StaticFieldTracingChain? result) {
327-
if (EncapsulationHierarchy.Length < 2) {
327+
if (EncapsulationHierarchy.Length < 1) {
328328
result = null;
329329
return false;
330330
}
331331
if (EncapsulationHierarchy[0] is not EnumeratorLayer) {
332332
result = null;
333333
return false;
334334
}
335-
if (EncapsulationHierarchy[1] is not ArrayElementLayer && EncapsulationHierarchy[1] is not CollectionElementLayer) {
336-
throw new NotSupportedException("Enumerator layer must be followed by ArrayElementLayer or CollectionElementLayer.");
335+
if (EncapsulationHierarchy.Length > 1 && EncapsulationHierarchy[1] is ArrayElementLayer or CollectionElementLayer) {
336+
result = new StaticFieldTracingChain(
337+
TracingStaticField,
338+
EncapsulationHierarchy.RemoveAt(0).RemoveAt(0),
339+
ComponentAccessPath
340+
);
341+
return true;
337342
}
338-
result = new StaticFieldTracingChain(
339-
TracingStaticField,
340-
EncapsulationHierarchy.RemoveAt(0).RemoveAt(0),
341-
ComponentAccessPath
342-
);
343-
return true;
343+
if (EncapsulationHierarchy.Length is 1) {
344+
var typeRef = TracingStaticField.FieldType;
345+
if (!ComponentAccessPath.IsEmpty) {
346+
typeRef = ComponentAccessPath.Last().MemberType;
347+
}
348+
if (typeRef is ArrayType at) {
349+
result = new StaticFieldTracingChain(
350+
TracingStaticField,
351+
EncapsulationHierarchy.RemoveAt(0),
352+
ComponentAccessPath.Add(new ArrayElementLayer(at))
353+
);
354+
return true;
355+
}
356+
var interfaces = typeRef.GetAllInterfaces().ToArray();
357+
var (idef, iref) = interfaces.FirstOrDefault(i => i.idef.FullName == EnumeratorLayer.GetEnumerableType(typeRef.Module).FullName);
358+
if (iref is GenericInstanceType git) {
359+
result = new StaticFieldTracingChain(
360+
TracingStaticField,
361+
EncapsulationHierarchy.RemoveAt(0),
362+
ComponentAccessPath.Add(new CollectionElementLayer(typeRef, git.GenericArguments.Last()))
363+
);
364+
return true;
365+
}
366+
}
367+
throw new NotSupportedException("Enumerator layer must be followed by ArrayElementLayer or CollectionElementLayer.");
344368
}
345369

346370
public StaticFieldTracingChain CreateEncapsulatedInstance(MemberReference storedIn)
@@ -443,11 +467,22 @@ private static ImmutableArray<MemberAccessStep> NormalizeEncapsulationHierarchyW
443467
}
444468
public StaticFieldTracingChain? CreateEncapsulatedEnumeratorInstance() {
445469
if (EncapsulationHierarchy.IsEmpty) {
446-
return null;
470+
var typeRef = TracingStaticField.FieldType;
471+
if (!ComponentAccessPath.IsEmpty) {
472+
typeRef = ComponentAccessPath.Last().MemberType;
473+
}
474+
var typeDef = typeRef.TryResolve();
475+
if (typeRef is not ArrayType && !typeRef.GetAllInterfaces()
476+
.Select(x => x.idef.FullName)
477+
.Contains(EnumeratorLayer.GetEnumerableType(typeRef.Module).FullName)) {
478+
return null;
479+
}
480+
481+
return new StaticFieldTracingChain(this, new EnumeratorLayer(typeRef));
447482
}
448483
if (EncapsulationHierarchy[0] is ArrayElementLayer or CollectionElementLayer) {
449-
var collectionEle = EncapsulationHierarchy[0];
450-
return new StaticFieldTracingChain(this, new EnumeratorLayer(collectionEle.DeclaringType));
484+
var collectionLayer = EncapsulationHierarchy[0];
485+
return new StaticFieldTracingChain(this, new EnumeratorLayer(collectionLayer.DeclaringType));
451486
}
452487
// if there has already enumerator, we don't need to create a nested one
453488
if (EncapsulationHierarchy[0] is EnumeratorLayer) {

src/OTAPI.UnifiedServerProcess/Extensions/MonoModExtensions.cs

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,33 +239,29 @@ public static MethodDefinition GetMethod(this TypeDefinition type, string name)
239239
return type.Methods.Single((MethodDefinition x) => x.Name == name);
240240
}
241241
public static IEnumerable<MethodDefinition> GetRuntimeMethods(this TypeDefinition type, bool includeInterf = false) {
242-
foreach (var md in type.Methods) {
243-
yield return md;
244-
}
245-
var baseType = type.BaseType?.TryResolve();
246-
if (baseType is not null) {
247-
foreach (var md in baseType.GetRuntimeMethods()) {
242+
HashSet<MethodDefinition> visited = [];
243+
foreach (var md in type.Methods)
244+
if (visited.Add(md))
248245
yield return md;
249-
}
250-
}
246+
var baseType = type.BaseType?.TryResolve();
247+
if (baseType is not null)
248+
foreach (var md in baseType.GetRuntimeMethods())
249+
if (visited.Add(md))
250+
yield return md;
251251
static IEnumerable<MethodDefinition> GetInterfaceMethods(TypeDefinition type) {
252-
HashSet<MethodDefinition> visited = [];
253252
if (type.IsInterface)
254253
foreach (var md in type.Methods)
255-
if (visited.Add(md))
256-
yield return md;
254+
yield return md;
257255
foreach (var interf in type.Interfaces) {
258256
var interfDef = interf.InterfaceType.TryResolve();
259257
if (interfDef is not null)
260258
foreach (var md in interfDef.Methods)
261-
if (visited.Add(md))
262-
yield return md;
259+
yield return md;
263260
}
264261
var baseType = type.BaseType?.TryResolve();
265262
if (baseType is not null) {
266263
foreach (var md in GetInterfaceMethods(baseType))
267-
if (visited.Add(md))
268-
yield return md;
264+
yield return md;
269265
}
270266
}
271267
if (includeInterf) {
@@ -274,6 +270,27 @@ static IEnumerable<MethodDefinition> GetInterfaceMethods(TypeDefinition type) {
274270
}
275271
}
276272
}
273+
public static IEnumerable<(TypeDefinition idef, TypeReference iref)> GetAllInterfaces(this TypeReference type) {
274+
HashSet<string> visited = [];
275+
var typeDef = type.TryResolve();
276+
if (typeDef is null) {
277+
yield break;
278+
}
279+
if (typeDef.IsInterface) {
280+
yield return (typeDef, type);
281+
}
282+
foreach (var interf in typeDef.Interfaces) {
283+
var interfDef = interf.InterfaceType.TryResolve();
284+
if (interfDef is not null)
285+
if (visited.Add(interfDef.FullName))
286+
yield return (interfDef, interf.InterfaceType);
287+
}
288+
var baseType = typeDef.BaseType?.TryResolve();
289+
if (baseType is not null)
290+
foreach (var interfDef in typeDef.BaseType!.GetAllInterfaces())
291+
if(visited.Add(interfDef.idef.FullName))
292+
yield return interfDef;
293+
}
277294

278295
public static EventDefinition GetEvent(this TypeDefinition type, string name) {
279296
return type.Events.Single((EventDefinition x) => x.Name == name);

0 commit comments

Comments
 (0)