Skip to content

Commit 0a8e9ae

Browse files
committed
Remove plaintext receive buffer
PR#1752 added a persistent buffer into which to decrypt packets, rather than allocating a new array for each packet. This was on the back of sshnet#1733 which added support in the cipher types for decrypting into a given buffer, but for the case of AES-CTR, not into the same buffer in-place. sshnet#1787 adds that missing support, meaning we can now decrypt in-place and remove the plaintext buffer.
1 parent 51747c5 commit 0a8e9ae

2 files changed

Lines changed: 84 additions & 76 deletions

File tree

src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,23 +138,28 @@ public override int Encrypt(byte[] input, int offset, int length, byte[] output,
138138
/// <returns>The decrypted plaintext.</returns>
139139
public override byte[] Decrypt(byte[] input, int offset, int length)
140140
{
141-
byte[] output;
141+
var output = new byte[length];
142+
143+
_cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv));
144+
145+
var keyStream = new byte[64];
146+
_cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0);
147+
_mac.Init(new KeyParameter(keyStream, 0, 32));
142148

143149
if (_aadLength > 0)
144150
{
145151
// If we are in 'AAD mode', then put these bytes through the AAD cipher.
146152

153+
_mac.BlockUpdate(input, offset, length);
154+
147155
Debug.Assert(_aadCipher != null);
148156

149157
_aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv));
150158

151-
output = new byte[length];
152159
_aadCipher.ProcessBytes(input, offset, length, output, 0);
153160
}
154161
else
155162
{
156-
output = new byte[length];
157-
158163
var bytesWritten = Decrypt(input, offset, length, output, 0);
159164

160165
Debug.Assert(bytesWritten == length);
@@ -169,7 +174,7 @@ public override byte[] Decrypt(byte[] input, int offset, int length)
169174
/// <param name="input">
170175
/// The input data with below format:
171176
/// <code>
172-
/// [----][----Cipher AAD----(offset)][----Cipher Text----(length)][----TAG----]
177+
/// [----(offset)][----Cipher Text----(length)][----TAG----]
173178
/// </code>
174179
/// </param>
175180
/// <param name="offset">The zero-based offset in <paramref name="input"/> at which to begin decrypting and authenticating.</param>
@@ -179,16 +184,8 @@ public override byte[] Decrypt(byte[] input, int offset, int length)
179184
/// <returns>The number of plaintext bytes written to <paramref name="output"/>.</returns>
180185
public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset)
181186
{
182-
Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length");
183-
184-
_cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv));
185-
186-
var keyStream = new byte[64];
187-
_cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0);
188-
_mac.Init(new KeyParameter(keyStream, 0, 32));
189-
190187
var tag = new byte[TagSize];
191-
_mac.BlockUpdate(input, offset - _aadLength, length + _aadLength);
188+
_mac.BlockUpdate(input, offset, length);
192189
_ = _mac.DoFinal(tag, 0);
193190
if (!Arrays.FixedTimeEquals(TagSize, tag, 0, input, offset + length))
194191
{

src/Renci.SshNet/Session.cs

Lines changed: 73 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,23 @@ public sealed class Session : ISession
105105
/// </summary>
106106
private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1);
107107

108+
private readonly byte[] _inboundPacketSequenceBytes = new byte[4];
109+
110+
/// <summary>
111+
/// Gets or sets the incoming packet number.
112+
/// </summary>
113+
private uint InboundPacketSequence
114+
{
115+
get
116+
{
117+
return BinaryPrimitives.ReadUInt32BigEndian(_inboundPacketSequenceBytes);
118+
}
119+
set
120+
{
121+
BinaryPrimitives.WriteUInt32BigEndian(_inboundPacketSequenceBytes, value);
122+
}
123+
}
124+
108125
/// <summary>
109126
/// Holds metadata about session messages.
110127
/// </summary>
@@ -120,11 +137,6 @@ public sealed class Session : ISession
120137
/// </summary>
121138
private volatile uint _outboundPacketSequence;
122139

123-
/// <summary>
124-
/// Specifies incoming packet number.
125-
/// </summary>
126-
private uint _inboundPacketSequence;
127-
128140
/// <summary>
129141
/// WaitHandle to signal that last service request was accepted.
130142
/// </summary>
@@ -200,7 +212,6 @@ public sealed class Session : ISession
200212
private Socket _socket;
201213

202214
private ArrayBuffer _receiveBuffer = new(4 * 1024);
203-
private byte[] _plaintextReceiveBuffer = new byte[4 * 1024];
204215

205216
/// <summary>
206217
/// Gets the session semaphore that controls session channels.
@@ -1213,9 +1224,6 @@ private bool TrySendMessage(Message message)
12131224
/// </remarks>
12141225
private Message ReceiveMessage(Socket socket)
12151226
{
1216-
// the length of the packet sequence field in bytes
1217-
const int inboundPacketSequenceLength = 4;
1218-
12191227
// The length of the "packet length" field in bytes
12201228
const int packetLengthFieldLength = 4;
12211229

@@ -1272,31 +1280,28 @@ private Message ReceiveMessage(Socket socket)
12721280
}
12731281
}
12741282

1275-
var firstBlock = new ArraySegment<byte>(
1276-
_receiveBuffer.DangerousGetUnderlyingBuffer(),
1277-
_receiveBuffer.ActiveStartOffset,
1278-
blockSize);
1279-
1280-
var plainFirstBlock = firstBlock;
1281-
1282-
// For ETM or AES-GCM, firstBlock holds the packet length which is
1283-
// not encrypted. Otherwise, we decrypt the first "blockSize" bytes.
1284-
// (For chacha20-poly1305, this means passing the encrypted packet
1285-
// length as AAD).
1283+
// For ETM or AES-GCM, the first "blockSize" bytes hold the packet length
1284+
// which is not encrypted. Otherwise, we decrypt them.
1285+
// (For chacha20-poly1305, this means passing the encrypted packet length
1286+
// to its AAD cipher instance - it is the awkward difference between the
1287+
// 3-arg and 5-arg Decrypt, and explains why we don't just decrypt these
1288+
// bytes in-place).
12861289
if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher)
12871290
{
1288-
_serverCipher.SetSequenceNumber(_inboundPacketSequence);
1291+
_serverCipher.SetSequenceNumber(InboundPacketSequence);
12891292

12901293
if (_serverMac == null || !_serverEtm)
12911294
{
1292-
plainFirstBlock = new ArraySegment<byte>(_serverCipher.Decrypt(
1293-
firstBlock.Array,
1294-
firstBlock.Offset,
1295-
firstBlock.Count));
1295+
var plainFirstBlock = _serverCipher.Decrypt(
1296+
_receiveBuffer.DangerousGetUnderlyingBuffer(),
1297+
_receiveBuffer.ActiveStartOffset,
1298+
blockSize);
1299+
1300+
plainFirstBlock.CopyTo(_receiveBuffer.ActiveSpan);
12961301
}
12971302
}
12981303

1299-
var packetLength = BinaryPrimitives.ReadInt32BigEndian(plainFirstBlock);
1304+
var packetLength = BinaryPrimitives.ReadInt32BigEndian(_receiveBuffer.ActiveReadOnlySpan);
13001305

13011306
// Test packet minimum and maximum boundaries
13021307
if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
@@ -1330,26 +1335,13 @@ private Message ReceiveMessage(Socket socket)
13301335
}
13311336
}
13321337

1333-
// Construct buffer for holding the payload and the inbound packet sequence as we need both in order
1334-
// to generate the hash.
1335-
var plaintextLength = 4 + totalPacketLength - serverMacLength;
1336-
1337-
if (_plaintextReceiveBuffer.Length < plaintextLength)
1338-
{
1339-
Array.Resize(ref _plaintextReceiveBuffer, Math.Max(plaintextLength, 2 * _plaintextReceiveBuffer.Length));
1340-
}
1341-
1342-
BinaryPrimitives.WriteUInt32BigEndian(_plaintextReceiveBuffer, _inboundPacketSequence);
1343-
1344-
plainFirstBlock.AsSpan().CopyTo(_plaintextReceiveBuffer.AsSpan(4));
1345-
13461338
if (_serverMac != null && _serverEtm)
13471339
{
13481340
// ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet)
13491341

13501342
// sequence_number
13511343
_ = _serverMac.TransformBlock(
1352-
inputBuffer: _plaintextReceiveBuffer,
1344+
inputBuffer: _inboundPacketSequenceBytes,
13531345
inputOffset: 0,
13541346
inputCount: 4,
13551347
outputBuffer: null,
@@ -1377,58 +1369,77 @@ private Message ReceiveMessage(Socket socket)
13771369
{
13781370
Debug.Assert(numberOfBytesToDecrypt % blockSize == 0);
13791371

1372+
var decryptBuffer = _receiveBuffer.DangerousGetUnderlyingBuffer();
1373+
var decryptOffset = _receiveBuffer.ActiveStartOffset + blockSize;
1374+
13801375
var numberOfBytesDecrypted = _serverCipher.Decrypt(
1381-
input: _receiveBuffer.DangerousGetUnderlyingBuffer(),
1382-
offset: _receiveBuffer.ActiveStartOffset + blockSize,
1376+
input: decryptBuffer,
1377+
offset: decryptOffset,
13831378
length: numberOfBytesToDecrypt,
1384-
output: _plaintextReceiveBuffer,
1385-
outputOffset: 4 + blockSize);
1379+
output: decryptBuffer,
1380+
outputOffset: decryptOffset);
13861381

13871382
Debug.Assert(numberOfBytesDecrypted == numberOfBytesToDecrypt);
13881383
}
1389-
else
1390-
{
1391-
_receiveBuffer.ActiveReadOnlySpan
1392-
.Slice(blockSize, numberOfBytesToDecrypt)
1393-
.CopyTo(_plaintextReceiveBuffer.AsSpan(4 + blockSize));
1394-
}
13951384

13961385
if (_serverMac != null && !_serverEtm)
13971386
{
13981387
// non-ETM mac = MAC(key, sequence_number || unencrypted_packet)
13991388

1400-
var clientHash = _serverMac.ComputeHash(_plaintextReceiveBuffer, 0, plaintextLength);
1389+
// sequence_number
1390+
_ = _serverMac.TransformBlock(
1391+
inputBuffer: _inboundPacketSequenceBytes,
1392+
inputOffset: 0,
1393+
inputCount: 4,
1394+
outputBuffer: null,
1395+
outputOffset: 0);
1396+
1397+
// unencrypted_packet
1398+
_ = _serverMac.TransformBlock(
1399+
inputBuffer: _receiveBuffer.DangerousGetUnderlyingBuffer(),
1400+
inputOffset: _receiveBuffer.ActiveStartOffset,
1401+
inputCount: totalPacketLength - serverMacLength,
1402+
outputBuffer: null,
1403+
outputOffset: 0);
1404+
1405+
_ = _serverMac.TransformFinalBlock(Array.Empty<byte>(), 0, 0);
14011406

1402-
if (!CryptoAbstraction.FixedTimeEquals(clientHash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength)))
1407+
if (!CryptoAbstraction.FixedTimeEquals(_serverMac.Hash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength)))
14031408
{
14041409
throw new SshConnectionException("MAC error", DisconnectReason.MacError);
14051410
}
14061411
}
14071412

1408-
_receiveBuffer.Discard(totalPacketLength);
1409-
1410-
var paddingLength = _plaintextReceiveBuffer[inboundPacketSequenceLength + packetLengthFieldLength];
1413+
var paddingLength = _receiveBuffer.ActiveReadOnlySpan[packetLengthFieldLength];
14111414

14121415
ArraySegment<byte> payload = new(
1413-
_plaintextReceiveBuffer,
1414-
offset: inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength,
1416+
_receiveBuffer.DangerousGetUnderlyingBuffer(),
1417+
offset: _receiveBuffer.ActiveStartOffset + packetLengthFieldLength + paddingLengthFieldLength,
14151418
count: packetLength - paddingLength - paddingLengthFieldLength);
14161419

14171420
if (_serverDecompression != null)
14181421
{
14191422
payload = new(_serverDecompression.Decompress(payload.Array, payload.Offset, payload.Count));
14201423
}
14211424

1422-
_inboundPacketSequence++;
1425+
var newInboundPacketSequence = ++InboundPacketSequence;
14231426

14241427
// The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
14251428
// It ensures the integrity of key exchange process.
1426-
if (_inboundPacketSequence == uint.MaxValue && _isInitialKex)
1429+
if (newInboundPacketSequence == uint.MaxValue && _isInitialKex)
14271430
{
14281431
throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed);
14291432
}
14301433

1431-
return LoadMessage(payload.Array, payload.Offset, payload.Count);
1434+
var message = LoadMessage(payload.Array, payload.Offset, payload.Count);
1435+
1436+
// The deserialised message may still reference data in the buffer, so calling Discard
1437+
// here might seem misguided. It is OK because Discard does not mutate the buffer
1438+
// and it will not be touched again until the next call to ReceiveMessage, which will
1439+
// only occur after the message has been fully processed.
1440+
_receiveBuffer.Discard(totalPacketLength);
1441+
1442+
return message;
14321443
}
14331444

14341445
private void TrySendDisconnect(DisconnectReason reasonCode, string message)
@@ -1545,7 +1556,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
15451556

15461557
_logger.LogDebug("[{SessionId}] Enabling strict key exchange extension.", SessionIdHex);
15471558

1548-
if (_inboundPacketSequence != 1)
1559+
if (InboundPacketSequence != 1)
15491560
{
15501561
throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed);
15511562
}
@@ -1646,7 +1657,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)
16461657

16471658
if (_isStrictKex)
16481659
{
1649-
_inboundPacketSequence = 0;
1660+
InboundPacketSequence = 0;
16501661
}
16511662

16521663
NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));

0 commit comments

Comments
 (0)