diff --git a/EssentialCSharp.Web/Controllers/ChatController.cs b/EssentialCSharp.Web/Controllers/ChatController.cs index 0fcd8577..66a2866c 100644 --- a/EssentialCSharp.Web/Controllers/ChatController.cs +++ b/EssentialCSharp.Web/Controllers/ChatController.cs @@ -1,3 +1,4 @@ +using System.IO; using System.Security.Claims; using System.Text.Json; using EssentialCSharp.Chat.Common.Services; @@ -15,13 +16,13 @@ namespace EssentialCSharp.Web.Controllers; [IgnoreAntiforgeryToken] public partial class ChatController : ControllerBase { - private readonly AIChatService _AiChatService; + private readonly AIChatService _AIChatService; private readonly ResponseIdValidationService _ResponseIdValidationService; private readonly ILogger _Logger; public ChatController(ILogger logger, AIChatService aiChatService, ResponseIdValidationService responseIdValidationService) { - _AiChatService = aiChatService; + _AIChatService = aiChatService; _ResponseIdValidationService = responseIdValidationService; _Logger = logger; } @@ -46,7 +47,7 @@ public async Task SendMessage([FromBody] ChatMessageRequest reque try { - var (response, responseId) = await _AiChatService.GetChatCompletion( + var (response, responseId) = await _AIChatService.GetChatCompletion( prompt: request.Message, previousResponseId: previousResponseId, enableContextualSearch: request.EnableContextualSearch, @@ -75,7 +76,7 @@ public async Task StreamMessage([FromBody] ChatMessageRequest request, Cancellat if (string.IsNullOrEmpty(request.Message)) { Response.StatusCode = 400; - await Response.WriteAsJsonAsync(new { error = "Message cannot be empty." }, CancellationToken.None); + await Response.WriteAsJsonAsync(new { error = "Message cannot be empty." }, cancellationToken); return; } @@ -83,7 +84,7 @@ public async Task StreamMessage([FromBody] ChatMessageRequest request, Cancellat if (string.IsNullOrEmpty(userId)) { Response.StatusCode = 401; - await Response.WriteAsJsonAsync(new { error = "Unauthorized." }, CancellationToken.None); + await Response.WriteAsJsonAsync(new { error = "Unauthorized." }, cancellationToken); return; } @@ -94,7 +95,7 @@ public async Task StreamMessage([FromBody] ChatMessageRequest request, Cancellat if (!_ResponseIdValidationService.ValidateResponseId(userId, previousResponseId)) { Response.StatusCode = 400; - await Response.WriteAsJsonAsync(new { error = "Invalid conversation context." }, CancellationToken.None); + await Response.WriteAsJsonAsync(new { error = "Invalid conversation context." }, cancellationToken); return; } @@ -104,7 +105,7 @@ public async Task StreamMessage([FromBody] ChatMessageRequest request, Cancellat try { - await foreach (var (text, responseId) in _AiChatService.GetChatCompletionStream( + await foreach (var (text, responseId) in _AIChatService.GetChatCompletionStream( prompt: request.Message, previousResponseId: previousResponseId, enableContextualSearch: request.EnableContextualSearch, @@ -133,53 +134,102 @@ public async Task StreamMessage([FromBody] ChatMessageRequest request, Cancellat await Response.WriteAsync("data: [DONE]\n\n", cancellationToken); await Response.Body.FlushAsync(cancellationToken); } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested || HttpContext.RequestAborted.IsCancellationRequested) + catch (OperationCanceledException) { - LogChatStreamCancelled(_Logger, User.Identity?.Name); - } - catch (ConversationContextLimitExceededException) when (!Response.HasStarted) - { - Response.StatusCode = 400; - Response.ContentType = "application/json"; - await Response.WriteAsJsonAsync(new { error = "This conversation has grown too long. Please start a new one.", errorCode = "context_limit_exceeded" }, CancellationToken.None); + if (cancellationToken.IsCancellationRequested || HttpContext.RequestAborted.IsCancellationRequested) + { + LogChatStreamCancelled(_Logger, User.Identity?.Name); + return; + } + + throw; } catch (ConversationContextLimitExceededException ex) { - LogChatStreamErrorMidStream(_Logger, ex, User.Identity?.Name); - try + if (!Response.HasStarted) { - await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"This conversation has grown too long. Please start a new one.\",\"errorCode\":\"context_limit_exceeded\"}\n\n", CancellationToken.None); - await Response.Body.FlushAsync(CancellationToken.None); + if (cancellationToken.IsCancellationRequested || HttpContext.RequestAborted.IsCancellationRequested) + return; + + Response.StatusCode = 400; + Response.ContentType = "application/json"; + try + { + var writeCancellationToken = + cancellationToken.IsCancellationRequested || HttpContext.RequestAborted.IsCancellationRequested + ? CancellationToken.None + : cancellationToken; + await Response.WriteAsJsonAsync(new { error = "This conversation has grown too long. Please start a new one.", errorCode = "context_limit_exceeded" }, writeCancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested || HttpContext.RequestAborted.IsCancellationRequested) + { + // Best-effort write during an aborted request — no response body can be delivered. + } + catch (IOException) when (HttpContext.RequestAborted.IsCancellationRequested) + { + // Expected client disconnect while attempting a best-effort error response write. + } + catch (ObjectDisposedException) when (HttpContext.RequestAborted.IsCancellationRequested) + { + // Response stream can already be disposed after an abrupt client disconnect. + } } - catch (Exception) + else { - // Best-effort write to an already-streaming response. Kestrel can throw - // IOException (connection reset), OperationCanceledException, or - // ObjectDisposedException on abrupt client disconnect — swallow all to - // avoid masking the original exception. + LogChatStreamErrorMidStream(_Logger, ex, User.Identity?.Name); + try + { + await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"This conversation has grown too long. Please start a new one.\",\"errorCode\":\"context_limit_exceeded\"}\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); + } + catch (Exception writeException) when (writeException is IOException or OperationCanceledException or ObjectDisposedException) + { + // Best-effort write to an already-streaming response. Kestrel can throw + // IOException (connection reset), OperationCanceledException, or + // ObjectDisposedException on abrupt client disconnect — swallow expected + // transport/disconnect exceptions to avoid masking the original exception. + } } } catch (Exception ex) when (!Response.HasStarted) { LogChatStreamErrorBeforeResponseStarted(_Logger, ex, User.Identity?.Name); + if (cancellationToken.IsCancellationRequested || HttpContext.RequestAborted.IsCancellationRequested) + return; + Response.StatusCode = 500; Response.ContentType = "application/json"; - await Response.WriteAsJsonAsync(new { error = "Chat service unavailable" }, CancellationToken.None); + try + { + await Response.WriteAsJsonAsync(new { error = "Chat service unavailable" }, cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested || HttpContext.RequestAborted.IsCancellationRequested) + { + // Best-effort write during an aborted request — no response body can be delivered. + } + catch (IOException) when (HttpContext.RequestAborted.IsCancellationRequested) + { + // Expected client disconnect while attempting a best-effort error response write. + } + catch (ObjectDisposedException) when (HttpContext.RequestAborted.IsCancellationRequested) + { + // Response stream can already be disposed after an abrupt client disconnect. + } } catch (Exception ex) { LogChatStreamErrorMidStream(_Logger, ex, User.Identity?.Name); try { - await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Stream interrupted\"}\n\n", CancellationToken.None); - await Response.Body.FlushAsync(CancellationToken.None); + await Response.WriteAsync("data: {\"type\":\"error\",\"message\":\"Stream interrupted\"}\n\n", cancellationToken); + await Response.Body.FlushAsync(cancellationToken); } - catch (Exception) + catch (Exception writeException) when (writeException is IOException or OperationCanceledException or ObjectDisposedException) { // Best-effort write to an already-streaming response. Kestrel can throw // IOException (connection reset), OperationCanceledException, or - // ObjectDisposedException on abrupt client disconnect — swallow all to - // avoid masking the original exception. + // ObjectDisposedException on abrupt client disconnect — swallow expected + // transport/disconnect exceptions to avoid masking the original exception. } } } diff --git a/EssentialCSharp.Web/Extensions/IServiceCollectionExtensions.cs b/EssentialCSharp.Web/Extensions/IServiceCollectionExtensions.cs index 9ef4a7b1..74967040 100644 --- a/EssentialCSharp.Web/Extensions/IServiceCollectionExtensions.cs +++ b/EssentialCSharp.Web/Extensions/IServiceCollectionExtensions.cs @@ -1,4 +1,7 @@ -using EssentialCSharp.Web.Services; +using System.Net; +using System.Net.Sockets; +using EssentialCSharp.Web.Services; +using Microsoft.AspNetCore.HttpOverrides; namespace EssentialCSharp.Web.Extensions; @@ -13,4 +16,77 @@ public static void AddCaptchaService(this IServiceCollection services, IConfigur c.BaseAddress = new Uri("https://api.hcaptcha.com"); }); } + + public static void AddTrustedForwardedHeaders(this IServiceCollection services, IConfiguration configuration, IHostEnvironment environment) + { + services.Configure(options => + { + options.ForwardedHeaders = + ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto; + options.ForwardLimit = 1; + + var trustedProxyCidrs = configuration + .GetSection("ForwardedHeaders:TrustedProxyCidrs") + .Get() ?? []; + var trustedProxies = configuration + .GetSection("ForwardedHeaders:TrustedProxies") + .Get() ?? []; + + if (trustedProxyCidrs.Length == 0 && trustedProxies.Length == 0) + { + if (!environment.IsDevelopment()) + { + throw new InvalidOperationException( + "Forwarded headers are enabled but no trusted proxies are configured. " + + "Set ForwardedHeaders:TrustedProxyCidrs or ForwardedHeaders:TrustedProxies."); + } + return; + } + + options.KnownIPNetworks.Clear(); + options.KnownProxies.Clear(); + + foreach (var cidr in trustedProxyCidrs) + { + if (!TryParseCidr(cidr, out var network)) + throw new InvalidOperationException($"Invalid ForwardedHeaders:TrustedProxyCidrs entry '{cidr}'. Use CIDR notation, e.g. '10.0.0.0/8'."); + + options.KnownIPNetworks.Add(network); + } + + foreach (var proxy in trustedProxies) + { + if (!IPAddress.TryParse(proxy, out var proxyAddress)) + throw new InvalidOperationException($"Invalid ForwardedHeaders:TrustedProxies entry '{proxy}'. Use a valid IP address."); + + options.KnownProxies.Add(proxyAddress); + } + }); + } + + private static bool TryParseCidr(string cidr, out System.Net.IPNetwork network) + { + network = default!; + if (string.IsNullOrWhiteSpace(cidr)) + return false; + + string[] parts = cidr.Split('/', 2, StringSplitOptions.TrimEntries); + if (parts.Length != 2 + || !IPAddress.TryParse(parts[0], out var networkAddress) + || !int.TryParse(parts[1], out var prefixLength)) + return false; + + int maxPrefixLength = networkAddress.AddressFamily switch + { + AddressFamily.InterNetwork => 32, + AddressFamily.InterNetworkV6 => 128, + _ => -1 + }; + + if (maxPrefixLength < 0 || prefixLength < 0 || prefixLength > maxPrefixLength) + return false; + + network = new System.Net.IPNetwork(networkAddress, prefixLength); + return true; + } } diff --git a/EssentialCSharp.Web/Program.cs b/EssentialCSharp.Web/Program.cs index 78b3fef1..be4341b7 100644 --- a/EssentialCSharp.Web/Program.cs +++ b/EssentialCSharp.Web/Program.cs @@ -119,17 +119,7 @@ private static void Main(string[] args) - builder.Services.Configure(options => - { - options.ForwardedHeaders = - ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto; - - // Only loopback proxies are allowed by default. - // Clear that restriction because forwarders are enabled by explicit - // configuration. - options.KnownIPNetworks.Clear(); - options.KnownProxies.Clear(); - }); + builder.Services.AddTrustedForwardedHeaders(builder.Configuration, builder.Environment); ConfigurationManager configuration = builder.Configuration; string connectionString = builder.Configuration.GetConnectionString("EssentialCSharpWebContextConnection") ?? throw new InvalidOperationException("Connection string 'EssentialCSharpWebContextConnection' not found."); diff --git a/EssentialCSharp.Web/appsettings.json b/EssentialCSharp.Web/appsettings.json index 12e8530a..7d922fa7 100644 --- a/EssentialCSharp.Web/appsettings.json +++ b/EssentialCSharp.Web/appsettings.json @@ -9,6 +9,10 @@ } }, "AllowedHosts": "*", + "ForwardedHeaders": { + "TrustedProxyCidrs": [], + "TrustedProxies": [] + }, "HCaptcha": { "SecretKey": "0x0000000000000000000000000000000000000000", "SiteKey": "10000000-ffff-ffff-ffff-000000000001"