Skip to content

Commit b0d621f

Browse files
authored
.Net: Implement PostgreSQL hybrid search (#13502)
**Note: this PR is based on top of #13501, skip first commit** And add provider-specific annotation support Closes #11084 Closes #10359
1 parent 7d7a649 commit b0d621f

12 files changed

Lines changed: 474 additions & 84 deletions

File tree

dotnet/src/VectorData/PgVector/PostgresCollection.cs

Lines changed: 96 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace Microsoft.SemanticKernel.Connectors.PgVector;
2424
/// <typeparam name="TKey">The type of the key.</typeparam>
2525
/// <typeparam name="TRecord">The type of the record.</typeparam>
2626
#pragma warning disable CA1711 // Identifiers should not have incorrect suffix
27-
public class PostgresCollection<TKey, TRecord> : VectorStoreCollection<TKey, TRecord>
27+
public class PostgresCollection<TKey, TRecord> : VectorStoreCollection<TKey, TRecord>, IKeywordHybridSearchable<TRecord>
2828
#pragma warning restore CA1711 // Identifiers should not have incorrect suffix
2929
where TKey : notnull
3030
where TRecord : class
@@ -52,6 +52,9 @@ public class PostgresCollection<TKey, TRecord> : VectorStoreCollection<TKey, TRe
5252
/// <summary>The default options for vector search.</summary>
5353
private static readonly VectorSearchOptions<TRecord> s_defaultVectorSearchOptions = new();
5454

55+
/// <summary>The default options for hybrid search.</summary>
56+
private static readonly HybridSearchOptions<TRecord> s_defaultHybridSearchOptions = new();
57+
5558
/// <summary>
5659
/// Initializes a new instance of the <see cref="PostgresCollection{TKey, TRecord}"/> class.
5760
/// </summary>
@@ -396,43 +399,7 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
396399
}
397400

398401
var vectorProperty = this._model.GetVectorPropertyOrSingle(options);
399-
400-
object vector = searchValue switch
401-
{
402-
// Dense float32
403-
ReadOnlyMemory<float> r => r,
404-
float[] f => new ReadOnlyMemory<float>(f),
405-
Embedding<float> e => e.Vector,
406-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
407-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
408-
409-
#if NET
410-
// Dense float16
411-
ReadOnlyMemory<Half> r => r,
412-
Half[] f => new ReadOnlyMemory<Half>(f),
413-
Embedding<Half> e => e.Vector,
414-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<Half>> generator
415-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
416-
#endif
417-
418-
// Dense Binary
419-
BitArray b => b,
420-
BinaryEmbedding e => e.Vector,
421-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, BinaryEmbedding> generator
422-
=> await generator.GenerateAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
423-
424-
// Sparse
425-
SparseVector sv => sv,
426-
// TODO: Add a PG-specific SparseVectorEmbedding type
427-
428-
_ => vectorProperty.EmbeddingGenerator is null
429-
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), PostgresModelBuilder.SupportedVectorTypes))
430-
: throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType()))
431-
};
432-
433-
var pgVector = PostgresPropertyMapping.MapVectorForStorageModel(vector);
434-
435-
Verify.NotNull(pgVector);
402+
var pgVector = await this.ConvertSearchInputToVectorAsync(searchValue, vectorProperty, cancellationToken).ConfigureAwait(false);
436403

437404
// Simulating skip/offset logic locally, since OFFSET can work only with LIMIT in combination
438405
// and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used.
@@ -460,6 +427,51 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, BinaryEm
460427
}
461428
}
462429

430+
/// <inheritdoc />
431+
public async IAsyncEnumerable<VectorSearchResult<TRecord>> HybridSearchAsync<TInput>(
432+
TInput searchValue,
433+
ICollection<string> keywords,
434+
int top,
435+
HybridSearchOptions<TRecord>? options = null,
436+
[EnumeratorCancellation] CancellationToken cancellationToken = default)
437+
where TInput : notnull
438+
{
439+
Verify.NotNull(searchValue);
440+
Verify.NotNull(keywords);
441+
Verify.NotLessThan(top, 1);
442+
443+
options ??= s_defaultHybridSearchOptions;
444+
if (options.IncludeVectors && this._model.EmbeddingGenerationRequired)
445+
{
446+
throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration);
447+
}
448+
449+
var vectorProperty = this._model.GetVectorPropertyOrSingle<TRecord>(new() { VectorProperty = options.VectorProperty });
450+
var textProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty);
451+
var pgVector = await this.ConvertSearchInputToVectorAsync(searchValue, vectorProperty, cancellationToken).ConfigureAwait(false);
452+
453+
using var connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
454+
using var command = connection.CreateCommand();
455+
PostgresSqlBuilder.BuildHybridSearchCommand(command, this._schema, this.Name, this._model, vectorProperty, textProperty, pgVector, keywords,
456+
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
457+
options.OldFilter,
458+
#pragma warning restore CS0618 // VectorSearchFilter is obsolete
459+
options.Filter, options.Skip, options.IncludeVectors, top, options.ScoreThreshold);
460+
461+
using var reader = await connection.ExecuteWithErrorHandlingAsync(
462+
this._collectionMetadata,
463+
"HybridSearch",
464+
() => command.ExecuteReaderAsync(cancellationToken),
465+
cancellationToken).ConfigureAwait(false);
466+
467+
while (await reader.ReadWithErrorHandlingAsync(this._collectionMetadata, "HybridSearch", cancellationToken).ConfigureAwait(false))
468+
{
469+
yield return new VectorSearchResult<TRecord>(
470+
this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors),
471+
reader.GetDouble(reader.GetOrdinal(PostgresConstants.DistanceColumnName)));
472+
}
473+
}
474+
463475
#endregion Search
464476

465477
/// <inheritdoc />
@@ -513,11 +525,11 @@ private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationT
513525
batch.BatchCommands.Add(
514526
new NpgsqlBatchCommand(PostgresSqlBuilder.BuildCreateTableSql(this._schema, this.Name, this._model, pgVersion, ifNotExists)));
515527

516-
foreach (var (column, kind, function, isVector) in PostgresPropertyMapping.GetIndexInfo(this._model.Properties))
528+
foreach (var (column, kind, function, isVector, isFullText, fullTextLanguage) in PostgresPropertyMapping.GetIndexInfo(this._model.Properties))
517529
{
518530
batch.BatchCommands.Add(
519531
new NpgsqlBatchCommand(
520-
PostgresSqlBuilder.BuildCreateIndexSql(this._schema, this.Name, column, kind, function, isVector, ifNotExists)));
532+
PostgresSqlBuilder.BuildCreateIndexSql(this._schema, this.Name, column, kind, function, isVector, isFullText, fullTextLanguage, ifNotExists)));
521533
}
522534

523535
await batch.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
@@ -535,4 +547,48 @@ private Task<T> RunOperationAsync<T>(string operationName, Func<Task<T>> operati
535547
this._collectionMetadata,
536548
operationName,
537549
operation);
550+
551+
/// <summary>
552+
/// Converts a search input value to a PostgreSQL vector representation, generating embeddings if necessary.
553+
/// </summary>
554+
private async Task<object> ConvertSearchInputToVectorAsync<TInput>(TInput searchValue, VectorPropertyModel vectorProperty, CancellationToken cancellationToken)
555+
where TInput : notnull
556+
{
557+
object vector = searchValue switch
558+
{
559+
// Dense float32
560+
ReadOnlyMemory<float> r => r,
561+
float[] f => new ReadOnlyMemory<float>(f),
562+
Embedding<float> e => e.Vector,
563+
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
564+
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
565+
566+
#if NET
567+
// Dense float16
568+
ReadOnlyMemory<Half> r => r,
569+
Half[] f => new ReadOnlyMemory<Half>(f),
570+
Embedding<Half> e => e.Vector,
571+
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<Half>> generator
572+
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
573+
#endif
574+
575+
// Dense Binary
576+
BitArray b => b,
577+
BinaryEmbedding e => e.Vector,
578+
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, BinaryEmbedding> generator
579+
=> await generator.GenerateAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
580+
581+
// Sparse
582+
SparseVector sv => sv,
583+
// TODO: Add a PG-specific SparseVectorEmbedding type
584+
585+
_ => vectorProperty.EmbeddingGenerator is null
586+
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), PostgresModelBuilder.SupportedVectorTypes))
587+
: throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType()))
588+
};
589+
590+
var pgVector = PostgresPropertyMapping.MapVectorForStorageModel(vector);
591+
Verify.NotNull(pgVector);
592+
return pgVector;
593+
}
538594
}

dotnet/src/VectorData/PgVector/PostgresConstants.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ internal static class PostgresConstants
2424
/// <summary>The default distance function.</summary>
2525
public const string DefaultDistanceFunction = DistanceFunction.CosineDistance;
2626

27+
/// <summary>The default full-text search language for PostgreSQL.</summary>
28+
public const string DefaultFullTextSearchLanguage = "english";
29+
2730
public static readonly Dictionary<string, int> IndexMaxDimensions = new()
2831
{
2932
{ IndexKind.Hnsw, 2000 },
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using Microsoft.Extensions.VectorData;
4+
using Microsoft.Extensions.VectorData.ProviderServices;
5+
6+
namespace Microsoft.SemanticKernel.Connectors.PgVector;
7+
8+
/// <summary>
9+
/// Extension methods for configuring PostgreSQL-specific properties on vector store property definitions.
10+
/// </summary>
11+
public static class PostgresPropertyExtensions
12+
{
13+
private const string FullTextSearchLanguageKey = "Postgres:FullTextSearchLanguage";
14+
15+
/// <summary>
16+
/// Sets the PostgreSQL full-text search language for a data property.
17+
/// </summary>
18+
/// <param name="property">The data property to configure.</param>
19+
/// <param name="language">The PostgreSQL text search language name (e.g., "english", "spanish", "german").</param>
20+
/// <returns>The same property instance for method chaining.</returns>
21+
/// <remarks>
22+
/// This language is used with PostgreSQL's <c>to_tsvector</c> and <c>plainto_tsquery</c> functions
23+
/// when creating GIN indexes and performing full-text search operations.
24+
/// Common language options include: "simple", "english", "spanish", "german", "french", etc.
25+
/// See PostgreSQL documentation for the full list of available text search configurations.
26+
/// </remarks>
27+
public static VectorStoreDataProperty WithFullTextSearchLanguage(this VectorStoreDataProperty property, string? language)
28+
{
29+
property.ProviderAnnotations ??= [];
30+
property.ProviderAnnotations[FullTextSearchLanguageKey] = language;
31+
return property;
32+
}
33+
34+
/// <summary>
35+
/// Gets the PostgreSQL full-text search language configured for a data property.
36+
/// </summary>
37+
/// <param name="property">The data property to read from.</param>
38+
/// <returns>The configured language, or <see langword="null"/> if not set.</returns>
39+
public static string? GetFullTextSearchLanguage(this VectorStoreDataProperty property)
40+
=> property.ProviderAnnotations?.TryGetValue(FullTextSearchLanguageKey, out var value) == true
41+
? value as string
42+
: null;
43+
44+
/// <summary>
45+
/// Gets the PostgreSQL full-text search language configured for a data property model.
46+
/// </summary>
47+
/// <param name="property">The data property model to read from.</param>
48+
/// <returns>The configured language, or the default language ("english") if not set.</returns>
49+
internal static string GetFullTextSearchLanguageOrDefault(this DataPropertyModel property)
50+
=> property.ProviderAnnotations?.TryGetValue(FullTextSearchLanguageKey, out var value) == true && value is string language
51+
? language
52+
: PostgresConstants.DefaultFullTextSearchLanguage;
53+
}

dotnet/src/VectorData/PgVector/PostgresPropertyMapping.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value)
152152
/// Returns information about indexes to create, validating that the dimensions of the vector are supported.
153153
/// </summary>
154154
/// <param name="properties">The properties of the vector store record.</param>
155-
/// <returns>A list of tuples containing the column name, index kind, and distance function for each property.</returns>
155+
/// <returns>A list of tuples containing the column name, index kind, distance function, and full-text language for each property.</returns>
156156
/// <remarks>
157157
/// The default index kind is "Flat", which prevents the creation of an index.
158158
/// </remarks>
159-
public static List<(string column, string kind, string function, bool isVector)> GetIndexInfo(IReadOnlyList<PropertyModel> properties)
159+
public static List<(string column, string kind, string function, bool isVector, bool isFullText, string? fullTextLanguage)> GetIndexInfo(IReadOnlyList<PropertyModel> properties)
160160
{
161-
var vectorIndexesToCreate = new List<(string column, string kind, string function, bool isVector)>();
161+
var vectorIndexesToCreate = new List<(string column, string kind, string function, bool isVector, bool isFullText, string? fullTextLanguage)>();
162162
foreach (var property in properties)
163163
{
164164
switch (property)
@@ -185,15 +185,21 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value)
185185
);
186186
}
187187

188-
vectorIndexesToCreate.Add((vectorProperty.StorageName, indexKind, distanceFunction, isVector: true));
188+
vectorIndexesToCreate.Add((vectorProperty.StorageName, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null));
189189
}
190190

191191
break;
192192

193193
case DataPropertyModel dataProperty:
194194
if (dataProperty.IsIndexed)
195195
{
196-
vectorIndexesToCreate.Add((dataProperty.StorageName, "", "", isVector: false));
196+
vectorIndexesToCreate.Add((dataProperty.StorageName, kind: "", function: "", isVector: false, isFullText: false, fullTextLanguage: null));
197+
}
198+
199+
if (dataProperty.IsFullTextIndexed)
200+
{
201+
var language = dataProperty.GetFullTextSearchLanguageOrDefault();
202+
vectorIndexesToCreate.Add((dataProperty.StorageName, kind: "", function: "", isVector: false, isFullText: true, fullTextLanguage: language));
197203
}
198204
break;
199205

dotnet/src/VectorData/PgVector/PostgresServiceCollectionExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ private static void AddAbstractions<TKey, TRecord>(IServiceCollection services,
286286
services.Add(new ServiceDescriptor(typeof(IVectorSearchable<TRecord>), serviceKey,
287287
static (sp, key) => sp.GetRequiredKeyedService<PostgresCollection<TKey, TRecord>>(key), lifetime));
288288

289-
// Once HybridSearch supports get implemented by PostgresCollection,
290-
// we need to add IKeywordHybridSearchable abstraction here as well.
289+
services.Add(new ServiceDescriptor(typeof(IKeywordHybridSearchable<TRecord>), serviceKey,
290+
static (sp, key) => sp.GetRequiredKeyedService<PostgresCollection<TKey, TRecord>>(key), lifetime));
291291
}
292292

293293
private static PostgresVectorStoreOptions? GetStoreOptions(IServiceProvider sp, Func<IServiceProvider, PostgresVectorStoreOptions?>? optionsProvider)

0 commit comments

Comments
 (0)