@@ -27,6 +27,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
2727
2828 static readonly Exception ChannelClosedException = new IOException ( "Channel is closed" ) ;
2929 static readonly Action < Task , object > HandshakeCompletionCallback = new Action < Task , object > ( HandleHandshakeCompleted ) ;
30+ static readonly Action < Task < int > , object > UnwrapCompletedCallback = new Action < Task < int > , object > ( UnwrapCompleted ) ;
3031
3132 readonly SslStream sslStream ;
3233 readonly MediationStream mediationStream ;
@@ -40,6 +41,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
4041 bool firedChannelRead ;
4142 volatile FlushMode flushMode = FlushMode . ForceFlush ;
4243 IByteBuffer pendingSslStreamReadBuffer ;
44+ int pendingSslStreamReadLength ;
4345 Task < int > pendingSslStreamReadFuture ;
4446
4547 public TlsHandler ( TlsSettings settings )
@@ -342,10 +344,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
342344 Contract . Assert ( this . pendingSslStreamReadBuffer != null ) ;
343345
344346 outputBuffer = this . pendingSslStreamReadBuffer ;
345- outputBufferLength = outputBuffer . WritableBytes ;
347+ outputBufferLength = this . pendingSslStreamReadLength ;
346348
347349 this . pendingSslStreamReadFuture = null ;
348350 this . pendingSslStreamReadBuffer = null ;
351+ this . pendingSslStreamReadLength = 0 ;
349352 }
350353 else
351354 {
@@ -358,86 +361,78 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
358361 int currentPacketLength = packetLengths [ packetIndex ] ;
359362 this . mediationStream . ExpandSource ( currentPacketLength ) ;
360363
361- if ( currentReadFuture != null )
364+ while ( true )
362365 {
363- // there was a read pending already, so we make sure we completed that first
364-
365- if ( ! currentReadFuture . IsCompleted )
366+ int totalRead = 0 ;
367+ if ( currentReadFuture != null )
366368 {
367- // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
369+ // there was a read pending already, so we make sure we completed that first
368370
369- continue ;
370- }
371+ if ( ! currentReadFuture . IsCompleted )
372+ {
373+ // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
371374
372- int read = currentReadFuture . Result ;
375+ break ;
376+ }
373377
374- if ( read == 0 )
375- {
376- //Stream closed
377- return ;
378- }
378+ int read = currentReadFuture . Result ;
379+ totalRead += read ;
379380
380- // Now output the result of previous read and decide whether to do an extra read on the same source or move forward
381- AddBufferToOutput ( outputBuffer , read , output ) ;
381+ if ( read == 0 )
382+ {
383+ //Stream closed
384+ return ;
385+ }
382386
383- currentReadFuture = null ;
384- outputBuffer = null ;
385- if ( this . mediationStream . SourceReadableBytes == 0 )
386- {
387- // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
387+ // Now output the result of previous read and decide whether to do an extra read on the same source or move forward
388+ AddBufferToOutput ( outputBuffer , read , output ) ;
388389
389- if ( read < outputBufferLength )
390+ currentReadFuture = null ;
391+ outputBuffer = null ;
392+ if ( this . mediationStream . TotalReadableBytes == 0 )
390393 {
391- // SslStream returned non-full buffer and there's no more input to go through ->
392- // typically it means SslStream is done reading current frame so we skip
393- continue ;
394- }
394+ // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
395395
396- // we've read out `read` bytes out of current packet to fulfil previously outstanding read
397- outputBufferLength = currentPacketLength - read ;
398- if ( outputBufferLength <= 0 )
396+ if ( read < outputBufferLength )
397+ {
398+ // SslStream returned non-full buffer and there's no more input to go through ->
399+ // typically it means SslStream is done reading current frame so we skip
400+ break ;
401+ }
402+
403+ // we've read out `read` bytes out of current packet to fulfil previously outstanding read
404+ outputBufferLength = currentPacketLength - totalRead ;
405+ if ( outputBufferLength <= 0 )
406+ {
407+ // after feeding to SslStream current frame it read out more bytes than current packet size
408+ outputBufferLength = FallbackReadBufferSize ;
409+ }
410+ }
411+ else
399412 {
400- // after feeding to SslStream current frame it read out more bytes than current packet size
401- outputBufferLength = FallbackReadBufferSize ;
413+ // SslStream did not get to reading current frame so it completed previous read sync
414+ // and the next read will likely read out the new frame
415+ outputBufferLength = currentPacketLength ;
402416 }
403417 }
404418 else
405419 {
406- // SslStream did not get to reading current frame so it completed previous read sync
407- // and the next read will likely read out the new frame
420+ // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient
408421 outputBufferLength = currentPacketLength ;
409422 }
410- }
411- else
412- {
413- // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient
414- outputBufferLength = currentPacketLength ;
415- }
416423
417- outputBuffer = ctx . Allocator . Buffer ( outputBufferLength ) ;
418- currentReadFuture = this . ReadFromSslStreamAsync ( outputBuffer , outputBufferLength ) ;
424+ outputBuffer = ctx . Allocator . Buffer ( outputBufferLength ) ;
425+ currentReadFuture = this . ReadFromSslStreamAsync ( outputBuffer , outputBufferLength ) ;
426+ }
419427 }
420428
421- // read out the rest of SslStream's output (if any) at risk of going async
422- // using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading
423- while ( true )
429+ if ( currentReadFuture != null )
424430 {
425- if ( currentReadFuture != null )
426- {
427- if ( ! currentReadFuture . IsCompleted )
428- {
429- break ;
430- }
431- int read = currentReadFuture . Result ;
432- AddBufferToOutput ( outputBuffer , read , output ) ;
433- }
434- outputBuffer = ctx . Allocator . Buffer ( FallbackReadBufferSize ) ;
435- currentReadFuture = this . ReadFromSslStreamAsync ( outputBuffer , FallbackReadBufferSize ) ;
431+ pending = true ;
432+ this . pendingSslStreamReadBuffer = outputBuffer ;
433+ this . pendingSslStreamReadFuture = currentReadFuture ;
434+ this . pendingSslStreamReadLength = outputBufferLength ;
436435 }
437-
438- pending = true ;
439- this . pendingSslStreamReadBuffer = outputBuffer ;
440- this . pendingSslStreamReadFuture = currentReadFuture ;
441436 }
442437 catch ( Exception ex )
443438 {
@@ -458,6 +453,91 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
458453 outputBuffer . SafeRelease ( ) ;
459454 }
460455 }
456+
457+ if ( pending )
458+ {
459+ //Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here.
460+ this . pendingSslStreamReadFuture ? . ContinueWith ( UnwrapCompletedCallback , this , TaskContinuationOptions . None ) ;
461+ }
462+ }
463+ }
464+
465+ static void UnwrapCompleted ( Task < int > task , object state )
466+ {
467+ // Mono(with legacy provider) finish ReadAsync in async,
468+ // so extra check is needed to receive data in async
469+ var self = ( TlsHandler ) state ;
470+ Debug . Assert ( self . capturedContext . Executor . InEventLoop ) ;
471+
472+ //Ignore task completed in Unwrap
473+ if ( task == self . pendingSslStreamReadFuture )
474+ {
475+ IByteBuffer buf = self . pendingSslStreamReadBuffer ;
476+ int outputBufferLength = self . pendingSslStreamReadLength ;
477+
478+ self . pendingSslStreamReadFuture = null ;
479+ self . pendingSslStreamReadBuffer = null ;
480+ self . pendingSslStreamReadLength = 0 ;
481+
482+ while ( true )
483+ {
484+ switch ( task . Status )
485+ {
486+ case TaskStatus . RanToCompletion :
487+ {
488+ //The logic is the same as the one in Unwrap()
489+ var read = task . Result ;
490+ //Stream Closed
491+ if ( read == 0 )
492+ return ;
493+ self . capturedContext . FireChannelRead ( buf . SetWriterIndex ( buf . WriterIndex + read ) ) ;
494+
495+ if ( self . mediationStream . TotalReadableBytes == 0 )
496+ {
497+ self . capturedContext . FireChannelReadComplete ( ) ;
498+ self . mediationStream . ResetSource ( self . capturedContext . Allocator ) ;
499+
500+ if ( read < outputBufferLength )
501+ {
502+ // SslStream returned non-full buffer and there's no more input to go through ->
503+ // typically it means SslStream is done reading current frame so we skip
504+ return ;
505+ }
506+ }
507+
508+ outputBufferLength = self . mediationStream . TotalReadableBytes ;
509+ if ( outputBufferLength <= 0 )
510+ outputBufferLength = FallbackReadBufferSize ;
511+
512+ buf = self . capturedContext . Allocator . Buffer ( outputBufferLength ) ;
513+ task = self . ReadFromSslStreamAsync ( buf , outputBufferLength ) ;
514+ if ( task . IsCompleted )
515+ {
516+ continue ;
517+ }
518+
519+ self . pendingSslStreamReadFuture = task ;
520+ self . pendingSslStreamReadBuffer = buf ;
521+ self . pendingSslStreamReadLength = outputBufferLength ;
522+ task . ContinueWith ( UnwrapCompletedCallback , self , TaskContinuationOptions . ExecuteSynchronously ) ;
523+ return ;
524+ }
525+
526+ case TaskStatus . Canceled :
527+ case TaskStatus . Faulted :
528+ {
529+ buf . SafeRelease ( ) ;
530+ self . HandleFailure ( task . Exception ) ;
531+ return ;
532+ }
533+
534+ default :
535+ {
536+ buf . SafeRelease ( ) ;
537+ throw new ArgumentOutOfRangeException ( nameof ( task ) , "Unexpected task status: " + task . Status ) ;
538+ }
539+ }
540+ }
461541 }
462542 }
463543
0 commit comments