@@ -105,6 +105,23 @@ public sealed class Session : ISession
105105 /// </summary>
106106 private readonly SemaphoreSlim _connectLock = new SemaphoreSlim ( 1 , 1 ) ;
107107
108+ private readonly byte [ ] _inboundPacketSequenceBytes = new byte [ 4 ] ;
109+
110+ /// <summary>
111+ /// Gets or sets the incoming packet number.
112+ /// </summary>
113+ private uint InboundPacketSequence
114+ {
115+ get
116+ {
117+ return BinaryPrimitives . ReadUInt32BigEndian ( _inboundPacketSequenceBytes ) ;
118+ }
119+ set
120+ {
121+ BinaryPrimitives . WriteUInt32BigEndian ( _inboundPacketSequenceBytes , value ) ;
122+ }
123+ }
124+
108125 /// <summary>
109126 /// Holds metadata about session messages.
110127 /// </summary>
@@ -120,11 +137,6 @@ public sealed class Session : ISession
120137 /// </summary>
121138 private volatile uint _outboundPacketSequence ;
122139
123- /// <summary>
124- /// Specifies incoming packet number.
125- /// </summary>
126- private uint _inboundPacketSequence ;
127-
128140 /// <summary>
129141 /// WaitHandle to signal that last service request was accepted.
130142 /// </summary>
@@ -200,7 +212,6 @@ public sealed class Session : ISession
200212 private Socket _socket ;
201213
202214 private ArrayBuffer _receiveBuffer = new ( 4 * 1024 ) ;
203- private byte [ ] _plaintextReceiveBuffer = new byte [ 4 * 1024 ] ;
204215
205216 /// <summary>
206217 /// Gets the session semaphore that controls session channels.
@@ -1213,9 +1224,6 @@ private bool TrySendMessage(Message message)
12131224 /// </remarks>
12141225 private Message ReceiveMessage ( Socket socket )
12151226 {
1216- // the length of the packet sequence field in bytes
1217- const int inboundPacketSequenceLength = 4 ;
1218-
12191227 // The length of the "packet length" field in bytes
12201228 const int packetLengthFieldLength = 4 ;
12211229
@@ -1272,31 +1280,28 @@ private Message ReceiveMessage(Socket socket)
12721280 }
12731281 }
12741282
1275- var firstBlock = new ArraySegment < byte > (
1276- _receiveBuffer . DangerousGetUnderlyingBuffer ( ) ,
1277- _receiveBuffer . ActiveStartOffset ,
1278- blockSize ) ;
1279-
1280- var plainFirstBlock = firstBlock ;
1281-
1282- // For ETM or AES-GCM, firstBlock holds the packet length which is
1283- // not encrypted. Otherwise, we decrypt the first "blockSize" bytes.
1284- // (For chacha20-poly1305, this means passing the encrypted packet
1285- // length as AAD).
1283+ // For ETM or AES-GCM, the first "blockSize" bytes hold the packet length
1284+ // which is not encrypted. Otherwise, we decrypt them.
1285+ // (For chacha20-poly1305, this means passing the encrypted packet length
1286+ // to its AAD cipher instance - it is the awkward difference between the
1287+ // 3-arg and 5-arg Decrypt, and explains why we don't just decrypt these
1288+ // bytes in-place).
12861289 if ( _serverCipher is not null and not Security . Cryptography . Ciphers . AesGcmCipher )
12871290 {
1288- _serverCipher . SetSequenceNumber ( _inboundPacketSequence ) ;
1291+ _serverCipher . SetSequenceNumber ( InboundPacketSequence ) ;
12891292
12901293 if ( _serverMac == null || ! _serverEtm )
12911294 {
1292- plainFirstBlock = new ArraySegment < byte > ( _serverCipher . Decrypt (
1293- firstBlock . Array ,
1294- firstBlock . Offset ,
1295- firstBlock . Count ) ) ;
1295+ var plainFirstBlock = _serverCipher . Decrypt (
1296+ _receiveBuffer . DangerousGetUnderlyingBuffer ( ) ,
1297+ _receiveBuffer . ActiveStartOffset ,
1298+ blockSize ) ;
1299+
1300+ plainFirstBlock . CopyTo ( _receiveBuffer . ActiveSpan ) ;
12961301 }
12971302 }
12981303
1299- var packetLength = BinaryPrimitives . ReadInt32BigEndian ( plainFirstBlock ) ;
1304+ var packetLength = BinaryPrimitives . ReadInt32BigEndian ( _receiveBuffer . ActiveReadOnlySpan ) ;
13001305
13011306 // Test packet minimum and maximum boundaries
13021307 if ( packetLength < Math . Max ( ( byte ) 8 , blockSize ) - 4 || packetLength > MaximumSshPacketSize - 4 )
@@ -1330,26 +1335,13 @@ private Message ReceiveMessage(Socket socket)
13301335 }
13311336 }
13321337
1333- // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
1334- // to generate the hash.
1335- var plaintextLength = 4 + totalPacketLength - serverMacLength ;
1336-
1337- if ( _plaintextReceiveBuffer . Length < plaintextLength )
1338- {
1339- Array . Resize ( ref _plaintextReceiveBuffer , Math . Max ( plaintextLength , 2 * _plaintextReceiveBuffer . Length ) ) ;
1340- }
1341-
1342- BinaryPrimitives . WriteUInt32BigEndian ( _plaintextReceiveBuffer , _inboundPacketSequence ) ;
1343-
1344- plainFirstBlock . AsSpan ( ) . CopyTo ( _plaintextReceiveBuffer . AsSpan ( 4 ) ) ;
1345-
13461338 if ( _serverMac != null && _serverEtm )
13471339 {
13481340 // ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet)
13491341
13501342 // sequence_number
13511343 _ = _serverMac . TransformBlock (
1352- inputBuffer : _plaintextReceiveBuffer ,
1344+ inputBuffer : _inboundPacketSequenceBytes ,
13531345 inputOffset : 0 ,
13541346 inputCount : 4 ,
13551347 outputBuffer : null ,
@@ -1377,58 +1369,77 @@ private Message ReceiveMessage(Socket socket)
13771369 {
13781370 Debug . Assert ( numberOfBytesToDecrypt % blockSize == 0 ) ;
13791371
1372+ var decryptBuffer = _receiveBuffer . DangerousGetUnderlyingBuffer ( ) ;
1373+ var decryptOffset = _receiveBuffer . ActiveStartOffset + blockSize ;
1374+
13801375 var numberOfBytesDecrypted = _serverCipher . Decrypt (
1381- input : _receiveBuffer . DangerousGetUnderlyingBuffer ( ) ,
1382- offset : _receiveBuffer . ActiveStartOffset + blockSize ,
1376+ input : decryptBuffer ,
1377+ offset : decryptOffset ,
13831378 length : numberOfBytesToDecrypt ,
1384- output : _plaintextReceiveBuffer ,
1385- outputOffset : 4 + blockSize ) ;
1379+ output : decryptBuffer ,
1380+ outputOffset : decryptOffset ) ;
13861381
13871382 Debug . Assert ( numberOfBytesDecrypted == numberOfBytesToDecrypt ) ;
13881383 }
1389- else
1390- {
1391- _receiveBuffer . ActiveReadOnlySpan
1392- . Slice ( blockSize , numberOfBytesToDecrypt )
1393- . CopyTo ( _plaintextReceiveBuffer . AsSpan ( 4 + blockSize ) ) ;
1394- }
13951384
13961385 if ( _serverMac != null && ! _serverEtm )
13971386 {
13981387 // non-ETM mac = MAC(key, sequence_number || unencrypted_packet)
13991388
1400- var clientHash = _serverMac . ComputeHash ( _plaintextReceiveBuffer , 0 , plaintextLength ) ;
1389+ // sequence_number
1390+ _ = _serverMac . TransformBlock (
1391+ inputBuffer : _inboundPacketSequenceBytes ,
1392+ inputOffset : 0 ,
1393+ inputCount : 4 ,
1394+ outputBuffer : null ,
1395+ outputOffset : 0 ) ;
1396+
1397+ // unencrypted_packet
1398+ _ = _serverMac . TransformBlock (
1399+ inputBuffer : _receiveBuffer . DangerousGetUnderlyingBuffer ( ) ,
1400+ inputOffset : _receiveBuffer . ActiveStartOffset ,
1401+ inputCount : totalPacketLength - serverMacLength ,
1402+ outputBuffer : null ,
1403+ outputOffset : 0 ) ;
1404+
1405+ _ = _serverMac . TransformFinalBlock ( Array . Empty < byte > ( ) , 0 , 0 ) ;
14011406
1402- if ( ! CryptoAbstraction . FixedTimeEquals ( clientHash , _receiveBuffer . ActiveSpan . Slice ( totalPacketLength - serverMacLength , serverMacLength ) ) )
1407+ if ( ! CryptoAbstraction . FixedTimeEquals ( _serverMac . Hash , _receiveBuffer . ActiveSpan . Slice ( totalPacketLength - serverMacLength , serverMacLength ) ) )
14031408 {
14041409 throw new SshConnectionException ( "MAC error" , DisconnectReason . MacError ) ;
14051410 }
14061411 }
14071412
1408- _receiveBuffer . Discard ( totalPacketLength ) ;
1409-
1410- var paddingLength = _plaintextReceiveBuffer [ inboundPacketSequenceLength + packetLengthFieldLength ] ;
1413+ var paddingLength = _receiveBuffer . ActiveReadOnlySpan [ packetLengthFieldLength ] ;
14111414
14121415 ArraySegment < byte > payload = new (
1413- _plaintextReceiveBuffer ,
1414- offset : inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength ,
1416+ _receiveBuffer . DangerousGetUnderlyingBuffer ( ) ,
1417+ offset : _receiveBuffer . ActiveStartOffset + packetLengthFieldLength + paddingLengthFieldLength ,
14151418 count : packetLength - paddingLength - paddingLengthFieldLength ) ;
14161419
14171420 if ( _serverDecompression != null )
14181421 {
14191422 payload = new ( _serverDecompression . Decompress ( payload . Array , payload . Offset , payload . Count ) ) ;
14201423 }
14211424
1422- _inboundPacketSequence ++ ;
1425+ var newInboundPacketSequence = ++ InboundPacketSequence ;
14231426
14241427 // The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
14251428 // It ensures the integrity of key exchange process.
1426- if ( _inboundPacketSequence == uint . MaxValue && _isInitialKex )
1429+ if ( newInboundPacketSequence == uint . MaxValue && _isInitialKex )
14271430 {
14281431 throw new SshConnectionException ( "Inbound packet sequence number is about to wrap during initial key exchange." , DisconnectReason . KeyExchangeFailed ) ;
14291432 }
14301433
1431- return LoadMessage ( payload . Array , payload . Offset , payload . Count ) ;
1434+ var message = LoadMessage ( payload . Array , payload . Offset , payload . Count ) ;
1435+
1436+ // The deserialised message may still reference data in the buffer, so calling Discard
1437+ // here might seem misguided. It is OK because Discard does not mutate the buffer
1438+ // and it will not be touched again until the next call to ReceiveMessage, which will
1439+ // only occur after the message has been fully processed.
1440+ _receiveBuffer . Discard ( totalPacketLength ) ;
1441+
1442+ return message ;
14321443 }
14331444
14341445 private void TrySendDisconnect ( DisconnectReason reasonCode , string message )
@@ -1545,7 +1556,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
15451556
15461557 _logger . LogDebug ( "[{SessionId}] Enabling strict key exchange extension." , SessionIdHex ) ;
15471558
1548- if ( _inboundPacketSequence != 1 )
1559+ if ( InboundPacketSequence != 1 )
15491560 {
15501561 throw new SshConnectionException ( "KEXINIT was not the first packet during strict key exchange." , DisconnectReason . KeyExchangeFailed ) ;
15511562 }
@@ -1646,7 +1657,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)
16461657
16471658 if ( _isStrictKex )
16481659 {
1649- _inboundPacketSequence = 0 ;
1660+ InboundPacketSequence = 0 ;
16501661 }
16511662
16521663 NewKeysReceived ? . Invoke ( this , new MessageEventArgs < NewKeysMessage > ( message ) ) ;
0 commit comments