Skip to content

Commit 5b8382d

Browse files
authored
Serialise packets into a buffer (#1792)
A byte array is allocated to hold each plaintext packet. This removes that by adding a buffer for that purpose.
1 parent b6217cb commit 5b8382d

10 files changed

Lines changed: 103 additions & 102 deletions

src/Renci.SshNet/Messages/Message.cs

Lines changed: 54 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#nullable enable
2+
using System;
3+
using System.Diagnostics;
4+
using System.Globalization;
25
using System.IO;
36

47
using Renci.SshNet.Abstractions;
@@ -38,116 +41,85 @@ protected override void WriteBytes(SshDataStream stream)
3841
base.WriteBytes(stream);
3942
}
4043

41-
/// <returns>[4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes].</returns>
42-
internal byte[] GetPacket(byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding = false, int macLength = 0)
44+
/// <returns>The number of bytes occupied by the packet in <paramref name="buffer"/>.</returns>
45+
/// <remarks>
46+
/// [4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes].
47+
/// </remarks>
48+
internal int GetPacket(ref byte[] buffer, byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding, int macLength)
4349
{
4450
const int outboundPacketSequenceSize = 4;
4551

4652
var messageLength = BufferCapacity;
4753

54+
ArraySegment<byte> payload = default;
55+
4856
if (messageLength == -1 || compressor != null)
4957
{
50-
using (var sshDataStream = new SshDataStream(DefaultCapacity))
58+
using (var sshDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity))
5159
{
52-
// skip:
53-
// * 4 bytes for the outbound packet sequence
54-
// * 4 bytes for the packet data length
55-
// * one byte for the packet padding length
56-
_ = sshDataStream.Seek(outboundPacketSequenceSize + 4 + 1, SeekOrigin.Begin);
57-
58-
if (compressor != null)
59-
{
60-
// obtain uncompressed message payload
61-
using (var uncompressedDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity))
62-
{
63-
WriteBytes(uncompressedDataStream);
64-
65-
// compress message payload
66-
var compressedMessageData = compressor.Compress(uncompressedDataStream.ToArray());
67-
68-
// add compressed message payload
69-
sshDataStream.Write(compressedMessageData, 0, compressedMessageData.Length);
70-
}
71-
}
72-
else
73-
{
74-
// add message payload
75-
WriteBytes(sshDataStream);
76-
}
77-
78-
messageLength = (int)sshDataStream.Length - (outboundPacketSequenceSize + 4 + 1);
79-
80-
var packetLength = messageLength + 4 + 1;
81-
82-
// determine the padding length
83-
// in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
84-
// padding length calculation
85-
var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength);
60+
WriteBytes(sshDataStream);
8661

87-
var packetDataLength = GetPacketDataLength(messageLength, paddingLength);
62+
var success = sshDataStream.TryGetBuffer(out payload);
8863

89-
// skip bytes for outbound packet sequence
90-
_ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);
64+
Debug.Assert(success);
65+
}
9166

92-
// add packet data length
93-
sshDataStream.Write(packetDataLength);
67+
if (compressor != null)
68+
{
69+
payload = new(compressor.Compress(payload.Array, payload.Offset, payload.Count));
70+
}
9471

95-
// add packet padding length
96-
sshDataStream.WriteByte(paddingLength);
72+
messageLength = payload.Count;
73+
}
9774

98-
_ = sshDataStream.Seek(0, SeekOrigin.End);
75+
// determine the padding length
76+
// in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
77+
// padding length calculation
78+
var paddingLength = GetPaddingLength(
79+
paddingMultiplier, (excludePacketLengthFieldWhenPadding ? 0 : 4) + 1 + messageLength);
9980

100-
sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength);
81+
var packetLength = 1 + messageLength + paddingLength;
10182

102-
var buffer = sshDataStream.ToArray();
83+
var bytesRequired = 4 + 4 + packetLength + macLength;
10384

104-
// add padding bytes
105-
CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
85+
if ((uint)bytesRequired > (uint)Session.MaximumSshPacketSize)
86+
{
87+
throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", Session.MaximumSshPacketSize));
88+
}
10689

107-
return buffer;
108-
}
90+
if (buffer.Length < bytesRequired)
91+
{
92+
Array.Resize(ref buffer, Math.Max(bytesRequired, 2 * buffer.Length));
10993
}
110-
else
94+
95+
using (var sshDataStream = new SshDataStream(buffer))
11196
{
112-
var packetLength = messageLength + 4 + 1;
97+
// skip bytes for outbound packet sequenceSize
98+
_ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);
11399

114-
// determine the padding length
115-
// in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the
116-
// padding length calculation
117-
var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength);
100+
// add packet length
101+
sshDataStream.Write((uint)packetLength);
118102

119-
var packetDataLength = GetPacketDataLength(messageLength, paddingLength);
103+
// add padding length
104+
sshDataStream.WriteByte(paddingLength);
120105

121-
// lets construct an SSH data stream of the exact size required
122-
using (var sshDataStream = new SshDataStream(packetLength + paddingLength + outboundPacketSequenceSize + macLength))
106+
// add message payload
107+
if (payload != default)
108+
{
109+
sshDataStream.Write(payload.Array!, payload.Offset, payload.Count);
110+
}
111+
else
123112
{
124-
// skip bytes for outbound packet sequenceSize
125-
_ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin);
126-
127-
// add packet data length
128-
sshDataStream.Write(packetDataLength);
129-
130-
// add packet padding length
131-
sshDataStream.WriteByte(paddingLength);
132-
133-
// add message payload
134113
WriteBytes(sshDataStream);
114+
}
135115

136-
sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength);
137-
138-
var buffer = sshDataStream.ToArray();
139-
140-
// add padding bytes
141-
CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
116+
Debug.Assert(sshDataStream.Position == bytesRequired - macLength - paddingLength);
142117

143-
return buffer;
144-
}
118+
// add padding bytes
119+
CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength);
145120
}
146-
}
147121

148-
private static uint GetPacketDataLength(int messageLength, byte paddingLength)
149-
{
150-
return (uint)(messageLength + paddingLength + 1);
122+
return bytesRequired;
151123
}
152124

153125
private static byte GetPaddingLength(byte paddingMultiplier, long packetLength)

src/Renci.SshNet/Session.cs

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ private uint InboundPacketSequence
212212
private Socket _socket;
213213

214214
private ArrayBuffer _receiveBuffer = new(4 * 1024);
215+
private byte[] _sendBuffer = new byte[4 * 1024];
215216

216217
/// <summary>
217218
/// Gets the session semaphore that controls session channels.
@@ -1073,29 +1074,29 @@ internal void SendMessage(Message message)
10731074
macLength = _clientMac.HashSize / 8;
10741075
}
10751076

1076-
var packetData = message.GetPacket(paddingMultiplier, _clientCompression, _clientEtm || _clientAead, macLength);
1077-
1078-
if (packetData.Length > MaximumSshPacketSize)
1079-
{
1080-
throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", MaximumSshPacketSize));
1081-
}
1082-
10831077
// take a write lock to ensure the outbound packet sequence number is incremented
10841078
// atomically, and only after the packet has actually been sent
10851079
lock (_socketWriteLock)
10861080
{
1081+
var activeBufferLength = message.GetPacket(
1082+
ref _sendBuffer,
1083+
paddingMultiplier,
1084+
_clientCompression,
1085+
_clientEtm || _clientAead,
1086+
macLength);
1087+
10871088
// write outbound packet sequence to start of packet data
1088-
BinaryPrimitives.WriteUInt32BigEndian(packetData, _outboundPacketSequence);
1089+
BinaryPrimitives.WriteUInt32BigEndian(_sendBuffer, _outboundPacketSequence);
10891090

10901091
if (_clientMac != null && !_clientEtm)
10911092
{
10921093
// non-ETM mac = MAC(key, sequence_number || unencrypted_packet)
10931094

10941095
var hashSuccess = _clientMac.TryComputeHash(
1095-
buffer: packetData,
1096+
buffer: _sendBuffer,
10961097
offset: 0,
1097-
count: packetData.Length - macLength,
1098-
destination: packetData.AsSpan(packetData.Length - macLength),
1098+
count: activeBufferLength - macLength,
1099+
destination: _sendBuffer.AsSpan(activeBufferLength - macLength),
10991100
bytesWritten: out var bytesWritten);
11001101

11011102
Debug.Assert(hashSuccess && bytesWritten == macLength);
@@ -1110,30 +1111,30 @@ internal void SendMessage(Message message)
11101111
var offset = _clientEtm ? 8 : 4;
11111112

11121113
var numberOfBytesEncrypted = _clientCipher.Encrypt(
1113-
input: packetData,
1114+
input: _sendBuffer,
11141115
offset,
1115-
length: packetData.Length - offset - macLength,
1116-
output: packetData,
1116+
length: activeBufferLength - offset - macLength,
1117+
output: _sendBuffer,
11171118
outputOffset: offset);
11181119

1119-
Debug.Assert(numberOfBytesEncrypted == packetData.Length - offset - macLength + (_clientAead ? macLength : 0));
1120+
Debug.Assert(numberOfBytesEncrypted == activeBufferLength - offset - macLength + (_clientAead ? macLength : 0));
11201121
}
11211122

11221123
if (_clientMac != null && _clientEtm)
11231124
{
11241125
// ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet)
11251126

11261127
var hashSuccess = _clientMac.TryComputeHash(
1127-
buffer: packetData,
1128+
buffer: _sendBuffer,
11281129
offset: 0,
1129-
count: packetData.Length - macLength,
1130-
destination: packetData.AsSpan(packetData.Length - macLength),
1130+
count: activeBufferLength - macLength,
1131+
destination: _sendBuffer.AsSpan(activeBufferLength - macLength),
11311132
bytesWritten: out var bytesWritten);
11321133

11331134
Debug.Assert(hashSuccess && bytesWritten == macLength);
11341135
}
11351136

1136-
SendPacket(packetData, 4, packetData.Length - 4);
1137+
SendPacket(_sendBuffer, 4, activeBufferLength - 4);
11371138

11381139
if (_isStrictKex && message is NewKeysMessage)
11391140
{

test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
using Renci.SshNet.Messages.Connection;
1313
using Renci.SshNet.Messages.Transport;
14+
using Renci.SshNet.Tests.Common;
1415

1516
namespace Renci.SshNet.Tests.Classes
1617
{

test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_StrictKex.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
using Renci.SshNet.Common;
88
using Renci.SshNet.Messages.Transport;
9+
using Renci.SshNet.Tests.Common;
910

1011
namespace Renci.SshNet.Tests.Classes
1112
{

test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsDisconnectMessageAfterKexInit_StrictKex.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using Renci.SshNet.Common;
66
using Renci.SshNet.Messages.Transport;
7+
using Renci.SshNet.Tests.Common;
78

89
namespace Renci.SshNet.Tests.Classes
910
{

test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_NoStrictKex.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Microsoft.VisualStudio.TestTools.UnitTesting;
44

55
using Renci.SshNet.Messages.Transport;
6+
using Renci.SshNet.Tests.Common;
67

78
namespace Renci.SshNet.Tests.Classes
89
{

test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageAfterKexInit_StrictKex.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using Renci.SshNet.Common;
66
using Renci.SshNet.Messages.Transport;
7+
using Renci.SshNet.Tests.Common;
78

89
namespace Renci.SshNet.Tests.Classes
910
{

test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_NoStrictKex.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Microsoft.VisualStudio.TestTools.UnitTesting;
44

55
using Renci.SshNet.Messages.Transport;
6+
using Renci.SshNet.Tests.Common;
67

78
namespace Renci.SshNet.Tests.Classes
89
{

test/Renci.SshNet.Tests/Classes/SessionTest_Connecting_ServerSendsIgnoreMessageBeforeKexInit_StrictKex.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using Renci.SshNet.Common;
66
using Renci.SshNet.Messages.Transport;
7+
using Renci.SshNet.Tests.Common;
78

89
namespace Renci.SshNet.Tests.Classes
910
{

test/Renci.SshNet.Tests/Common/Extensions.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
using System.Collections.Generic;
1+
#nullable enable
2+
using System;
3+
using System.Collections.Generic;
24

35
using Renci.SshNet.Common;
6+
using Renci.SshNet.Compression;
7+
using Renci.SshNet.Messages;
48

59
namespace Renci.SshNet.Tests.Common
610
{
@@ -21,5 +25,22 @@ public static string AsString(this IList<ExceptionEventArgs> exceptionEvents)
2125

2226
return reportedExceptions;
2327
}
28+
29+
/// <returns>[4 bytes] || packet_len || padding_len || payload || padding.</returns>
30+
public static byte[] GetPacket(this Message message, byte paddingMultiplier, Compressor? compressor)
31+
{
32+
var buffer = Array.Empty<byte>();
33+
34+
var byteCount = message.GetPacket(
35+
ref buffer,
36+
paddingMultiplier,
37+
compressor,
38+
excludePacketLengthFieldWhenPadding: false,
39+
macLength: 0);
40+
41+
Array.Resize(ref buffer, byteCount);
42+
43+
return buffer;
44+
}
2445
}
2546
}

0 commit comments

Comments
 (0)