Skip to content

Commit e06a8f8

Browse files
committed
Harden MSSQL pool
1 parent 4dc5f68 commit e06a8f8

4 files changed

Lines changed: 117 additions & 25 deletions

File tree

Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
<Authors>vkuttyp</Authors>
88
<PackageLicenseExpression>MIT</PackageLicenseExpression>
99
<RepositoryUrl>https://github.com/vkuttyp/CosmoSQLClient-Dotnet</RepositoryUrl>
10-
<Version>1.9.42</Version>
10+
<Version>1.9.43</Version>
1111
</PropertyGroup>
1212
</Project>
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
using System.Runtime.CompilerServices;
2+
3+
[assembly: InternalsVisibleTo("CosmoSQLClient.MsSql.Tests")]

src/CosmoSQLClient.MsSql/MsSqlConnectionPool.cs

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ public sealed class MsSqlConnectionPool : ISqlDatabase
2626
private readonly int _maxConnections;
2727
private readonly int _minIdle;
2828
private readonly Channel<MsSqlConnection> _idle;
29+
private readonly Func<MsSqlConnection, CancellationToken, ValueTask<bool>> _validateConnection;
30+
private readonly Func<CancellationToken, Task<MsSqlConnection>> _connectionFactory;
2931
private int _count;
3032
private bool _disposed;
3133
private CancellationTokenSource? _keepAliveCts;
@@ -50,7 +52,9 @@ public sealed class MsSqlConnectionPool : ISqlDatabase
5052
public MsSqlConnectionPool(
5153
MsSqlConfiguration config,
5254
int maxConnections = 10,
53-
int minIdle = 0)
55+
int minIdle = 0,
56+
Func<MsSqlConnection, CancellationToken, ValueTask<bool>>? validateConnection = null,
57+
Func<CancellationToken, Task<MsSqlConnection>>? connectionFactory = null)
5458
{
5559
_config = config;
5660
_maxConnections = Math.Max(1, maxConnections);
@@ -88,6 +92,9 @@ public MsSqlConnectionPool(
8892
_sslOptions = new object(); // Placeholder for netstandard2.0, ConnectAsync handles it
8993
#endif
9094
}
95+
96+
_validateConnection = validateConnection ?? DefaultValidateConnectionAsync;
97+
_connectionFactory = connectionFactory ?? OpenConnectionAsync;
9198
}
9299

93100
public bool IsOpen => !_disposed;
@@ -101,7 +108,7 @@ public MsSqlConnectionPool(
101108
public async Task WarmUpAsync(int? count = null, CancellationToken ct = default)
102109
{
103110
int n = Math.Min(count ?? _minIdle, _maxConnections);
104-
var tasks = Enumerable.Range(0, n).Select(_ => OpenConnectionAsync(ct));
111+
var tasks = Enumerable.Range(0, n).Select(_ => _connectionFactory(ct));
105112
var conns = await Task.WhenAll(tasks).ConfigureAwait(false);
106113
foreach (var c in conns)
107114
await _idle.Writer.WriteAsync(c, ct).ConfigureAwait(false);
@@ -192,35 +199,72 @@ private async Task PingIdleAsync(CancellationToken ct)
192199

193200
// ── Pool acquire/release ───────────────────────────────────────────────────
194201

195-
public async Task<MsSqlConnection> AcquireAsync(CancellationToken ct = default)
196-
{
197-
// Try an idle connection first (non-blocking).
198-
if (_idle.Reader.TryRead(out var conn))
199-
return conn;
202+
public async Task<MsSqlConnection> AcquireAsync(CancellationToken ct = default)
203+
{
204+
while (true)
205+
{
206+
if (_idle.Reader.TryRead(out var idleConn))
207+
{
208+
if (await _validateConnection(idleConn, ct).ConfigureAwait(false))
209+
return idleConn;
210+
211+
await idleConn.DisposeAsync().ConfigureAwait(false);
212+
Interlocked.Decrement(ref _count);
213+
continue;
214+
}
215+
216+
int current = Interlocked.Increment(ref _count);
217+
if (current <= _maxConnections)
218+
{
219+
try { return await _connectionFactory(ct).ConfigureAwait(false); }
220+
catch { Interlocked.Decrement(ref _count); throw; }
221+
}
200222

201-
// If below capacity, create a new connection.
202-
int current = Interlocked.Increment(ref _count);
203-
if (current <= _maxConnections)
223+
Interlocked.Decrement(ref _count);
224+
var pooledConn = await _idle.Reader.ReadAsync(ct).ConfigureAwait(false);
225+
if (await _validateConnection(pooledConn, ct).ConfigureAwait(false))
226+
return pooledConn;
227+
228+
await pooledConn.DisposeAsync().ConfigureAwait(false);
229+
Interlocked.Decrement(ref _count);
230+
}
231+
}
232+
233+
internal async Task InjectIdleConnectionAsync(MsSqlConnection conn)
204234
{
205-
try { return await OpenConnectionAsync(ct).ConfigureAwait(false); }
206-
catch { Interlocked.Decrement(ref _count); throw; }
235+
await _idle.Writer.WriteAsync(conn).ConfigureAwait(false);
236+
Interlocked.Increment(ref _count);
207237
}
208238

209-
// Already at max — wait for an idle one.
210-
Interlocked.Decrement(ref _count);
211-
return await _idle.Reader.ReadAsync(ct).ConfigureAwait(false);
212-
}
239+
public async Task ReleaseAsync(MsSqlConnection conn)
240+
{
241+
if (conn.IsOpen && !_disposed)
242+
await _idle.Writer.WriteAsync(conn).ConfigureAwait(false);
243+
else
244+
{
245+
await conn.DisposeAsync().ConfigureAwait(false);
246+
Interlocked.Decrement(ref _count);
247+
}
248+
}
213249

214-
public async Task ReleaseAsync(MsSqlConnection conn)
215-
{
216-
if (conn.IsOpen && !_disposed)
217-
await _idle.Writer.WriteAsync(conn).ConfigureAwait(false);
218-
else
250+
private async ValueTask<bool> DefaultValidateConnectionAsync(MsSqlConnection conn, CancellationToken ct)
219251
{
220-
await conn.DisposeAsync().ConfigureAwait(false);
221-
Interlocked.Decrement(ref _count);
252+
if (!conn.IsOpen)
253+
return false;
254+
try
255+
{
256+
await conn.QueryAsync("SELECT 1", ct: ct).ConfigureAwait(false);
257+
return true;
258+
}
259+
catch (OperationCanceledException)
260+
{
261+
throw;
262+
}
263+
catch
264+
{
265+
return false;
266+
}
222267
}
223-
}
224268

225269
// ── ISqlDatabase ───────────────────────────────────────────────────────────
226270

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using System.Threading;
2+
using System.Threading.Tasks;
3+
using CosmoSQLClient.MsSql;
4+
using Xunit;
5+
6+
namespace CosmoSQLClient.MsSql.Tests;
7+
8+
public sealed class MsSqlConnectionPoolTests
9+
{
10+
[Fact]
11+
public async Task AcquireAsync_RecreatesConnectionWhenValidationFails()
12+
{
13+
var config = new MsSqlConfiguration
14+
{
15+
Host = "localhost",
16+
Database = "master",
17+
TrustServerCertificate = true
18+
};
19+
20+
int validatorCalls = 0;
21+
int factoryCalls = 0;
22+
23+
var pool = new MsSqlConnectionPool(
24+
config,
25+
maxConnections: 1,
26+
minIdle: 0,
27+
validateConnection: (conn, ct) =>
28+
{
29+
Interlocked.Increment(ref validatorCalls);
30+
return new ValueTask<bool>(false);
31+
},
32+
connectionFactory: ct =>
33+
{
34+
Interlocked.Increment(ref factoryCalls);
35+
return Task.FromResult(new MsSqlConnection("Server=localhost;Database=master;User Id=sa;Password=pass;Pooling=false;"));
36+
});
37+
38+
await pool.InjectIdleConnectionAsync(new MsSqlConnection("Server=localhost;Database=master;User Id=sa;Password=pass;Pooling=false;"));
39+
var connection = await pool.AcquireAsync();
40+
await pool.ReleaseAsync(connection);
41+
42+
Assert.Equal(1, factoryCalls);
43+
Assert.Equal(1, validatorCalls);
44+
}
45+
}

0 commit comments

Comments
 (0)