Skip to content

Commit 7b5f277

Browse files
committed
CSHARP-6017: Implement support for LeftJoin
1 parent 8baaf79 commit 7b5f277

8 files changed

Lines changed: 630 additions & 0 deletions

File tree

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
namespace MongoDB.Driver.Linq
17+
{
18+
/// <summary>
19+
/// The result of a LeftJoin operation.
20+
/// </summary>
21+
/// <typeparam name="TOuter">The type of the outer documents.</typeparam>
22+
/// <typeparam name="TInner">The type of the inner documents.</typeparam>
23+
public struct LeftJoinResult<TOuter, TInner>
24+
{
25+
/// <summary>
26+
/// The outer document.
27+
/// </summary>
28+
public TOuter Outer { get; init; }
29+
30+
/// <summary>
31+
/// The inner document (null when no matching inner document exists).
32+
/// </summary>
33+
public TInner Inner { get; init; }
34+
}
35+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MongoQueryableMethod.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ internal static class MongoQueryableMethod
5959
private static readonly MethodInfo __firstOrDefaultAsync;
6060
private static readonly MethodInfo __firstOrDefaultWithPredicateAsync;
6161
private static readonly MethodInfo __firstWithPredicateAsync;
62+
private static readonly MethodInfo __leftJoin;
6263
private static readonly MethodInfo __longCountAsync;
6364
private static readonly MethodInfo __longCountWithPredicateAsync;
6465
private static readonly MethodInfo __lookupWithDocumentsAndLocalFieldAndForeignField;
@@ -242,6 +243,7 @@ static MongoQueryableMethod()
242243
__firstOrDefaultAsync = ReflectionInfo.Method((IQueryable<object> source, CancellationToken cancellationToken) => source.FirstOrDefaultAsync(cancellationToken));
243244
__firstOrDefaultWithPredicateAsync = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate, CancellationToken cancellationToken) => source.FirstOrDefaultAsync(predicate, cancellationToken));
244245
__firstWithPredicateAsync = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate, CancellationToken cancellationToken) => source.FirstAsync(predicate, cancellationToken));
246+
__leftJoin = ReflectionInfo.Method((IQueryable<object> outer, IQueryable<object> inner, Expression<Func<object, object>> outerKeySelector, Expression<Func<object, object>> innerKeySelector, Expression<Func<object, object, object>> resultSelector) => MongoQueryable.LeftJoin(outer, inner, outerKeySelector, innerKeySelector, resultSelector));
245247
__longCountAsync = ReflectionInfo.Method((IQueryable<object> source, CancellationToken cancellationToken) => source.LongCountAsync(cancellationToken));
246248
__longCountWithPredicateAsync = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate, CancellationToken cancellationToken) => source.LongCountAsync(predicate, cancellationToken));
247249
__lookupWithDocumentsAndLocalFieldAndForeignField = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, IEnumerable<object>>> documents, Expression<Func<object, object>> localField, Expression<Func<object, object>> foreignField) => source.Lookup(documents, localField, foreignField));
@@ -798,6 +800,7 @@ static MongoQueryableMethod()
798800
public static MethodInfo FirstOrDefaultAsync => __firstOrDefaultAsync;
799801
public static MethodInfo FirstOrDefaultWithPredicateAsync => __firstOrDefaultWithPredicateAsync;
800802
public static MethodInfo FirstWithPredicateAsync => __firstWithPredicateAsync;
803+
public static MethodInfo LeftJoin => __leftJoin;
801804
public static MethodInfo LongCountAsync => __longCountAsync;
802805
public static MethodInfo LongCountWithPredicateAsync => __longCountWithPredicateAsync;
803806
public static MethodInfo LookupWithDocumentsAndLocalFieldAndForeignField => __lookupWithDocumentsAndLocalFieldAndForeignField;

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/QueryableMethod.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ internal static class QueryableMethod
7575
private static readonly MethodInfo __groupJoin;
7676
private static readonly MethodInfo __intersect;
7777
private static readonly MethodInfo __join;
78+
private static readonly MethodInfo __leftJoin;
7879
private static readonly MethodInfo __last;
7980
private static readonly MethodInfo __lastOrDefault;
8081
private static readonly MethodInfo __lastOrDefaultWithPredicate;
@@ -210,6 +211,8 @@ static QueryableMethod()
210211
__groupJoin = ReflectionInfo.Method((IQueryable<object> outer, IEnumerable<object> inner, Expression<Func<object, object>> outerKeySelector, Expression<Func<object, object>> innerKeySelector, Expression<Func<object, IEnumerable<object>, object>> resultSelector) => outer.GroupJoin(inner, outerKeySelector, innerKeySelector, resultSelector));
211212
__intersect = ReflectionInfo.Method((IQueryable<object> source1, IEnumerable<object> source2) => source1.Intersect(source2));
212213
__join = ReflectionInfo.Method((IQueryable<object> outer, IEnumerable<object> inner, Expression<Func<object, object>> outerKeySelector, Expression<Func<object, object>> innerKeySelector, Expression<Func<object, object, object>> resultSelector) => outer.Join(inner, outerKeySelector, innerKeySelector, resultSelector));
214+
__leftJoin = typeof(Queryable).GetMethods()
215+
.FirstOrDefault(m => m.Name == "LeftJoin" && m.GetParameters().Length == 5);
213216
__last = ReflectionInfo.Method((IQueryable<object> source) => source.Last());
214217
__lastOrDefault = ReflectionInfo.Method((IQueryable<object> source) => source.LastOrDefault());
215218
__lastOrDefaultWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.LastOrDefault(predicate));
@@ -497,6 +500,7 @@ static QueryableMethod()
497500
public static MethodInfo GroupJoin => __groupJoin;
498501
public static MethodInfo Intersect => __intersect;
499502
public static MethodInfo Join => __join;
503+
public static MethodInfo LeftJoin => __leftJoin;
500504
public static MethodInfo Last => __last;
501505
public static MethodInfo LastOrDefault => __lastOrDefault;
502506
public static MethodInfo LastOrDefaultWithPredicate => __lastOrDefaultWithPredicate;

src/MongoDB.Driver/Linq/Linq3Implementation/SerializerFinders/SerializerFinderVisitMethodCall.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ void DeduceMethodCallSerializers()
127127
case "IsMatch": DeduceIsMatchMethodSerializers(); break;
128128
case "IsSubsetOf": DeduceIsSubsetOfMethodSerializers(); break;
129129
case "Join": DeduceJoinMethodSerializers(); break;
130+
case "LeftJoin": DeduceLeftJoinMethodSerializers(); break;
130131
case "Locf": DeduceLocfMethodSerializers(); break;
131132
case "Lookup": DeduceLookupMethodSerializers(); break;
132133
case "OfType": DeduceOfTypeMethodSerializers(); break;
@@ -1760,6 +1761,32 @@ void DeduceJoinMethodSerializers()
17601761
}
17611762
}
17621763

1764+
void DeduceLeftJoinMethodSerializers()
1765+
{
1766+
if (method.IsOneOf(MongoQueryableMethod.LeftJoin, QueryableMethod.LeftJoin))
1767+
{
1768+
var outerExpression = arguments[0];
1769+
var innerExpression = arguments[1];
1770+
var outerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]);
1771+
var outerKeySelectorItemParameter = outerKeySelectorLambda.Parameters.Single();
1772+
var innerKeySelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]);
1773+
var innerKeySelectorItemParameter = innerKeySelectorLambda.Parameters.Single();
1774+
var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[4]);
1775+
var resultSelectorOuterItemParameter = resultSelectorLambda.Parameters[0];
1776+
var resultSelectorInnerItemParameter = resultSelectorLambda.Parameters[1];
1777+
1778+
DeduceItemAndCollectionSerializers(outerKeySelectorItemParameter, outerExpression);
1779+
DeduceItemAndCollectionSerializers(innerKeySelectorItemParameter, innerExpression);
1780+
DeduceItemAndCollectionSerializers(resultSelectorOuterItemParameter, outerExpression);
1781+
DeduceItemAndCollectionSerializers(resultSelectorInnerItemParameter, innerExpression);
1782+
DeduceCollectionAndItemSerializers(node, resultSelectorLambda.Body);
1783+
}
1784+
else
1785+
{
1786+
DeduceUnknownMethodSerializer();
1787+
}
1788+
}
1789+
17631790
void DeduceIsNullOrEmptyOrIsNullOrWhiteSpaceMethodSerializers()
17641791
{
17651792
if (method.IsOneOf(StringMethod.IsNullOrEmpty, StringMethod.IsNullOrWhiteSpace))

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToPipelineTranslators/ExpressionToPipelineTranslator.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ public static TranslatedPipeline Translate(TranslationContext context, Expressio
5252
return GroupJoinMethodToPipelineTranslator.Translate(context, methodCallExpression);
5353
case "Join":
5454
return JoinMethodToPipelineTranslator.Translate(context, methodCallExpression);
55+
case "LeftJoin":
56+
return LeftJoinMethodToPipelineTranslator.Translate(context, methodCallExpression);
5557
case "Lookup":
5658
return LookupMethodToPipelineTranslator.Translate(context, methodCallExpression);
5759
case "OfType":
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Collections.Generic;
17+
using System.Linq;
18+
using System.Linq.Expressions;
19+
using MongoDB.Bson.Serialization;
20+
using MongoDB.Driver.Linq.Linq3Implementation.Ast;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
22+
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;
23+
using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods;
24+
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
25+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
26+
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
27+
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators;
28+
29+
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToPipelineTranslators
30+
{
31+
internal static class LeftJoinMethodToPipelineTranslator
32+
{
33+
// public static methods
34+
public static TranslatedPipeline Translate(TranslationContext context, MethodCallExpression expression)
35+
{
36+
var method = expression.Method;
37+
var arguments = expression.Arguments;
38+
39+
if (method.IsOneOf(MongoQueryableMethod.LeftJoin, QueryableMethod.LeftJoin))
40+
{
41+
var outerExpression = arguments[0];
42+
var innerExpression = arguments[1];
43+
var outerKeySelectorLambda = ExpressionHelper.UnquoteLambda(arguments[2]);
44+
var innerKeySelectorLambda = ExpressionHelper.UnquoteLambda(arguments[3]);
45+
var resultSelectorLambda = ExpressionHelper.UnquoteLambda(arguments[4]);
46+
47+
var pipeline = ExpressionToPipelineTranslator.Translate(context, outerExpression);
48+
ClientSideProjectionHelper.ThrowIfClientSideProjection(expression, pipeline, method);
49+
50+
AstExpression outerAst;
51+
IBsonSerializer outerSerializer;
52+
if (pipeline.OutputSerializer is IWrappedValueSerializer pipelineOutputWrappedSerializer)
53+
{
54+
outerAst = AstExpression.GetField(AstExpression.RootVar, pipelineOutputWrappedSerializer.FieldName);
55+
outerSerializer = pipelineOutputWrappedSerializer.ValueSerializer;
56+
}
57+
else
58+
{
59+
outerAst = AstExpression.RootVar;
60+
outerSerializer = pipeline.OutputSerializer;
61+
}
62+
63+
ThrowIfReservedFieldNames(expression, outerSerializer);
64+
65+
var wrapOuterStage = AstStage.Project(
66+
AstProject.Set("_outer", outerAst),
67+
AstProject.Exclude("_id"));
68+
var wrappedOuterSerializer = WrappedValueSerializer.Create("_outer", outerSerializer);
69+
70+
var (innerCollectionName, innerSerializer) = innerExpression.GetCollectionInfoFromQueryable(containerExpression: expression);
71+
var localField = outerKeySelectorLambda.TranslateToDottedFieldName(context, wrappedOuterSerializer);
72+
var foreignField = innerKeySelectorLambda.TranslateToDottedFieldName(context, innerSerializer);
73+
74+
var lookupStage = AstStage.Lookup(
75+
from: innerCollectionName,
76+
localField,
77+
foreignField,
78+
@as: "_inner");
79+
80+
var unwindStage = AstStage.Unwind("_inner", preserveNullAndEmptyArrays: true);
81+
82+
var outerParameter = resultSelectorLambda.Parameters[0];
83+
var outerField = AstExpression.GetField(AstExpression.RootVar, "_outer");
84+
var outerSymbol = context.CreateSymbol(outerParameter, outerField, outerSerializer);
85+
var innerParameter = resultSelectorLambda.Parameters[1];
86+
var innerField = AstExpression.GetField(AstExpression.RootVar, "_inner");
87+
var innerSymbol = context.CreateSymbol(innerParameter, innerField, innerSerializer);
88+
var resultSelectorContext = context.WithSymbols(outerSymbol, innerSymbol);
89+
var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body);
90+
var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(resultSelectorTranslation);
91+
92+
pipeline = pipeline.AddStages(
93+
wrapOuterStage,
94+
lookupStage,
95+
unwindStage,
96+
projectStage,
97+
projectSerializer);
98+
99+
return pipeline;
100+
}
101+
102+
throw new ExpressionNotSupportedException(expression);
103+
}
104+
105+
private static readonly HashSet<string> __reservedFieldNames = new HashSet<string> { "_outer", "_inner" };
106+
107+
private static void ThrowIfReservedFieldNames(MethodCallExpression expression, IBsonSerializer serializer)
108+
{
109+
if (serializer is not IBsonDocumentSerializer documentSerializer)
110+
return;
111+
112+
var conflicting = new List<string>();
113+
foreach (var member in serializer.ValueType.GetMembers())
114+
{
115+
if (documentSerializer.TryGetMemberSerializationInfo(member.Name, out var info) &&
116+
__reservedFieldNames.Contains(info.ElementName))
117+
{
118+
conflicting.Add(info.ElementName);
119+
}
120+
}
121+
122+
if (conflicting.Count > 0)
123+
{
124+
throw new ExpressionNotSupportedException(expression,
125+
because: $"the outer document type uses reserved field name(s) {string.Join(", ", conflicting.Select(n => $"'{n}'"))} which are used internally by LeftJoin");
126+
}
127+
}
128+
}
129+
}

0 commit comments

Comments
 (0)