|
1 | 1 | using System; |
2 | 2 | using System.Linq; |
| 3 | +using System.Threading; |
3 | 4 | using System.Threading.Tasks; |
4 | 5 | using ICSharpCode.CodeConverter.Util; |
5 | 6 | using Microsoft.CodeAnalysis; |
| 7 | +using Microsoft.CodeAnalysis.CSharp; |
| 8 | +using Microsoft.CodeAnalysis.Operations; |
6 | 9 | using Microsoft.CodeAnalysis.Simplification; |
7 | 10 | using VBasic = Microsoft.CodeAnalysis.VisualBasic; |
8 | 11 | using VBSyntax = Microsoft.CodeAnalysis.VisualBasic.Syntax; |
@@ -37,28 +40,91 @@ public static async Task<Document> WithExpandedRootAsync(this Document document) |
37 | 40 | var shouldExpand = document.Project.Language == LanguageNames.VisualBasic |
38 | 41 | ? (Func<SemanticModel, SyntaxNode, bool>)ShouldExpandVbNode |
39 | 42 | : ShouldExpandCsNode; |
40 | | - var root = (VBasic.VisualBasicSyntaxNode) await document.GetSyntaxRootAsync(); |
41 | | - root = await ExpandVbAsync(document, root, shouldExpand); |
42 | | - return await UndoBadVbExpansionsAsync(document, root); |
| 43 | + document = await WorkaroundBugsInExpandVbAsync(document, shouldExpand); |
| 44 | + document = await ExpandVbAsync(document, shouldExpand); |
| 45 | + return await UndoBadVbExpansionsAsync(document); |
43 | 46 | } |
44 | 47 |
|
45 | | - private static async Task<VBasic.VisualBasicSyntaxNode> ExpandVbAsync(Document document, |
46 | | - VBasic.VisualBasicSyntaxNode root, Func<SemanticModel, SyntaxNode, bool> shouldExpand) |
| 48 | + private static async Task<Document> WorkaroundBugsInExpandVbAsync(Document document, Func<SemanticModel, SyntaxNode, bool> shouldExpand) |
47 | 49 | { |
48 | 50 | var semanticModel = await document.GetSemanticModelAsync(); |
49 | | - var workspace = document.Project.Solution.Workspace; |
| 51 | + var root = (VBasic.VisualBasicSyntaxNode)await document.GetSyntaxRootAsync(); |
| 52 | + |
| 53 | + try { |
| 54 | + var newRoot = root.ReplaceNodes(root.DescendantNodes(n => !shouldExpand(semanticModel, n)).Where(n => shouldExpand(semanticModel, n)), |
| 55 | + (node, rewrittenNode) => { |
| 56 | + var symbol = semanticModel.GetSymbolInfo(node).Symbol; |
| 57 | + if (rewrittenNode is VBSyntax.SimpleNameSyntax sns && IsMyBaseBug(semanticModel, root, node, symbol) && semanticModel.GetOperation(node) is IMemberReferenceOperation mro) { |
| 58 | + return VBasic.SyntaxFactory.MemberAccessExpression(VBasic.SyntaxKind.SimpleMemberAccessExpression, |
| 59 | + (VBSyntax.ExpressionSyntax) mro.Instance.Syntax, |
| 60 | + VBasic.SyntaxFactory.Token(VBasic.SyntaxKind.DotToken), |
| 61 | + sns); |
| 62 | + }; |
| 63 | + return rewrittenNode; |
| 64 | + }); |
| 65 | + return document.WithSyntaxRoot(newRoot); |
| 66 | + } catch (Exception) { |
| 67 | + return document.WithSyntaxRoot(root); |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + /// <returns>True iff calling Expand would qualify with MyBase when the symbol isn't in the base type |
| 72 | + /// See https://github.com/dotnet/roslyn/blob/97123b393c3a5a91cc798b329db0d7fc38634784/src/Workspaces/VisualBasic/Portable/Simplification/VisualBasicSimplificationService.Expander.vb#L657</returns> |
| 73 | + private static bool IsMyBaseBug(SemanticModel semanticModel, VBasic.VisualBasicSyntaxNode root, SyntaxNode node, |
| 74 | + ISymbol symbol) |
| 75 | + { |
| 76 | + if (symbol?.IsStatic == false && (symbol.Kind == SymbolKind.Method || symbol.Kind == |
| 77 | + SymbolKind.Field || symbol.Kind == SymbolKind.Property)) |
| 78 | + { |
| 79 | + INamedTypeSymbol nodeEnclosingNamedType = GetEnclosingNamedType(semanticModel, root, node.SpanStart); |
| 80 | + if (!Equals(nodeEnclosingNamedType, symbol.ContainingType)) { |
| 81 | + return !Equals(nodeEnclosingNamedType, symbol.ContainingType?.BaseType); |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + return false; |
| 86 | + } |
| 87 | + |
| 88 | + /// <summary> |
| 89 | + /// Pasted from AbstractGenerateFromMembersCodeRefactoringProvider |
| 90 | + /// Gets the enclosing named type for the specified position. We can't use |
| 91 | + /// <see cref="SemanticModel.GetEnclosingSymbol"/> because that doesn't return |
| 92 | + /// the type you're current on if you're on the header of a class/interface. |
| 93 | + /// </summary> |
| 94 | + private static INamedTypeSymbol GetEnclosingNamedType( |
| 95 | + SemanticModel semanticModel, SyntaxNode root, int start, CancellationToken cancellationToken = default(CancellationToken)) |
| 96 | + { |
| 97 | + var token = root.FindToken(start); |
| 98 | + if (token == ((ICompilationUnitSyntax)root).EndOfFileToken) { |
| 99 | + token = token.GetPreviousToken(); |
| 100 | + } |
| 101 | + |
| 102 | + for (var node = token.Parent; node != null; node = node.Parent) { |
| 103 | + if (semanticModel.GetDeclaredSymbol(node) is INamedTypeSymbol declaration) { |
| 104 | + return declaration; |
| 105 | + } |
| 106 | + } |
50 | 107 |
|
| 108 | + return null; |
| 109 | + } |
| 110 | + |
| 111 | + private static async Task<Document> ExpandVbAsync(Document document, Func<SemanticModel, SyntaxNode, bool> shouldExpand) |
| 112 | + { |
| 113 | + var semanticModel = await document.GetSemanticModelAsync(); |
| 114 | + var workspace = document.Project.Solution.Workspace; |
| 115 | + var root = (VBasic.VisualBasicSyntaxNode) await document.GetSyntaxRootAsync(); |
51 | 116 | try { |
52 | | - return root.ReplaceNodes(root.DescendantNodes(n => !shouldExpand(semanticModel, n)).Where(n => shouldExpand(semanticModel, n)), |
| 117 | + var newRoot = root.ReplaceNodes(root.DescendantNodes(n => !shouldExpand(semanticModel, n)).Where(n => shouldExpand(semanticModel, n)), |
53 | 118 | (node, rewrittenNode) => TryExpandNode(node, semanticModel, workspace) |
54 | 119 | ); |
| 120 | + return document.WithSyntaxRoot(newRoot); |
55 | 121 | } catch (Exception) { |
56 | | - return root; |
| 122 | + return document.WithSyntaxRoot(root); |
57 | 123 | } |
58 | 124 | } |
59 | | - private static async Task<Document> UndoBadVbExpansionsAsync(Document document, |
60 | | - VBasic.VisualBasicSyntaxNode root) |
| 125 | + private static async Task<Document> UndoBadVbExpansionsAsync(Document document) |
61 | 126 | { |
| 127 | + var root = (VBasic.VisualBasicSyntaxNode)await document.GetSyntaxRootAsync(); |
62 | 128 | var toSimplify = root.DescendantNodes() |
63 | 129 | .Where(n => n.IsKind(VBasic.SyntaxKind.PredefinedCastExpression, VBasic.SyntaxKind.CTypeExpression, VBasic.SyntaxKind.DirectCastExpression)) |
64 | 130 | .Where(n => n.HasAnnotation(Simplifier.Annotation)); |
|
0 commit comments