|
| 1 | +#nullable enable |
| 2 | + |
| 3 | +using System; |
| 4 | +using System.Collections.Concurrent; |
| 5 | +using System.Collections.Generic; |
| 6 | +using System.Net; |
| 7 | +using System.Threading; |
| 8 | +using Basis.Contrib.Crypto; |
| 9 | +using LiteNetLib.Layers; |
| 10 | + |
| 11 | +namespace Basis.Network.Core |
| 12 | +{ |
| 13 | + /// Per-endpoint AEAD encryption applied at the LiteNetLib socket boundary. |
| 14 | + /// Each connection has its own pair of ChaCha20-Poly1305 keys (one per |
| 15 | + /// direction) established by an X25519 handshake; see <see cref="BasisCryptoHandshake"/>. |
| 16 | + /// |
| 17 | + /// Only the user-data-bearing packet properties are encrypted (Unreliable, |
| 18 | + /// Channeled, Merged). Connection setup, NAT, MTU and out-of-band probe packets |
| 19 | + /// stay cleartext so the handshake itself never depends on a key being present. |
| 20 | + /// |
| 21 | + /// Wire layout of an encrypted datagram: |
| 22 | + /// [byte 0 : LiteNetLib header (cleartext, authenticated as AAD)] |
| 23 | + /// [bytes 1..n : ciphertext] |
| 24 | + /// [16 bytes : Poly1305 tag] |
| 25 | + /// [8 bytes : little-endian nonce counter] |
| 26 | + public sealed class BasisCryptoLayer : PacketLayerBase |
| 27 | + { |
| 28 | + public const int CounterSize = 8; |
| 29 | + public const int Overhead = BasisAeadCipher.TagSize + CounterSize; |
| 30 | + |
| 31 | + private const byte PropertyMask = 0x1F; |
| 32 | + // Mirrors LiteNetLib.PacketProperty: Unreliable = 0, Channeled = 1, Merged = 12. |
| 33 | + private const byte PropUnreliable = 0; |
| 34 | + private const byte PropChanneled = 1; |
| 35 | + private const byte PropMerged = 12; |
| 36 | + |
| 37 | + private sealed class Session |
| 38 | + { |
| 39 | + public BasisAeadCipher Send = null!; |
| 40 | + public BasisAeadCipher Recv = null!; |
| 41 | + public long SendCounter; |
| 42 | + } |
| 43 | + |
| 44 | + // Keyed by address+port only. NetPeer (seen outbound) and the plain IPEndPoint seen |
| 45 | + // inbound / at install hash differently under non-native sockets; this comparer makes |
| 46 | + // all three resolve to the same session without allocating per packet. |
| 47 | + private readonly ConcurrentDictionary<IPEndPoint, Session> _sessions |
| 48 | + = new ConcurrentDictionary<IPEndPoint, Session>(EndpointComparer.Instance); |
| 49 | + |
| 50 | + public BasisCryptoLayer() : base(Overhead) { } |
| 51 | + |
| 52 | + private sealed class EndpointComparer : IEqualityComparer<IPEndPoint> |
| 53 | + { |
| 54 | + public static readonly EndpointComparer Instance = new EndpointComparer(); |
| 55 | + |
| 56 | + public bool Equals(IPEndPoint x, IPEndPoint y) |
| 57 | + { |
| 58 | + if (ReferenceEquals(x, y)) return true; |
| 59 | + if (x is null || y is null) return false; |
| 60 | + return x.Port == y.Port && x.Address.Equals(y.Address); |
| 61 | + } |
| 62 | + |
| 63 | + public int GetHashCode(IPEndPoint ep) |
| 64 | + { |
| 65 | + if (ep is null) return 0; |
| 66 | + unchecked { return (ep.Address.GetHashCode() * 397) ^ ep.Port; } |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + public int SessionCount => _sessions.Count; |
| 71 | + |
| 72 | + /// <param name="initialSendCounter"> |
| 73 | + /// First nonce counter to use. Pass a value strictly greater than any counter |
| 74 | + /// previously used with these keys when re-installing the same keys for a |
| 75 | + /// reconnect, so a (key, nonce) pair is never reused. |
| 76 | + /// </param> |
| 77 | + public void SetEndpointKeys(IPEndPoint endpoint, byte[] sendKey, byte[] recvKey, long initialSendCounter = 0) |
| 78 | + { |
| 79 | + if (endpoint == null) return; |
| 80 | + var session = new Session |
| 81 | + { |
| 82 | + Send = new BasisAeadCipher(sendKey), |
| 83 | + Recv = new BasisAeadCipher(recvKey), |
| 84 | + SendCounter = initialSendCounter |
| 85 | + }; |
| 86 | + if (_sessions.TryRemove(endpoint, out var old)) DisposeSession(old); |
| 87 | + _sessions[endpoint] = session; |
| 88 | + } |
| 89 | + |
| 90 | + public bool HasEndpoint(IPEndPoint endpoint) => endpoint != null && _sessions.ContainsKey(endpoint); |
| 91 | + |
| 92 | + public void RemoveEndpoint(IPEndPoint endpoint) |
| 93 | + { |
| 94 | + if (endpoint != null && _sessions.TryRemove(endpoint, out var session)) DisposeSession(session); |
| 95 | + } |
| 96 | + |
| 97 | + public void RemapEndpoint(IPEndPoint oldEndpoint, IPEndPoint newEndpoint) |
| 98 | + { |
| 99 | + if (oldEndpoint == null || newEndpoint == null) return; |
| 100 | + if (_sessions.TryRemove(oldEndpoint, out var session)) _sessions[newEndpoint] = session; |
| 101 | + } |
| 102 | + |
| 103 | + public override void ProcessOutBoundPacket(ref IPEndPoint endPoint, ref byte[] data, ref int offset, ref int length) |
| 104 | + { |
| 105 | + if (length < 1) return; |
| 106 | + byte header = data[offset]; |
| 107 | + if (!IsEncryptable((byte)(header & PropertyMask))) return; |
| 108 | + if (endPoint == null || !_sessions.TryGetValue(endPoint, out var session)) return; |
| 109 | + |
| 110 | + long counter = Interlocked.Increment(ref session.SendCounter); |
| 111 | + Span<byte> nonce = stackalloc byte[BasisAeadCipher.NonceSize]; |
| 112 | + WriteCounter(nonce, counter); |
| 113 | + |
| 114 | + int tagOffset = offset + length; |
| 115 | + session.Send.Seal(nonce, header, data, offset + 1, length - 1, data, tagOffset); |
| 116 | + WriteCounterBytes(data, tagOffset + BasisAeadCipher.TagSize, counter); |
| 117 | + length += Overhead; |
| 118 | + } |
| 119 | + |
| 120 | + public override void ProcessInboundPacket(ref IPEndPoint endPoint, ref byte[] data, ref int length) |
| 121 | + { |
| 122 | + if (length < 1) return; |
| 123 | + byte header = data[0]; |
| 124 | + if (!IsEncryptable((byte)(header & PropertyMask))) return; |
| 125 | + if (endPoint == null || !_sessions.TryGetValue(endPoint, out var session)) return; |
| 126 | + |
| 127 | + if (length < 1 + Overhead) |
| 128 | + { |
| 129 | + length = 0; |
| 130 | + return; |
| 131 | + } |
| 132 | + |
| 133 | + int tagOffset = length - Overhead; |
| 134 | + int counterOffset = length - CounterSize; |
| 135 | + long counter = ReadCounterBytes(data, counterOffset); |
| 136 | + Span<byte> nonce = stackalloc byte[BasisAeadCipher.NonceSize]; |
| 137 | + WriteCounter(nonce, counter); |
| 138 | + |
| 139 | + int payloadLength = tagOffset - 1; |
| 140 | + if (!session.Recv.Open(nonce, header, data, 1, payloadLength, data, tagOffset)) |
| 141 | + { |
| 142 | + length = 0; |
| 143 | + return; |
| 144 | + } |
| 145 | + length -= Overhead; |
| 146 | + } |
| 147 | + |
| 148 | + private static bool IsEncryptable(byte property) |
| 149 | + => property == PropUnreliable || property == PropChanneled || property == PropMerged; |
| 150 | + |
| 151 | + private static void DisposeSession(Session session) |
| 152 | + { |
| 153 | + session.Send.Dispose(); |
| 154 | + session.Recv.Dispose(); |
| 155 | + } |
| 156 | + |
| 157 | + private static void WriteCounter(Span<byte> nonce, long counter) |
| 158 | + { |
| 159 | + nonce.Clear(); |
| 160 | + ulong c = (ulong)counter; |
| 161 | + nonce[0] = (byte)c; |
| 162 | + nonce[1] = (byte)(c >> 8); |
| 163 | + nonce[2] = (byte)(c >> 16); |
| 164 | + nonce[3] = (byte)(c >> 24); |
| 165 | + nonce[4] = (byte)(c >> 32); |
| 166 | + nonce[5] = (byte)(c >> 40); |
| 167 | + nonce[6] = (byte)(c >> 48); |
| 168 | + nonce[7] = (byte)(c >> 56); |
| 169 | + } |
| 170 | + |
| 171 | + private static void WriteCounterBytes(byte[] buffer, int offset, long counter) |
| 172 | + { |
| 173 | + ulong c = (ulong)counter; |
| 174 | + buffer[offset] = (byte)c; |
| 175 | + buffer[offset + 1] = (byte)(c >> 8); |
| 176 | + buffer[offset + 2] = (byte)(c >> 16); |
| 177 | + buffer[offset + 3] = (byte)(c >> 24); |
| 178 | + buffer[offset + 4] = (byte)(c >> 32); |
| 179 | + buffer[offset + 5] = (byte)(c >> 40); |
| 180 | + buffer[offset + 6] = (byte)(c >> 48); |
| 181 | + buffer[offset + 7] = (byte)(c >> 56); |
| 182 | + } |
| 183 | + |
| 184 | + private static long ReadCounterBytes(byte[] buffer, int offset) |
| 185 | + { |
| 186 | + ulong c = buffer[offset] |
| 187 | + | ((ulong)buffer[offset + 1] << 8) |
| 188 | + | ((ulong)buffer[offset + 2] << 16) |
| 189 | + | ((ulong)buffer[offset + 3] << 24) |
| 190 | + | ((ulong)buffer[offset + 4] << 32) |
| 191 | + | ((ulong)buffer[offset + 5] << 40) |
| 192 | + | ((ulong)buffer[offset + 6] << 48) |
| 193 | + | ((ulong)buffer[offset + 7] << 56); |
| 194 | + return (long)c; |
| 195 | + } |
| 196 | + } |
| 197 | +} |
0 commit comments