Skip to content

Commit 10f345d

Browse files
authored
CSHARP-5828: Add Rerank stage builder (#1936)
1 parent c24ee00 commit 10f345d

8 files changed

Lines changed: 466 additions & 0 deletions

src/MongoDB.Driver/IAggregateFluentExtensions.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,70 @@ public static IAggregateFluent<TNewResult> Unwind<TResult, TNewResult>(this IAgg
11351135
return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken);
11361136
}
11371137

1138+
/// <summary>
1139+
/// Appends a $rerank stage.
1140+
/// </summary>
1141+
/// <typeparam name="TResult">The type of the result.</typeparam>
1142+
/// <param name="aggregate">The aggregate.</param>
1143+
/// <param name="query">The rerank query.</param>
1144+
/// <param name="path">The field to send to the reranker.</param>
1145+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1146+
/// <param name="model">The reranking model name.</param>
1147+
/// <returns>The fluent aggregate interface.</returns>
1148+
public static IAggregateFluent<TResult> Rerank<TResult>(
1149+
this IAggregateFluent<TResult> aggregate,
1150+
RerankQuery query,
1151+
FieldDefinition<TResult> path,
1152+
int numDocsToRerank,
1153+
string model)
1154+
{
1155+
Ensure.IsNotNull(aggregate, nameof(aggregate));
1156+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model));
1157+
}
1158+
1159+
/// <summary>
1160+
/// Appends a $rerank stage.
1161+
/// </summary>
1162+
/// <typeparam name="TResult">The type of the result.</typeparam>
1163+
/// <typeparam name="TField">The type of the field.</typeparam>
1164+
/// <param name="aggregate">The aggregate.</param>
1165+
/// <param name="query">The rerank query.</param>
1166+
/// <param name="path">The field to send to the reranker.</param>
1167+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1168+
/// <param name="model">The reranking model name.</param>
1169+
/// <returns>The fluent aggregate interface.</returns>
1170+
public static IAggregateFluent<TResult> Rerank<TResult, TField>(
1171+
this IAggregateFluent<TResult> aggregate,
1172+
RerankQuery query,
1173+
Expression<Func<TResult, TField>> path,
1174+
int numDocsToRerank,
1175+
string model)
1176+
{
1177+
Ensure.IsNotNull(aggregate, nameof(aggregate));
1178+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model));
1179+
}
1180+
1181+
/// <summary>
1182+
/// Appends a $rerank stage.
1183+
/// </summary>
1184+
/// <typeparam name="TResult">The type of the result.</typeparam>
1185+
/// <param name="aggregate">The aggregate.</param>
1186+
/// <param name="query">The rerank query.</param>
1187+
/// <param name="paths">The fields to send to the reranker.</param>
1188+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1189+
/// <param name="model">The reranking model name.</param>
1190+
/// <returns>The fluent aggregate interface.</returns>
1191+
public static IAggregateFluent<TResult> Rerank<TResult>(
1192+
this IAggregateFluent<TResult> aggregate,
1193+
RerankQuery query,
1194+
IEnumerable<FieldDefinition<TResult>> paths,
1195+
int numDocsToRerank,
1196+
string model)
1197+
{
1198+
Ensure.IsNotNull(aggregate, nameof(aggregate));
1199+
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model));
1200+
}
1201+
11381202
/// <summary>
11391203
/// Appends a $vectorSearch stage.
11401204
/// </summary>

src/MongoDB.Driver/PipelineDefinitionBuilder.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,79 @@ public static PipelineDefinition<TInput, TOutput> RankFusion<TInput, TIntermedia
10811081
return pipeline.AppendStage(PipelineStageDefinitionBuilder.RankFusion(pipelinesWithWeights, options));
10821082
}
10831083

1084+
/// <summary>
1085+
/// Appends a $rerank stage to the pipeline.
1086+
/// </summary>
1087+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1088+
/// <typeparam name="TField">The type of the field.</typeparam>
1089+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1090+
/// <param name="pipeline">The pipeline.</param>
1091+
/// <param name="query">The rerank query.</param>
1092+
/// <param name="path">The field to send to the reranker.</param>
1093+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1094+
/// <param name="model">The reranking model name.</param>
1095+
/// <returns>A new pipeline with an additional stage.</returns>
1096+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TField, TOutput>(
1097+
this PipelineDefinition<TInput, TOutput> pipeline,
1098+
RerankQuery query,
1099+
Expression<Func<TOutput, TField>> path,
1100+
int numDocsToRerank,
1101+
string model)
1102+
{
1103+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1104+
return pipeline.AppendStage(
1105+
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
1106+
pipeline.OutputSerializer);
1107+
}
1108+
1109+
/// <summary>
1110+
/// Appends a $rerank stage to the pipeline.
1111+
/// </summary>
1112+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1113+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1114+
/// <param name="pipeline">The pipeline.</param>
1115+
/// <param name="query">The rerank query.</param>
1116+
/// <param name="path">The field to send to the reranker.</param>
1117+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1118+
/// <param name="model">The reranking model name.</param>
1119+
/// <returns>A new pipeline with an additional stage.</returns>
1120+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
1121+
this PipelineDefinition<TInput, TOutput> pipeline,
1122+
RerankQuery query,
1123+
FieldDefinition<TOutput> path,
1124+
int numDocsToRerank,
1125+
string model)
1126+
{
1127+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1128+
return pipeline.AppendStage(
1129+
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
1130+
pipeline.OutputSerializer);
1131+
}
1132+
1133+
/// <summary>
1134+
/// Appends a $rerank stage to the pipeline.
1135+
/// </summary>
1136+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1137+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1138+
/// <param name="pipeline">The pipeline.</param>
1139+
/// <param name="query">The rerank query.</param>
1140+
/// <param name="paths">The fields to send to the reranker.</param>
1141+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1142+
/// <param name="model">The reranking model name.</param>
1143+
/// <returns>A new pipeline with an additional stage.</returns>
1144+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
1145+
this PipelineDefinition<TInput, TOutput> pipeline,
1146+
RerankQuery query,
1147+
IEnumerable<FieldDefinition<TOutput>> paths,
1148+
int numDocsToRerank,
1149+
string model)
1150+
{
1151+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1152+
return pipeline.AppendStage(
1153+
PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model),
1154+
pipeline.OutputSerializer);
1155+
}
1156+
10841157
/// <summary>
10851158
/// Appends a $replaceRoot stage to the pipeline.
10861159
/// </summary>

src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,6 +1737,94 @@ public static PipelineStageDefinition<TInput, TOutput> ReplaceWith<TInput, TOutp
17371737
return ReplaceWith(new ExpressionAggregateExpressionDefinition<TInput, TOutput>(newRoot));
17381738
}
17391739

1740+
/// <summary>
1741+
/// Creates a $rerank stage.
1742+
/// </summary>
1743+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1744+
/// <typeparam name="TField">The type of the field.</typeparam>
1745+
/// <param name="query">The rerank query.</param>
1746+
/// <param name="path">The field to send to the reranker.</param>
1747+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1748+
/// <param name="model">The reranking model name.</param>
1749+
/// <returns>The stage.</returns>
1750+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput, TField>(
1751+
RerankQuery query,
1752+
Expression<Func<TInput, TField>> path,
1753+
int numDocsToRerank,
1754+
string model)
1755+
=> Rerank(
1756+
query,
1757+
new ExpressionFieldDefinition<TInput>(Ensure.IsNotNull(path, nameof(path))),
1758+
numDocsToRerank,
1759+
model);
1760+
1761+
/// <summary>
1762+
/// Creates a $rerank stage.
1763+
/// </summary>
1764+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1765+
/// <param name="query">The rerank query.</param>
1766+
/// <param name="path">The field to send to the reranker.</param>
1767+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1768+
/// <param name="model">The reranking model name.</param>
1769+
/// <returns>The stage.</returns>
1770+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
1771+
RerankQuery query,
1772+
FieldDefinition<TInput> path,
1773+
int numDocsToRerank,
1774+
string model)
1775+
=> Rerank(
1776+
query,
1777+
[Ensure.IsNotNull(path, nameof(path))],
1778+
numDocsToRerank,
1779+
model);
1780+
1781+
/// <summary>
1782+
/// Creates a $rerank stage.
1783+
/// </summary>
1784+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1785+
/// <param name="query">The rerank query.</param>
1786+
/// <param name="paths">The fields to send to the reranker.</param>
1787+
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
1788+
/// <param name="model">The reranking model name.</param>
1789+
/// <returns>The stage.</returns>
1790+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
1791+
RerankQuery query,
1792+
IEnumerable<FieldDefinition<TInput>> paths,
1793+
int numDocsToRerank,
1794+
string model)
1795+
{
1796+
Ensure.IsNotNull(query, nameof(query));
1797+
Ensure.IsNotNullOrEmpty(paths, nameof(paths));
1798+
Ensure.IsGreaterThanZero(numDocsToRerank, nameof(numDocsToRerank));
1799+
Ensure.IsNotNull(model, nameof(model));
1800+
1801+
const string operatorName = "$rerank";
1802+
var stage = new DelegatedPipelineStageDefinition<TInput, TInput>(
1803+
operatorName,
1804+
args =>
1805+
{
1806+
ClientSideProjectionHelper.ThrowIfClientSideProjection(args.DocumentSerializer, operatorName);
1807+
1808+
var renderedPaths = paths.Select(p => p.Render(args).FieldName).ToList();
1809+
BsonValue pathValue = renderedPaths.Count == 1
1810+
? renderedPaths[0]
1811+
: new BsonArray(renderedPaths);
1812+
1813+
var rerankDocument = new BsonDocument
1814+
{
1815+
{ "query", query.Render() },
1816+
{ "path", pathValue },
1817+
{ "numDocsToRerank", numDocsToRerank },
1818+
{ "model", model }
1819+
};
1820+
1821+
var document = new BsonDocument(operatorName, rerankDocument);
1822+
return new RenderedPipelineStageDefinition<TInput>(operatorName, document, args.DocumentSerializer);
1823+
});
1824+
1825+
return stage;
1826+
}
1827+
17401828
/// <summary>
17411829
/// Creates a $set stage.
17421830
/// </summary>

src/MongoDB.Driver/RerankQuery.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using MongoDB.Bson;
17+
using MongoDB.Driver.Core.Misc;
18+
19+
namespace MongoDB.Driver;
20+
21+
/// <summary>
22+
/// Represents a query for the $rerank aggregation stage.
23+
/// </summary>
24+
/// <seealso cref="PipelineStageDefinitionBuilder.Rerank{TInput}(RerankQuery, FieldDefinition{TInput}, int, string)"/>
25+
public sealed class RerankQuery
26+
{
27+
private readonly string _text;
28+
29+
private RerankQuery(string text)
30+
{
31+
_text = text;
32+
}
33+
34+
/// <summary>
35+
/// Creates a text-based rerank query.
36+
/// </summary>
37+
/// <param name="text">The text to rerank against.</param>
38+
/// <returns>A text rerank query.</returns>
39+
public static RerankQuery Text(string text) =>
40+
new RerankQuery(Ensure.IsNotNullOrEmpty(text, nameof(text)));
41+
42+
internal BsonDocument Render() => new BsonDocument("text", _text);
43+
}

tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,80 @@ public void RankFusion_should_throw_when_pipeline_is_null()
323323
}).ParamName.Should().Be("pipeline");
324324
}
325325

326+
[Fact]
327+
public void Rerank_with_single_string_path_should_render_expected_stage()
328+
{
329+
var pipeline = new EmptyPipelineDefinition<BsonDocument>();
330+
331+
var result = pipeline.Rerank(RerankQuery.Text("machine learning"), "plot", 25, "rerank-2.5");
332+
333+
var stages = RenderStages(result, BsonDocumentSerializer.Instance);
334+
stages.Count.Should().Be(1);
335+
stages[0].Should().Be("""
336+
{
337+
$rerank: {
338+
"query": { "text": "machine learning" },
339+
"path": "plot",
340+
"numDocsToRerank": 25,
341+
"model": "rerank-2.5"
342+
}
343+
}
344+
""");
345+
}
346+
347+
[Fact]
348+
public void Rerank_with_multiple_string_paths_should_render_expected_stage()
349+
{
350+
var pipeline = new EmptyPipelineDefinition<BsonDocument>();
351+
352+
var result = pipeline.Rerank(RerankQuery.Text("machine learning"), ["plot", "title"], 100, "rerank-2.5-lite");
353+
354+
var stages = RenderStages(result, BsonDocumentSerializer.Instance);
355+
stages.Count.Should().Be(1);
356+
stages[0].Should().Be("""
357+
{
358+
$rerank: {
359+
"query": { "text": "machine learning" },
360+
"path": ["plot", "title"],
361+
"numDocsToRerank": 100,
362+
"model": "rerank-2.5-lite"
363+
}
364+
}
365+
""");
366+
}
367+
368+
[Fact]
369+
public void Rerank_with_expression_path_should_render_expected_stage()
370+
{
371+
var pipeline = new EmptyPipelineDefinition<MovieWithPlot>();
372+
373+
var result = pipeline.Rerank(RerankQuery.Text("tutorials"), x => x.Title, 50, "rerank-2");
374+
375+
var stages = RenderStages(result, BsonSerializer.SerializerRegistry.GetSerializer<MovieWithPlot>());
376+
stages.Count.Should().Be(1);
377+
stages[0].Should().Be("""
378+
{
379+
$rerank: {
380+
"query": { "text": "tutorials" },
381+
"path": "Title",
382+
"numDocsToRerank": 50,
383+
"model": "rerank-2"
384+
}
385+
}
386+
""");
387+
}
388+
389+
[Fact]
390+
public void Rerank_should_throw_when_pipeline_is_null()
391+
{
392+
PipelineDefinition<BsonDocument, BsonDocument> pipeline = null;
393+
394+
var exception = Record.Exception(() => pipeline.Rerank(RerankQuery.Text("query"), "field", 10, "rerank-2.5"));
395+
396+
exception.Should().BeOfType<ArgumentNullException>()
397+
.Which.ParamName.Should().Be("pipeline");
398+
}
399+
326400
[Theory]
327401
[InlineData(0)]
328402
[InlineData(15)]
@@ -705,6 +779,8 @@ public void VectorSearch_should_add_expected_stage_with_parent_filters([Values(f
705779

706780
private class MovieWithPlot
707781
{
782+
public string Title { get; set; }
783+
public string Synopsis { get; set; }
708784
public int Year { get; set; }
709785
public NestedPlot Plot { get; set; }
710786
public float[] NonNestedEmbedding { get; set; }

0 commit comments

Comments
 (0)