diff --git a/src/MongoDB.Driver/IAggregateFluentExtensions.cs b/src/MongoDB.Driver/IAggregateFluentExtensions.cs index 2dfa93993a2..a11f5c5f712 100644 --- a/src/MongoDB.Driver/IAggregateFluentExtensions.cs +++ b/src/MongoDB.Driver/IAggregateFluentExtensions.cs @@ -1135,6 +1135,70 @@ public static IAggregateFluent Unwind(this IAgg return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken); } + /// + /// Appends a $rerank stage. + /// + /// The type of the result. + /// The aggregate. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fluent aggregate interface. + public static IAggregateFluent Rerank( + this IAggregateFluent aggregate, + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(aggregate, nameof(aggregate)); + return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model)); + } + + /// + /// Appends a $rerank stage. + /// + /// The type of the result. + /// The type of the field. + /// The aggregate. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fluent aggregate interface. + public static IAggregateFluent Rerank( + this IAggregateFluent aggregate, + RerankQuery query, + Expression> path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(aggregate, nameof(aggregate)); + return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model)); + } + + /// + /// Appends a $rerank stage. + /// + /// The type of the result. + /// The aggregate. + /// The rerank query. + /// The fields to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The fluent aggregate interface. + public static IAggregateFluent Rerank( + this IAggregateFluent aggregate, + RerankQuery query, + IEnumerable> paths, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(aggregate, nameof(aggregate)); + return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model)); + } + /// /// Appends a $vectorSearch stage. /// diff --git a/src/MongoDB.Driver/PipelineDefinitionBuilder.cs b/src/MongoDB.Driver/PipelineDefinitionBuilder.cs index c086535fb18..8996e160118 100644 --- a/src/MongoDB.Driver/PipelineDefinitionBuilder.cs +++ b/src/MongoDB.Driver/PipelineDefinitionBuilder.cs @@ -1081,6 +1081,79 @@ public static PipelineDefinition RankFusion + /// Appends a $rerank stage to the pipeline. + /// + /// The type of the input documents. + /// The type of the field. + /// The type of the output documents. + /// The pipeline. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// A new pipeline with an additional stage. + public static PipelineDefinition Rerank( + this PipelineDefinition pipeline, + RerankQuery query, + Expression> path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(pipeline, nameof(pipeline)); + return pipeline.AppendStage( + PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model), + pipeline.OutputSerializer); + } + + /// + /// Appends a $rerank stage to the pipeline. + /// + /// The type of the input documents. + /// The type of the output documents. + /// The pipeline. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// A new pipeline with an additional stage. + public static PipelineDefinition Rerank( + this PipelineDefinition pipeline, + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(pipeline, nameof(pipeline)); + return pipeline.AppendStage( + PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model), + pipeline.OutputSerializer); + } + + /// + /// Appends a $rerank stage to the pipeline. + /// + /// The type of the input documents. + /// The type of the output documents. + /// The pipeline. + /// The rerank query. + /// The fields to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// A new pipeline with an additional stage. + public static PipelineDefinition Rerank( + this PipelineDefinition pipeline, + RerankQuery query, + IEnumerable> paths, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(pipeline, nameof(pipeline)); + return pipeline.AppendStage( + PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model), + pipeline.OutputSerializer); + } + /// /// Appends a $replaceRoot stage to the pipeline. /// diff --git a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs index e6338366e66..f95e9716c43 100644 --- a/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs +++ b/src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs @@ -1737,6 +1737,94 @@ public static PipelineStageDefinition ReplaceWith(newRoot)); } + /// + /// Creates a $rerank stage. + /// + /// The type of the input documents. + /// The type of the field. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The stage. + public static PipelineStageDefinition Rerank( + RerankQuery query, + Expression> path, + int numDocsToRerank, + string model) + => Rerank( + query, + new ExpressionFieldDefinition(Ensure.IsNotNull(path, nameof(path))), + numDocsToRerank, + model); + + /// + /// Creates a $rerank stage. + /// + /// The type of the input documents. + /// The rerank query. + /// The field to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The stage. + public static PipelineStageDefinition Rerank( + RerankQuery query, + FieldDefinition path, + int numDocsToRerank, + string model) + => Rerank( + query, + [Ensure.IsNotNull(path, nameof(path))], + numDocsToRerank, + model); + + /// + /// Creates a $rerank stage. + /// + /// The type of the input documents. + /// The rerank query. + /// The fields to send to the reranker. + /// The maximum number of documents to rerank. + /// The reranking model name. + /// The stage. + public static PipelineStageDefinition Rerank( + RerankQuery query, + IEnumerable> paths, + int numDocsToRerank, + string model) + { + Ensure.IsNotNull(query, nameof(query)); + Ensure.IsNotNullOrEmpty(paths, nameof(paths)); + Ensure.IsGreaterThanZero(numDocsToRerank, nameof(numDocsToRerank)); + Ensure.IsNotNull(model, nameof(model)); + + const string operatorName = "$rerank"; + var stage = new DelegatedPipelineStageDefinition( + operatorName, + args => + { + ClientSideProjectionHelper.ThrowIfClientSideProjection(args.DocumentSerializer, operatorName); + + var renderedPaths = paths.Select(p => p.Render(args).FieldName).ToList(); + BsonValue pathValue = renderedPaths.Count == 1 + ? renderedPaths[0] + : new BsonArray(renderedPaths); + + var rerankDocument = new BsonDocument + { + { "query", query.Render() }, + { "path", pathValue }, + { "numDocsToRerank", numDocsToRerank }, + { "model", model } + }; + + var document = new BsonDocument(operatorName, rerankDocument); + return new RenderedPipelineStageDefinition(operatorName, document, args.DocumentSerializer); + }); + + return stage; + } + /// /// Creates a $set stage. /// diff --git a/src/MongoDB.Driver/RerankQuery.cs b/src/MongoDB.Driver/RerankQuery.cs new file mode 100644 index 00000000000..fc928d787ec --- /dev/null +++ b/src/MongoDB.Driver/RerankQuery.cs @@ -0,0 +1,43 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using MongoDB.Bson; +using MongoDB.Driver.Core.Misc; + +namespace MongoDB.Driver; + +/// +/// Represents a query for the $rerank aggregation stage. +/// +/// +public sealed class RerankQuery +{ + private readonly string _text; + + private RerankQuery(string text) + { + _text = text; + } + + /// + /// Creates a text-based rerank query. + /// + /// The text to rerank against. + /// A text rerank query. + public static RerankQuery Text(string text) => + new RerankQuery(Ensure.IsNotNullOrEmpty(text, nameof(text))); + + internal BsonDocument Render() => new BsonDocument("text", _text); +} diff --git a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs index 6857a73faa4..add45c5c197 100644 --- a/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs @@ -323,6 +323,80 @@ public void RankFusion_should_throw_when_pipeline_is_null() }).ParamName.Should().Be("pipeline"); } + [Fact] + public void Rerank_with_single_string_path_should_render_expected_stage() + { + var pipeline = new EmptyPipelineDefinition(); + + var result = pipeline.Rerank(RerankQuery.Text("machine learning"), "plot", 25, "rerank-2.5"); + + var stages = RenderStages(result, BsonDocumentSerializer.Instance); + stages.Count.Should().Be(1); + stages[0].Should().Be(""" + { + $rerank: { + "query": { "text": "machine learning" }, + "path": "plot", + "numDocsToRerank": 25, + "model": "rerank-2.5" + } + } + """); + } + + [Fact] + public void Rerank_with_multiple_string_paths_should_render_expected_stage() + { + var pipeline = new EmptyPipelineDefinition(); + + var result = pipeline.Rerank(RerankQuery.Text("machine learning"), ["plot", "title"], 100, "rerank-2.5-lite"); + + var stages = RenderStages(result, BsonDocumentSerializer.Instance); + stages.Count.Should().Be(1); + stages[0].Should().Be(""" + { + $rerank: { + "query": { "text": "machine learning" }, + "path": ["plot", "title"], + "numDocsToRerank": 100, + "model": "rerank-2.5-lite" + } + } + """); + } + + [Fact] + public void Rerank_with_expression_path_should_render_expected_stage() + { + var pipeline = new EmptyPipelineDefinition(); + + var result = pipeline.Rerank(RerankQuery.Text("tutorials"), x => x.Title, 50, "rerank-2"); + + var stages = RenderStages(result, BsonSerializer.SerializerRegistry.GetSerializer()); + stages.Count.Should().Be(1); + stages[0].Should().Be(""" + { + $rerank: { + "query": { "text": "tutorials" }, + "path": "Title", + "numDocsToRerank": 50, + "model": "rerank-2" + } + } + """); + } + + [Fact] + public void Rerank_should_throw_when_pipeline_is_null() + { + PipelineDefinition pipeline = null; + + var exception = Record.Exception(() => pipeline.Rerank(RerankQuery.Text("query"), "field", 10, "rerank-2.5")); + + exception.Should().BeOfType() + .Which.ParamName.Should().Be("pipeline"); + } + [Theory] [InlineData(0)] [InlineData(15)] @@ -705,6 +779,8 @@ public void VectorSearch_should_add_expected_stage_with_parent_filters([Values(f private class MovieWithPlot { + public string Title { get; set; } + public string Synopsis { get; set; } public int Year { get; set; } public NestedPlot Plot { get; set; } public float[] NonNestedEmbedding { get; set; } diff --git a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs index fd7aca30204..0361442fd50 100644 --- a/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs +++ b/tests/MongoDB.Driver.Tests/PipelineStageDefinitionBuilderTests.cs @@ -653,6 +653,60 @@ public void RankFusion_with_incorrect_params_should_throw_expected_exception() }); } + [Fact] + public void Rerank_should_throw_when_model_is_null() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), "field", 10, null); + }).ParamName.Should().Be("model"); + } + + [Fact] + public void Rerank_should_throw_when_numDocsToRerank_is_not_greater_than_zero() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), "field", 0, "rerank-2.5"); + }).ParamName.Should().Be("numDocsToRerank"); + } + + [Fact] + public void Rerank_should_throw_when_path_is_null() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), (FieldDefinition)null, 10, "rerank-2.5"); + }).ParamName.Should().Be("path"); + } + + [Fact] + public void Rerank_should_throw_when_paths_is_empty() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), Array.Empty>(), 10, "rerank-2.5"); + }).ParamName.Should().Be("paths"); + } + + [Fact] + public void Rerank_should_throw_when_paths_is_null() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(RerankQuery.Text("query"), (IEnumerable>)null, 10, "rerank-2.5"); + }).ParamName.Should().Be("paths"); + } + + [Fact] + public void Rerank_should_throw_when_query_is_null() + { + Assert.Throws(() => + { + PipelineStageDefinitionBuilder.Rerank(null, "field", 10, "rerank-2.5"); + }).ParamName.Should().Be("query"); + } + [Fact] public void Search_without_returnScope_should_throw_when_output_type_differs_from_input_type() { diff --git a/tests/MongoDB.Driver.Tests/RerankQueryTests.cs b/tests/MongoDB.Driver.Tests/RerankQueryTests.cs new file mode 100644 index 00000000000..83e010c9b03 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/RerankQueryTests.cs @@ -0,0 +1,51 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests; + +public class RerankQueryTests +{ + [Fact] + public void Text_should_create_query_with_expected_rendering() + { + var query = RerankQuery.Text("machine learning"); + + var rendered = query.Render(); + + rendered.Should().Be("{ text: 'machine learning' }"); + } + + [Fact] + public void Text_should_throw_when_text_is_null() + { + var exception = Record.Exception(() => RerankQuery.Text(null)); + + exception.Should().BeOfType() + .Which.ParamName.Should().Be("text"); + } + + [Fact] + public void Text_should_throw_when_text_is_empty() + { + var exception = Record.Exception(() => RerankQuery.Text("")); + + exception.Should().BeOfType() + .Which.ParamName.Should().Be("text"); + } +} diff --git a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs index cbd0d70c09c..0f3723471c6 100644 --- a/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs +++ b/tests/MongoDB.Driver.Tests/Search/AtlasSearchTests.cs @@ -1097,6 +1097,23 @@ public void ReturnScope_HasAncestor(bool useExpression) } } + [Fact] + public void Rerank() + { + const int numDocsToRerank = 10; + + var result = GetMoviesCollection() + .Aggregate() + .Search(Builders.Search.Text("plot", "apes")) + .Rerank(RerankQuery.Text("a movie about intelligent apes who take over civilization"), "plot", numDocsToRerank, "rerank-2.5-lite") + .Project(Builders.Projection + .MetaScore(m => m.Score)) + .ToList(); + + result.Count.Should().BeGreaterThan(0).And.BeLessOrEqualTo(numDocsToRerank); + result.Should().BeInDescendingOrder(m => m.Score); + } + [Fact] public void SearchSequenceToken() {