Skip to content

Commit c24ee00

Browse files
authored
CSHARP-5847: Support Select/SelectMany/Where index overloads in LINQ provider (#1949)
1 parent e242566 commit c24ee00

16 files changed

Lines changed: 372 additions & 15 deletions

File tree

src/MongoDB.Driver/Core/Misc/Feature.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public class Feature
3535
private static readonly Feature __aggregateOutToDifferentDatabase = new Feature("AggregateOutToDifferentDatabase", WireVersion.Server44);
3636
private static readonly Feature __aggregateToString = new Feature("AggregateToString", WireVersion.Server40);
3737
private static readonly Feature __aggregateUnionWith = new Feature("AggregateUnionWith", WireVersion.Server44);
38+
private static readonly Feature __arrayIndexAs = new Feature("ArrayIndexAs", WireVersion.Server83);
3839
private static readonly Feature __bitwiseOperators = new Feature("BitwiseOperators", WireVersion.Server63);
3940
private static readonly Feature __changeStreamAllChangesForCluster = new Feature("ChangeStreamAllChangesForCluster", WireVersion.Server40);
4041
private static readonly Feature __changeStreamForDatabase = new Feature("ChangeStreamForDatabase", WireVersion.Server40);
@@ -161,6 +162,11 @@ public class Feature
161162
/// </summary>
162163
public static Feature AggregateUnionWith => __aggregateUnionWith;
163164

165+
/// <summary>
166+
/// Gets the arrayIndexAs feature for $map, $filter and $reduce.
167+
/// </summary>
168+
public static Feature ArrayIndexAs => __arrayIndexAs;
169+
164170
/// <summary>
165171
/// Gets the bitwise operators feature.
166172
/// </summary>

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ public static AstFieldPathExpression FieldPath(string path)
428428
return new AstFieldPathExpression(path);
429429
}
430430

431-
public static AstExpression Filter(AstExpression input, AstExpression cond, string @as, AstExpression limit = null)
431+
public static AstExpression Filter(AstExpression input, AstExpression cond, string @as, AstExpression limit = null, string arrayIndexAs = null)
432432
{
433-
return new AstFilterExpression(input, cond, @as, limit);
433+
return new AstFilterExpression(input, cond, @as, limit, arrayIndexAs);
434434
}
435435

436436
public static AstExpression First(AstExpression array)
@@ -592,9 +592,9 @@ public static AstExpression LTrim(AstExpression input, AstExpression chars = nul
592592
return new AstLTrimExpression(input, chars);
593593
}
594594

595-
public static AstExpression Map(AstExpression input, AstVarExpression @as, AstExpression @in)
595+
public static AstExpression Map(AstExpression input, AstVarExpression @as, AstExpression @in, AstVarExpression arrayIndexAs = null)
596596
{
597-
return new AstMapExpression(input, @as, @in);
597+
return new AstMapExpression(input, @as, @in, arrayIndexAs);
598598
}
599599

600600
public static AstExpression Max(AstExpression array)

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstFilterExpression.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions
2121
{
2222
internal sealed class AstFilterExpression : AstExpression
2323
{
24+
private readonly string _arrayIndexAs;
2425
private readonly string _as;
2526
private readonly AstExpression _cond;
2627
private readonly AstExpression _input;
@@ -30,14 +31,17 @@ public AstFilterExpression(
3031
AstExpression input,
3132
AstExpression cond,
3233
string @as = null,
33-
AstExpression limit = null)
34+
AstExpression limit = null,
35+
string arrayIndexAs = null)
3436
{
3537
_input = Ensure.IsNotNull(input, nameof(input));
3638
_cond = Ensure.IsNotNull(cond, nameof(cond));
3739
_as = @as;
3840
_limit = limit;
41+
_arrayIndexAs = arrayIndexAs;
3942
}
4043

44+
public string ArrayIndexAs => _arrayIndexAs;
4145
public string As => _as;
4246
public new AstExpression Cond => _cond;
4347
public AstExpression Input => _input;
@@ -57,6 +61,7 @@ public override BsonValue Render()
5761
{
5862
{ "input", _input.Render() },
5963
{ "as", _as, _as != null },
64+
{ "arrayIndexAs", _arrayIndexAs, _arrayIndexAs != null },
6065
{ "cond", _cond.Render() },
6166
{ "limit", () => _limit.Render(), _limit != null }
6267
}
@@ -69,12 +74,12 @@ public AstFilterExpression Update(
6974
AstExpression cond,
7075
AstExpression limit)
7176
{
72-
if (input == _input && cond == _cond)
77+
if (input == _input && cond == _cond && limit == _limit)
7378
{
7479
return this;
7580
}
7681

77-
return new AstFilterExpression(input, cond, _as, limit);
82+
return new AstFilterExpression(input, cond, _as, limit, _arrayIndexAs);
7883
}
7984
}
8085
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstMapExpression.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,24 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions
2121
{
2222
internal sealed class AstMapExpression : AstExpression
2323
{
24+
private readonly AstVarExpression _arrayIndexAs;
2425
private readonly AstVarExpression _as;
2526
private readonly AstExpression _in;
2627
private readonly AstExpression _input;
2728

2829
public AstMapExpression(
2930
AstExpression input,
3031
AstVarExpression @as,
31-
AstExpression @in)
32+
AstExpression @in,
33+
AstVarExpression arrayIndexAs = null)
3234
{
3335
_input = Ensure.IsNotNull(input, nameof(input));
3436
_as = @as;
3537
_in = Ensure.IsNotNull(@in, nameof(@in));
38+
_arrayIndexAs = arrayIndexAs;
3639
}
3740

41+
public AstVarExpression ArrayIndexAs => _arrayIndexAs;
3842
public AstVarExpression As => _as;
3943
public new AstExpression In => _in;
4044
public AstExpression Input => _input;
@@ -53,6 +57,7 @@ public override BsonValue Render()
5357
{
5458
{ "input", _input.Render() },
5559
{ "as", _as?.Name, _as != null },
60+
{ "arrayIndexAs", _arrayIndexAs?.Name, _arrayIndexAs != null },
5661
{ "in", _in.Render() }
5762
}
5863
}
@@ -62,14 +67,15 @@ public override BsonValue Render()
6267
public AstMapExpression Update(
6368
AstExpression input,
6469
AstVarExpression @as,
70+
AstVarExpression arrayIndexAs,
6571
AstExpression @in)
6672
{
67-
if (input == _input && @as == _as && @in == _in)
73+
if (input == _input && @as == _as && arrayIndexAs == _arrayIndexAs && @in == _in)
6874
{
6975
return this;
7076
}
7177

72-
return new AstMapExpression(input, @as, @in);
78+
return new AstMapExpression(input, @as, @in, arrayIndexAs);
7379
}
7480
}
7581
}

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Visitors/AstNodeVisitor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ public virtual AstNode VisitLTrimExpression(AstLTrimExpression node)
496496

497497
public virtual AstNode VisitMapExpression(AstMapExpression node)
498498
{
499-
return node.Update(VisitAndConvert(node.Input), VisitAndConvert(node.As), VisitAndConvert(node.In));
499+
return node.Update(VisitAndConvert(node.Input), VisitAndConvert(node.As), VisitAndConvert(node.ArrayIndexAs), VisitAndConvert(node.In));
500500
}
501501

502502
public virtual AstNode VisitMatchesEverythingFilter(AstMatchesEverythingFilter node)

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ static EnumerableMethod()
568568
public static MethodInfo Select => __select;
569569
public static MethodInfo SelectManyWithSelector => __selectManyWithSelector;
570570
public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector;
571-
public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector;
571+
public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; // TODO CSHARP-5978: not yet supported in nested expressions
572572
public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex;
573573
public static MethodInfo SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex;
574574
public static MethodInfo SequenceEqual => __sequenceEqual;

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableOrQueryableMethod.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ internal static class EnumerableOrQueryableMethod
5757
private static readonly IReadOnlyMethodInfoSet __select;
5858
private static readonly IReadOnlyMethodInfoSet __selectManyWithCollectionSelectorAndResultSelector;
5959
private static readonly IReadOnlyMethodInfoSet __selectManyWithSelector;
60+
private static readonly IReadOnlyMethodInfoSet __selectManyWithSelectorTakingIndex;
61+
private static readonly IReadOnlyMethodInfoSet __selectWithSelectorTakingIndex;
6062
private static readonly IReadOnlyMethodInfoSet __single;
6163
private static readonly IReadOnlyMethodInfoSet __singleOrDefault;
6264
private static readonly IReadOnlyMethodInfoSet __singleWithPredicate;
@@ -69,6 +71,7 @@ internal static class EnumerableOrQueryableMethod
6971
private static readonly IReadOnlyMethodInfoSet __thenByDescending;
7072
private static readonly IReadOnlyMethodInfoSet __union;
7173
private static readonly IReadOnlyMethodInfoSet __where;
74+
private static readonly IReadOnlyMethodInfoSet __whereWithPredicateTakingIndex;
7275
private static readonly IReadOnlyMethodInfoSet __zip;
7376

7477
// sets of methods
@@ -350,6 +353,18 @@ static EnumerableOrQueryableMethod()
350353
QueryableMethod.SelectManyWithSelector
351354
]);
352355

356+
__selectManyWithSelectorTakingIndex = MethodInfoSet.Create(
357+
[
358+
EnumerableMethod.SelectManyWithSelectorTakingIndex,
359+
QueryableMethod.SelectManyWithSelectorTakingIndex
360+
]);
361+
362+
__selectWithSelectorTakingIndex = MethodInfoSet.Create(
363+
[
364+
EnumerableMethod.SelectWithSelectorTakingIndex,
365+
QueryableMethod.SelectWithSelectorTakingIndex
366+
]);
367+
353368
__single = MethodInfoSet.Create(
354369
[
355370
EnumerableMethod.Single,
@@ -422,6 +437,12 @@ static EnumerableOrQueryableMethod()
422437
QueryableMethod.Where,
423438
]);
424439

440+
__whereWithPredicateTakingIndex = MethodInfoSet.Create(
441+
[
442+
EnumerableMethod.WhereWithPredicateTakingIndex,
443+
QueryableMethod.WhereWithPredicateTakingIndex
444+
]);
445+
425446
__zip = MethodInfoSet.Create(
426447
[
427448
EnumerableMethod.Zip,
@@ -899,6 +920,8 @@ static EnumerableOrQueryableMethod()
899920
public static IReadOnlyMethodInfoSet Select => __select;
900921
public static IReadOnlyMethodInfoSet SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector;
901922
public static IReadOnlyMethodInfoSet SelectManyWithSelector => __selectManyWithSelector;
923+
public static IReadOnlyMethodInfoSet SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex;
924+
public static IReadOnlyMethodInfoSet SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex;
902925
public static IReadOnlyMethodInfoSet Single => __single;
903926
public static IReadOnlyMethodInfoSet SingleOrDefault => __singleOrDefault;
904927
public static IReadOnlyMethodInfoSet SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate;
@@ -909,6 +932,7 @@ static EnumerableOrQueryableMethod()
909932
public static IReadOnlyMethodInfoSet TakeWhile => __takeWhile;
910933
public static IReadOnlyMethodInfoSet Union => __union;
911934
public static IReadOnlyMethodInfoSet Where => __where;
935+
public static IReadOnlyMethodInfoSet WhereWithPredicateTakingIndex => __whereWithPredicateTakingIndex;
912936
public static IReadOnlyMethodInfoSet Zip => __zip;
913937

914938
// sets of methods

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ static QueryableMethod()
511511
public static MethodInfo Select => __select;
512512
public static MethodInfo SelectManyWithSelector => __selectManyWithSelector;
513513
public static MethodInfo SelectManyWithCollectionSelectorAndResultSelector => __selectManyWithCollectionSelectorAndResultSelector;
514-
public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector;
514+
public static MethodInfo SelectManyWithCollectionSelectorTakingIndexAndResultSelector => __selectManyWithCollectionSelectorTakingIndexAndResultSelector; // TODO CSHARP-5978: not yet supported in nested expressions
515515
public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex;
516516
public static MethodInfo SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex;
517517
public static MethodInfo SequenceEqual => __sequenceEqual;

src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,19 @@ void DeduceSelectMethodSerializers()
23542354
DeduceItemAndCollectionSerializers(selectorParameter, sourceExpression);
23552355
DeduceCollectionAndItemSerializers(node, selectorLambda.Body);
23562356
}
2357+
else if (method.IsOneOf(EnumerableOrQueryableMethod.SelectWithSelectorTakingIndex))
2358+
{
2359+
var sourceExpression = arguments[0];
2360+
var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]);
2361+
var itemParameter = selectorLambda.Parameters[0];
2362+
var indexParameter = selectorLambda.Parameters[1];
2363+
DeduceItemAndCollectionSerializers(itemParameter, sourceExpression);
2364+
if (IsNotKnown(indexParameter))
2365+
{
2366+
AddNodeSerializer(indexParameter, Int32Serializer.Instance);
2367+
}
2368+
DeduceCollectionAndItemSerializers(node, selectorLambda.Body);
2369+
}
23572370
else
23582371
{
23592372
DeduceUnknownMethodSerializer();
@@ -2390,6 +2403,19 @@ void DeduceSelectManySerializers()
23902403
DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body);
23912404
}
23922405
}
2406+
else if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithSelectorTakingIndex))
2407+
{
2408+
var sourceExpression = arguments[0];
2409+
var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]);
2410+
var itemParameter = selectorLambda.Parameters[0];
2411+
var indexParameter = selectorLambda.Parameters[1];
2412+
DeduceItemAndCollectionSerializers(itemParameter, sourceExpression);
2413+
if (IsNotKnown(indexParameter))
2414+
{
2415+
AddNodeSerializer(indexParameter, Int32Serializer.Instance);
2416+
}
2417+
DeduceCollectionAndCollectionSerializers(node, selectorLambda.Body);
2418+
}
23932419
else
23942420
{
23952421
DeduceUnknownMethodSerializer();
@@ -2820,6 +2846,19 @@ void DeduceWhereSerializers()
28202846
DeduceItemAndCollectionSerializers(predicateParameter, sourceExpression);
28212847
DeduceCollectionAndCollectionSerializers(node, sourceExpression);
28222848
}
2849+
else if (method.IsOneOf(EnumerableOrQueryableMethod.WhereWithPredicateTakingIndex))
2850+
{
2851+
var sourceExpression = arguments[0];
2852+
var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]);
2853+
var itemParameter = predicateLambda.Parameters[0];
2854+
var indexParameter = predicateLambda.Parameters[1];
2855+
DeduceItemAndCollectionSerializers(itemParameter, sourceExpression);
2856+
if (IsNotKnown(indexParameter))
2857+
{
2858+
AddNodeSerializer(indexParameter, Int32Serializer.Instance);
2859+
}
2860+
DeduceCollectionAndCollectionSerializers(node, sourceExpression);
2861+
}
28232862
else
28242863
{
28252864
DeduceUnknownMethodSerializer();

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/SelectManyMethodToAggregationExpressionTranslator.cs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,38 @@ public static TranslatedExpression Translate(TranslationContext context, MethodC
5555
initialValue: new BsonArray(),
5656
@in: AstExpression.ConcatArrays(valueVar, thisVar));
5757

58-
var ienumerableSerializer = IEnumerableSerializer.Create(itemSerializer);
59-
return new TranslatedExpression(expression, ast, ienumerableSerializer);
58+
var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);
59+
return new TranslatedExpression(expression, ast, serializer);
60+
}
61+
62+
if (method.IsOneOf(EnumerableOrQueryableMethod.SelectManyWithSelectorTakingIndex))
63+
{
64+
var sourceExpression = arguments[0];
65+
var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
66+
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
67+
68+
var selectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]);
69+
var itemParameter = selectorLambda.Parameters[0];
70+
var indexParameter = selectorLambda.Parameters[1];
71+
var itemSymbol = context.CreateSymbol(itemParameter, context.GetSerializer(itemParameter));
72+
var indexSymbol = context.CreateSymbol(indexParameter, context.GetSerializer(indexParameter));
73+
var selectorContext = context.WithSymbols(itemSymbol, indexSymbol);
74+
var selectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(selectorContext, selectorLambda.Body);
75+
var resultItemSerializer = ArraySerializerHelper.GetItemSerializer(selectorTranslation.Serializer);
76+
77+
var valueVar = AstExpression.Var("value");
78+
var thisVar = AstExpression.Var("this");
79+
var ast = AstExpression.Reduce(
80+
input: AstExpression.Map(
81+
input: sourceTranslation.Ast,
82+
@as: itemSymbol.Var,
83+
@in: selectorTranslation.Ast,
84+
arrayIndexAs: indexSymbol.Var),
85+
initialValue: new BsonArray(),
86+
@in: AstExpression.ConcatArrays(valueVar, thisVar));
87+
88+
var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, resultItemSerializer);
89+
return new TranslatedExpression(expression, ast, serializer);
6090
}
6191

6292
throw new ExpressionNotSupportedException(expression);

0 commit comments

Comments
 (0)