Skip to content

Commit 0125013

Browse files
committed
CSHARP-5828: Add Rerank stage builder
1 parent 29a8892 commit 0125013

File tree

9 files changed

+469
-0
lines changed

9 files changed

+469
-0
lines changed

src/MongoDB.Driver/AggregateFluent.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,15 @@ public override IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<
379379
return WithPipeline(_pipeline.Unwind(field, options));
380380
}
381381

382+
public override IAggregateFluent<TResult> Rerank(
383+
RerankQuery query,
384+
FieldDefinition<TResult> path,
385+
int numDocsToRerank,
386+
string model)
387+
{
388+
return WithPipeline(_pipeline.Rerank(query, path, numDocsToRerank, model));
389+
}
390+
382391
public override IAggregateFluent<TResult> VectorSearch(
383392
FieldDefinition<TResult> field,
384393
QueryVector queryVector,

src/MongoDB.Driver/AggregateFluentBase.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,16 @@ public virtual IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<T
331331
throw new NotImplementedException();
332332
}
333333

334+
/// <inheritdoc />
335+
public virtual IAggregateFluent<TResult> Rerank(
336+
RerankQuery query,
337+
FieldDefinition<TResult> path,
338+
int numDocsToRerank,
339+
string model)
340+
{
341+
throw new NotImplementedException();
342+
}
343+
334344
/// <inheritdoc />
335345
public virtual IAggregateFluent<TResult> VectorSearch(
336346
FieldDefinition<TResult> field,

src/MongoDB.Driver/IAggregateFluent.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,20 @@ IAggregateFluent<TResult> UnionWith<TWith>(
545545
/// <returns>The fluent aggregate interface.</returns>
546546
IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<TResult> field, AggregateUnwindOptions<TNewResult> options = null);
547547

548+
/// <summary>
549+
/// Appends a $rerank stage.
550+
/// </summary>
551+
/// <param name="query">The rerank query.</param>
552+
/// <param name="path">The field to send to the reranker.</param>
553+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
554+
/// <param name="model">The reranking model name.</param>
555+
/// <returns>The fluent aggregate interface.</returns>
556+
IAggregateFluent<TResult> Rerank(
557+
RerankQuery query,
558+
FieldDefinition<TResult> path,
559+
int numDocsToRerank,
560+
string model);
561+
548562
/// <summary>
549563
/// Appends a vector search stage.
550564
/// </summary>

src/MongoDB.Driver/IAggregateFluentExtensions.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,52 @@ public static IAggregateFluent<TNewResult> Unwind<TResult, TNewResult>(this IAgg
10411041
return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken);
10421042
}
10431043

1044+
/// <summary>
1045+
/// Appends a $rerank stage.
1046+
/// </summary>
1047+
/// <typeparam name="TResult">The type of the result.</typeparam>
1048+
/// <typeparam name="TField">The type of the field.</typeparam>
1049+
/// <param name="aggregate">The aggregate.</param>
1050+
/// <param name="query">The rerank query.</param>
1051+
/// <param name="path">The field to send to the reranker.</param>
1052+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1053+
/// <param name="model">The reranking model name.</param>
1054+
/// <returns>The fluent aggregate interface.</returns>
1055+
public static IAggregateFluent<TResult> Rerank<TResult, TField>(
1056+
this IAggregateFluent<TResult> aggregate,
1057+
RerankQuery query,
1058+
Expression<Func<TResult, TField>> path,
1059+
int numDocsToRerank,
1060+
string model)
1061+
{
1062+
Ensure.IsNotNull(aggregate, nameof(aggregate));
1063+
1064+
return aggregate.Rerank(query, new ExpressionFieldDefinition<TResult>(path), numDocsToRerank, model);
1065+
}
1066+
1067+
/// <summary>
1068+
/// Appends a $rerank stage.
1069+
/// </summary>
1070+
/// <typeparam name="TResult">The type of the result.</typeparam>
1071+
/// <typeparam name="TField">The type of the field.</typeparam>
1072+
/// <param name="aggregate">The aggregate.</param>
1073+
/// <param name="query">The rerank query.</param>
1074+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1075+
/// <param name="model">The reranking model name.</param>
1076+
/// <param name="paths">The fields to send to the reranker.</param>
1077+
/// <returns>The fluent aggregate interface.</returns>
1078+
public static IAggregateFluent<TResult> Rerank<TResult, TField>(
1079+
this IAggregateFluent<TResult> aggregate,
1080+
RerankQuery query,
1081+
int numDocsToRerank,
1082+
string model,
1083+
params Expression<Func<TResult, TField>>[] paths)
1084+
{
1085+
Ensure.IsNotNull(aggregate, nameof(aggregate));
1086+
1087+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, numDocsToRerank, model, paths));
1088+
}
1089+
10441090
/// <summary>
10451091
/// Appends a $vectorSearch stage.
10461092
/// </summary>

src/MongoDB.Driver/PipelineDefinitionBuilder.cs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,104 @@ public static PipelineDefinition<TInput, TOutput> RankFusion<TInput, TIntermedia
10811081
return pipeline.AppendStage(PipelineStageDefinitionBuilder.RankFusion(pipelinesWithWeights, options));
10821082
}
10831083

1084+
/// <summary>
1085+
/// Appends a $rerank stage to the pipeline.
1086+
/// </summary>
1087+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1088+
/// <typeparam name="TField">The type of the field.</typeparam>
1089+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1090+
/// <param name="pipeline">The pipeline.</param>
1091+
/// <param name="query">The rerank query.</param>
1092+
/// <param name="path">The field to send to the reranker.</param>
1093+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1094+
/// <param name="model">The reranking model name.</param>
1095+
/// <returns>A new pipeline with an additional stage.</returns>
1096+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TField, TOutput>(
1097+
this PipelineDefinition<TInput, TOutput> pipeline,
1098+
RerankQuery query,
1099+
Expression<Func<TOutput, TField>> path,
1100+
int numDocsToRerank,
1101+
string model)
1102+
{
1103+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1104+
return pipeline.AppendStage(
1105+
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
1106+
pipeline.OutputSerializer);
1107+
}
1108+
1109+
/// <summary>
1110+
/// Appends a $rerank stage to the pipeline.
1111+
/// </summary>
1112+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1113+
/// <typeparam name="TField">The type of the field.</typeparam>
1114+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1115+
/// <param name="pipeline">The pipeline.</param>
1116+
/// <param name="query">The rerank query.</param>
1117+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1118+
/// <param name="model">The reranking model name.</param>
1119+
/// <param name="paths">The fields to send to the reranker.</param>
1120+
/// <returns>A new pipeline with an additional stage.</returns>
1121+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TField, TOutput>(
1122+
this PipelineDefinition<TInput, TOutput> pipeline,
1123+
RerankQuery query,
1124+
int numDocsToRerank,
1125+
string model,
1126+
params Expression<Func<TOutput, TField>>[] paths)
1127+
{
1128+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1129+
return pipeline.AppendStage(
1130+
PipelineStageDefinitionBuilder.Rerank(query, numDocsToRerank, model, paths),
1131+
pipeline.OutputSerializer);
1132+
}
1133+
1134+
/// <summary>
1135+
/// Appends a $rerank stage to the pipeline.
1136+
/// </summary>
1137+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1138+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1139+
/// <param name="pipeline">The pipeline.</param>
1140+
/// <param name="query">The rerank query.</param>
1141+
/// <param name="path">The field to send to the reranker.</param>
1142+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1143+
/// <param name="model">The reranking model name.</param>
1144+
/// <returns>A new pipeline with an additional stage.</returns>
1145+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
1146+
this PipelineDefinition<TInput, TOutput> pipeline,
1147+
RerankQuery query,
1148+
FieldDefinition<TOutput> path,
1149+
int numDocsToRerank,
1150+
string model)
1151+
{
1152+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1153+
return pipeline.AppendStage(
1154+
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
1155+
pipeline.OutputSerializer);
1156+
}
1157+
1158+
/// <summary>
1159+
/// Appends a $rerank stage to the pipeline.
1160+
/// </summary>
1161+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1162+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1163+
/// <param name="pipeline">The pipeline.</param>
1164+
/// <param name="query">The rerank query.</param>
1165+
/// <param name="paths">The fields to send to the reranker.</param>
1166+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1167+
/// <param name="model">The reranking model name.</param>
1168+
/// <returns>A new pipeline with an additional stage.</returns>
1169+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
1170+
this PipelineDefinition<TInput, TOutput> pipeline,
1171+
RerankQuery query,
1172+
IEnumerable<FieldDefinition<TOutput>> paths,
1173+
int numDocsToRerank,
1174+
string model)
1175+
{
1176+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1177+
return pipeline.AppendStage(
1178+
PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model),
1179+
pipeline.OutputSerializer);
1180+
}
1181+
10841182
/// <summary>
10851183
/// Appends a $replaceRoot stage to the pipeline.
10861184
/// </summary>

src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,6 +1684,118 @@ public static PipelineStageDefinition<TInput, TOutput> ReplaceWith<TInput, TOutp
16841684
return ReplaceWith(new ExpressionAggregateExpressionDefinition<TInput, TOutput>(newRoot));
16851685
}
16861686

1687+
/// <summary>
1688+
/// Creates a $rerank stage.
1689+
/// </summary>
1690+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1691+
/// <typeparam name="TField">The type of the field.</typeparam>
1692+
/// <param name="query">The rerank query.</param>
1693+
/// <param name="path">The field to send to the reranker.</param>
1694+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1695+
/// <param name="model">The reranking model name.</param>
1696+
/// <returns>The stage.</returns>
1697+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput, TField>(
1698+
RerankQuery query,
1699+
Expression<Func<TInput, TField>> path,
1700+
int numDocsToRerank,
1701+
string model)
1702+
=> Rerank(
1703+
query,
1704+
new ExpressionFieldDefinition<TInput>(path),
1705+
numDocsToRerank,
1706+
model);
1707+
1708+
/// <summary>
1709+
/// Creates a $rerank stage.
1710+
/// </summary>
1711+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1712+
/// <typeparam name="TField">The type of the field.</typeparam>
1713+
/// <param name="query">The rerank query.</param>
1714+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1715+
/// <param name="model">The reranking model name.</param>
1716+
/// <param name="paths">The fields to send to the reranker.</param>
1717+
/// <returns>The stage.</returns>
1718+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput, TField>(
1719+
RerankQuery query,
1720+
int numDocsToRerank,
1721+
string model,
1722+
params Expression<Func<TInput, TField>>[] paths)
1723+
=> Rerank(
1724+
query,
1725+
paths.Select(p => (FieldDefinition<TInput>)new ExpressionFieldDefinition<TInput>(p)),
1726+
numDocsToRerank,
1727+
model);
1728+
1729+
/// <summary>
1730+
/// Creates a $rerank stage.
1731+
/// </summary>
1732+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1733+
/// <param name="query">The rerank query.</param>
1734+
/// <param name="path">The field to send to the reranker.</param>
1735+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1736+
/// <param name="model">The reranking model name.</param>
1737+
/// <returns>The stage.</returns>
1738+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
1739+
RerankQuery query,
1740+
FieldDefinition<TInput> path,
1741+
int numDocsToRerank,
1742+
string model)
1743+
{
1744+
Ensure.IsNotNull(path, nameof(path));
1745+
return Rerank(
1746+
query,
1747+
[path],
1748+
numDocsToRerank,
1749+
model);
1750+
}
1751+
1752+
/// <summary>
1753+
/// Creates a $rerank stage.
1754+
/// </summary>
1755+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1756+
/// <param name="query">The rerank query.</param>
1757+
/// <param name="paths">The fields to send to the reranker.</param>
1758+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1759+
/// <param name="model">The reranking model name.</param>
1760+
/// <returns>The stage.</returns>
1761+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
1762+
RerankQuery query,
1763+
IEnumerable<FieldDefinition<TInput>> paths,
1764+
int numDocsToRerank,
1765+
string model)
1766+
{
1767+
Ensure.IsNotNull(query, nameof(query));
1768+
Ensure.IsNotNullOrEmpty(paths, nameof(paths));
1769+
Ensure.IsGreaterThanZero(numDocsToRerank, nameof(numDocsToRerank));
1770+
Ensure.IsNotNull(model, nameof(model));
1771+
1772+
const string operatorName = "$rerank";
1773+
var stage = new DelegatedPipelineStageDefinition<TInput, TInput>(
1774+
operatorName,
1775+
args =>
1776+
{
1777+
ClientSideProjectionHelper.ThrowIfClientSideProjection(args.DocumentSerializer, operatorName);
1778+
1779+
var renderedPaths = paths.Select(p => p.Render(args).FieldName).ToList();
1780+
BsonValue pathValue = renderedPaths.Count == 1
1781+
? renderedPaths[0]
1782+
: new BsonArray(renderedPaths);
1783+
1784+
var rerankDocument = new BsonDocument
1785+
{
1786+
{ "query", query.Render() },
1787+
{ "path", pathValue },
1788+
{ "numDocsToRerank", numDocsToRerank },
1789+
{ "model", model }
1790+
};
1791+
1792+
var document = new BsonDocument(operatorName, rerankDocument);
1793+
return new RenderedPipelineStageDefinition<TInput>(operatorName, document, args.DocumentSerializer);
1794+
});
1795+
1796+
return stage;
1797+
}
1798+
16871799
/// <summary>
16881800
/// Creates a $set stage.
16891801
/// </summary>

src/MongoDB.Driver/RerankQuery.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using MongoDB.Bson;
17+
using MongoDB.Driver.Core.Misc;
18+
19+
namespace MongoDB.Driver
20+
{
21+
/// <summary>
22+
/// Represents a query for the $rerank aggregation stage.
23+
/// </summary>
24+
public sealed class RerankQuery
25+
{
26+
private readonly BsonDocument _rendered;
27+
28+
private RerankQuery(BsonDocument rendered)
29+
{
30+
_rendered = rendered;
31+
}
32+
33+
/// <summary>
34+
/// Creates a text-based rerank query.
35+
/// </summary>
36+
/// <param name="text">The text to rerank against.</param>
37+
/// <returns>A text rerank query.</returns>
38+
public static RerankQuery Text(string text) =>
39+
new RerankQuery(new BsonDocument("text", Ensure.IsNotNull(text, nameof(text))));
40+
41+
internal BsonDocument Render() => _rendered;
42+
}
43+
}

0 commit comments

Comments
 (0)