Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/MongoDB.Driver/IAggregateFluentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,70 @@ public static IAggregateFluent<TNewResult> Unwind<TResult, TNewResult>(this IAgg
return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken);
}

/// <summary>
/// Appends a $rerank stage.
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <param name="aggregate">The aggregate.</param>
/// <param name="query">The rerank query.</param>
/// <param name="path">The field to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>The fluent aggregate interface.</returns>
public static IAggregateFluent<TResult> Rerank<TResult>(
this IAggregateFluent<TResult> aggregate,
RerankQuery query,
FieldDefinition<TResult> path,
int numDocsToRerank,
string model)
Comment thread
sanych-sun marked this conversation as resolved.
{
Ensure.IsNotNull(aggregate, nameof(aggregate));
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model));
}

/// <summary>
/// Appends a $rerank stage.
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <typeparam name="TField">The type of the field.</typeparam>
/// <param name="aggregate">The aggregate.</param>
/// <param name="query">The rerank query.</param>
/// <param name="path">The field to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>The fluent aggregate interface.</returns>
public static IAggregateFluent<TResult> Rerank<TResult, TField>(
this IAggregateFluent<TResult> aggregate,
RerankQuery query,
Expression<Func<TResult, TField>> path,
int numDocsToRerank,
string model)
{
Ensure.IsNotNull(aggregate, nameof(aggregate));
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model));
}

/// <summary>
/// Appends a $rerank stage.
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <param name="aggregate">The aggregate.</param>
/// <param name="query">The rerank query.</param>
/// <param name="paths">The fields to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>The fluent aggregate interface.</returns>
public static IAggregateFluent<TResult> Rerank<TResult>(
this IAggregateFluent<TResult> aggregate,
RerankQuery query,
IEnumerable<FieldDefinition<TResult>> paths,
int numDocsToRerank,
string model)
{
Ensure.IsNotNull(aggregate, nameof(aggregate));
return aggregate.AppendStage(PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model));
}

/// <summary>
/// Appends a $vectorSearch stage.
/// </summary>
Expand Down
73 changes: 73 additions & 0 deletions src/MongoDB.Driver/PipelineDefinitionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,79 @@ public static PipelineDefinition<TInput, TOutput> RankFusion<TInput, TIntermedia
return pipeline.AppendStage(PipelineStageDefinitionBuilder.RankFusion(pipelinesWithWeights, options));
}

/// <summary>
/// Appends a $rerank stage to the pipeline.
/// </summary>
/// <typeparam name="TInput">The type of the input documents.</typeparam>
/// <typeparam name="TField">The type of the field.</typeparam>
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
/// <param name="pipeline">The pipeline.</param>
/// <param name="query">The rerank query.</param>
/// <param name="path">The field to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>A new pipeline with an additional stage.</returns>
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TField, TOutput>(
this PipelineDefinition<TInput, TOutput> pipeline,
RerankQuery query,
Expression<Func<TOutput, TField>> path,
Comment thread
sanych-sun marked this conversation as resolved.
int numDocsToRerank,
string model)
{
Ensure.IsNotNull(pipeline, nameof(pipeline));
return pipeline.AppendStage(
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
pipeline.OutputSerializer);
}

/// <summary>
/// Appends a $rerank stage to the pipeline.
/// </summary>
/// <typeparam name="TInput">The type of the input documents.</typeparam>
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
/// <param name="pipeline">The pipeline.</param>
/// <param name="query">The rerank query.</param>
/// <param name="path">The field to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>A new pipeline with an additional stage.</returns>
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
this PipelineDefinition<TInput, TOutput> pipeline,
RerankQuery query,
FieldDefinition<TOutput> path,
int numDocsToRerank,
string model)
{
Ensure.IsNotNull(pipeline, nameof(pipeline));
return pipeline.AppendStage(
PipelineStageDefinitionBuilder.Rerank(query, path, numDocsToRerank, model),
pipeline.OutputSerializer);
}

/// <summary>
/// Appends a $rerank stage to the pipeline.
/// </summary>
/// <typeparam name="TInput">The type of the input documents.</typeparam>
/// <typeparam name="TOutput">The type of the output documents.</typeparam>
/// <param name="pipeline">The pipeline.</param>
/// <param name="query">The rerank query.</param>
/// <param name="paths">The fields to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>A new pipeline with an additional stage.</returns>
public static PipelineDefinition<TInput, TOutput> Rerank<TInput, TOutput>(
Comment thread
sanych-sun marked this conversation as resolved.
this PipelineDefinition<TInput, TOutput> pipeline,
RerankQuery query,
IEnumerable<FieldDefinition<TOutput>> paths,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's either use IEnumerable parameter or params in both cases: here and on the line 1126.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two multi-path overloads serve different use cases: the params Expression[] one is for inline type-safe expressions (x => x.Title, x => x.Plot), and the IEnumerable<FieldDefinition> one is for pre-built or dynamic collections of string-based paths. Using params for FieldDefinition wouldn't add much since there's no type inference benefit, and using IEnumerable<Expression> would lose the params convenience. I think the inconsistency is justified by the different use cases here.
Thoughts?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this as a different way to provide the same parameters. Having one as an array and another as a params looks confusing a little. About different use cases: both expressions and FieldDefinitions could be used as fixed list of values or dynamic collections I believe. I think I would prefer consistency between overloads if it's possible. @BorisDog @papafe @ajcvickers WDYT?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @adelinowona.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing my mind, as I already commented. I don't agree with, "Having one as an array and another as a params looks confusing a little. " but I do like keeping the parameter order the same between the overloads.

int numDocsToRerank,
string model)
{
Ensure.IsNotNull(pipeline, nameof(pipeline));
return pipeline.AppendStage(
PipelineStageDefinitionBuilder.Rerank(query, paths, numDocsToRerank, model),
pipeline.OutputSerializer);
}

/// <summary>
/// Appends a $replaceRoot stage to the pipeline.
/// </summary>
Expand Down
88 changes: 88 additions & 0 deletions src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,94 @@ public static PipelineStageDefinition<TInput, TOutput> ReplaceWith<TInput, TOutp
return ReplaceWith(new ExpressionAggregateExpressionDefinition<TInput, TOutput>(newRoot));
}

/// <summary>
/// Creates a $rerank stage.
/// </summary>
/// <typeparam name="TInput">The type of the input documents.</typeparam>
/// <typeparam name="TField">The type of the field.</typeparam>
/// <param name="query">The rerank query.</param>
/// <param name="path">The field to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>The stage.</returns>
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput, TField>(
RerankQuery query,
Expression<Func<TInput, TField>> path,
int numDocsToRerank,
string model)
Comment thread
adelinowona marked this conversation as resolved.
Comment thread
BorisDog marked this conversation as resolved.
=> Rerank(
query,
new ExpressionFieldDefinition<TInput>(Ensure.IsNotNull(path, nameof(path))),
numDocsToRerank,
model);

/// <summary>
/// Creates a $rerank stage.
/// </summary>
/// <typeparam name="TInput">The type of the input documents.</typeparam>
/// <param name="query">The rerank query.</param>
/// <param name="path">The field to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>The stage.</returns>
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
RerankQuery query,
FieldDefinition<TInput> path,
int numDocsToRerank,
string model)
=> Rerank(
query,
[Ensure.IsNotNull(path, nameof(path))],
numDocsToRerank,
model);

/// <summary>
/// Creates a $rerank stage.
/// </summary>
/// <typeparam name="TInput">The type of the input documents.</typeparam>
/// <param name="query">The rerank query.</param>
/// <param name="paths">The fields to send to the reranker.</param>
/// <param name="numDocsToRerank">The maximum number of documents to rerank.</param>
/// <param name="model">The reranking model name.</param>
/// <returns>The stage.</returns>
public static PipelineStageDefinition<TInput, TInput> Rerank<TInput>(
RerankQuery query,
IEnumerable<FieldDefinition<TInput>> paths,
int numDocsToRerank,
string model)
{
Ensure.IsNotNull(query, nameof(query));
Ensure.IsNotNullOrEmpty(paths, nameof(paths));
Ensure.IsGreaterThanZero(numDocsToRerank, nameof(numDocsToRerank));
Comment thread
adelinowona marked this conversation as resolved.
Ensure.IsNotNull(model, nameof(model));

const string operatorName = "$rerank";
var stage = new DelegatedPipelineStageDefinition<TInput, TInput>(
operatorName,
args =>
{
ClientSideProjectionHelper.ThrowIfClientSideProjection(args.DocumentSerializer, operatorName);

var renderedPaths = paths.Select(p => p.Render(args).FieldName).ToList();
BsonValue pathValue = renderedPaths.Count == 1
? renderedPaths[0]
: new BsonArray(renderedPaths);

var rerankDocument = new BsonDocument
{
{ "query", query.Render() },
{ "path", pathValue },
{ "numDocsToRerank", numDocsToRerank },
{ "model", model }
};

var document = new BsonDocument(operatorName, rerankDocument);
return new RenderedPipelineStageDefinition<TInput>(operatorName, document, args.DocumentSerializer);
});

return stage;
}

/// <summary>
/// Creates a $set stage.
/// </summary>
Expand Down
43 changes: 43 additions & 0 deletions src/MongoDB.Driver/RerankQuery.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using MongoDB.Bson;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver;

/// <summary>
/// Represents a query for the $rerank aggregation stage.
/// </summary>
/// <seealso cref="PipelineStageDefinitionBuilder.Rerank{TInput}(RerankQuery, FieldDefinition{TInput}, int, string)"/>
public sealed class RerankQuery
{
private readonly string _text;

private RerankQuery(string text)
{
_text = text;
}

/// <summary>
/// Creates a text-based rerank query.
/// </summary>
/// <param name="text">The text to rerank against.</param>
/// <returns>A text rerank query.</returns>
public static RerankQuery Text(string text) =>
new RerankQuery(Ensure.IsNotNullOrEmpty(text, nameof(text)));

internal BsonDocument Render() => new BsonDocument("text", _text);
}
76 changes: 76 additions & 0 deletions tests/MongoDB.Driver.Tests/PipelineDefinitionBuilderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,80 @@ public void RankFusion_should_throw_when_pipeline_is_null()
}).ParamName.Should().Be("pipeline");
}

[Fact]
public void Rerank_with_single_string_path_should_render_expected_stage()
{
var pipeline = new EmptyPipelineDefinition<BsonDocument>();

var result = pipeline.Rerank(RerankQuery.Text("machine learning"), "plot", 25, "rerank-2.5");

var stages = RenderStages(result, BsonDocumentSerializer.Instance);
stages.Count.Should().Be(1);
stages[0].Should().Be("""
{
$rerank: {
"query": { "text": "machine learning" },
"path": "plot",
"numDocsToRerank": 25,
"model": "rerank-2.5"
}
}
""");
}

[Fact]
public void Rerank_with_multiple_string_paths_should_render_expected_stage()
{
var pipeline = new EmptyPipelineDefinition<BsonDocument>();

var result = pipeline.Rerank(RerankQuery.Text("machine learning"), ["plot", "title"], 100, "rerank-2.5-lite");

var stages = RenderStages(result, BsonDocumentSerializer.Instance);
stages.Count.Should().Be(1);
stages[0].Should().Be("""
{
$rerank: {
"query": { "text": "machine learning" },
"path": ["plot", "title"],
"numDocsToRerank": 100,
"model": "rerank-2.5-lite"
}
}
""");
}

[Fact]
public void Rerank_with_expression_path_should_render_expected_stage()
{
var pipeline = new EmptyPipelineDefinition<MovieWithPlot>();

var result = pipeline.Rerank(RerankQuery.Text("tutorials"), x => x.Title, 50, "rerank-2");

var stages = RenderStages(result, BsonSerializer.SerializerRegistry.GetSerializer<MovieWithPlot>());
stages.Count.Should().Be(1);
stages[0].Should().Be("""
{
$rerank: {
"query": { "text": "tutorials" },
"path": "Title",
"numDocsToRerank": 50,
"model": "rerank-2"
}
}
""");
}

[Fact]
public void Rerank_should_throw_when_pipeline_is_null()
{
PipelineDefinition<BsonDocument, BsonDocument> pipeline = null;

var exception = Record.Exception(() => pipeline.Rerank(RerankQuery.Text("query"), "field", 10, "rerank-2.5"));

exception.Should().BeOfType<ArgumentNullException>()
.Which.ParamName.Should().Be("pipeline");
}

[Theory]
[InlineData(0)]
[InlineData(15)]
Expand Down Expand Up @@ -705,6 +779,8 @@ public void VectorSearch_should_add_expected_stage_with_parent_filters([Values(f

private class MovieWithPlot
{
public string Title { get; set; }
public string Synopsis { get; set; }
public int Year { get; set; }
public NestedPlot Plot { get; set; }
public float[] NonNestedEmbedding { get; set; }
Expand Down
Loading