|
1 | 1 | #nullable enable |
| 2 | +using System; |
| 3 | +using System.Diagnostics; |
| 4 | +using System.Globalization; |
2 | 5 | using System.IO; |
3 | 6 |
|
4 | 7 | using Renci.SshNet.Abstractions; |
@@ -38,116 +41,85 @@ protected override void WriteBytes(SshDataStream stream) |
38 | 41 | base.WriteBytes(stream); |
39 | 42 | } |
40 | 43 |
|
41 | | - /// <returns>[4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes].</returns> |
42 | | - internal byte[] GetPacket(byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding = false, int macLength = 0) |
| 44 | + /// <returns>The number of bytes occupied by the packet in <paramref name="buffer"/>.</returns> |
| 45 | + /// <remarks> |
| 46 | + /// [4 bytes] || packet_len || padding_len || payload || padding || [macLength bytes]. |
| 47 | + /// </remarks> |
| 48 | + internal int GetPacket(ref byte[] buffer, byte paddingMultiplier, Compressor? compressor, bool excludePacketLengthFieldWhenPadding, int macLength) |
43 | 49 | { |
44 | 50 | const int outboundPacketSequenceSize = 4; |
45 | 51 |
|
46 | 52 | var messageLength = BufferCapacity; |
47 | 53 |
|
| 54 | + ArraySegment<byte> payload = default; |
| 55 | + |
48 | 56 | if (messageLength == -1 || compressor != null) |
49 | 57 | { |
50 | | - using (var sshDataStream = new SshDataStream(DefaultCapacity)) |
| 58 | + using (var sshDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity)) |
51 | 59 | { |
52 | | - // skip: |
53 | | - // * 4 bytes for the outbound packet sequence |
54 | | - // * 4 bytes for the packet data length |
55 | | - // * one byte for the packet padding length |
56 | | - _ = sshDataStream.Seek(outboundPacketSequenceSize + 4 + 1, SeekOrigin.Begin); |
57 | | - |
58 | | - if (compressor != null) |
59 | | - { |
60 | | - // obtain uncompressed message payload |
61 | | - using (var uncompressedDataStream = new SshDataStream(messageLength != -1 ? messageLength : DefaultCapacity)) |
62 | | - { |
63 | | - WriteBytes(uncompressedDataStream); |
64 | | - |
65 | | - // compress message payload |
66 | | - var compressedMessageData = compressor.Compress(uncompressedDataStream.ToArray()); |
67 | | - |
68 | | - // add compressed message payload |
69 | | - sshDataStream.Write(compressedMessageData, 0, compressedMessageData.Length); |
70 | | - } |
71 | | - } |
72 | | - else |
73 | | - { |
74 | | - // add message payload |
75 | | - WriteBytes(sshDataStream); |
76 | | - } |
77 | | - |
78 | | - messageLength = (int)sshDataStream.Length - (outboundPacketSequenceSize + 4 + 1); |
79 | | - |
80 | | - var packetLength = messageLength + 4 + 1; |
81 | | - |
82 | | - // determine the padding length |
83 | | - // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the |
84 | | - // padding length calculation |
85 | | - var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength); |
| 60 | + WriteBytes(sshDataStream); |
86 | 61 |
|
87 | | - var packetDataLength = GetPacketDataLength(messageLength, paddingLength); |
| 62 | + var success = sshDataStream.TryGetBuffer(out payload); |
88 | 63 |
|
89 | | - // skip bytes for outbound packet sequence |
90 | | - _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin); |
| 64 | + Debug.Assert(success); |
| 65 | + } |
91 | 66 |
|
92 | | - // add packet data length |
93 | | - sshDataStream.Write(packetDataLength); |
| 67 | + if (compressor != null) |
| 68 | + { |
| 69 | + payload = new(compressor.Compress(payload.Array, payload.Offset, payload.Count)); |
| 70 | + } |
94 | 71 |
|
95 | | - // add packet padding length |
96 | | - sshDataStream.WriteByte(paddingLength); |
| 72 | + messageLength = payload.Count; |
| 73 | + } |
97 | 74 |
|
98 | | - _ = sshDataStream.Seek(0, SeekOrigin.End); |
| 75 | + // determine the padding length |
| 76 | + // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the |
| 77 | + // padding length calculation |
| 78 | + var paddingLength = GetPaddingLength( |
| 79 | + paddingMultiplier, (excludePacketLengthFieldWhenPadding ? 0 : 4) + 1 + messageLength); |
99 | 80 |
|
100 | | - sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength); |
| 81 | + var packetLength = 1 + messageLength + paddingLength; |
101 | 82 |
|
102 | | - var buffer = sshDataStream.ToArray(); |
| 83 | + var bytesRequired = 4 + 4 + packetLength + macLength; |
103 | 84 |
|
104 | | - // add padding bytes |
105 | | - CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength); |
| 85 | + if ((uint)bytesRequired > (uint)Session.MaximumSshPacketSize) |
| 86 | + { |
| 87 | + throw new InvalidOperationException(string.Format(CultureInfo.CurrentCulture, "Packet is too big. Maximum packet size is {0} bytes.", Session.MaximumSshPacketSize)); |
| 88 | + } |
106 | 89 |
|
107 | | - return buffer; |
108 | | - } |
| 90 | + if (buffer.Length < bytesRequired) |
| 91 | + { |
| 92 | + Array.Resize(ref buffer, Math.Max(bytesRequired, 2 * buffer.Length)); |
109 | 93 | } |
110 | | - else |
| 94 | + |
| 95 | + using (var sshDataStream = new SshDataStream(buffer)) |
111 | 96 | { |
112 | | - var packetLength = messageLength + 4 + 1; |
| 97 | + // skip bytes for outbound packet sequenceSize |
| 98 | + _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin); |
113 | 99 |
|
114 | | - // determine the padding length |
115 | | - // in Encrypt-then-MAC mode or AEAD, the length field is not encrypted, so we should keep it out of the |
116 | | - // padding length calculation |
117 | | - var paddingLength = GetPaddingLength(paddingMultiplier, excludePacketLengthFieldWhenPadding ? packetLength - 4 : packetLength); |
| 100 | + // add packet length |
| 101 | + sshDataStream.Write((uint)packetLength); |
118 | 102 |
|
119 | | - var packetDataLength = GetPacketDataLength(messageLength, paddingLength); |
| 103 | + // add padding length |
| 104 | + sshDataStream.WriteByte(paddingLength); |
120 | 105 |
|
121 | | - // lets construct an SSH data stream of the exact size required |
122 | | - using (var sshDataStream = new SshDataStream(packetLength + paddingLength + outboundPacketSequenceSize + macLength)) |
| 106 | + // add message payload |
| 107 | + if (payload != default) |
| 108 | + { |
| 109 | + sshDataStream.Write(payload.Array!, payload.Offset, payload.Count); |
| 110 | + } |
| 111 | + else |
123 | 112 | { |
124 | | - // skip bytes for outbound packet sequenceSize |
125 | | - _ = sshDataStream.Seek(outboundPacketSequenceSize, SeekOrigin.Begin); |
126 | | - |
127 | | - // add packet data length |
128 | | - sshDataStream.Write(packetDataLength); |
129 | | - |
130 | | - // add packet padding length |
131 | | - sshDataStream.WriteByte(paddingLength); |
132 | | - |
133 | | - // add message payload |
134 | 113 | WriteBytes(sshDataStream); |
| 114 | + } |
135 | 115 |
|
136 | | - sshDataStream.SetLength(sshDataStream.Length + paddingLength + macLength); |
137 | | - |
138 | | - var buffer = sshDataStream.ToArray(); |
139 | | - |
140 | | - // add padding bytes |
141 | | - CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength); |
| 116 | + Debug.Assert(sshDataStream.Position == bytesRequired - macLength - paddingLength); |
142 | 117 |
|
143 | | - return buffer; |
144 | | - } |
| 118 | + // add padding bytes |
| 119 | + CryptoAbstraction.Randomizer.GetBytes(buffer, (int)sshDataStream.Position, paddingLength); |
145 | 120 | } |
146 | | - } |
147 | 121 |
|
148 | | - private static uint GetPacketDataLength(int messageLength, byte paddingLength) |
149 | | - { |
150 | | - return (uint)(messageLength + paddingLength + 1); |
| 122 | + return bytesRequired; |
151 | 123 | } |
152 | 124 |
|
153 | 125 | private static byte GetPaddingLength(byte paddingMultiplier, long packetLength) |
|
0 commit comments