|
1 | 1 | using System.Net; |
| 2 | +using System.Security.Cryptography; |
2 | 3 | using System.Text; |
3 | 4 | using DotNetty.Buffers; |
4 | 5 | using DotNetty.Transport.Bootstrapping; |
@@ -63,10 +64,21 @@ internal sealed class MyHandler : ChannelHandlerAdapter |
63 | 64 | { |
64 | 65 | private readonly List<byte> _buffer = new(); |
65 | 66 | private TaskCompletionSource<byte[]>? _pendingTcs; |
| 67 | + private IChannelHandlerContext? _ctx; |
66 | 68 |
|
67 | | - public void SetPendingResponse(TaskCompletionSource<byte[]> tcs) |
68 | | - => _pendingTcs = tcs; |
| 69 | + public override void ChannelActive(IChannelHandlerContext ctx) { _ctx = ctx; base.ChannelActive(ctx); } |
69 | 70 |
|
| 71 | + public void SetPendingResponse(TaskCompletionSource<byte[]> tcs) |
| 72 | + { |
| 73 | + _pendingTcs = tcs; |
| 74 | + // Schedule TryDeliver on the event loop to avoid a race on _buffer. |
| 75 | + var ctx = _ctx; |
| 76 | + if (ctx != null) |
| 77 | + { |
| 78 | + if (ctx.Executor.InEventLoop) TryDeliver(); |
| 79 | + else ctx.Executor.Execute(TryDeliver); |
| 80 | + } |
| 81 | + } |
70 | 82 | public override void ChannelRead(IChannelHandlerContext ctx, object msg) |
71 | 83 | { |
72 | 84 | if (msg is not IByteBuffer incoming) |
@@ -189,6 +201,106 @@ public Task<SqlDataTable> QueryTableAsync( |
189 | 201 | .ContinueWith(t => SqlDataTable.From(sql, t.Result), |
190 | 202 | TaskContinuationOptions.OnlyOnRanToCompletion); |
191 | 203 |
|
| 204 | + // ── Streaming (IAsyncEnumerable) ────────────────────────────────────────── |
| 205 | + |
| 206 | + /// <summary> |
| 207 | + /// Execute a query and stream rows one at a time as they arrive from the server, |
| 208 | + /// without buffering the full result set in memory. |
| 209 | + /// </summary> |
| 210 | + public async IAsyncEnumerable<SqlRow> QueryStreamAsync( |
| 211 | + string sql, |
| 212 | + IReadOnlyList<SqlParameter>? parameters = null, |
| 213 | + [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default) |
| 214 | + { |
| 215 | + // Parameterized path: use prepared statement (buffers internally — MySQL binary protocol |
| 216 | + // doesn't expose individual row packets to us here without a larger refactor). |
| 217 | + if (parameters is { Count: > 0 }) |
| 218 | + { |
| 219 | + var buffered = await QueryAsync(sql, parameters, ct).ConfigureAwait(false); |
| 220 | + foreach (var row in buffered) yield return row; |
| 221 | + yield break; |
| 222 | + } |
| 223 | + |
| 224 | + await _lock.WaitAsync(ct).ConfigureAwait(false); |
| 225 | + try |
| 226 | + { |
| 227 | + _seqId = 0; |
| 228 | + await WritePacketAsync(MyQueryMessage.Build(sql)).ConfigureAwait(false); |
| 229 | + |
| 230 | + // Read result-set header (column count + column definitions). |
| 231 | + var firstPayload = await ReceivePayloadAsync(ct).ConfigureAwait(false); |
| 232 | + if (firstPayload[0] == 0x00) yield break; // OK (DDL) |
| 233 | + if (firstPayload[0] == 0xFF) |
| 234 | + { |
| 235 | + var e = MyDecoder.ParseErr(firstPayload); |
| 236 | + throw SqlException.Query($"MySQL error [{e.ErrorCode}]: {e.ErrorMessage}"); |
| 237 | + } |
| 238 | + |
| 239 | + int tmp = 0; |
| 240 | + long colCount = MyDecoder.ReadLenEnc(firstPayload, out tmp); |
| 241 | + var columns = new List<MyColumnDef>((int)colCount); |
| 242 | + for (int i = 0; i < colCount; i++) |
| 243 | + { |
| 244 | + var colPayload = await ReceivePayloadAsync(ct).ConfigureAwait(false); |
| 245 | + columns.Add(MyDecoder.ParseColumnDef(colPayload)); |
| 246 | + } |
| 247 | + var sqlColumns = columns.Select(c => new SqlColumn(c.Name, c.FieldType.ToString())).ToList(); |
| 248 | + |
| 249 | + // Stream rows. |
| 250 | + while (true) |
| 251 | + { |
| 252 | + var rowPayload = await ReceivePayloadAsync(ct).ConfigureAwait(false); |
| 253 | + if (rowPayload[0] == 0xFE && rowPayload.Length < 9) yield break; // EOF |
| 254 | + if (rowPayload[0] == 0xFF) |
| 255 | + { |
| 256 | + var e = MyDecoder.ParseErr(rowPayload); |
| 257 | + throw SqlException.Query($"MySQL error [{e.ErrorCode}]: {e.ErrorMessage}"); |
| 258 | + } |
| 259 | + if (rowPayload[0] == 0x00 && rowPayload.Length <= 7) yield break; // OK (deprecate EOF) |
| 260 | + |
| 261 | + var values = ParseTextRow(rowPayload, columns); |
| 262 | + yield return new SqlRow(sqlColumns, values); |
| 263 | + } |
| 264 | + } |
| 265 | + finally { _lock.Release(); } |
| 266 | + } |
| 267 | + |
| 268 | + /// <summary> |
| 269 | + /// Execute a query where each row contains a complete JSON object in |
| 270 | + /// <paramref name="jsonColumnIndex"/> (e.g. <c>SELECT JSON_OBJECT(…) FROM …</c>) |
| 271 | + /// and stream one <see cref="System.Text.Json.JsonElement"/> per row. |
| 272 | + /// </summary> |
| 273 | + /// <remarks> |
| 274 | + /// Unlike SQL Server's <c>FOR JSON PATH</c>, MySQL's <c>JSON_OBJECT()</c> returns |
| 275 | + /// one <b>complete</b> JSON object per row — no chunk reassembly is needed. Each row |
| 276 | + /// is parsed independently, so memory usage is bounded to the largest single object. |
| 277 | + /// </remarks> |
| 278 | + /// <example> |
| 279 | + /// <code> |
| 280 | + /// await foreach (var elem in conn.QueryJsonStreamAsync( |
| 281 | + /// "SELECT JSON_OBJECT('Id', Id, 'Name', Name) FROM Products")) |
| 282 | + /// { |
| 283 | + /// var id = elem.GetProperty("Id").GetInt32(); |
| 284 | + /// } |
| 285 | + /// </code> |
| 286 | + /// </example> |
| 287 | + public async IAsyncEnumerable<System.Text.Json.JsonElement> QueryJsonStreamAsync( |
| 288 | + string sql, |
| 289 | + IReadOnlyList<SqlParameter>? parameters = null, |
| 290 | + int jsonColumnIndex = 0, |
| 291 | + [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default) |
| 292 | + { |
| 293 | + await foreach (var row in QueryStreamAsync(sql, parameters, ct).ConfigureAwait(false)) |
| 294 | + { |
| 295 | + string? json = row[jsonColumnIndex] is SqlValue.Bytes bytes |
| 296 | + ? System.Text.Encoding.UTF8.GetString(bytes.Value) |
| 297 | + : row[jsonColumnIndex].AsString(); |
| 298 | + if (string.IsNullOrEmpty(json)) continue; |
| 299 | + using var doc = System.Text.Json.JsonDocument.Parse(json); |
| 300 | + yield return doc.RootElement.Clone(); |
| 301 | + } |
| 302 | + } |
| 303 | + |
192 | 304 | public async Task BeginTransactionAsync(CancellationToken ct = default) |
193 | 305 | { |
194 | 306 | await _lock.WaitAsync(ct).ConfigureAwait(false); |
@@ -310,7 +422,74 @@ private async Task ConnectAsync(CancellationToken ct) |
310 | 422 | var err = MyDecoder.ParseErr(responseRaw); |
311 | 423 | throw SqlException.Auth($"MySQL authentication error [{err.ErrorCode}]: {err.ErrorMessage}"); |
312 | 424 | } |
313 | | - if (responseRaw[0] != 0x00) |
| 425 | + if (responseRaw[0] == 0xFE) |
| 426 | + { |
| 427 | + // AuthSwitchRequest: server wants us to use a different auth plugin |
| 428 | + // Packet: 0xFE | plugin_name\0 | new_auth_data |
| 429 | + int nameEnd = Array.IndexOf(responseRaw, (byte)0, 1); |
| 430 | + if (nameEnd < 0) nameEnd = responseRaw.Length; |
| 431 | + var switchPlugin = Encoding.UTF8.GetString(responseRaw, 1, nameEnd - 1); |
| 432 | + var switchData = nameEnd + 1 < responseRaw.Length |
| 433 | + ? responseRaw[(nameEnd + 1)..] |
| 434 | + : Array.Empty<byte>(); |
| 435 | + |
| 436 | + byte[] switchResponse = switchPlugin switch |
| 437 | + { |
| 438 | + "caching_sha2_password" => MyHandshakeResponse41.ComputeCachingSha2Auth(_config.Password, switchData), |
| 439 | + _ => MyHandshakeResponse41.ComputeNativePasswordAuth(_config.Password, switchData), |
| 440 | + }; |
| 441 | + |
| 442 | + // seqId continues from the packet sequence (3rd packet sent = seqId 3) |
| 443 | + await WritePacketAsync(switchResponse, seqId: 3).ConfigureAwait(false); |
| 444 | + |
| 445 | + // Read response: may be fast-auth OK (0x01 0x03), full-auth required (0x01 0x04), or ERR |
| 446 | + var finalRaw = await ReceivePayloadAsync(ct).ConfigureAwait(false); |
| 447 | + if (finalRaw[0] == 0xFF) |
| 448 | + { |
| 449 | + var err = MyDecoder.ParseErr(finalRaw); |
| 450 | + throw SqlException.Auth($"MySQL auth switch error [{err.ErrorCode}]: {err.ErrorMessage}"); |
| 451 | + } |
| 452 | + // 0x01 0x03 = fast auth success — read following OK |
| 453 | + if (finalRaw[0] == 0x01 && finalRaw.Length > 1 && finalRaw[1] == 0x03) |
| 454 | + { |
| 455 | + finalRaw = await ReceivePayloadAsync(ct).ConfigureAwait(false); |
| 456 | + if (finalRaw[0] == 0xFF) |
| 457 | + { |
| 458 | + var err2 = MyDecoder.ParseErr(finalRaw); |
| 459 | + throw SqlException.Auth($"MySQL auth fast-auth OK error [{err2.ErrorCode}]: {err2.ErrorMessage}"); |
| 460 | + } |
| 461 | + } |
| 462 | + // 0x01 0x04 = full auth required — exchange RSA public key and encrypt password |
| 463 | + else if (finalRaw[0] == 0x01 && finalRaw.Length > 1 && finalRaw[1] == 0x04) |
| 464 | + { |
| 465 | + // Request server's RSA public key |
| 466 | + await WritePacketAsync(new byte[] { 0x02 }, seqId: 5).ConfigureAwait(false); |
| 467 | + var pkPacket = await ReceivePayloadAsync(ct).ConfigureAwait(false); |
| 468 | + // pkPacket[0] == 0x01, rest is PEM public key |
| 469 | + int pemStart = pkPacket[0] == 0x01 ? 1 : 0; |
| 470 | + var pem = Encoding.UTF8.GetString(pkPacket, pemStart, pkPacket.Length - pemStart).Trim(); |
| 471 | + // XOR password bytes (null-terminated) with nonce (cyclic) |
| 472 | + var pwBytes = Encoding.UTF8.GetBytes(_config.Password + "\0"); |
| 473 | + var nonce = switchData; |
| 474 | + var obfuscated = new byte[pwBytes.Length]; |
| 475 | + for (int i = 0; i < pwBytes.Length; i++) |
| 476 | + obfuscated[i] = (byte)(pwBytes[i] ^ nonce[i % nonce.Length]); |
| 477 | + // RSA-OAEP-SHA1 encrypt |
| 478 | + using var rsa = RSA.Create(); |
| 479 | + rsa.ImportFromPem(pem); |
| 480 | + var encrypted = rsa.Encrypt(obfuscated, RSAEncryptionPadding.OaepSHA1); |
| 481 | + await WritePacketAsync(encrypted, seqId: 7).ConfigureAwait(false); |
| 482 | + finalRaw = await ReceivePayloadAsync(ct).ConfigureAwait(false); |
| 483 | + if (finalRaw[0] == 0xFF) |
| 484 | + { |
| 485 | + var err3 = MyDecoder.ParseErr(finalRaw); |
| 486 | + throw SqlException.Auth($"MySQL RSA auth error [{err3.ErrorCode}]: {err3.ErrorMessage}"); |
| 487 | + } |
| 488 | + } |
| 489 | + if (finalRaw[0] != 0x00) |
| 490 | + throw SqlException.Auth($"MySQL: unexpected post-switch response byte 0x{finalRaw[0]:X2}."); |
| 491 | + } |
| 492 | + else if (responseRaw[0] != 0x00) |
314 | 493 | { |
315 | 494 | // Could be auth switch request or other — for simplicity, fail |
316 | 495 | throw SqlException.Auth($"MySQL: unexpected authentication response byte 0x{responseRaw[0]:X2}."); |
|
0 commit comments