Skip to content

Commit 08cf34f

Browse files
vkuttypCopilot
andcommitted
Fix handler race condition and MySQL caching_sha2_password auth
- Fix PgHandler and MyHandler race: schedule TryDeliver on the event loop thread when called from SetPendingResponse (fixes buffer corruption at high row counts with localhost latency) - Fix MySQL caching_sha2_password fast-auth algorithm (H1 XOR H3 where H3 = SHA256(H2 || nonce), was incorrectly including password bytes) - Implement MySQL caching_sha2_password full-auth (RSA OAEP-SHA1 path for cache-miss first connection, including 0x01 0x03 fast-auth-OK and 0x01 0x04 full-auth-required handling) - Fix MySQL 8.4 compatibility: mysql_native_password removed, now use caching_sha2_password with proper auth switch handling - Fix QueryJsonStreamAsync for MySQL: JSON_OBJECT() returns Blob type, decode as UTF-8 string instead of calling AsString() (which base64s) - Fix JSON_OBJECT() SQL syntax: use comma syntax ('k', v) not VALUE kw - Postgres: 6/6 tests pass; MySQL: 5/5 tests pass Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 2baba33 commit 08cf34f

7 files changed

Lines changed: 481 additions & 21 deletions

File tree

src/SqlDotnetty.MySql/MySqlConnection.cs

Lines changed: 182 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Net;
2+
using System.Security.Cryptography;
23
using System.Text;
34
using DotNetty.Buffers;
45
using DotNetty.Transport.Bootstrapping;
@@ -63,10 +64,21 @@ internal sealed class MyHandler : ChannelHandlerAdapter
6364
{
6465
private readonly List<byte> _buffer = new();
6566
private TaskCompletionSource<byte[]>? _pendingTcs;
67+
private IChannelHandlerContext? _ctx;
6668

67-
public void SetPendingResponse(TaskCompletionSource<byte[]> tcs)
68-
=> _pendingTcs = tcs;
69+
public override void ChannelActive(IChannelHandlerContext ctx) { _ctx = ctx; base.ChannelActive(ctx); }
6970

71+
public void SetPendingResponse(TaskCompletionSource<byte[]> tcs)
72+
{
73+
_pendingTcs = tcs;
74+
// Schedule TryDeliver on the event loop to avoid a race on _buffer.
75+
var ctx = _ctx;
76+
if (ctx != null)
77+
{
78+
if (ctx.Executor.InEventLoop) TryDeliver();
79+
else ctx.Executor.Execute(TryDeliver);
80+
}
81+
}
7082
public override void ChannelRead(IChannelHandlerContext ctx, object msg)
7183
{
7284
if (msg is not IByteBuffer incoming)
@@ -189,6 +201,106 @@ public Task<SqlDataTable> QueryTableAsync(
189201
.ContinueWith(t => SqlDataTable.From(sql, t.Result),
190202
TaskContinuationOptions.OnlyOnRanToCompletion);
191203

204+
// ── Streaming (IAsyncEnumerable) ──────────────────────────────────────────
205+
206+
/// <summary>
207+
/// Execute a query and stream rows one at a time as they arrive from the server,
208+
/// without buffering the full result set in memory.
209+
/// </summary>
210+
public async IAsyncEnumerable<SqlRow> QueryStreamAsync(
211+
string sql,
212+
IReadOnlyList<SqlParameter>? parameters = null,
213+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
214+
{
215+
// Parameterized path: use prepared statement (buffers internally — MySQL binary protocol
216+
// doesn't expose individual row packets to us here without a larger refactor).
217+
if (parameters is { Count: > 0 })
218+
{
219+
var buffered = await QueryAsync(sql, parameters, ct).ConfigureAwait(false);
220+
foreach (var row in buffered) yield return row;
221+
yield break;
222+
}
223+
224+
await _lock.WaitAsync(ct).ConfigureAwait(false);
225+
try
226+
{
227+
_seqId = 0;
228+
await WritePacketAsync(MyQueryMessage.Build(sql)).ConfigureAwait(false);
229+
230+
// Read result-set header (column count + column definitions).
231+
var firstPayload = await ReceivePayloadAsync(ct).ConfigureAwait(false);
232+
if (firstPayload[0] == 0x00) yield break; // OK (DDL)
233+
if (firstPayload[0] == 0xFF)
234+
{
235+
var e = MyDecoder.ParseErr(firstPayload);
236+
throw SqlException.Query($"MySQL error [{e.ErrorCode}]: {e.ErrorMessage}");
237+
}
238+
239+
int tmp = 0;
240+
long colCount = MyDecoder.ReadLenEnc(firstPayload, out tmp);
241+
var columns = new List<MyColumnDef>((int)colCount);
242+
for (int i = 0; i < colCount; i++)
243+
{
244+
var colPayload = await ReceivePayloadAsync(ct).ConfigureAwait(false);
245+
columns.Add(MyDecoder.ParseColumnDef(colPayload));
246+
}
247+
var sqlColumns = columns.Select(c => new SqlColumn(c.Name, c.FieldType.ToString())).ToList();
248+
249+
// Stream rows.
250+
while (true)
251+
{
252+
var rowPayload = await ReceivePayloadAsync(ct).ConfigureAwait(false);
253+
if (rowPayload[0] == 0xFE && rowPayload.Length < 9) yield break; // EOF
254+
if (rowPayload[0] == 0xFF)
255+
{
256+
var e = MyDecoder.ParseErr(rowPayload);
257+
throw SqlException.Query($"MySQL error [{e.ErrorCode}]: {e.ErrorMessage}");
258+
}
259+
if (rowPayload[0] == 0x00 && rowPayload.Length <= 7) yield break; // OK (deprecate EOF)
260+
261+
var values = ParseTextRow(rowPayload, columns);
262+
yield return new SqlRow(sqlColumns, values);
263+
}
264+
}
265+
finally { _lock.Release(); }
266+
}
267+
268+
/// <summary>
269+
/// Execute a query where each row contains a complete JSON object in
270+
/// <paramref name="jsonColumnIndex"/> (e.g. <c>SELECT JSON_OBJECT(…) FROM …</c>)
271+
/// and stream one <see cref="System.Text.Json.JsonElement"/> per row.
272+
/// </summary>
273+
/// <remarks>
274+
/// Unlike SQL Server's <c>FOR JSON PATH</c>, MySQL's <c>JSON_OBJECT()</c> returns
275+
/// one <b>complete</b> JSON object per row — no chunk reassembly is needed. Each row
276+
/// is parsed independently, so memory usage is bounded to the largest single object.
277+
/// </remarks>
278+
/// <example>
279+
/// <code>
280+
/// await foreach (var elem in conn.QueryJsonStreamAsync(
281+
/// "SELECT JSON_OBJECT('Id', Id, 'Name', Name) FROM Products"))
282+
/// {
283+
/// var id = elem.GetProperty("Id").GetInt32();
284+
/// }
285+
/// </code>
286+
/// </example>
287+
public async IAsyncEnumerable<System.Text.Json.JsonElement> QueryJsonStreamAsync(
288+
string sql,
289+
IReadOnlyList<SqlParameter>? parameters = null,
290+
int jsonColumnIndex = 0,
291+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
292+
{
293+
await foreach (var row in QueryStreamAsync(sql, parameters, ct).ConfigureAwait(false))
294+
{
295+
string? json = row[jsonColumnIndex] is SqlValue.Bytes bytes
296+
? System.Text.Encoding.UTF8.GetString(bytes.Value)
297+
: row[jsonColumnIndex].AsString();
298+
if (string.IsNullOrEmpty(json)) continue;
299+
using var doc = System.Text.Json.JsonDocument.Parse(json);
300+
yield return doc.RootElement.Clone();
301+
}
302+
}
303+
192304
public async Task BeginTransactionAsync(CancellationToken ct = default)
193305
{
194306
await _lock.WaitAsync(ct).ConfigureAwait(false);
@@ -310,7 +422,74 @@ private async Task ConnectAsync(CancellationToken ct)
310422
var err = MyDecoder.ParseErr(responseRaw);
311423
throw SqlException.Auth($"MySQL authentication error [{err.ErrorCode}]: {err.ErrorMessage}");
312424
}
313-
if (responseRaw[0] != 0x00)
425+
if (responseRaw[0] == 0xFE)
426+
{
427+
// AuthSwitchRequest: server wants us to use a different auth plugin
428+
// Packet: 0xFE | plugin_name\0 | new_auth_data
429+
int nameEnd = Array.IndexOf(responseRaw, (byte)0, 1);
430+
if (nameEnd < 0) nameEnd = responseRaw.Length;
431+
var switchPlugin = Encoding.UTF8.GetString(responseRaw, 1, nameEnd - 1);
432+
var switchData = nameEnd + 1 < responseRaw.Length
433+
? responseRaw[(nameEnd + 1)..]
434+
: Array.Empty<byte>();
435+
436+
byte[] switchResponse = switchPlugin switch
437+
{
438+
"caching_sha2_password" => MyHandshakeResponse41.ComputeCachingSha2Auth(_config.Password, switchData),
439+
_ => MyHandshakeResponse41.ComputeNativePasswordAuth(_config.Password, switchData),
440+
};
441+
442+
// seqId continues from the packet sequence (3rd packet sent = seqId 3)
443+
await WritePacketAsync(switchResponse, seqId: 3).ConfigureAwait(false);
444+
445+
// Read response: may be fast-auth OK (0x01 0x03), full-auth required (0x01 0x04), or ERR
446+
var finalRaw = await ReceivePayloadAsync(ct).ConfigureAwait(false);
447+
if (finalRaw[0] == 0xFF)
448+
{
449+
var err = MyDecoder.ParseErr(finalRaw);
450+
throw SqlException.Auth($"MySQL auth switch error [{err.ErrorCode}]: {err.ErrorMessage}");
451+
}
452+
// 0x01 0x03 = fast auth success — read following OK
453+
if (finalRaw[0] == 0x01 && finalRaw.Length > 1 && finalRaw[1] == 0x03)
454+
{
455+
finalRaw = await ReceivePayloadAsync(ct).ConfigureAwait(false);
456+
if (finalRaw[0] == 0xFF)
457+
{
458+
var err2 = MyDecoder.ParseErr(finalRaw);
459+
throw SqlException.Auth($"MySQL auth fast-auth OK error [{err2.ErrorCode}]: {err2.ErrorMessage}");
460+
}
461+
}
462+
// 0x01 0x04 = full auth required — exchange RSA public key and encrypt password
463+
else if (finalRaw[0] == 0x01 && finalRaw.Length > 1 && finalRaw[1] == 0x04)
464+
{
465+
// Request server's RSA public key
466+
await WritePacketAsync(new byte[] { 0x02 }, seqId: 5).ConfigureAwait(false);
467+
var pkPacket = await ReceivePayloadAsync(ct).ConfigureAwait(false);
468+
// pkPacket[0] == 0x01, rest is PEM public key
469+
int pemStart = pkPacket[0] == 0x01 ? 1 : 0;
470+
var pem = Encoding.UTF8.GetString(pkPacket, pemStart, pkPacket.Length - pemStart).Trim();
471+
// XOR password bytes (null-terminated) with nonce (cyclic)
472+
var pwBytes = Encoding.UTF8.GetBytes(_config.Password + "\0");
473+
var nonce = switchData;
474+
var obfuscated = new byte[pwBytes.Length];
475+
for (int i = 0; i < pwBytes.Length; i++)
476+
obfuscated[i] = (byte)(pwBytes[i] ^ nonce[i % nonce.Length]);
477+
// RSA-OAEP-SHA1 encrypt
478+
using var rsa = RSA.Create();
479+
rsa.ImportFromPem(pem);
480+
var encrypted = rsa.Encrypt(obfuscated, RSAEncryptionPadding.OaepSHA1);
481+
await WritePacketAsync(encrypted, seqId: 7).ConfigureAwait(false);
482+
finalRaw = await ReceivePayloadAsync(ct).ConfigureAwait(false);
483+
if (finalRaw[0] == 0xFF)
484+
{
485+
var err3 = MyDecoder.ParseErr(finalRaw);
486+
throw SqlException.Auth($"MySQL RSA auth error [{err3.ErrorCode}]: {err3.ErrorMessage}");
487+
}
488+
}
489+
if (finalRaw[0] != 0x00)
490+
throw SqlException.Auth($"MySQL: unexpected post-switch response byte 0x{finalRaw[0]:X2}.");
491+
}
492+
else if (responseRaw[0] != 0x00)
314493
{
315494
// Could be auth switch request or other — for simplicity, fail
316495
throw SqlException.Auth($"MySQL: unexpected authentication response byte 0x{responseRaw[0]:X2}.");

src/SqlDotnetty.MySql/MySqlConnectionPool.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,35 @@ public async Task RollbackAsync(CancellationToken ct = default)
122122
finally { await ReleaseAsync(conn).ConfigureAwait(false); }
123123
}
124124

125+
public async IAsyncEnumerable<SqlRow> QueryStreamAsync(
126+
string sql,
127+
IReadOnlyList<SqlParameter>? parameters = null,
128+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
129+
{
130+
var conn = await AcquireAsync(ct).ConfigureAwait(false);
131+
try
132+
{
133+
await foreach (var row in conn.QueryStreamAsync(sql, parameters, ct).ConfigureAwait(false))
134+
yield return row;
135+
}
136+
finally { await ReleaseAsync(conn).ConfigureAwait(false); }
137+
}
138+
139+
public async IAsyncEnumerable<System.Text.Json.JsonElement> QueryJsonStreamAsync(
140+
string sql,
141+
IReadOnlyList<SqlParameter>? parameters = null,
142+
int jsonColumnIndex = 0,
143+
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
144+
{
145+
var conn = await AcquireAsync(ct).ConfigureAwait(false);
146+
try
147+
{
148+
await foreach (var elem in conn.QueryJsonStreamAsync(sql, parameters, jsonColumnIndex, ct).ConfigureAwait(false))
149+
yield return elem;
150+
}
151+
finally { await ReleaseAsync(conn).ConfigureAwait(false); }
152+
}
153+
125154
public Task CloseAsync() => DisposeAsync().AsTask();
126155

127156
public async ValueTask DisposeAsync()

src/SqlDotnetty.MySql/Proto/MyMessage.cs

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public static byte[] Build(
124124
return ms.ToArray();
125125
}
126126

127-
private static byte[] ComputeNativePasswordAuth(string password, byte[] scramble)
127+
internal static byte[] ComputeNativePasswordAuth(string password, byte[] scramble)
128128
{
129129
if (string.IsNullOrEmpty(password)) return Array.Empty<byte>();
130130

@@ -144,25 +144,24 @@ private static byte[] ComputeNativePasswordAuth(string password, byte[] scramble
144144
return result;
145145
}
146146

147-
private static byte[] ComputeCachingSha2Auth(string password, byte[] scramble)
147+
internal static byte[] ComputeCachingSha2Auth(string password, byte[] scramble)
148148
{
149149
if (string.IsNullOrEmpty(password)) return Array.Empty<byte>();
150150

151151
var pwBytes = Encoding.UTF8.GetBytes(password);
152-
// SHA256(password)
153-
var sha256pw = SHA256.HashData(pwBytes);
154-
// SHA256(SHA256(password))
155-
var sha256pw2 = SHA256.HashData(sha256pw);
156-
// SHA256(password + scramble + SHA256(SHA256(password)))
157-
var combined = new byte[pwBytes.Length + scramble.Length + sha256pw2.Length];
158-
pwBytes.CopyTo(combined, 0);
159-
scramble.CopyTo(combined, pwBytes.Length);
160-
sha256pw2.CopyTo(combined, pwBytes.Length + scramble.Length);
161-
var hash = SHA256.HashData(combined);
162-
// XOR with SHA256(password)
163-
var result = new byte[hash.Length];
164-
for (int i = 0; i < hash.Length; i++)
165-
result[i] = (byte)(hash[i] ^ sha256pw[i]);
152+
// H1 = SHA256(password)
153+
var h1 = SHA256.HashData(pwBytes);
154+
// H2 = SHA256(SHA256(password))
155+
var h2 = SHA256.HashData(h1);
156+
// H3 = SHA256(H2 || nonce)
157+
var combined = new byte[h2.Length + scramble.Length];
158+
h2.CopyTo(combined, 0);
159+
scramble.CopyTo(combined, h2.Length);
160+
var h3 = SHA256.HashData(combined);
161+
// token = H1 XOR H3
162+
var result = new byte[h1.Length];
163+
for (int i = 0; i < h1.Length; i++)
164+
result[i] = (byte)(h1[i] ^ h3[i]);
166165
return result;
167166
}
168167
}

0 commit comments

Comments
 (0)