Skip to content

Commit 7d7a649

Browse files
authored
.Net: [MEVD] Implement support for score threshold (#13501)
For providers which don't natively support score threshold, I've gone with applying the filter client-side - the other option would be to throw. * In EF we're quite strict on not bringing back database data in order to client-evaluate; in other words, if the user wants to load database rows from the database and then perform some post-filtering, they need to write that explicitly in code, to avoid scenarios where tons of data gets unintentionally transferred. * In MEVD we already implement Skip in various providers by bringing back the entire matching resultset and then filtering client-side. This is potentially much worse than filtering for score threshold client-side, as the amount of extra data brought back by threshold post-filtering is limited to top (k), whereas with Skip/pagination queries that amount of data grows with each page. * Ultimately, MEVD is simpler and more high-level, and is also explicitly designed for layers to be composed on top (e.g. MEDI). So there does seem to be more value in features "just working" across providers. Closes #9566
1 parent 838b951 commit 7d7a649

26 files changed

Lines changed: 358 additions & 14 deletions

dotnet/src/VectorData/AzureAISearch/AzureAISearchCollection.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ public override IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord, bool
343343
VectorSearch = new(),
344344
Size = top,
345345
Skip = options.Skip,
346-
Filter = new AzureAISearchFilterTranslator().Translate(filter, this._model),
346+
Filter = new AzureAISearchFilterTranslator().Translate(filter, this._model)
347347
};
348348

349349
// Filter out vector fields if requested.
@@ -405,6 +405,15 @@ floatVector is null
405405

406406
await foreach (var record in this.SearchAndMapToDataModelAsync(null, searchOptions, options.IncludeVectors, cancellationToken).ConfigureAwait(false))
407407
{
408+
// Azure AI Search threshold filtering is in preview:
409+
// https://learn.microsoft.com/azure/search/vector-search-how-to-query#set-thresholds-to-exclude-low-scoring-results-preview
410+
// See https://github.com/microsoft/semantic-kernel/issues/13500.
411+
// For now, perform post-filtering on the client-side.
412+
if (options.ScoreThreshold.HasValue && record.Score < options.ScoreThreshold.Value)
413+
{
414+
continue;
415+
}
416+
408417
yield return record;
409418
}
410419
}
@@ -450,6 +459,12 @@ floatVector is null
450459

451460
await foreach (var record in this.SearchAndMapToDataModelAsync(keywordsCombined, searchOptions, options.IncludeVectors, cancellationToken).ConfigureAwait(false))
452461
{
462+
// Azure AI Search returns scores where higher values indicate more relevant results.
463+
if (options.ScoreThreshold.HasValue && record.Score < options.ScoreThreshold.Value)
464+
{
465+
continue;
466+
}
467+
453468
yield return record;
454469
}
455470
}

dotnet/src/VectorData/CosmosMongoDB/CosmosMongoCollection.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,13 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin
419419
ScorePropertyName,
420420
DocumentPropertyName);
421421

422-
BsonDocument[] pipeline = [searchQuery, projectionQuery];
422+
List<BsonDocument> pipeline = [searchQuery, projectionQuery];
423+
424+
// Add score threshold filter as a $match stage if specified
425+
if (options.ScoreThreshold.HasValue)
426+
{
427+
pipeline.Add(CosmosMongoCollectionSearchMapping.GetScoreThresholdMatchQuery(ScorePropertyName, options.ScoreThreshold.Value));
428+
}
423429

424430
const string OperationName = "Aggregate";
425431
var cursor = await this.RunOperationAsync(

dotnet/src/VectorData/CosmosMongoDB/CosmosMongoCollectionSearchMapping.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,20 @@ public static BsonDocument GetProjectionQuery(string scorePropertyName, string d
166166
}
167167
};
168168
}
169+
170+
/// <summary>Returns a $match stage to filter results by score threshold.</summary>
171+
/// <remarks>
172+
/// Cosmos MongoDB returns a similarity score where higher values mean more similar,
173+
/// so we filter with $gte to keep results at or above the threshold.
174+
/// </remarks>
175+
public static BsonDocument GetScoreThresholdMatchQuery(string scorePropertyName, double scoreThreshold)
176+
=> new()
177+
{
178+
{
179+
"$match", new BsonDocument
180+
{
181+
{ scorePropertyName, new BsonDocument { { "$gte", scoreThreshold } } }
182+
}
183+
}
184+
};
169185
}

dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollection.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,10 +525,12 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
525525
null,
526526
this._model,
527527
vectorProperty.StorageName,
528+
vectorProperty.DistanceFunction,
528529
null,
529530
ScorePropertyName,
530531
options.OldFilter,
531532
options.Filter,
533+
options.ScoreThreshold,
532534
top,
533535
options.Skip,
534536
options.IncludeVectors);
@@ -630,10 +632,12 @@ public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TIn
630632
keywords,
631633
this._model,
632634
vectorProperty.StorageName,
635+
vectorProperty.DistanceFunction,
633636
textProperty.StorageName,
634637
ScorePropertyName,
635638
options.OldFilter,
636639
options.Filter,
640+
options.ScoreThreshold,
637641
top,
638642
options.Skip,
639643
options.IncludeVectors);

dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollectionQueryBuilder.cs

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
2828
ICollection<string>? keywords,
2929
CollectionModel model,
3030
string vectorPropertyName,
31+
string? distanceFunction,
3132
string? textPropertyName,
3233
string scorePropertyName,
3334
#pragma warning disable CS0618 // Type or member is obsolete
3435
VectorSearchFilter? oldFilter,
3536
#pragma warning restore CS0618 // Type or member is obsolete
3637
Expression<Func<TRecord, bool>>? filter,
38+
double? scoreThreshold,
3739
int top,
3840
int skip,
3941
bool includeVectors)
@@ -68,7 +70,7 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
6870

6971
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
7072
// Build filter object.
71-
var (whereClause, filterParameters) = (OldFilter: oldFilter, Filter: filter) switch
73+
var (filterClause, filterParameters) = (OldFilter: oldFilter, Filter: filter) switch
7274
{
7375
{ OldFilter: not null, Filter: not null } => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"),
7476
{ OldFilter: VectorSearchFilter legacyFilter } => BuildSearchFilter(legacyFilter, model),
@@ -82,6 +84,24 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
8284
[VectorVariableName] = vector
8385
};
8486

87+
// Add score threshold filter if specified.
88+
// For similarity functions (CosineSimilarity, DotProductSimilarity), higher scores are better, so filter with >=.
89+
// For distance functions (EuclideanDistance), lower scores are better, so filter with <=.
90+
const string ScoreThresholdVariableName = "@scoreThreshold";
91+
string? scoreThresholdClause = null;
92+
if (scoreThreshold.HasValue)
93+
{
94+
var comparisonOperator = distanceFunction switch
95+
{
96+
Microsoft.Extensions.VectorData.DistanceFunction.CosineSimilarity => ">=",
97+
Microsoft.Extensions.VectorData.DistanceFunction.DotProductSimilarity => ">=",
98+
Microsoft.Extensions.VectorData.DistanceFunction.EuclideanDistance => "<=",
99+
_ => throw new NotSupportedException($"Score threshold is not supported for distance function '{distanceFunction}'.")
100+
};
101+
scoreThresholdClause = $"{vectorDistanceArgument} {comparisonOperator} {ScoreThresholdVariableName}";
102+
queryParameters[ScoreThresholdVariableName] = scoreThreshold.Value;
103+
}
104+
85105
// If Offset is not configured, use Top parameter instead of Limit/Offset
86106
// since it's more optimized. Hybrid search doesn't allow top to be passed as a parameter
87107
// so directly add it to the query here.
@@ -92,9 +112,25 @@ public static QueryDefinition BuildSearchQuery<TRecord>(
92112
builder.AppendLine($"SELECT {topArgument}{selectClauseArguments}");
93113
builder.AppendLine($"FROM {tableVariableName}");
94114

95-
if (whereClause is not null)
115+
if (filterClause is not null || scoreThresholdClause is not null)
96116
{
97-
builder.Append("WHERE ").AppendLine(whereClause);
117+
builder.Append("WHERE ");
118+
119+
if (filterClause is not null)
120+
{
121+
builder.Append(filterClause);
122+
if (scoreThresholdClause is not null)
123+
{
124+
builder.Append(AndConditionDelimiter);
125+
}
126+
}
127+
128+
if (scoreThresholdClause is not null)
129+
{
130+
builder.Append(scoreThresholdClause);
131+
}
132+
133+
builder.AppendLine();
98134
}
99135

100136
builder.AppendLine($"ORDER BY {rankingArgument}");

dotnet/src/VectorData/InMemory/InMemoryCollection.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,14 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin
339339
// Get the non-null results since any record with a null vector results in a null result.
340340
var nonNullResults = results.Where(x => x.HasValue).Select(x => x!.Value);
341341

342+
// Filter by score threshold if specified.
343+
if (options.ScoreThreshold is double scoreThreshold)
344+
{
345+
nonNullResults = InMemoryCollectionSearchMapping.ShouldSortDescending(vectorProperty.DistanceFunction)
346+
? nonNullResults.Where(x => x.score >= scoreThreshold)
347+
: nonNullResults.Where(x => x.score <= scoreThreshold);
348+
}
349+
342350
// Sort the results appropriately for the selected distance function and get the right page of results .
343351
var sortedScoredResults = InMemoryCollectionSearchMapping.ShouldSortDescending(vectorProperty.DistanceFunction) ?
344352
nonNullResults.OrderByDescending(x => x.score) :

dotnet/src/VectorData/MongoDB/MongoCollection.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,13 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
415415
ScorePropertyName,
416416
DocumentPropertyName);
417417

418-
BsonDocument[] pipeline = [searchQuery, projectionQuery];
418+
List<BsonDocument> pipeline = [searchQuery, projectionQuery];
419+
420+
// Add score threshold filter as a $match stage if specified
421+
if (options.ScoreThreshold.HasValue)
422+
{
423+
pipeline.Add(MongoCollectionSearchMapping.GetScoreThresholdMatchQuery(ScorePropertyName, options.ScoreThreshold.Value));
424+
}
419425

420426
const string OperationName = "Aggregate";
421427
using var cursor = await this.RunOperationWithRetryAsync(
@@ -536,7 +542,7 @@ public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TIn
536542

537543
var numCandidates = this._numCandidates ?? itemsAmount * MongoConstants.DefaultNumCandidatesRatio;
538544

539-
BsonDocument[] pipeline = MongoCollectionSearchMapping.GetHybridSearchPipeline(
545+
List<BsonDocument> pipeline = [.. MongoCollectionSearchMapping.GetHybridSearchPipeline(
540546
vectorArray,
541547
keywords,
542548
this.Name,
@@ -548,7 +554,13 @@ public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TIn
548554
DocumentPropertyName,
549555
itemsAmount,
550556
numCandidates,
551-
filter);
557+
filter)];
558+
559+
// Add score threshold filter as a $match stage if specified
560+
if (options.ScoreThreshold.HasValue)
561+
{
562+
pipeline.Add(MongoCollectionSearchMapping.GetScoreThresholdMatchQuery(ScorePropertyName, options.ScoreThreshold.Value));
563+
}
552564

553565
var results = await this.RunOperationWithRetryAsync(
554566
"KeywordVectorizedHybridSearch",

dotnet/src/VectorData/MongoDB/MongoCollectionSearchMapping.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public static BsonDocument GetSearchQuery<TVector>(
9292
int numCandidates,
9393
BsonDocument? filter)
9494
{
95+
// Docs: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage
9596
var searchQuery = new BsonDocument
9697
{
9798
{ "index", indexName },
@@ -127,6 +128,22 @@ public static BsonDocument GetProjectionQuery(string scorePropertyName, string d
127128
};
128129
}
129130

131+
/// <summary>Returns a $match stage to filter results by score threshold.</summary>
132+
/// <remarks>
133+
/// MongoDB Atlas Vector Search returns a similarity score where higher values mean more similar,
134+
/// so we filter with $gte to keep results at or above the threshold.
135+
/// </remarks>
136+
public static BsonDocument GetScoreThresholdMatchQuery(string scorePropertyName, double scoreThreshold)
137+
=> new()
138+
{
139+
{
140+
"$match", new BsonDocument
141+
{
142+
{ scorePropertyName, new BsonDocument { { "$gte", scoreThreshold } } }
143+
}
144+
}
145+
};
146+
130147
/// <summary>Returns a pipeline for hybrid search using vector search and full text search.</summary>
131148
public static BsonDocument[] GetHybridSearchPipeline<TVector>(
132149
TVector vector,

dotnet/src/VectorData/PgVector/PostgresCollection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, BinaryEm
444444
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
445445
options.OldFilter,
446446
#pragma warning restore CS0618 // VectorSearchFilter is obsolete
447-
options.Filter, options.Skip, options.IncludeVectors, top);
447+
options.Filter, options.Skip, options.IncludeVectors, top, options.ScoreThreshold);
448448

449449
using var reader = await connection.ExecuteWithErrorHandlingAsync(
450450
this._collectionMetadata,

dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ internal static StringBuilder AppendIdentifier(this StringBuilder sb, string ide
426426
/// <inheritdoc />
427427
internal static void BuildGetNearestMatchCommand<TRecord>(
428428
NpgsqlCommand command, string schema, string tableName, CollectionModel model, VectorPropertyModel vectorProperty, object vectorValue,
429-
VectorSearchFilter? legacyFilter, Expression<Func<TRecord, bool>>? newFilter, int? skip, bool includeVectors, int limit)
429+
VectorSearchFilter? legacyFilter, Expression<Func<TRecord, bool>>? newFilter, int? skip, bool includeVectors, int limit,
430+
double? scoreThreshold = null)
430431
{
431432
// Build column list with proper escaping
432433
StringBuilder columns = new();
@@ -501,6 +502,33 @@ internal static void BuildGetNearestMatchCommand<TRecord>(
501502
commandText = outerSql.ToString();
502503
}
503504

505+
// Apply score threshold filter if specified.
506+
// For similarity functions (higher = more similar), filter out results below the threshold.
507+
// For distance functions (lower = more similar), filter out results above the threshold.
508+
if (scoreThreshold.HasValue)
509+
{
510+
var scoreThresholdParamIndex = parameters.Count + 2;
511+
var comparisonOp = distanceFunction switch
512+
{
513+
DistanceFunction.CosineSimilarity or DistanceFunction.DotProductSimilarity
514+
=> ">=",
515+
516+
DistanceFunction.EuclideanDistance
517+
or DistanceFunction.CosineDistance
518+
or DistanceFunction.ManhattanDistance
519+
or DistanceFunction.HammingDistance
520+
=> "<=",
521+
522+
_ => throw new UnreachableException($"Unexpected distance function: {distanceFunction}")
523+
};
524+
525+
StringBuilder outerSql = new();
526+
outerSql.Append("SELECT * FROM (").Append(commandText).Append(") AS scored WHERE ")
527+
.AppendIdentifier(PostgresConstants.DistanceColumnName).Append(' ').Append(comparisonOp)
528+
.Append(" $").Append(scoreThresholdParamIndex);
529+
commandText = outerSql.ToString();
530+
}
531+
504532
command.CommandText = commandText;
505533

506534
Debug.Assert(command.Parameters.Count == 0);
@@ -510,6 +538,11 @@ internal static void BuildGetNearestMatchCommand<TRecord>(
510538
{
511539
command.Parameters.Add(new NpgsqlParameter { Value = parameter });
512540
}
541+
542+
if (scoreThreshold.HasValue)
543+
{
544+
command.Parameters.Add(new NpgsqlParameter { Value = scoreThreshold.Value });
545+
}
513546
}
514547

515548
internal static void BuildSelectWhereCommand<TRecord>(

0 commit comments

Comments
 (0)