diff --git a/EssentialCSharp.Chat.Shared/Extensions/ServiceCollectionExtensions.cs b/EssentialCSharp.Chat.Shared/Extensions/ServiceCollectionExtensions.cs index 2846fca2..ccc3c5b5 100644 --- a/EssentialCSharp.Chat.Shared/Extensions/ServiceCollectionExtensions.cs +++ b/EssentialCSharp.Chat.Shared/Extensions/ServiceCollectionExtensions.cs @@ -100,8 +100,8 @@ public static IServiceCollection AddAzureOpenAIServices( // Configure AI options from configuration services.Configure(configuration.GetSection("AIOptions")); - // Configure retry options from configuration section - // Environment variables like EmbeddingRetry:MaxRetries will override defaults + // Configure retry options from configuration section. + // Environment variables can override via AIOptions__EmbeddingRetry__*. services.AddOptions() .Bind(configuration.GetSection(EmbeddingRetryOptions.SectionPath)) .ValidateDataAnnotations() diff --git a/EssentialCSharp.Chat.Shared/Models/EmbeddingRetryOptions.cs b/EssentialCSharp.Chat.Shared/Models/EmbeddingRetryOptions.cs index 40064f0f..ce943146 100644 --- a/EssentialCSharp.Chat.Shared/Models/EmbeddingRetryOptions.cs +++ b/EssentialCSharp.Chat.Shared/Models/EmbeddingRetryOptions.cs @@ -34,6 +34,20 @@ public sealed class EmbeddingRetryOptions [Range(1, 600000)] public int MaxDelayMs { get; set; } = 60000; + /// + /// Maximum embedding request payload size sent per API call. + /// The service may adaptively downshift below this value when throttled. + /// + [Range(1, 2048)] + public int MaxEmbeddingBatchSize { get; set; } = 2048; + + /// + /// Minimum delay between embedding API requests in milliseconds. + /// This adds request pacing to reduce sustained rate-limit pressure. + /// + [Range(0, 600000)] + public int MinInterRequestDelayMs { get; set; } = 250; + /// /// Exponential backoff multiplier. Each retry delay is multiplied by this value. /// For example, with baseDelay=1000ms and multiplier=2.0: @@ -74,6 +88,15 @@ public void Validate() if (BaseDelayMs > MaxDelayMs) throw new InvalidOperationException("BaseDelayMs must be less than or equal to MaxDelayMs."); + if (MaxEmbeddingBatchSize <= 0) + throw new InvalidOperationException("MaxEmbeddingBatchSize must be positive."); + + if (MaxEmbeddingBatchSize > 2048) + throw new InvalidOperationException("MaxEmbeddingBatchSize cannot exceed Azure embedding API limit (2048)."); + + if (MinInterRequestDelayMs < 0) + throw new InvalidOperationException("MinInterRequestDelayMs must be non-negative."); + if (BackoffMultiplier < 1.0) throw new InvalidOperationException("BackoffMultiplier must be >= 1.0."); diff --git a/EssentialCSharp.Chat.Shared/Services/EmbeddingService.cs b/EssentialCSharp.Chat.Shared/Services/EmbeddingService.cs index d7b6428c..a192fdb2 100644 --- a/EssentialCSharp.Chat.Shared/Services/EmbeddingService.cs +++ b/EssentialCSharp.Chat.Shared/Services/EmbeddingService.cs @@ -6,6 +6,7 @@ using Microsoft.Extensions.VectorData; using Npgsql; using System.ClientModel; +using System.ClientModel.Primitives; using System.Globalization; namespace EssentialCSharp.Chat.Common.Services; @@ -32,9 +33,12 @@ public partial class EmbeddingService( private readonly EmbeddingRetryOptions _retryOptions = ValidateRetryOptions(retryOptions?.Value ?? new EmbeddingRetryOptions()); private readonly ILogger? _logger = logger; + private static readonly SemaphoreSlim _embeddingRequestLock = new(1, 1); + private DateTimeOffset _lastEmbeddingRequestStartedUtc = DateTimeOffset.MinValue; // Only allow simple identifiers: letters, digits, and underscores, starting with a letter or underscore. private static readonly Regex _safeIdentifierRegex = new(@"^[a-zA-Z_][a-zA-Z0-9_]*$", RegexOptions.Compiled); + private static readonly Regex _retryAfterSecondsRegex = new(@"retry\s+after\s+(?\d+)\s+seconds?", RegexOptions.Compiled | RegexOptions.IgnoreCase); private static EmbeddingRetryOptions ValidateRetryOptions(EmbeddingRetryOptions options) { @@ -96,6 +100,10 @@ private static bool IsTransientStatusCode(int statusCode) => return ex.InnerException is null ? null : TryGetStatusCode(ex.InnerException); } + private static bool IsRateLimitError(Exception ex) => + TryGetStatusCode(ex) == 429 + || ex.Message.Contains("RateLimitReached", StringComparison.OrdinalIgnoreCase); + /// /// Extracts the Retry-After delay from known exception types if present. /// Returns null if the header is not present or invalid. @@ -105,19 +113,61 @@ private static bool IsTransientStatusCode(int statusCode) => if (ex is ClientResultException clientResultException) { var rawResponse = clientResultException.GetRawResponse(); - var headerValue = rawResponse?.Headers.TryGetValue("retry-after", out var value) == true - ? value - : null; - if (TryParseRetryAfterValue(headerValue, out var retryAfter)) + if (TryGetHeaderValue(rawResponse, "x-ms-retry-after-ms", out var msHeaderValue) + && TryParseMilliseconds(msHeaderValue, out var msRetryAfter)) + { + return msRetryAfter; + } + + if (TryGetHeaderValue(rawResponse, "retry-after-ms", out var retryAfterMsHeaderValue) + && TryParseMilliseconds(retryAfterMsHeaderValue, out var retryAfterMs)) + { + return retryAfterMs; + } + + if (TryGetHeaderValue(rawResponse, "retry-after", out var headerValue) + && TryParseRetryAfterValue(headerValue, out var retryAfter)) + { return retryAfter; + } } - if (ex is HttpRequestException) - return null; + if (TryParseRetryAfterValue(ex.Message, out var messageRetryAfter)) + return messageRetryAfter; return ex.InnerException is null ? null : ExtractRetryAfter(ex.InnerException); } + private static bool TryGetHeaderValue(PipelineResponse? response, string headerName, out string? headerValue) + { + headerValue = null; + if (response is null) + return false; + + if (response.Headers.TryGetValue(headerName, out var value) && !string.IsNullOrWhiteSpace(value)) + { + headerValue = value; + return true; + } + + return false; + } + + private static bool TryParseMilliseconds(string? value, out TimeSpan retryAfter) + { + retryAfter = default; + if (string.IsNullOrWhiteSpace(value)) + return false; + + if (int.TryParse(value, NumberStyles.Integer, CultureInfo.InvariantCulture, out var msValue) && msValue >= 0) + { + retryAfter = TimeSpan.FromMilliseconds(msValue); + return true; + } + + return false; + } + private static bool TryParseRetryAfterValue(string? headerValue, out TimeSpan retryAfter) { retryAfter = default; @@ -140,6 +190,15 @@ private static bool TryParseRetryAfterValue(string? headerValue, out TimeSpan re } } + var secondsMatch = _retryAfterSecondsRegex.Match(headerValue); + if (secondsMatch.Success + && int.TryParse(secondsMatch.Groups["seconds"].Value, NumberStyles.Integer, CultureInfo.InvariantCulture, out var extractedSeconds) + && extractedSeconds >= 0) + { + retryAfter = TimeSpan.FromSeconds(extractedSeconds); + return true; + } + return false; } @@ -165,6 +224,31 @@ private TimeSpan ClampRetryDelay(TimeSpan delay) => ? TimeSpan.FromMilliseconds(_retryOptions.MaxDelayMs) : delay; + private async Task ExecuteEmbeddingRequestWithPacingAsync( + Func> embeddingRequest, + CancellationToken cancellationToken) + { + await _embeddingRequestLock.WaitAsync(cancellationToken); + try + { + var minimumSpacing = TimeSpan.FromMilliseconds(_retryOptions.MinInterRequestDelayMs); + if (minimumSpacing > TimeSpan.Zero && _lastEmbeddingRequestStartedUtc != DateTimeOffset.MinValue) + { + var elapsed = DateTimeOffset.UtcNow - _lastEmbeddingRequestStartedUtc; + var remainingDelay = minimumSpacing - elapsed; + if (remainingDelay > TimeSpan.Zero) + await Task.Delay(remainingDelay, cancellationToken); + } + + _lastEmbeddingRequestStartedUtc = DateTimeOffset.UtcNow; + return await embeddingRequest(cancellationToken); + } + finally + { + _embeddingRequestLock.Release(); + } + } + /// /// Wraps an async operation with retry logic for transient failures. /// @@ -237,7 +321,9 @@ private async Task ExecuteWithRetryAsync( public async Task> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) { var embedding = await ExecuteWithRetryAsync( - async ct => await embeddingGenerator.GenerateAsync(text, cancellationToken: ct), + async ct => await ExecuteEmbeddingRequestWithPacingAsync( + async pacingCt => await embeddingGenerator.GenerateAsync(text, cancellationToken: pacingCt), + ct), "GenerateEmbedding", cancellationToken); return embedding.Vector; @@ -287,31 +373,151 @@ public async Task GenerateBookContentEmbeddingsAndUploadToVectorStore( // ── Step 2 & 3: Batch-embed and immediately upsert each batch ───────────────── // Azure OpenAI supports at most EmbeddingBatchSize inputs per GenerateAsync call. - // bookContents is streamed in fixed-size batches without full upfront materialization, - // keeping peak memory bounded to one batch of chunk objects and their embeddings at a time. + // The effective request size starts at min(EmbeddingBatchSize, MaxEmbeddingBatchSize) + // and adaptively downshifts on 429 throttling responses. + // bookContents is streamed in batches without full upfront materialization, + // keeping peak memory bounded to one batch (or adaptive split) at a time. // The staging-swap (Step 3) is safe because it only runs after all batches have // been successfully upserted. - var buffer = new List(EmbeddingBatchSize); + var configuredMaxBatchSize = Math.Clamp(_retryOptions.MaxEmbeddingBatchSize, 1, EmbeddingBatchSize); + var adaptiveBatchSize = configuredMaxBatchSize; + var buffer = new List(configuredMaxBatchSize); + var knownTotalChunks = bookContents.TryGetNonEnumeratedCount(out var totalChunkCount) ? totalChunkCount : (int?)null; + var nextProgressPercentToLog = 10; + var nextProgressChunkCountToLog = 500; + var successfulBatchRequestCounts = new Dictionary(); + var successfulBatchChunkTotals = new Dictionary(); int totalCount = 0; - async Task EmbedAndUpsertBatchAsync() + if (_logger is not null) { + LogEmbeddingRebuildStarted( + _logger, + knownTotalChunks, + configuredMaxBatchSize, + _retryOptions.MinInterRequestDelayMs); + } + + void LogProgressIfNeeded() + { + if (_logger is null) + return; + + if (knownTotalChunks is > 0) + { + while (nextProgressPercentToLog <= 100 + && (long)totalCount * 100 >= (long)knownTotalChunks.Value * nextProgressPercentToLog) + { + LogEmbeddingProgressPercent(_logger, totalCount, knownTotalChunks.Value, nextProgressPercentToLog, adaptiveBatchSize); + nextProgressPercentToLog += 10; + } + } + else if (totalCount >= nextProgressChunkCountToLog) + { + LogEmbeddingProgressCount(_logger, totalCount, adaptiveBatchSize); + nextProgressChunkCountToLog += 500; + } + } + + async Task EmbedAndUpsertExactBatchAsync(IReadOnlyList batch) + { + const string operationName = "GenerateBatchEmbeddings"; + int attemptNumber = 0; + var batchEmbeddings = await ExecuteWithRetryAsync( - async ct => await embeddingGenerator.GenerateAsync( - buffer.Select(c => c.ChunkText), cancellationToken: ct), - $"GenerateBatchEmbeddings(size={buffer.Count})", + async ct => + { + attemptNumber++; + if (_logger is not null) + { + LogEmbeddingBatchRequestState( + _logger, + operationName, + batch.Count, + adaptiveBatchSize, + batch.Count, + attemptNumber, + false, + false); + } + + return await ExecuteEmbeddingRequestWithPacingAsync( + async pacingCt => await embeddingGenerator.GenerateAsync( + batch.Select(c => c.ChunkText), cancellationToken: pacingCt), + ct); + }, + $"{operationName}(size={batch.Count})", cancellationToken); - if (batchEmbeddings.Count != buffer.Count) + if (batchEmbeddings.Count != batch.Count) throw new InvalidOperationException( - $"Embedding count mismatch: expected {buffer.Count}, got {batchEmbeddings.Count}."); + $"Embedding count mismatch: expected {batch.Count}, got {batchEmbeddings.Count}."); + + for (int i = 0; i < batch.Count; i++) + batch[i].TextEmbedding = batchEmbeddings[i].Vector; + + await staging.UpsertAsync(batch, cancellationToken); + if (_logger is not null) + { + LogEmbeddingBatchRequestState( + _logger, + operationName, + batch.Count, + adaptiveBatchSize, + batch.Count, + attemptNumber, + true, + false); + } + + if (!successfulBatchRequestCounts.TryAdd(batch.Count, 1)) + successfulBatchRequestCounts[batch.Count]++; - for (int i = 0; i < buffer.Count; i++) - buffer[i].TextEmbedding = batchEmbeddings[i].Vector; + if (!successfulBatchChunkTotals.TryAdd(batch.Count, batch.Count)) + successfulBatchChunkTotals[batch.Count] += batch.Count; - await staging.UpsertAsync(buffer, cancellationToken); - totalCount += buffer.Count; - buffer.Clear(); + totalCount += batch.Count; + LogProgressIfNeeded(); + } + + async Task EmbedAndUpsertBatchAdaptiveAsync(IReadOnlyList batch) + { + try + { + await EmbedAndUpsertExactBatchAsync(batch); + } + catch (Exception ex) when (IsRateLimitError(ex) && batch.Count > 1) + { + var splitSize = Math.Max(1, batch.Count / 2); + if (adaptiveBatchSize > splitSize) + { + var previousAdaptiveBatchSize = adaptiveBatchSize; + adaptiveBatchSize = splitSize; + if (_logger is not null) + { + LogEmbeddingBatchRequestState( + _logger, + "GenerateBatchEmbeddings", + previousAdaptiveBatchSize, + adaptiveBatchSize, + batch.Count, + _retryOptions.MaxRetries + 1, + false, + true); + } + } + + var first = batch.Take(splitSize).ToArray(); + var second = batch.Skip(splitSize).ToArray(); + await EmbedAndUpsertBatchAdaptiveAsync(first); + await EmbedAndUpsertBatchAdaptiveAsync(second); + } + catch (Exception ex) when (IsRateLimitError(ex) && batch.Count == 1) + { + throw new InvalidOperationException( + $"Embedding request failed with repeated 429 rate limits even at batch size 1 after {_retryOptions.MaxRetries + 1} attempts.", + ex); + } } try @@ -319,12 +525,33 @@ async Task EmbedAndUpsertBatchAsync() foreach (var chunk in bookContents) { buffer.Add(chunk); - if (buffer.Count == EmbeddingBatchSize) - await EmbedAndUpsertBatchAsync(); + if (buffer.Count >= adaptiveBatchSize) + { + var batchToProcess = buffer.ToArray(); + buffer.Clear(); + await EmbedAndUpsertBatchAdaptiveAsync(batchToProcess); + } } if (buffer.Count > 0) - await EmbedAndUpsertBatchAsync(); + { + var batchToProcess = buffer.ToArray(); + buffer.Clear(); + await EmbedAndUpsertBatchAdaptiveAsync(batchToProcess); + } + + if (_logger is not null) + { + foreach (var entry in successfulBatchRequestCounts.OrderBy(kvp => kvp.Key)) + { + successfulBatchChunkTotals.TryGetValue(entry.Key, out var successfulChunkCount); + LogEmbeddingBatchSizeSummary( + _logger, + entry.Key, + entry.Value, + successfulChunkCount); + } + } Console.WriteLine($"Uploaded {totalCount} chunks to staging collection '{stagingName}'."); } @@ -417,4 +644,58 @@ private static partial void LogEmbeddingRetryAttemptsExhausted( int attemptCount, string lastError, int? statusCode); + + [LoggerMessage( + EventId = 12004, + Level = LogLevel.Information, + Message = "Embedding batch request state: operation_name={operation_name} current_batch_size={current_batch_size} effective_batch_size={effective_batch_size} chunk_count_in_request={chunk_count_in_request} attempt_number={attempt_number} request_succeeded={request_succeeded} request_throttled={request_throttled}")] + private static partial void LogEmbeddingBatchRequestState( + ILogger logger, + string operation_name, + int current_batch_size, + int effective_batch_size, + int chunk_count_in_request, + int attempt_number, + bool request_succeeded, + bool request_throttled); + + [LoggerMessage( + EventId = 12005, + Level = LogLevel.Information, + Message = "Embedding rebuild started. TotalChunks={TotalChunks}, InitialBatchSize={InitialBatchSize}, MinInterRequestDelayMs={MinInterRequestDelayMs}")] + private static partial void LogEmbeddingRebuildStarted( + ILogger logger, + int? totalChunks, + int initialBatchSize, + int minInterRequestDelayMs); + + [LoggerMessage( + EventId = 12006, + Level = LogLevel.Information, + Message = "Embedding progress: {ProcessedChunks}/{TotalChunks} chunks ({ProgressPercent}%). CurrentAdaptiveBatchSize={AdaptiveBatchSize}")] + private static partial void LogEmbeddingProgressPercent( + ILogger logger, + int processedChunks, + int totalChunks, + int progressPercent, + int adaptiveBatchSize); + + [LoggerMessage( + EventId = 12007, + Level = LogLevel.Information, + Message = "Embedding progress: {ProcessedChunks} chunks processed. CurrentAdaptiveBatchSize={AdaptiveBatchSize}")] + private static partial void LogEmbeddingProgressCount( + ILogger logger, + int processedChunks, + int adaptiveBatchSize); + + [LoggerMessage( + EventId = 12008, + Level = LogLevel.Information, + Message = "Embedding successful batch-size summary: successful_batch_size={successful_batch_size} successful_request_count={successful_request_count} successful_chunk_count={successful_chunk_count}")] + private static partial void LogEmbeddingBatchSizeSummary( + ILogger logger, + int successful_batch_size, + int successful_request_count, + int successful_chunk_count); } diff --git a/EssentialCSharp.Web/appsettings.json b/EssentialCSharp.Web/appsettings.json index 12e8530a..1f32663d 100644 --- a/EssentialCSharp.Web/appsettings.json +++ b/EssentialCSharp.Web/appsettings.json @@ -25,6 +25,8 @@ "MaxRetries": 5, "BaseDelayMs": 1000, "MaxDelayMs": 60000, + "MaxEmbeddingBatchSize": 2048, + "MinInterRequestDelayMs": 250, "BackoffMultiplier": 2.0, "MaxJitterFraction": 0.2 },