From ddf378089cb2f9487195ab18c0b93d2dedc342a7 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Fri, 3 Apr 2026 13:47:11 -0400 Subject: [PATCH 1/8] CSHARP-5828: Add Rerank stage builder --- src/MongoDB.Driver/AggregateFluent.cs | 9 ++ src/MongoDB.Driver/AggregateFluentBase.cs | 10 ++ src/MongoDB.Driver/IAggregateFluent.cs | 14 +++ .../IAggregateFluentExtensions.cs | 46 +++++++ .../PipelineDefinitionBuilder.cs | 98 +++++++++++++++ .../PipelineStageDefinitionBuilder.cs | 112 ++++++++++++++++++ src/MongoDB.Driver/RerankQuery.cs | 43 +++++++ .../PipelineDefinitionBuilderTests.cs | 98 +++++++++++++++ .../PipelineStageDefinitionBuilderTests.cs | 39 ++++++ 9 files changed, 469 insertions(+) create mode 100644 src/MongoDB.Driver/RerankQuery.cs diff --git a/src/MongoDB.Driver/AggregateFluent.cs b/src/MongoDB.Driver/AggregateFluent.cs index 90bd9823ce1..24e69ca4265 100644 --- a/src/MongoDB.Driver/AggregateFluent.cs +++ b/src/MongoDB.Driver/AggregateFluent.cs @@ -379,6 +379,15 @@ public override IAggregateFluent Unwind(FieldDefinition< return WithPipeline(_pipeline.Unwind(field, options)); } + public override IAggregateFluent Rerank( + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + { + return WithPipeline(_pipeline.Rerank(query, path, numDocsToRerank, model)); + } + public override IAggregateFluent VectorSearch( FieldDefinition field, QueryVector queryVector, diff --git a/src/MongoDB.Driver/AggregateFluentBase.cs b/src/MongoDB.Driver/AggregateFluentBase.cs index 4aad9f8a97e..f19892aaeb6 100644 --- a/src/MongoDB.Driver/AggregateFluentBase.cs +++ b/src/MongoDB.Driver/AggregateFluentBase.cs @@ -325,6 +325,16 @@ public virtual IAggregateFluent Unwind(FieldDefinition + public virtual IAggregateFluent Rerank( + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + { + throw new NotImplementedException(); + } + /// public virtual IAggregateFluent VectorSearch( FieldDefinition field, diff --git a/src/MongoDB.Driver/IAggregateFluent.cs b/src/MongoDB.Driver/IAggregateFluent.cs index e8d825c66e6..ab4b679b4e5 100644 --- a/src/MongoDB.Driver/IAggregateFluent.cs +++ b/src/MongoDB.Driver/IAggregateFluent.cs @@ -545,6 +545,20 @@ IAggregateFluent UnionWith( /// The fluent aggregate interface. IAggregateFluent Unwind(FieldDefinition field, AggregateUnwindOptions options = null); + /// + /// Appends a $rerank stage. + /// + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fluent aggregate interface. + IAggregateFluent Rerank( + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model); + /// /// Appends a vector search stage. /// diff --git a/src/MongoDB.Driver/IAggregateFluentExtensions.cs b/src/MongoDB.Driver/IAggregateFluentExtensions.cs index 2dfa93993a2..682273ba99e 100644 --- a/src/MongoDB.Driver/IAggregateFluentExtensions.cs +++ b/src/MongoDB.Driver/IAggregateFluentExtensions.cs @@ -1135,6 +1135,52 @@ public static IAggregateFluent Unwind(this IAgg return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken); } + /// + /// Appends a $rerank stage. + /// + /// The type of the result. + /// The type of the field. + /// The aggregate. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fluent aggregate interface. + public static IAggregateFluent Rerank( + this IAggregateFluent aggregate, + RerankQuery query, + Expression> path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(aggregate, nameof(aggregate)); + + return aggregate.Rerank(query, new ExpressionFieldDefinition(path), numDocsToRerank, model); + } + + /// + /// Appends a $rerank stage. + /// + /// The type of the result. + /// The type of the field. + /// The aggregate. + /// The rerank query. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fields to send to the reranker. + /// The fluent aggregate interface. + public static IAggregateFluent Rerank( + this IAggregateFluent aggregate, + RerankQuery query, + int numDocsToRerank, + string model, + params Expression>[] paths) + { + Ensure.IsNotNull(aggregate, nameof(aggregate)); + + return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, numDocsToRerank, model, paths)); + } + /// /// Appends a $vectorSearch stage. /// diff --git a/src/MongoDB.Driver/PipelineDefinitionBuilder.cs b/src/MongoDB.Driver/PipelineDefinitionBuilder.cs index c086535fb18..b5c209a1d52 100644 --- a/src/MongoDB.Driver/PipelineDefinitionBuilder.cs +++ b/src/MongoDB.Driver/PipelineDefinitionBuilder.cs @@ -1081,6 +1081,104 @@ public static PipelineDefinition RankFusion + /// Appends a $rerank stage to the pipeline. + /// + /// The type of the input documents. + /// The type of the field. + /// The type of the output documents. + /// The pipeline. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// A new pipeline with an additional stage. + public static PipelineDefinition Rerank( + this PipelineDefinition pipeline, + RerankQuery query, + Expression> path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(pipeline, nameof(pipeline)); + return pipeline.AppendStage( + PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model), + pipeline.OutputSerializer); + } + + /// + /// Appends a $rerank stage to the pipeline. + /// + /// The type of the input documents. + /// The type of the field. + /// The type of the output documents. + /// The pipeline. + /// The rerank query. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fields to send to the reranker. + /// A new pipeline with an additional stage. + public static PipelineDefinition Rerank( + this PipelineDefinition pipeline, + RerankQuery query, + int numDocsToRerank, + string model, + params Expression>[] paths) + { + Ensure.IsNotNull(pipeline, nameof(pipeline)); + return pipeline.AppendStage( + PipelineStageDefinitionBuilder.Rerank(query, numDocsToRerank, model, paths), + pipeline.OutputSerializer); + } + + /// + /// Appends a $rerank stage to the pipeline. + /// + /// The type of the input documents. + /// The type of the output documents. + /// The pipeline. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// A new pipeline with an additional stage. + public static PipelineDefinition Rerank( + this PipelineDefinition pipeline, + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(pipeline, nameof(pipeline)); + return pipeline.AppendStage( + PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model), + pipeline.OutputSerializer); + } + + /// + /// Appends a $rerank stage to the pipeline. + /// + /// The type of the input documents. + /// The type of the output documents. + /// The pipeline. + /// The rerank query. + /// The fields to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// A new pipeline with an additional stage. + public static PipelineDefinition Rerank( + this PipelineDefinition pipeline, + RerankQuery query, + IEnumerable> paths, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(pipeline, nameof(pipeline)); + return pipeline.AppendStage( + PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model), + pipeline.OutputSerializer); + } + /// /// Appends a $replaceRoot stage to the pipeline. /// diff --git a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs index e6338366e66..1f901aba4c6 100644 --- a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs +++ b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs @@ -1737,6 +1737,118 @@ public static PipelineStageDefinition ReplaceWith(newRoot)); } + /// + /// Creates a $rerank stage. + /// + /// The type of the input documents. + /// The type of the field. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The stage. + public static PipelineStageDefinition Rerank( + RerankQuery query, + Expression> path, + int numDocsToRerank, + string model) + => Rerank( + query, + new ExpressionFieldDefinition(path), + numDocsToRerank, + model); + + /// + /// Creates a $rerank stage. + /// + /// The type of the input documents. + /// The type of the field. + /// The rerank query. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fields to send to the reranker. + /// The stage. + public static PipelineStageDefinition Rerank( + RerankQuery query, + int numDocsToRerank, + string model, + params Expression>[] paths) + => Rerank( + query, + paths.Select(p => (FieldDefinition)new ExpressionFieldDefinition(p)), + numDocsToRerank, + model); + + /// + /// Creates a $rerank stage. + /// + /// The type of the input documents. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The stage. + public static PipelineStageDefinition Rerank( + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(path, nameof(path)); + return Rerank( + query, + [path], + numDocsToRerank, + model); + } + + /// + /// Creates a $rerank stage. + /// + /// The type of the input documents. + /// The rerank query. + /// The fields to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The stage. + public static PipelineStageDefinition Rerank( + RerankQuery query, + IEnumerable> paths, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(query, nameof(query)); + Ensure.IsNotNullOrEmpty(paths, nameof(paths)); + Ensure.IsGreaterThanZero(numDocsToRerank, nameof(numDocsToRerank)); + Ensure.IsNotNull(model, nameof(model)); + + const string operatorName = "$rerank"; + var stage = new DelegatedPipelineStageDefinition( + operatorName, + args => + { + ClientSideProjectionHelper.ThrowIfClientSideProjection(args.DocumentSerializer, operatorName); + + var renderedPaths = paths.Select(p => p.Render(args).FieldName).ToList(); + BsonValue pathValue = renderedPaths.Count == 1 + ? renderedPaths[0] + : new BsonArray(renderedPaths); + + var rerankDocument = new BsonDocument + { + { "query", query.Render() }, + { "path", pathValue }, + { "numDocsToRerank", numDocsToRerank }, + { "model", model } + }; + + var document = new BsonDocument(operatorName, rerankDocument); + return new RenderedPipelineStageDefinition(operatorName, document, args.DocumentSerializer); + }); + + return stage; + } + /// /// Creates a $set stage. /// diff --git a/src/MongoDB.Driver/RerankQuery.cs b/src/MongoDB.Driver/RerankQuery.cs new file mode 100644 index 00000000000..beb4ccf0468 --- /dev/null +++ b/src/MongoDB.Driver/RerankQuery.cs @@ -0,0 +1,43 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; + +namespace MongoDB.Driver +{ + /// + /// Represents a query for the $rerank aggregation stage. + /// + public sealed class RerankQuery + { + private readonly BsonDocument _rendered; + + private RerankQuery(BsonDocument rendered) + { + _rendered = rendered; + } + + /// + /// Creates a text-based rerank query. + /// + /// The text to rerank against. + /// A text rerank query. + public static RerankQuery Text(string text) => + new RerankQuery(new BsonDocument("text", Ensure.IsNotNull(text, nameof(text)))); + + internal BsonDocument Render() => _rendered; + } +} diff --git a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs index 6857a73faa4..908c6421a4c 100644 --- a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs @@ -323,6 +323,102 @@ public void RankFusion_should_throw_when_pipeline_is_null() }).ParamName.Should().Be("pipeline"); } + [Fact] + public void Rerank_with_single_string_path_should_render_expected_stage() + { + var pipeline = new EmptyPipelineDefinition(); + + var result = pipeline.Rerank(RerankQuery.Text("machine learning"), "plot", 25, "rerank-2.5"); + + var stages = RenderStages(result, BsonDocumentSerializer.Instance); + stages.Count.Should().Be(1); + stages[0].Should().Be(""" + { + $rerank: { + "query": { "text": "machine learning" }, + "path": "plot", + "numDocsToRerank": 25, + "model": "rerank-2.5" + } + } + """); + } + + [Fact] + public void Rerank_with_multiple_string_paths_should_render_expected_stage() + { + var pipeline = new EmptyPipelineDefinition(); + FieldDefinition[] paths = ["plot", "title"]; + + var result = pipeline.Rerank(RerankQuery.Text("machine learning"), paths, 100, "rerank-2.5-lite"); + + var stages = RenderStages(result, BsonDocumentSerializer.Instance); + stages.Count.Should().Be(1); + stages[0].Should().Be(""" + { + $rerank: { + "query": { "text": "machine learning" }, + "path": ["plot", "title"], + "numDocsToRerank": 100, + "model": "rerank-2.5-lite" + } + } + """); + } + + [Fact] + public void Rerank_with_expression_path_should_render_expected_stage() + { + var pipeline = new EmptyPipelineDefinition(); + + var result = pipeline.Rerank(RerankQuery.Text("tutorials"), x => x.Title, 50, "rerank-2"); + + var stages = RenderStages(result, BsonSerializer.SerializerRegistry.GetSerializer()); + stages.Count.Should().Be(1); + stages[0].Should().Be(""" + { + $rerank: { + "query": { "text": "tutorials" }, + "path": "Title", + "numDocsToRerank": 50, + "model": "rerank-2" + } + } + """); + } + + [Fact] + public void Rerank_with_params_expression_paths_should_render_expected_stage() + { + var pipeline = new EmptyPipelineDefinition(); + + var result = pipeline.Rerank(RerankQuery.Text("machine learning"), 25, "rerank-2.5", x => x.Title, x => x.Synopsis); + + var stages = RenderStages(result, BsonSerializer.SerializerRegistry.GetSerializer()); + stages.Count.Should().Be(1); + stages[0].Should().Be(""" + { + $rerank: { + "query": { "text": "machine learning" }, + "path": ["Title", "Synopsis"], + "numDocsToRerank": 25, + "model": "rerank-2.5" + } + } + """); + } + + [Fact] + public void Rerank_should_throw_when_pipeline_is_null() + { + PipelineDefinition pipeline = null; + + var exception = Record.Exception(() => pipeline.Rerank(RerankQuery.Text("query"), "field", 10, "rerank-2.5")); + + exception.Should().BeOfType() + .Which.ParamName.Should().Be("pipeline"); + } + [Theory] [InlineData(0)] [InlineData(15)] @@ -705,6 +801,8 @@ public void VectorSearch_should_add_expected_stage_with_parent_filters([Values(f private class MovieWithPlot { + public string Title { get; set; } + public string Synopsis { get; set; } public int Year { get; set; } public NestedPlot Plot { get; set; } public float[] NonNestedEmbedding { get; set; } diff --git a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs index fd7aca30204..42f9596a784 100644 --- a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs @@ -653,6 +653,45 @@ public void RankFusion_with_incorrect_params_should_throw_expected_exception() }); } + [Fact] + public void Rerank_with_incorrect_params_should_throw_expected_exception() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(null, "field", 10, "rerank-2.5"); + }).ParamName.Should().Be("query"); + + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), (IEnumerable>)null, 10, "rerank-2.5"); + }).ParamName.Should().Be("paths"); + + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), "field", 10, null); + }).ParamName.Should().Be("model"); + + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), "field", 0, "rerank-2.5"); + }).ParamName.Should().Be("numDocsToRerank"); + + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(null, 10, "rerank-2.5"); + }).ParamName.Should().Be("query"); + + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), (FieldDefinition)null, 10, "rerank-2.5"); + }).ParamName.Should().Be("path"); + + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), Array.Empty>(), 10, "rerank-2.5"); + }).ParamName.Should().Be("paths"); + } + [Fact] public void Search_without_returnScope_should_throw_when_output_type_differs_from_input_type() { From abee37fdc32067492772af8dc6d700dac25f0922 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Mon, 6 Apr 2026 23:39:29 -0400 Subject: [PATCH 2/8] fix api compat issues --- src/MongoDB.Driver/AggregateFluent.cs | 9 ------- src/MongoDB.Driver/AggregateFluentBase.cs | 10 -------- src/MongoDB.Driver/IAggregateFluent.cs | 14 ----------- .../IAggregateFluentExtensions.cs | 25 ++++++++++++++++--- 4 files changed, 22 insertions(+), 36 deletions(-) diff --git a/src/MongoDB.Driver/AggregateFluent.cs b/src/MongoDB.Driver/AggregateFluent.cs index 24e69ca4265..90bd9823ce1 100644 --- a/src/MongoDB.Driver/AggregateFluent.cs +++ b/src/MongoDB.Driver/AggregateFluent.cs @@ -379,15 +379,6 @@ public override IAggregateFluent Unwind(FieldDefinition< return WithPipeline(_pipeline.Unwind(field, options)); } - public override IAggregateFluent Rerank( - RerankQuery query, - FieldDefinition path, - int numDocsToRerank, - string model) - { - return WithPipeline(_pipeline.Rerank(query, path, numDocsToRerank, model)); - } - public override IAggregateFluent VectorSearch( FieldDefinition field, QueryVector queryVector, diff --git a/src/MongoDB.Driver/AggregateFluentBase.cs b/src/MongoDB.Driver/AggregateFluentBase.cs index f19892aaeb6..4aad9f8a97e 100644 --- a/src/MongoDB.Driver/AggregateFluentBase.cs +++ b/src/MongoDB.Driver/AggregateFluentBase.cs @@ -325,16 +325,6 @@ public virtual IAggregateFluent Unwind(FieldDefinition - public virtual IAggregateFluent Rerank( - RerankQuery query, - FieldDefinition path, - int numDocsToRerank, - string model) - { - throw new NotImplementedException(); - } - /// public virtual IAggregateFluent VectorSearch( FieldDefinition field, diff --git a/src/MongoDB.Driver/IAggregateFluent.cs b/src/MongoDB.Driver/IAggregateFluent.cs index ab4b679b4e5..e8d825c66e6 100644 --- a/src/MongoDB.Driver/IAggregateFluent.cs +++ b/src/MongoDB.Driver/IAggregateFluent.cs @@ -545,20 +545,6 @@ IAggregateFluent UnionWith( /// The fluent aggregate interface. IAggregateFluent Unwind(FieldDefinition field, AggregateUnwindOptions options = null); - /// - /// Appends a $rerank stage. - /// - /// The rerank query. - /// The field to send to the reranker. - /// The maximum number of documents to rerank. - /// The reranking model name. - /// The fluent aggregate interface. - IAggregateFluent Rerank( - RerankQuery query, - FieldDefinition path, - int numDocsToRerank, - string model); - /// /// Appends a vector search stage. /// diff --git a/src/MongoDB.Driver/IAggregateFluentExtensions.cs b/src/MongoDB.Driver/IAggregateFluentExtensions.cs index 682273ba99e..b7a4028773f 100644 --- a/src/MongoDB.Driver/IAggregateFluentExtensions.cs +++ b/src/MongoDB.Driver/IAggregateFluentExtensions.cs @@ -1135,6 +1135,27 @@ public static IAggregateFluent Unwind(this IAgg return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken); } + /// + /// Appends a $rerank stage. + /// + /// The type of the result. + /// The aggregate. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fluent aggregate interface. + public static IAggregateFluent Rerank( + this IAggregateFluent aggregate, + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(aggregate, nameof(aggregate)); + return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model)); + } + /// /// Appends a $rerank stage. /// @@ -1154,8 +1175,7 @@ public static IAggregateFluent Rerank( string model) { Ensure.IsNotNull(aggregate, nameof(aggregate)); - - return aggregate.Rerank(query, new ExpressionFieldDefinition(path), numDocsToRerank, model); + return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model)); } /// @@ -1177,7 +1197,6 @@ public static IAggregateFluent Rerank( params Expression>[] paths) { Ensure.IsNotNull(aggregate, nameof(aggregate)); - return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, numDocsToRerank, model, paths)); } From 44cb7872c40d633b3972dea1622cf4b87d15c828 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 9 Apr 2026 11:13:10 -0400 Subject: [PATCH 3/8] address code review comments --- .../IAggregateFluentExtensions.cs | 21 +++++++ src/MongoDB.Driver/RerankQuery.cs | 10 ++-- .../PipelineDefinitionBuilderTests.cs | 3 +- .../PipelineStageDefinitionBuilderTests.cs | 56 +++++++++++++------ 4 files changed, 67 insertions(+), 23 deletions(-) diff --git a/src/MongoDB.Driver/IAggregateFluentExtensions.cs b/src/MongoDB.Driver/IAggregateFluentExtensions.cs index b7a4028773f..773e708cf24 100644 --- a/src/MongoDB.Driver/IAggregateFluentExtensions.cs +++ b/src/MongoDB.Driver/IAggregateFluentExtensions.cs @@ -1178,6 +1178,27 @@ public static IAggregateFluent Rerank( return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model)); } + /// + /// Appends a $rerank stage. + /// + /// The type of the result. + /// The aggregate. + /// The rerank query. + /// The fields to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fluent aggregate interface. + public static IAggregateFluent Rerank( + this IAggregateFluent aggregate, + RerankQuery query, + IEnumerable> paths, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(aggregate, nameof(aggregate)); + return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model)); + } + /// /// Appends a $rerank stage. /// diff --git a/src/MongoDB.Driver/RerankQuery.cs b/src/MongoDB.Driver/RerankQuery.cs index beb4ccf0468..ac49a4ed4ed 100644 --- a/src/MongoDB.Driver/RerankQuery.cs +++ b/src/MongoDB.Driver/RerankQuery.cs @@ -23,11 +23,11 @@ namespace MongoDB.Driver /// public sealed class RerankQuery { - private readonly BsonDocument _rendered; + private readonly string _text; - private RerankQuery(BsonDocument rendered) + private RerankQuery(string text) { - _rendered = rendered; + _text = text; } /// @@ -36,8 +36,8 @@ private RerankQuery(BsonDocument rendered) /// The text to rerank against. /// A text rerank query. public static RerankQuery Text(string text) => - new RerankQuery(new BsonDocument("text", Ensure.IsNotNull(text, nameof(text)))); + new RerankQuery(Ensure.IsNotNullOrEmpty(text, nameof(text))); - internal BsonDocument Render() => _rendered; + internal BsonDocument Render() => new BsonDocument("text", _text); } } diff --git a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs index 908c6421a4c..8a178a31af1 100644 --- a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs @@ -348,9 +348,8 @@ public void Rerank_with_single_string_path_should_render_expected_stage() public void Rerank_with_multiple_string_paths_should_render_expected_stage() { var pipeline = new EmptyPipelineDefinition(); - FieldDefinition[] paths = ["plot", "title"]; - var result = pipeline.Rerank(RerankQuery.Text("machine learning"), paths, 100, "rerank-2.5-lite"); + var result = pipeline.Rerank(RerankQuery.Text("machine learning"), ["plot", "title"], 100, "rerank-2.5-lite"); var stages = RenderStages(result, BsonDocumentSerializer.Instance); stages.Count.Should().Be(1); diff --git a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs index 42f9596a784..bba96ac27a4 100644 --- a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs @@ -654,44 +654,68 @@ public void RankFusion_with_incorrect_params_should_throw_expected_exception() } [Fact] - public void Rerank_with_incorrect_params_should_throw_expected_exception() + public void Rerank_should_throw_when_model_is_null() { - Assert.Throws(() => - { - PipelineStageDefinitionBuilder.Rerank(null, "field", 10, "rerank-2.5"); - }).ParamName.Should().Be("query"); - - Assert.Throws(() => - { - PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), (IEnumerable>)null, 10, "rerank-2.5"); - }).ParamName.Should().Be("paths"); - Assert.Throws(() => { PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), "field", 10, null); }).ParamName.Should().Be("model"); + } + [Fact] + public void Rerank_should_throw_when_numDocsToRerank_is_not_greater_than_zero() + { Assert.Throws(() => { PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), "field", 0, "rerank-2.5"); }).ParamName.Should().Be("numDocsToRerank"); + } - Assert.Throws(() => - { - PipelineStageDefinitionBuilder.Rerank(null, 10, "rerank-2.5"); - }).ParamName.Should().Be("query"); - + [Fact] + public void Rerank_should_throw_when_path_is_null() + { Assert.Throws(() => { PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), (FieldDefinition)null, 10, "rerank-2.5"); }).ParamName.Should().Be("path"); + } + [Fact] + public void Rerank_should_throw_when_paths_is_empty() + { Assert.Throws(() => { PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), Array.Empty>(), 10, "rerank-2.5"); }).ParamName.Should().Be("paths"); } + [Fact] + public void Rerank_should_throw_when_paths_is_null() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), (IEnumerable>)null, 10, "rerank-2.5"); + }).ParamName.Should().Be("paths"); + } + + [Fact] + public void Rerank_should_throw_when_query_is_null() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(null, "field", 10, "rerank-2.5"); + }).ParamName.Should().Be("query"); + } + + [Fact] + public void Rerank_should_throw_when_query_is_null_for_params_overload() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(null, 10, "rerank-2.5"); + }).ParamName.Should().Be("query"); + } + [Fact] public void Search_without_returnScope_should_throw_when_output_type_differs_from_input_type() { From 6c0d2c351f19198c863a8c747175f5e132f47fb4 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 16 Apr 2026 10:07:52 -0400 Subject: [PATCH 4/8] address PR feedback --- .../PipelineStageDefinitionBuilder.cs | 7 ++- src/MongoDB.Driver/RerankQuery.cs | 40 +++++++-------- .../MongoDB.Driver.Tests/RerankQueryTests.cs | 51 +++++++++++++++++++ 3 files changed, 76 insertions(+), 22 deletions(-) create mode 100644 tests/MongoDB.Driver.Tests/RerankQueryTests.cs diff --git a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs index 1f901aba4c6..363bcfd3216 100644 --- a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs +++ b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs @@ -1752,11 +1752,14 @@ public static PipelineStageDefinition Rerank( Expression> path, int numDocsToRerank, string model) - => Rerank( + { + Ensure.IsNotNull(path, nameof(path)); + return Rerank( query, new ExpressionFieldDefinition(path), numDocsToRerank, model); + } /// /// Creates a $rerank stage. @@ -1775,7 +1778,7 @@ public static PipelineStageDefinition Rerank( params Expression>[] paths) => Rerank( query, - paths.Select(p => (FieldDefinition)new ExpressionFieldDefinition(p)), + paths.Select(FieldDefinition (p) => new ExpressionFieldDefinition(p)), numDocsToRerank, model); diff --git a/src/MongoDB.Driver/RerankQuery.cs b/src/MongoDB.Driver/RerankQuery.cs index ac49a4ed4ed..fc928d787ec 100644 --- a/src/MongoDB.Driver/RerankQuery.cs +++ b/src/MongoDB.Driver/RerankQuery.cs @@ -16,28 +16,28 @@ using MongoDB.Bson; using MongoDB.Driver.Core.Misc; -namespace MongoDB.Driver +namespace MongoDB.Driver; + +/// +/// Represents a query for the $rerank aggregation stage. +/// +/// +public sealed class RerankQuery { - /// - /// Represents a query for the $rerank aggregation stage. - /// - public sealed class RerankQuery - { - private readonly string _text; + private readonly string _text; - private RerankQuery(string text) - { - _text = text; - } + private RerankQuery(string text) + { + _text = text; + } - /// - /// Creates a text-based rerank query. - /// - /// The text to rerank against. - /// A text rerank query. - public static RerankQuery Text(string text) => - new RerankQuery(Ensure.IsNotNullOrEmpty(text, nameof(text))); + /// + /// Creates a text-based rerank query. + /// + /// The text to rerank against. + /// A text rerank query. + public static RerankQuery Text(string text) => + new RerankQuery(Ensure.IsNotNullOrEmpty(text, nameof(text))); - internal BsonDocument Render() => new BsonDocument("text", _text); - } + internal BsonDocument Render() => new BsonDocument("text", _text); } diff --git a/tests/MongoDB.Driver.Tests/RerankQueryTests.cs b/tests/MongoDB.Driver.Tests/RerankQueryTests.cs new file mode 100644 index 00000000000..83e010c9b03 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/RerankQueryTests.cs @@ -0,0 +1,51 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests; + +public class RerankQueryTests +{ + [Fact] + public void Text_should_create_query_with_expected_rendering() + { + var query = RerankQuery.Text("machine learning"); + + var rendered = query.Render(); + + rendered.Should().Be("{ text: 'machine learning' }"); + } + + [Fact] + public void Text_should_throw_when_text_is_null() + { + var exception = Record.Exception(() => RerankQuery.Text(null)); + + exception.Should().BeOfType() + .Which.ParamName.Should().Be("text"); + } + + [Fact] + public void Text_should_throw_when_text_is_empty() + { + var exception = Record.Exception(() => RerankQuery.Text("")); + + exception.Should().BeOfType() + .Which.ParamName.Should().Be("text"); + } +} From 83f9a3b4d9add7308974724b1648cae6f9618657 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 16 Apr 2026 12:44:25 -0400 Subject: [PATCH 5/8] remove params overload add some small improvements --- .../IAggregateFluentExtensions.cs | 22 ------------- .../PipelineDefinitionBuilder.cs | 25 -------------- .../PipelineStageDefinitionBuilder.cs | 33 ++----------------- .../PipelineDefinitionBuilderTests.cs | 21 ------------ .../PipelineStageDefinitionBuilderTests.cs | 9 ----- 5 files changed, 3 insertions(+), 107 deletions(-) diff --git a/src/MongoDB.Driver/IAggregateFluentExtensions.cs b/src/MongoDB.Driver/IAggregateFluentExtensions.cs index 773e708cf24..a11f5c5f712 100644 --- a/src/MongoDB.Driver/IAggregateFluentExtensions.cs +++ b/src/MongoDB.Driver/IAggregateFluentExtensions.cs @@ -1199,28 +1199,6 @@ public static IAggregateFluent Rerank( return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model)); } - /// - /// Appends a $rerank stage. - /// - /// The type of the result. - /// The type of the field. - /// The aggregate. - /// The rerank query. - /// The maximum number of documents to rerank. - /// The reranking model name. - /// The fields to send to the reranker. - /// The fluent aggregate interface. - public static IAggregateFluent Rerank( - this IAggregateFluent aggregate, - RerankQuery query, - int numDocsToRerank, - string model, - params Expression>[] paths) - { - Ensure.IsNotNull(aggregate, nameof(aggregate)); - return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, numDocsToRerank, model, paths)); - } - /// /// Appends a $vectorSearch stage. /// diff --git a/src/MongoDB.Driver/PipelineDefinitionBuilder.cs b/src/MongoDB.Driver/PipelineDefinitionBuilder.cs index b5c209a1d52..8996e160118 100644 --- a/src/MongoDB.Driver/PipelineDefinitionBuilder.cs +++ b/src/MongoDB.Driver/PipelineDefinitionBuilder.cs @@ -1106,31 +1106,6 @@ public static PipelineDefinition Rerank - /// Appends a $rerank stage to the pipeline. - /// - /// The type of the input documents. - /// The type of the field. - /// The type of the output documents. - /// The pipeline. - /// The rerank query. - /// The maximum number of documents to rerank. - /// The reranking model name. - /// The fields to send to the reranker. - /// A new pipeline with an additional stage. - public static PipelineDefinition Rerank( - this PipelineDefinition pipeline, - RerankQuery query, - int numDocsToRerank, - string model, - params Expression>[] paths) - { - Ensure.IsNotNull(pipeline, nameof(pipeline)); - return pipeline.AppendStage( - PipelineStageDefinitionBuilder.Rerank(query, numDocsToRerank, model, paths), - pipeline.OutputSerializer); - } - /// /// Appends a $rerank stage to the pipeline. /// diff --git a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs index 363bcfd3216..f95e9716c43 100644 --- a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs +++ b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs @@ -1752,33 +1752,9 @@ public static PipelineStageDefinition Rerank( Expression> path, int numDocsToRerank, string model) - { - Ensure.IsNotNull(path, nameof(path)); - return Rerank( - query, - new ExpressionFieldDefinition(path), - numDocsToRerank, - model); - } - - /// - /// Creates a $rerank stage. - /// - /// The type of the input documents. - /// The type of the field. - /// The rerank query. - /// The maximum number of documents to rerank. - /// The reranking model name. - /// The fields to send to the reranker. - /// The stage. - public static PipelineStageDefinition Rerank( - RerankQuery query, - int numDocsToRerank, - string model, - params Expression>[] paths) => Rerank( query, - paths.Select(FieldDefinition (p) => new ExpressionFieldDefinition(p)), + new ExpressionFieldDefinition(Ensure.IsNotNull(path, nameof(path))), numDocsToRerank, model); @@ -1796,14 +1772,11 @@ public static PipelineStageDefinition Rerank( FieldDefinition path, int numDocsToRerank, string model) - { - Ensure.IsNotNull(path, nameof(path)); - return Rerank( + => Rerank( query, - [path], + [Ensure.IsNotNull(path, nameof(path))], numDocsToRerank, model); - } /// /// Creates a $rerank stage. diff --git a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs index 8a178a31af1..add45c5c197 100644 --- a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs @@ -386,27 +386,6 @@ public void Rerank_with_expression_path_should_render_expected_stage() """); } - [Fact] - public void Rerank_with_params_expression_paths_should_render_expected_stage() - { - var pipeline = new EmptyPipelineDefinition(); - - var result = pipeline.Rerank(RerankQuery.Text("machine learning"), 25, "rerank-2.5", x => x.Title, x => x.Synopsis); - - var stages = RenderStages(result, BsonSerializer.SerializerRegistry.GetSerializer()); - stages.Count.Should().Be(1); - stages[0].Should().Be(""" - { - $rerank: { - "query": { "text": "machine learning" }, - "path": ["Title", "Synopsis"], - "numDocsToRerank": 25, - "model": "rerank-2.5" - } - } - """); - } - [Fact] public void Rerank_should_throw_when_pipeline_is_null() { diff --git a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs index bba96ac27a4..0361442fd50 100644 --- a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs @@ -707,15 +707,6 @@ public void Rerank_should_throw_when_query_is_null() }).ParamName.Should().Be("query"); } - [Fact] - public void Rerank_should_throw_when_query_is_null_for_params_overload() - { - Assert.Throws(() => - { - PipelineStageDefinitionBuilder.Rerank(null, 10, "rerank-2.5"); - }).ParamName.Should().Be("query"); - } - [Fact] public void Search_without_returnScope_should_throw_when_output_type_differs_from_input_type() { From 9b6a730319095754f4fa326037a61e1a4fb021da Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 16 Apr 2026 17:20:15 -0400 Subject: [PATCH 6/8] add integration test --- .../Search/AtlasSearchTests.cs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs index cbd0d70c09c..0f3723471c6 100644 --- a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs +++ b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs @@ -1097,6 +1097,23 @@ public void ReturnScope_HasAncestor(bool useExpression) } } + [Fact] + public void Rerank() + { + const int numDocsToRerank = 10; + + var result = GetMoviesCollection() + .Aggregate() + .Search(Builders.Search.Text("plot", "apes")) + .Rerank(RerankQuery.Text("a movie about intelligent apes who take over civilization"), "plot", numDocsToRerank, "rerank-2.5-lite") + .Project(Builders.Projection + .MetaScore(m => m.Score)) + .ToList(); + + result.Count.Should().BeGreaterThan(0).And.BeLessOrEqualTo(numDocsToRerank); + result.Should().BeInDescendingOrder(m => m.Score); + } + [Fact] public void SearchSequenceToken() { From a0c2e39483a2042e9992abd859c80b99da518ec2 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 16 Apr 2026 17:21:09 -0400 Subject: [PATCH 7/8] remove unused using directives --- tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs index 0f3723471c6..7095a361493 100644 --- a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs +++ b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs @@ -16,7 +16,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Threading; using FluentAssertions; using MongoDB.Bson; using MongoDB.Bson.Serialization; From ca2b24705eb0d0060e530c46f73a19026245e914 Mon Sep 17 00:00:00 2001 From: adelinowona Date: Thu, 16 Apr 2026 17:53:34 -0400 Subject: [PATCH 8/8] readd accidentally removed thread library --- tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs index 7095a361493..0f3723471c6 100644 --- a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs +++ b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs @@ -16,6 +16,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using FluentAssertions; using MongoDB.Bson; using MongoDB.Bson.Serialization;