-
Notifications
You must be signed in to change notification settings - Fork 863
Expand file tree
/
Copy pathVectorStoreWriter.cs
More file actions
162 lines (139 loc) · 6.37 KB
/
VectorStoreWriter.cs
File metadata and controls
162 lines (139 loc) · 6.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.VectorData;
using Microsoft.Shared.Diagnostics;
namespace Microsoft.Extensions.DataIngestion;
/// <summary>
/// Writes chunks to a <see cref="VectorStoreCollection{TKey, TRecord}"/>.
/// </summary>
/// <typeparam name="TChunk">The type of the chunk content.</typeparam>
/// <typeparam name="TRecord">The type of the record stored in the vector store.</typeparam>
public class VectorStoreWriter<TChunk, TRecord> : IngestionChunkWriter<TChunk>
where TRecord : IngestionChunkVectorRecord<TChunk>, new()
{
private readonly VectorStoreWriterOptions _options;
private bool _collectionEnsured;
/// <summary>
/// Initializes a new instance of the <see cref="VectorStoreWriter{TChunk, TRecord}"/> class.
/// </summary>
/// <param name="collection">The <see cref="VectorStoreCollection{TKey, TRecord}"/> to use to store the <see cref="IngestionChunk{T}"/> instances.</param>
/// <param name="options">The options for the vector store writer.</param>
/// <exception cref="ArgumentNullException">When <paramref name="collection"/> is null.</exception>
/// <remarks>
/// You can use the <see cref="VectorStoreExtensions.GetIngestionRecordCollection{TRecord, TChunk}(VectorStore, string, int, string?, string?)"/>
/// helper to create a <see cref="VectorStoreCollection{TKey, TRecord}"/> with the appropriate schema for storing ingestion chunks.
/// </remarks>
public VectorStoreWriter(VectorStoreCollection<Guid, TRecord> collection, VectorStoreWriterOptions? options = default)
{
VectorStoreCollection = Throw.IfNull(collection);
_options = options ?? new VectorStoreWriterOptions();
}
/// <summary>
/// Gets the underlying <see cref="VectorStoreCollection{TKey,TRecord}"/> used to store the chunks.
/// </summary>
public VectorStoreCollection<Guid, TRecord> VectorStoreCollection { get; }
/// <inheritdoc/>
public override async Task WriteAsync(IngestionDocument document, IAsyncEnumerable<IngestionChunk<TChunk>> chunks, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(document);
_ = Throw.IfNull(chunks);
IReadOnlyList<Guid>? preExistingKeys = null;
List<TRecord>? batch = null;
long currentBatchTokenCount = 0;
await foreach (IngestionChunk<TChunk> chunk in chunks.WithCancellation(cancellationToken))
{
if (!_collectionEnsured)
{
await VectorStoreCollection.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false);
_collectionEnsured = true;
}
// We obtain the IDs of the pre-existing chunks for given document,
// and delete them after we finish inserting the new chunks,
// to avoid a situation where we delete the chunks and then fail to insert the new ones.
preExistingKeys ??= await GetPreExistingChunksIdsAsync(document, cancellationToken).ConfigureAwait(false);
TRecord record = new()
{
Content = chunk.Content,
Context = chunk.Context,
DocumentId = document.Identifier,
};
if (chunk.HasMetadata)
{
foreach (var metadata in chunk.Metadata)
{
SetMetadata(record, metadata.Key, metadata.Value);
}
}
batch ??= [];
// Check if adding this chunk would exceed the batch token limit
// If the batch is empty or the chunk alone exceeds the limit, add it anyway.
if (batch.Count > 0 && currentBatchTokenCount + chunk.TokenCount > _options.BatchTokenCount)
{
await VectorStoreCollection.UpsertAsync(batch, cancellationToken).ConfigureAwait(false);
batch.Clear();
currentBatchTokenCount = 0;
}
batch.Add(record);
currentBatchTokenCount += chunk.TokenCount;
}
// Upsert any remaining chunks in the batch
if (batch?.Count > 0)
{
await VectorStoreCollection.UpsertAsync(batch, cancellationToken).ConfigureAwait(false);
}
if (preExistingKeys?.Count > 0)
{
await VectorStoreCollection.DeleteAsync(preExistingKeys, cancellationToken).ConfigureAwait(false);
}
}
/// <summary>
/// Sets a metadata value on the record.
/// </summary>
/// <param name="record">The record on which to set the metadata.</param>
/// <param name="key">The metadata key.</param>
/// <param name="value">The metadata value.</param>
/// <remarks>
/// Override this method in derived classes to store metadata as typed properties with
/// <see cref="VectorStoreDataAttribute"/> attributes.
/// </remarks>
protected virtual void SetMetadata(TRecord record, string key, object? value)
{
throw new NotSupportedException($"Metadata key '{key}' is not supported. Override {nameof(SetMetadata)} in a derived class to handle metadata.");
}
private async Task<IReadOnlyList<Guid>> GetPreExistingChunksIdsAsync(IngestionDocument document, CancellationToken cancellationToken)
{
if (!_options.IncrementalIngestion)
{
return [];
}
// Each Vector Store has a different max top count limit, so we use low value and loop.
// Use smaller batch size in debug to be able to test the looping logic without needing to insert a lot of records.
const int MaxTopCount =
#if RELEASE
1_000;
#else
10;
#endif
List<Guid> keys = [];
int insertedCount;
do
{
insertedCount = 0;
await foreach (var record in VectorStoreCollection.GetAsync(
filter: record => record.DocumentId == document.Identifier,
top: MaxTopCount,
options: new() { Skip = keys.Count },
cancellationToken: cancellationToken).ConfigureAwait(false))
{
keys.Add(record.Key);
insertedCount++;
}
}
while (insertedCount == MaxTopCount);
return keys;
}
}