Skip to content

Commit 5d11534

Browse files
CopilotPhenX
andcommitted
Fix null conditional rewrite for nullable value types (add .Value accessor)
Co-authored-by: PhenX <42170+PhenX@users.noreply.github.com>
1 parent 285201c commit 5d11534

6 files changed

Lines changed: 160 additions & 8 deletions

src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.NullConditionalRewrite.cs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ public partial class ExpressionSyntaxRewriter
99
public override SyntaxNode? VisitConditionalAccessExpression(ConditionalAccessExpressionSyntax node)
1010
{
1111
var targetExpression = (ExpressionSyntax)Visit(node.Expression);
12+
var targetType = _semanticModel.GetTypeInfo(node.Expression).Type;
1213

13-
_conditionalAccessExpressionsStack.Push(targetExpression);
14+
_conditionalAccessExpressionsStack.Push((targetExpression, targetType));
1415

1516
if (_nullConditionalRewriteSupport == NullConditionalRewriteSupport.None)
1617
{
@@ -60,11 +61,16 @@ public partial class ExpressionSyntaxRewriter
6061
{
6162
if (_conditionalAccessExpressionsStack.Count > 0)
6263
{
63-
var targetExpression = _conditionalAccessExpressionsStack.Pop();
64+
var (targetExpression, targetType) = _conditionalAccessExpressionsStack.Pop();
65+
66+
// When the target is a Nullable<T> value type, we need .Value to access members on the underlying type
67+
var accessExpression = IsNullableValueType(targetType)
68+
? SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, SyntaxFactory.IdentifierName("Value"))
69+
: targetExpression;
6470

6571
return _nullConditionalRewriteSupport switch {
66-
NullConditionalRewriteSupport.Ignore => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, node.Name),
67-
NullConditionalRewriteSupport.Rewrite => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, node.Name),
72+
NullConditionalRewriteSupport.Ignore => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, accessExpression, node.Name),
73+
NullConditionalRewriteSupport.Rewrite => SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, accessExpression, node.Name),
6874
_ => node
6975
};
7076
}
@@ -76,15 +82,26 @@ public partial class ExpressionSyntaxRewriter
7682
{
7783
if (_conditionalAccessExpressionsStack.Count > 0)
7884
{
79-
var targetExpression = _conditionalAccessExpressionsStack.Pop();
85+
var (targetExpression, targetType) = _conditionalAccessExpressionsStack.Pop();
86+
87+
// When the target is a Nullable<T> value type, we need .Value to access indexer on the underlying type
88+
var accessExpression = IsNullableValueType(targetType)
89+
? SyntaxFactory.MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, targetExpression, SyntaxFactory.IdentifierName("Value"))
90+
: targetExpression;
8091

8192
return _nullConditionalRewriteSupport switch {
82-
NullConditionalRewriteSupport.Ignore => SyntaxFactory.ElementAccessExpression(targetExpression, node.ArgumentList),
83-
NullConditionalRewriteSupport.Rewrite => SyntaxFactory.ElementAccessExpression(targetExpression, node.ArgumentList),
93+
NullConditionalRewriteSupport.Ignore => SyntaxFactory.ElementAccessExpression(accessExpression, node.ArgumentList),
94+
NullConditionalRewriteSupport.Rewrite => SyntaxFactory.ElementAccessExpression(accessExpression, node.ArgumentList),
8495
_ => Visit(node)
8596
};
8697
}
8798

8899
return base.VisitElementBindingExpression(node);
89100
}
101+
102+
private static bool IsNullableValueType(ITypeSymbol? type)
103+
{
104+
return type is { IsValueType: true } &&
105+
type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T;
106+
}
90107
}

src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public partial class ExpressionSyntaxRewriter : CSharpSyntaxRewriter
1212
readonly NullConditionalRewriteSupport _nullConditionalRewriteSupport;
1313
readonly bool _expandEnumMethods;
1414
readonly SourceProductionContext _context;
15-
readonly Stack<ExpressionSyntax> _conditionalAccessExpressionsStack = new();
15+
readonly Stack<(ExpressionSyntax Expression, ITypeSymbol? Type)> _conditionalAccessExpressionsStack = new();
1616
readonly string? _extensionParameterName;
1717

1818
public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullConditionalRewriteSupport nullConditionalRewriteSupport, bool expandEnumMethods, SemanticModel semanticModel, SourceProductionContext context, string? extensionParameterName = null)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// <auto-generated/>
2+
#nullable disable
3+
using System;
4+
using EntityFrameworkCore.Projectables;
5+
using Foo;
6+
7+
namespace EntityFrameworkCore.Projectables.Generated
8+
{
9+
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
10+
static class Foo_C_GetX_P0_Foo_Point_
11+
{
12+
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Point?, double>> Expression()
13+
{
14+
return (global::Foo.Point? point) => (point != null ? (point.Value.X) : ( double ? )null) ?? 0.0;
15+
}
16+
}
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// <auto-generated/>
2+
#nullable disable
3+
using System;
4+
using EntityFrameworkCore.Projectables;
5+
using Foo;
6+
7+
namespace EntityFrameworkCore.Projectables.Generated
8+
{
9+
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
10+
static class Foo_C_GetX_P0_Foo_Point_
11+
{
12+
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Point?, double?>> Expression()
13+
{
14+
return (global::Foo.Point? point) => point.Value.X;
15+
}
16+
}
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// <auto-generated/>
2+
#nullable disable
3+
using System;
4+
using EntityFrameworkCore.Projectables;
5+
using Foo;
6+
7+
namespace EntityFrameworkCore.Projectables.Generated
8+
{
9+
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
10+
static class Foo_C_GetX_P0_Foo_Point_
11+
{
12+
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Point?, double?>> Expression()
13+
{
14+
return (global::Foo.Point? point) => (point != null ? (point.Value.X) : ( double ? )null);
15+
}
16+
}
17+
}

tests/EntityFrameworkCore.Projectables.Generator.Tests/NullableTests.cs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,4 +489,88 @@ class Foo {
489489

490490
return Verifier.Verify(result.GeneratedTrees[0].ToString());
491491
}
492+
493+
[Fact]
494+
public Task NullableValueType_MemberAccess_WithRewriteSupport_IsBeingRewritten()
495+
{
496+
var compilation = CreateCompilation(@"
497+
using System;
498+
using EntityFrameworkCore.Projectables;
499+
500+
namespace Foo {
501+
public struct Point {
502+
public double X { get; set; }
503+
public double Y { get; set; }
504+
}
505+
506+
static class C {
507+
[Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Rewrite)]
508+
public static double? GetX(this Point? point) => point?.X;
509+
}
510+
}
511+
");
512+
513+
var result = RunGenerator(compilation);
514+
515+
Assert.Empty(result.Diagnostics);
516+
Assert.Single(result.GeneratedTrees);
517+
518+
return Verifier.Verify(result.GeneratedTrees[0].ToString());
519+
}
520+
521+
[Fact]
522+
public Task NullableValueType_MemberAccess_WithIgnoreSupport_IsBeingRewritten()
523+
{
524+
var compilation = CreateCompilation(@"
525+
using System;
526+
using EntityFrameworkCore.Projectables;
527+
528+
namespace Foo {
529+
public struct Point {
530+
public double X { get; set; }
531+
public double Y { get; set; }
532+
}
533+
534+
static class C {
535+
[Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Ignore)]
536+
public static double? GetX(this Point? point) => point?.X;
537+
}
538+
}
539+
");
540+
541+
var result = RunGenerator(compilation);
542+
543+
Assert.Empty(result.Diagnostics);
544+
Assert.Single(result.GeneratedTrees);
545+
546+
return Verifier.Verify(result.GeneratedTrees[0].ToString());
547+
}
548+
549+
[Fact]
550+
public Task NullableValueType_MemberAccessWithCoalesce_WithRewriteSupport_IsBeingRewritten()
551+
{
552+
var compilation = CreateCompilation(@"
553+
using System;
554+
using EntityFrameworkCore.Projectables;
555+
556+
namespace Foo {
557+
public struct Point {
558+
public double X { get; set; }
559+
public double Y { get; set; }
560+
}
561+
562+
static class C {
563+
[Projectable(NullConditionalRewriteSupport = NullConditionalRewriteSupport.Rewrite)]
564+
public static double GetX(this Point? point) => point?.X ?? 0.0;
565+
}
566+
}
567+
");
568+
569+
var result = RunGenerator(compilation);
570+
571+
Assert.Empty(result.Diagnostics);
572+
Assert.Single(result.GeneratedTrees);
573+
574+
return Verifier.Verify(result.GeneratedTrees[0].ToString());
575+
}
492576
}

0 commit comments

Comments
 (0)