Skip to content

Commit 3cdccab

Browse files
committed
Handle if the last data in async path.
1 parent 868c9ae commit 3cdccab

1 file changed

Lines changed: 139 additions & 59 deletions

File tree

src/DotNetty.Handlers/Tls/TlsHandler.cs

Lines changed: 139 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
2727

2828
static readonly Exception ChannelClosedException = new IOException("Channel is closed");
2929
static readonly Action<Task, object> HandshakeCompletionCallback = new Action<Task, object>(HandleHandshakeCompleted);
30+
static readonly Action<Task<int>, object> UnwrapCompletedCallback = new Action<Task<int>, object>(UnwrapCompleted);
3031

3132
readonly SslStream sslStream;
3233
readonly MediationStream mediationStream;
@@ -40,6 +41,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
4041
bool firedChannelRead;
4142
volatile FlushMode flushMode = FlushMode.ForceFlush;
4243
IByteBuffer pendingSslStreamReadBuffer;
44+
int pendingSslStreamReadLength;
4345
Task<int> pendingSslStreamReadFuture;
4446

4547
public TlsHandler(TlsSettings settings)
@@ -342,10 +344,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
342344
Contract.Assert(this.pendingSslStreamReadBuffer != null);
343345

344346
outputBuffer = this.pendingSslStreamReadBuffer;
345-
outputBufferLength = outputBuffer.WritableBytes;
347+
outputBufferLength = this.pendingSslStreamReadLength;
346348

347349
this.pendingSslStreamReadFuture = null;
348350
this.pendingSslStreamReadBuffer = null;
351+
this.pendingSslStreamReadLength = 0;
349352
}
350353
else
351354
{
@@ -358,86 +361,78 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
358361
int currentPacketLength = packetLengths[packetIndex];
359362
this.mediationStream.ExpandSource(currentPacketLength);
360363

361-
if (currentReadFuture != null)
364+
while (true)
362365
{
363-
// there was a read pending already, so we make sure we completed that first
364-
365-
if (!currentReadFuture.IsCompleted)
366+
int totalRead = 0;
367+
if (currentReadFuture != null)
366368
{
367-
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
369+
// there was a read pending already, so we make sure we completed that first
368370

369-
continue;
370-
}
371+
if (!currentReadFuture.IsCompleted)
372+
{
373+
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
371374

372-
int read = currentReadFuture.Result;
375+
break;
376+
}
373377

374-
if (read == 0)
375-
{
376-
//Stream closed
377-
return;
378-
}
378+
int read = currentReadFuture.Result;
379+
totalRead += read;
379380

380-
// Now output the result of previous read and decide whether to do an extra read on the same source or move forward
381-
AddBufferToOutput(outputBuffer, read, output);
381+
if (read == 0)
382+
{
383+
//Stream closed
384+
return;
385+
}
382386

383-
currentReadFuture = null;
384-
outputBuffer = null;
385-
if (this.mediationStream.SourceReadableBytes == 0)
386-
{
387-
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
387+
// Now output the result of previous read and decide whether to do an extra read on the same source or move forward
388+
AddBufferToOutput(outputBuffer, read, output);
388389

389-
if (read < outputBufferLength)
390+
currentReadFuture = null;
391+
outputBuffer = null;
392+
if (this.mediationStream.TotalReadableBytes == 0)
390393
{
391-
// SslStream returned non-full buffer and there's no more input to go through ->
392-
// typically it means SslStream is done reading current frame so we skip
393-
continue;
394-
}
394+
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
395395

396-
// we've read out `read` bytes out of current packet to fulfil previously outstanding read
397-
outputBufferLength = currentPacketLength - read;
398-
if (outputBufferLength <= 0)
396+
if (read < outputBufferLength)
397+
{
398+
// SslStream returned non-full buffer and there's no more input to go through ->
399+
// typically it means SslStream is done reading current frame so we skip
400+
break;
401+
}
402+
403+
// we've read out `read` bytes out of current packet to fulfil previously outstanding read
404+
outputBufferLength = currentPacketLength - totalRead;
405+
if (outputBufferLength <= 0)
406+
{
407+
// after feeding to SslStream current frame it read out more bytes than current packet size
408+
outputBufferLength = FallbackReadBufferSize;
409+
}
410+
}
411+
else
399412
{
400-
// after feeding to SslStream current frame it read out more bytes than current packet size
401-
outputBufferLength = FallbackReadBufferSize;
413+
// SslStream did not get to reading current frame so it completed previous read sync
414+
// and the next read will likely read out the new frame
415+
outputBufferLength = currentPacketLength;
402416
}
403417
}
404418
else
405419
{
406-
// SslStream did not get to reading current frame so it completed previous read sync
407-
// and the next read will likely read out the new frame
420+
// there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient
408421
outputBufferLength = currentPacketLength;
409422
}
410-
}
411-
else
412-
{
413-
// there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient
414-
outputBufferLength = currentPacketLength;
415-
}
416423

417-
outputBuffer = ctx.Allocator.Buffer(outputBufferLength);
418-
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength);
424+
outputBuffer = ctx.Allocator.Buffer(outputBufferLength);
425+
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength);
426+
}
419427
}
420428

421-
// read out the rest of SslStream's output (if any) at risk of going async
422-
// using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading
423-
while (true)
429+
if (currentReadFuture != null)
424430
{
425-
if (currentReadFuture != null)
426-
{
427-
if (!currentReadFuture.IsCompleted)
428-
{
429-
break;
430-
}
431-
int read = currentReadFuture.Result;
432-
AddBufferToOutput(outputBuffer, read, output);
433-
}
434-
outputBuffer = ctx.Allocator.Buffer(FallbackReadBufferSize);
435-
currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, FallbackReadBufferSize);
431+
pending = true;
432+
this.pendingSslStreamReadBuffer = outputBuffer;
433+
this.pendingSslStreamReadFuture = currentReadFuture;
434+
this.pendingSslStreamReadLength = outputBufferLength;
436435
}
437-
438-
pending = true;
439-
this.pendingSslStreamReadBuffer = outputBuffer;
440-
this.pendingSslStreamReadFuture = currentReadFuture;
441436
}
442437
catch (Exception ex)
443438
{
@@ -458,6 +453,91 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
458453
outputBuffer.SafeRelease();
459454
}
460455
}
456+
457+
if (pending)
458+
{
459+
//Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here.
460+
this.pendingSslStreamReadFuture?.ContinueWith(UnwrapCompletedCallback, this, TaskContinuationOptions.None);
461+
}
462+
}
463+
}
464+
465+
static void UnwrapCompleted(Task<int> task, object state)
466+
{
467+
// Mono(with legacy provider) finish ReadAsync in async,
468+
// so extra check is needed to receive data in async
469+
var self = (TlsHandler)state;
470+
Debug.Assert(self.capturedContext.Executor.InEventLoop);
471+
472+
//Ignore task completed in Unwrap
473+
if (task == self.pendingSslStreamReadFuture)
474+
{
475+
IByteBuffer buf = self.pendingSslStreamReadBuffer;
476+
int outputBufferLength = self.pendingSslStreamReadLength;
477+
478+
self.pendingSslStreamReadFuture = null;
479+
self.pendingSslStreamReadBuffer = null;
480+
self.pendingSslStreamReadLength = 0;
481+
482+
while (true)
483+
{
484+
switch (task.Status)
485+
{
486+
case TaskStatus.RanToCompletion:
487+
{
488+
//The logic is the same as the one in Unwrap()
489+
var read = task.Result;
490+
//Stream Closed
491+
if (read == 0)
492+
return;
493+
self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read));
494+
495+
if (self.mediationStream.TotalReadableBytes == 0)
496+
{
497+
self.capturedContext.FireChannelReadComplete();
498+
self.mediationStream.ResetSource(self.capturedContext.Allocator);
499+
500+
if (read < outputBufferLength)
501+
{
502+
// SslStream returned non-full buffer and there's no more input to go through ->
503+
// typically it means SslStream is done reading current frame so we skip
504+
return;
505+
}
506+
}
507+
508+
outputBufferLength = self.mediationStream.TotalReadableBytes;
509+
if (outputBufferLength <= 0)
510+
outputBufferLength = FallbackReadBufferSize;
511+
512+
buf = self.capturedContext.Allocator.Buffer(outputBufferLength);
513+
task = self.ReadFromSslStreamAsync(buf, outputBufferLength);
514+
if (task.IsCompleted)
515+
{
516+
continue;
517+
}
518+
519+
self.pendingSslStreamReadFuture = task;
520+
self.pendingSslStreamReadBuffer = buf;
521+
self.pendingSslStreamReadLength = outputBufferLength;
522+
task.ContinueWith(UnwrapCompletedCallback, self, TaskContinuationOptions.ExecuteSynchronously);
523+
return;
524+
}
525+
526+
case TaskStatus.Canceled:
527+
case TaskStatus.Faulted:
528+
{
529+
buf.SafeRelease();
530+
self.HandleFailure(task.Exception);
531+
return;
532+
}
533+
534+
default:
535+
{
536+
buf.SafeRelease();
537+
throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status);
538+
}
539+
}
540+
}
461541
}
462542
}
463543

0 commit comments

Comments
 (0)