diff --git a/CelesteAnalyzer/CelesteAnalyzer/HookAnalyzer.cs b/CelesteAnalyzer/CelesteAnalyzer/HookAnalyzer.cs index 4c770a4..7772377 100644 --- a/CelesteAnalyzer/CelesteAnalyzer/HookAnalyzer.cs +++ b/CelesteAnalyzer/CelesteAnalyzer/HookAnalyzer.cs @@ -2,6 +2,8 @@ using System.Collections; using System.Collections.Immutable; using System.Linq; +using System.Linq.Expressions; +using System.Threading; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -41,16 +43,16 @@ private void AnalyzeEventAssignment(OperationAnalysisContext context) { // look for On/IL hook assignment to MonoMod.RuntimeDetour.HookGen - if (context.Operation.SemanticModel is not { } sem) - return; if (context.Operation is not IEventAssignmentOperation op) return; if (op.Syntax is not AssignmentExpressionSyntax assignment) return; + if (context.Operation.SemanticModel is not { } sem) + return; if (sem.GetOperation(assignment.Left) is not IEventReferenceOperation refOp) return; if (Utils.BottommostNamespace(refOp.Event)?.Name is "On" or "IL") - AnalyzeHookFromLambdaOrIdentifier(context, assignment.Right, sem); + AnalyzeHookFromLambdaOrIdentifier(context, assignment.Right, sem, context.CancellationToken); } private void AnalyzeObjectCreationOperation(OperationAnalysisContext context) @@ -74,17 +76,19 @@ private void AnalyzeObjectCreationOperation(OperationAnalysisContext context) var targetArg = creationSyntax.ArgumentList!.Arguments[1]; - AnalyzeHookFromLambdaOrIdentifier(context, targetArg.Expression, sem); + AnalyzeHookFromLambdaOrIdentifier(context, targetArg.Expression, sem, context.CancellationToken); } - private static void AnalyzeHookFromLambdaOrIdentifier(OperationAnalysisContext context, ExpressionSyntax targetArg, SemanticModel sem) + private static void AnalyzeHookFromLambdaOrIdentifier( + OperationAnalysisContext context, ExpressionSyntax targetArg, SemanticModel sem, + CancellationToken ct) { // check method references if (targetArg is IdentifierNameSyntax id) { if (Utils.GetMethodDeclarationSyntaxFromIdentifier(id, sem, out var methodRef) is { } syntax) { - AnalyzeHook(context, methodRef!.Method, syntax.Body, syntax.GetLocation()); + AnalyzeHook(context, methodRef!.Method, syntax.Body, syntax.GetLocation(), ct); } } @@ -99,12 +103,13 @@ private static void AnalyzeHookFromLambdaOrIdentifier(OperationAnalysisContext c .OfType() .FirstOrDefault() is { } syntax) { - AnalyzeHook(context, methodRef.Symbol, (SyntaxNode?)syntax.Block ?? syntax.ExpressionBody, syntax.GetLocation()); + AnalyzeHook(context, methodRef.Symbol, (SyntaxNode?)syntax.Block ?? syntax.ExpressionBody, syntax.GetLocation(), ct); } } } - private static void AnalyzeHook(OperationAnalysisContext context, IMethodSymbol methodSymbol, SyntaxNode? bodySyntax, Location loc) + private static void AnalyzeHook(OperationAnalysisContext context, IMethodSymbol methodSymbol, SyntaxNode? bodySyntax, + Location loc, CancellationToken ct) { var firstParam = methodSymbol.Parameters.First(); @@ -127,6 +132,8 @@ private static void AnalyzeHook(OperationAnalysisContext context, IMethodSymbol foreach (var st in bodySyntax.DescendantNodes()) { + if (ct.IsCancellationRequested) + break; if (IsOrig(st, firstParam)) { origCalled = true; @@ -150,8 +157,28 @@ private static void AnalyzeHook(OperationAnalysisContext context, IMethodSymbol static bool IsOrig(SyntaxNode? st, IParameterSymbol orig) { - return st is InvocationExpressionSyntax invocationExpressionSyntax && - invocationExpressionSyntax.Expression.ToString() == orig.Name; + // orig?.Invoke(...) + if (st is ConditionalAccessExpressionSyntax cond + && cond.WhenNotNull is InvocationExpressionSyntax notNullInvocationExpr + && notNullInvocationExpr.Expression.ToString() == ".Invoke" + && cond.Expression.ToString() == orig.Name) + { + return true; + } + + if (st is not InvocationExpressionSyntax invocationExpressionSyntax) + return false; + + // orig.Invoke(...) + if (invocationExpressionSyntax.Expression is MemberAccessExpressionSyntax memberAccessExpressionSyntax + && memberAccessExpressionSyntax.Name.ToString() == "Invoke" + && memberAccessExpressionSyntax.Expression.ToString() == orig.Name) + { + return true; + } + + // orig(...); + return invocationExpressionSyntax.Expression.ToString() == orig.Name; } }