Skip to content

Commit b5efe84

Browse files
committed
CSHARP-5828: Add Rerank stage builder
1 parent 29a8892 commit b5efe84

8 files changed

Lines changed: 317 additions & 0 deletions

src/MongoDB.Driver/AggregateFluent.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,15 @@ public override IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<
379379
return WithPipeline(_pipeline.Unwind(field, options));
380380
}
381381

382+
public override IAggregateFluent<TResult> Rerank(
383+
RerankQuery query,
384+
FieldDefinition<TResult> path,
385+
int numDocsToRerank,
386+
string model)
387+
{
388+
return WithPipeline(_pipeline.Rerank(query, path, numDocsToRerank, model));
389+
}
390+
382391
public override IAggregateFluent<TResult> VectorSearch(
383392
FieldDefinition<TResult> field,
384393
QueryVector queryVector,

src/MongoDB.Driver/AggregateFluentBase.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,16 @@ public virtual IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<T
331331
throw new NotImplementedException();
332332
}
333333

334+
/// <inheritdoc />
335+
public virtual IAggregateFluent<TResult> Rerank(
336+
RerankQuery query,
337+
FieldDefinition<TResult> path,
338+
int numDocsToRerank,
339+
string model)
340+
{
341+
throw new NotImplementedException();
342+
}
343+
334344
/// <inheritdoc />
335345
public virtual IAggregateFluent<TResult> VectorSearch(
336346
FieldDefinition<TResult> field,

src/MongoDB.Driver/IAggregateFluent.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,20 @@ IAggregateFluent<TResult> UnionWith<TWith>(
545545
/// <returns>The fluent aggregate interface.</returns>
546546
IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<TResult> field, AggregateUnwindOptions<TNewResult> options = null);
547547

548+
/// <summary>
549+
/// Appends a $rerank stage.
550+
/// </summary>
551+
/// <param name="query">The rerank query.</param>
552+
/// <param name="path">The field to send to the reranker.</param>
553+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
554+
/// <param name="model">The reranking model name.</param>
555+
/// <returns>The fluent aggregate interface.</returns>
556+
IAggregateFluent<TResult> Rerank(
557+
RerankQuery query,
558+
FieldDefinition<TResult> path,
559+
int numDocsToRerank,
560+
string model);
561+
548562
/// <summary>
549563
/// Appends a vector search stage.
550564
/// </summary>

src/MongoDB.Driver/IAggregateFluentExtensions.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,29 @@ public static IAggregateFluent<TNewResult> Unwind<TResult, TNewResult>(this IAgg
10411041
return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken);
10421042
}
10431043

1044+
/// <summary>
1045+
/// Appends a $rerank stage.
1046+
/// </summary>
1047+
/// <typeparam name="TResult">The type of the result.</typeparam>
1048+
/// <typeparam name="TField">The type of the field.</typeparam>
1049+
/// <param name="aggregate">The aggregate.</param>
1050+
/// <param name="query">The rerank query.</param>
1051+
/// <param name="path">The field to send to the reranker.</param>
1052+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
1053+
/// <param name="model">The reranking model name.</param>
1054+
/// <returns>The fluent aggregate interface.</returns>
1055+
public static IAggregateFluent<TResult> Rerank<TResult, TField>(
1056+
this IAggregateFluent<TResult> aggregate,
1057+
RerankQuery query,
1058+
Expression<Func<TResult, TField>> path,
1059+
int numDocsToRerank,
1060+
string model)
1061+
{
1062+
Ensure.IsNotNull(aggregate, nameof(aggregate));
1063+
1064+
return aggregate.Rerank(query, new ExpressionFieldDefinition<TResult>(path), numDocsToRerank, model);
1065+
}
1066+
10441067
/// <summary>
10451068
/// Appends a $vectorSearch stage.
10461069
/// </summary>

src/MongoDB.Driver/PipelineDefinitionBuilder.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,79 @@ public static PipelineDefinition<TInput, TOutput> RankFusion<TInput, TIntermedia
10811081
return pipeline.AppendStage(PipelineStageDefinitionBuilder.RankFusion(pipelinesWithWeights, options));
10821082
}
10831083

1084+
/// <summary>
1085+
/// Appends a $rerank stage to the pipeline.
1086+
/// </summary>
1087+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1088+
/// <typeparam name="TField">The type of the field.</typeparam>
1089+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1090+
/// <param name="pipeline">The pipeline.</param>
1091+
/// <param name="query">The rerank query.</param>
1092+
/// <param name="path">The field to send to the reranker.</param>
1093+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
1094+
/// <param name="model">The reranking model name.</param>
1095+
/// <returns>A new pipeline with an additional stage.</returns>
1096+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TField, TOutput>(
1097+
this PipelineDefinition<TInput, TOutput> pipeline,
1098+
RerankQuery query,
1099+
Expression<Func<TOutput, TField>> path,
1100+
int numDocsToRerank,
1101+
string model)
1102+
{
1103+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1104+
return pipeline.AppendStage(
1105+
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
1106+
pipeline.OutputSerializer);
1107+
}
1108+
1109+
/// <summary>
1110+
/// Appends a $rerank stage to the pipeline.
1111+
/// </summary>
1112+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1113+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1114+
/// <param name="pipeline">The pipeline.</param>
1115+
/// <param name="query">The rerank query.</param>
1116+
/// <param name="path">The field to send to the reranker.</param>
1117+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
1118+
/// <param name="model">The reranking model name.</param>
1119+
/// <returns>A new pipeline with an additional stage.</returns>
1120+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
1121+
this PipelineDefinition<TInput, TOutput> pipeline,
1122+
RerankQuery query,
1123+
FieldDefinition<TOutput> path,
1124+
int numDocsToRerank,
1125+
string model)
1126+
{
1127+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1128+
return pipeline.AppendStage(
1129+
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
1130+
pipeline.OutputSerializer);
1131+
}
1132+
1133+
/// <summary>
1134+
/// Appends a $rerank stage to the pipeline.
1135+
/// </summary>
1136+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1137+
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
1138+
/// <param name="pipeline">The pipeline.</param>
1139+
/// <param name="query">The rerank query.</param>
1140+
/// <param name="paths">The fields to send to the reranker.</param>
1141+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
1142+
/// <param name="model">The reranking model name.</param>
1143+
/// <returns>A new pipeline with an additional stage.</returns>
1144+
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
1145+
this PipelineDefinition<TInput, TOutput> pipeline,
1146+
RerankQuery query,
1147+
IEnumerable<FieldDefinition<TOutput>> paths,
1148+
int numDocsToRerank,
1149+
string model)
1150+
{
1151+
Ensure.IsNotNull(pipeline, nameof(pipeline));
1152+
return pipeline.AppendStage(
1153+
PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model),
1154+
pipeline.OutputSerializer);
1155+
}
1156+
10841157
/// <summary>
10851158
/// Appends a $replaceRoot stage to the pipeline.
10861159
/// </summary>

src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,94 @@ public static PipelineStageDefinition<TInput, TOutput> RankFusion<TInput, TOutpu
16041604
return RankFusion(pipelinesMap, weightsMap, options);
16051605
}
16061606

1607+
/// <summary>
1608+
/// Creates a $rerank stage.
1609+
/// </summary>
1610+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1611+
/// <typeparam name="TField">The type of the field.</typeparam>
1612+
/// <param name="query">The rerank query.</param>
1613+
/// <param name="path">The field to send to the reranker.</param>
1614+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
1615+
/// <param name="model">The reranking model name.</param>
1616+
/// <returns>The stage.</returns>
1617+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput, TField>(
1618+
RerankQuery query,
1619+
Expression<Func<TInput, TField>> path,
1620+
int numDocsToRerank,
1621+
string model)
1622+
=> Rerank(
1623+
query,
1624+
new ExpressionFieldDefinition<TInput>(path),
1625+
numDocsToRerank,
1626+
model);
1627+
1628+
/// <summary>
1629+
/// Creates a $rerank stage.
1630+
/// </summary>
1631+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1632+
/// <param name="query">The rerank query.</param>
1633+
/// <param name="path">The field to send to the reranker.</param>
1634+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
1635+
/// <param name="model">The reranking model name.</param>
1636+
/// <returns>The stage.</returns>
1637+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
1638+
RerankQuery query,
1639+
FieldDefinition<TInput> path,
1640+
int numDocsToRerank,
1641+
string model)
1642+
=> Rerank(
1643+
query,
1644+
[path],
1645+
numDocsToRerank,
1646+
model);
1647+
1648+
/// <summary>
1649+
/// Creates a $rerank stage.
1650+
/// </summary>
1651+
/// <typeparam name="TInput">The type of the input documents.</typeparam>
1652+
/// <param name="query">The rerank query.</param>
1653+
/// <param name="paths">The fields to send to the reranker.</param>
1654+
/// <param name="numDocsToRerank">The maximum number of documents to rerank (1–1000).</param>
1655+
/// <param name="model">The reranking model name.</param>
1656+
/// <returns>The stage.</returns>
1657+
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
1658+
RerankQuery query,
1659+
IEnumerable<FieldDefinition<TInput>> paths,
1660+
int numDocsToRerank,
1661+
string model)
1662+
{
1663+
Ensure.IsNotNull(query, nameof(query));
1664+
Ensure.IsNotNull(paths, nameof(paths));
1665+
Ensure.IsGreaterThanZero(numDocsToRerank, nameof(numDocsToRerank));
1666+
Ensure.IsNotNull(model, nameof(model));
1667+
1668+
const string operatorName = "$rerank";
1669+
var stage = new DelegatedPipelineStageDefinition<TInput, TInput>(
1670+
operatorName,
1671+
args =>
1672+
{
1673+
ClientSideProjectionHelper.ThrowIfClientSideProjection(args.DocumentSerializer, operatorName);
1674+
1675+
var renderedPaths = paths.Select(p => p.Render(args).FieldName).ToList();
1676+
BsonValue pathValue = renderedPaths.Count == 1
1677+
? renderedPaths[0]
1678+
: new BsonArray(renderedPaths);
1679+
1680+
var rerankDocument = new BsonDocument
1681+
{
1682+
{ "query", query.Render() },
1683+
{ "path", pathValue },
1684+
{ "numDocsToRerank", numDocsToRerank },
1685+
{ "model", model }
1686+
};
1687+
1688+
var document = new BsonDocument(operatorName, rerankDocument);
1689+
return new RenderedPipelineStageDefinition<TInput>(operatorName, document, args.DocumentSerializer);
1690+
});
1691+
1692+
return stage;
1693+
}
1694+
16071695
/// <summary>
16081696
/// Creates a $replaceRoot stage.
16091697
/// </summary>

tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,81 @@ public void RankFusion_should_throw_when_pipeline_is_null()
323323
}).ParamName.Should().Be("pipeline");
324324
}
325325

326+
[Fact]
327+
public void Rerank_with_single_string_path_should_render_expected_stage()
328+
{
329+
var pipeline = new EmptyPipelineDefinition<BsonDocument>();
330+
331+
var result = pipeline.Rerank(RerankQuery.Text("machine learning"), "plot", 25, "rerank-2.5");
332+
333+
var stages = RenderStages(result, BsonDocumentSerializer.Instance);
334+
stages.Count.Should().Be(1);
335+
stages[0].Should().Be("""
336+
{
337+
$rerank: {
338+
"query": { "text": "machine learning" },
339+
"path": "plot",
340+
"numDocsToRerank": 25,
341+
"model": "rerank-2.5"
342+
}
343+
}
344+
""");
345+
}
346+
347+
[Fact]
348+
public void Rerank_with_multiple_string_paths_should_render_expected_stage()
349+
{
350+
var pipeline = new EmptyPipelineDefinition<BsonDocument>();
351+
FieldDefinition<BsonDocument>[] paths = ["plot", "title"];
352+
353+
var result = pipeline.Rerank(RerankQuery.Text("machine learning"), paths, 100, "rerank-2.5-lite");
354+
355+
var stages = RenderStages(result, BsonDocumentSerializer.Instance);
356+
stages.Count.Should().Be(1);
357+
stages[0].Should().Be("""
358+
{
359+
$rerank: {
360+
"query": { "text": "machine learning" },
361+
"path": ["plot", "title"],
362+
"numDocsToRerank": 100,
363+
"model": "rerank-2.5-lite"
364+
}
365+
}
366+
""");
367+
}
368+
369+
[Fact]
370+
public void Rerank_with_expression_path_should_render_expected_stage()
371+
{
372+
var pipeline = new EmptyPipelineDefinition<MovieWithPlot>();
373+
374+
var result = pipeline.Rerank(RerankQuery.Text("tutorials"), x => x.Title, 50, "rerank-2");
375+
376+
var stages = RenderStages(result, BsonSerializer.SerializerRegistry.GetSerializer<MovieWithPlot>());
377+
stages.Count.Should().Be(1);
378+
stages[0].Should().Be("""
379+
{
380+
$rerank: {
381+
"query": { "text": "tutorials" },
382+
"path": "Title",
383+
"numDocsToRerank": 50,
384+
"model": "rerank-2"
385+
}
386+
}
387+
""");
388+
}
389+
390+
[Fact]
391+
public void Rerank_should_throw_when_pipeline_is_null()
392+
{
393+
PipelineDefinition<BsonDocument, BsonDocument> pipeline = null;
394+
395+
var exception = Record.Exception(() => pipeline.Rerank(RerankQuery.Text("query"), "field", 10, "rerank-2.5"));
396+
397+
exception.Should().BeOfType<ArgumentNullException>()
398+
.Which.ParamName.Should().Be("pipeline");
399+
}
400+
326401
[Theory]
327402
[InlineData(0)]
328403
[InlineData(15)]
@@ -705,6 +780,7 @@ public void VectorSearch_should_add_expected_stage_with_parent_filters([Values(f
705780

706781
private class MovieWithPlot
707782
{
783+
public string Title { get; set; }
708784
public int Year { get; set; }
709785
public NestedPlot Plot { get; set; }
710786
public float[] NonNestedEmbedding { get; set; }

tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,30 @@ public void Lookup_with_empty_required_params_should_throw_expected_exception()
614614
});
615615
}
616616

617+
[Fact]
618+
public void Rerank_with_incorrect_params_should_throw_expected_exception()
619+
{
620+
Assert.Throws<ArgumentNullException>(() =>
621+
{
622+
PipelineStageDefinitionBuilder.Rerank<BsonDocument>(null, "field", 10, "rerank-2.5");
623+
}).ParamName.Should().Be("query");
624+
625+
Assert.Throws<ArgumentNullException>(() =>
626+
{
627+
PipelineStageDefinitionBuilder.Rerank<BsonDocument>(RerankQuery.Text("query"), (IEnumerable<FieldDefinition<BsonDocument>>)null, 10, "rerank-2.5");
628+
}).ParamName.Should().Be("paths");
629+
630+
Assert.Throws<ArgumentNullException>(() =>
631+
{
632+
PipelineStageDefinitionBuilder.Rerank<BsonDocument>(RerankQuery.Text("query"), "field", 10, null);
633+
}).ParamName.Should().Be("model");
634+
635+
Assert.Throws<ArgumentOutOfRangeException>(() =>
636+
{
637+
PipelineStageDefinitionBuilder.Rerank<BsonDocument>(RerankQuery.Text("query"), "field", 0, "rerank-2.5");
638+
}).ParamName.Should().Be("numDocsToRerank");
639+
}
640+
617641
[Fact]
618642
public void RankFusion_with_incorrect_params_should_throw_expected_exception()
619643
{

0 commit comments

Comments
 (0)