|
| 1 | +using Lagrange.Proto.Generator.Entity; |
| 2 | +using Lagrange.Proto.Generator.Utility; |
| 3 | +using Lagrange.Proto.Generator.Utility.Extension; |
| 4 | +using Lagrange.Proto.Serialization; |
| 5 | +using Microsoft.CodeAnalysis; |
| 6 | + |
| 7 | +namespace Lagrange.Proto.Generator; |
| 8 | + |
| 9 | +public partial class ProtoSourceGenerator |
| 10 | +{ |
| 11 | + private partial class Emitter |
| 12 | + { |
| 13 | + private const string ProtoReaderTypeRef = "global::Lagrange.Proto.Primitives.ProtoReader"; |
| 14 | + private const string ProtoSerializerTypeRef = "global::Lagrange.Proto.Serialization.ProtoSerializer"; |
| 15 | + |
| 16 | + private const string ReaderVarName = "reader"; |
| 17 | + |
| 18 | + private const string DecodeVarIntMethodName = "DecodeVarInt"; |
| 19 | + private const string DecodeVarIntUnsafeMethodName = "DecodeVarIntUnsafe"; |
| 20 | + private const string DecodeFixed32MethodName = "DecodeFixed32"; |
| 21 | + private const string DecodeFixed64MethodName = "DecodeFixed64"; |
| 22 | + private const string CreateSpanMethodName = "CreateSpan"; |
| 23 | + private const string SkipFieldMethodName = "SkipField"; |
| 24 | + |
| 25 | + private const string ZigZagDecodeMethodRef = $"{ProtoHelperTypeRef}.ZigZagDecode"; |
| 26 | + |
| 27 | + private void EmitDeserializeMethod(SourceWriter source) |
| 28 | + { |
| 29 | + source.WriteLine($"public static void DeserializeHandler({_fullQualifiedName} {ObjectVarName}, ref {ProtoReaderTypeRef} {ReaderVarName})"); |
| 30 | + source.WriteLine("{"); |
| 31 | + source.Indentation++; |
| 32 | + |
| 33 | + source.WriteLine($"while (!{ReaderVarName}.IsCompleted)"); |
| 34 | + source.WriteLine("{"); |
| 35 | + source.Indentation++; |
| 36 | + |
| 37 | + source.WriteLine($"uint tag = {ReaderVarName}.{DecodeVarIntUnsafeMethodName}<uint>();"); |
| 38 | + source.WriteLine("switch (tag)"); |
| 39 | + source.WriteLine("{"); |
| 40 | + source.Indentation++; |
| 41 | + |
| 42 | + foreach (var kv in parser.Fields) |
| 43 | + { |
| 44 | + int field = kv.Key; |
| 45 | + var info = kv.Value; |
| 46 | + |
| 47 | + EmitDeserializeCase(source, field, info); |
| 48 | + } |
| 49 | + |
| 50 | + // Default case for unknown fields |
| 51 | + source.WriteLine("default:"); |
| 52 | + source.Indentation++; |
| 53 | + source.WriteLine($"{ReaderVarName}.{SkipFieldMethodName}(({WireTypeTypeRef})(tag & 0x07));"); |
| 54 | + source.WriteLine("break;"); |
| 55 | + source.Indentation--; |
| 56 | + |
| 57 | + source.Indentation--; |
| 58 | + source.WriteLine("}"); // end switch |
| 59 | + |
| 60 | + source.Indentation--; |
| 61 | + source.WriteLine("}"); // end while |
| 62 | + |
| 63 | + source.Indentation--; |
| 64 | + source.WriteLine("}"); // end method |
| 65 | + } |
| 66 | + |
| 67 | + private void EmitDeserializeCase(SourceWriter source, int field, ProtoFieldInfo info) |
| 68 | + { |
| 69 | + uint tag = (uint)field << 3 | (byte)info.WireType; |
| 70 | + source.WriteLine($"case {tag}:"); |
| 71 | + source.Indentation++; |
| 72 | + |
| 73 | + EmitDeserializeMember(source, field, info); |
| 74 | + |
| 75 | + source.WriteLine("break;"); |
| 76 | + source.Indentation--; |
| 77 | + } |
| 78 | + |
| 79 | + private void EmitDeserializeMember(SourceWriter source, int field, ProtoFieldInfo info) |
| 80 | + { |
| 81 | + // For Map types and Repeated types, fall back to TypeInfo.Fields[tag].Read() |
| 82 | + if (SymbolResolver.IsMapType(info.TypeSymbol, out _, out _) || |
| 83 | + SymbolResolver.IsRepeatedType(info.TypeSymbol, out _) || |
| 84 | + SymbolResolver.IsNodesType(info.TypeSymbol)) |
| 85 | + { |
| 86 | + uint tag = (uint)field << 3 | (byte)info.WireType; |
| 87 | + source.WriteLine($"{TypeInfoPropertyName}.Fields[{tag}].Read(ref {ReaderVarName}, {ObjectVarName});"); |
| 88 | + return; |
| 89 | + } |
| 90 | + |
| 91 | + string memberName = $"{ObjectVarName}.{info.Symbol.Name}"; |
| 92 | + var typeSymbol = info.TypeSymbol; |
| 93 | + |
| 94 | + // Handle nullable types - unwrap to underlying type |
| 95 | + bool isNullable = typeSymbol.IsValueType && typeSymbol.IsNullable(); |
| 96 | + if (isNullable && SymbolResolver.IsNullableType(typeSymbol, out var underlyingType)) |
| 97 | + { |
| 98 | + typeSymbol = underlyingType; |
| 99 | + } |
| 100 | + |
| 101 | + // Handle based on wire type and type |
| 102 | + switch (info.WireType) |
| 103 | + { |
| 104 | + case WireType.VarInt: |
| 105 | + EmitVarIntRead(source, memberName, typeSymbol, info.IsSigned); |
| 106 | + break; |
| 107 | + |
| 108 | + case WireType.Fixed32: |
| 109 | + EmitFixed32Read(source, memberName, typeSymbol, info.IsSigned); |
| 110 | + break; |
| 111 | + |
| 112 | + case WireType.Fixed64: |
| 113 | + EmitFixed64Read(source, memberName, typeSymbol, info.IsSigned); |
| 114 | + break; |
| 115 | + |
| 116 | + case WireType.LengthDelimited: |
| 117 | + EmitLengthDelimitedRead(source, memberName, typeSymbol, info, field); |
| 118 | + break; |
| 119 | + |
| 120 | + default: |
| 121 | + // Fall back to TypeInfo for unknown wire types |
| 122 | + uint tag = (uint)field << 3 | (byte)info.WireType; |
| 123 | + source.WriteLine($"{TypeInfoPropertyName}.Fields[{tag}].Read(ref {ReaderVarName}, {ObjectVarName});"); |
| 124 | + break; |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + private void EmitVarIntRead(SourceWriter source, string memberName, ITypeSymbol typeSymbol, bool isSigned) |
| 129 | + { |
| 130 | + // Handle bool specially |
| 131 | + if (typeSymbol.SpecialType == SpecialType.System_Boolean) |
| 132 | + { |
| 133 | + source.WriteLine("{"); |
| 134 | + source.Indentation++; |
| 135 | + source.WriteLine($"byte __b = {ReaderVarName}.{DecodeVarIntMethodName}<byte>();"); |
| 136 | + source.WriteLine($"{memberName} = global::System.Runtime.CompilerServices.Unsafe.As<byte, bool>(ref __b);"); |
| 137 | + source.Indentation--; |
| 138 | + source.WriteLine("}"); |
| 139 | + return; |
| 140 | + } |
| 141 | + |
| 142 | + // Handle enums |
| 143 | + if (typeSymbol.TypeKind == TypeKind.Enum) |
| 144 | + { |
| 145 | + var underlyingType = ((INamedTypeSymbol)typeSymbol).EnumUnderlyingType; |
| 146 | + string enumTypeName = typeSymbol.GetFullName(); |
| 147 | + string underlyingTypeName = GetDecodeTypeName(underlyingType!); |
| 148 | + source.WriteLine($"{memberName} = ({enumTypeName}){ReaderVarName}.{DecodeVarIntMethodName}<{underlyingTypeName}>();"); |
| 149 | + return; |
| 150 | + } |
| 151 | + |
| 152 | + // Handle signed integers with ZigZag decoding |
| 153 | + if (isSigned && typeSymbol.IsIntegerType()) |
| 154 | + { |
| 155 | + string unsignedType = GetUnsignedTypeName(typeSymbol); |
| 156 | + string signedType = GetDecodeTypeName(typeSymbol); |
| 157 | + source.WriteLine($"{memberName} = {ZigZagDecodeMethodRef}(({signedType}){ReaderVarName}.{DecodeVarIntMethodName}<{unsignedType}>());"); |
| 158 | + return; |
| 159 | + } |
| 160 | + |
| 161 | + // Regular integer types |
| 162 | + string typeName = GetDecodeTypeName(typeSymbol); |
| 163 | + source.WriteLine($"{memberName} = {ReaderVarName}.{DecodeVarIntMethodName}<{typeName}>();"); |
| 164 | + } |
| 165 | + |
| 166 | + private void EmitFixed32Read(SourceWriter source, string memberName, ITypeSymbol typeSymbol, bool isSigned) |
| 167 | + { |
| 168 | + if (isSigned && typeSymbol.IsIntegerType()) |
| 169 | + { |
| 170 | + string unsignedType = GetUnsignedTypeName(typeSymbol); |
| 171 | + string signedType = GetDecodeTypeName(typeSymbol); |
| 172 | + source.WriteLine($"{memberName} = {ZigZagDecodeMethodRef}(({signedType}){ReaderVarName}.{DecodeFixed32MethodName}<{unsignedType}>());"); |
| 173 | + return; |
| 174 | + } |
| 175 | + |
| 176 | + string typeName = GetDecodeTypeName(typeSymbol); |
| 177 | + source.WriteLine($"{memberName} = {ReaderVarName}.{DecodeFixed32MethodName}<{typeName}>();"); |
| 178 | + } |
| 179 | + |
| 180 | + private void EmitFixed64Read(SourceWriter source, string memberName, ITypeSymbol typeSymbol, bool isSigned) |
| 181 | + { |
| 182 | + if (isSigned && typeSymbol.IsIntegerType()) |
| 183 | + { |
| 184 | + string unsignedType = GetUnsignedTypeName(typeSymbol); |
| 185 | + string signedType = GetDecodeTypeName(typeSymbol); |
| 186 | + source.WriteLine($"{memberName} = {ZigZagDecodeMethodRef}(({signedType}){ReaderVarName}.{DecodeFixed64MethodName}<{unsignedType}>());"); |
| 187 | + return; |
| 188 | + } |
| 189 | + |
| 190 | + string typeName = GetDecodeTypeName(typeSymbol); |
| 191 | + source.WriteLine($"{memberName} = {ReaderVarName}.{DecodeFixed64MethodName}<{typeName}>();"); |
| 192 | + } |
| 193 | + |
| 194 | + private void EmitLengthDelimitedRead(SourceWriter source, string memberName, ITypeSymbol typeSymbol, ProtoFieldInfo info, int field) |
| 195 | + { |
| 196 | + // String |
| 197 | + if (typeSymbol.SpecialType == SpecialType.System_String) |
| 198 | + { |
| 199 | + source.WriteLine("{"); |
| 200 | + source.Indentation++; |
| 201 | + source.WriteLine($"int __len = {ReaderVarName}.{DecodeVarIntMethodName}<int>();"); |
| 202 | + source.WriteLine($"var __span = {ReaderVarName}.{CreateSpanMethodName}(__len);"); |
| 203 | + source.WriteLine($"{memberName} = __span.IsEmpty ? string.Empty : global::System.Text.Encoding.UTF8.GetString(__span);"); |
| 204 | + source.Indentation--; |
| 205 | + source.WriteLine("}"); |
| 206 | + return; |
| 207 | + } |
| 208 | + |
| 209 | + // byte[] |
| 210 | + if (typeSymbol is IArrayTypeSymbol { ElementType.SpecialType: SpecialType.System_Byte }) |
| 211 | + { |
| 212 | + source.WriteLine("{"); |
| 213 | + source.Indentation++; |
| 214 | + source.WriteLine($"int __len = {ReaderVarName}.{DecodeVarIntMethodName}<int>();"); |
| 215 | + source.WriteLine("if (__len == 0)"); |
| 216 | + source.WriteLine("{"); |
| 217 | + source.Indentation++; |
| 218 | + source.WriteLine($"{memberName} = global::System.Array.Empty<byte>();"); |
| 219 | + source.Indentation--; |
| 220 | + source.WriteLine("}"); |
| 221 | + source.WriteLine("else"); |
| 222 | + source.WriteLine("{"); |
| 223 | + source.Indentation++; |
| 224 | + source.WriteLine("var __buffer = global::System.GC.AllocateUninitializedArray<byte>(__len);"); |
| 225 | + source.WriteLine($"var __span = {ReaderVarName}.{CreateSpanMethodName}(__len);"); |
| 226 | + source.WriteLine("__span.CopyTo(__buffer);"); |
| 227 | + source.WriteLine($"{memberName} = __buffer;"); |
| 228 | + source.Indentation--; |
| 229 | + source.WriteLine("}"); |
| 230 | + source.Indentation--; |
| 231 | + source.WriteLine("}"); |
| 232 | + return; |
| 233 | + } |
| 234 | + |
| 235 | + // Memory<byte> |
| 236 | + if (typeSymbol is INamedTypeSymbol { Name: "Memory", IsGenericType: true } memoryType && |
| 237 | + memoryType.TypeArguments[0].SpecialType == SpecialType.System_Byte) |
| 238 | + { |
| 239 | + source.WriteLine("{"); |
| 240 | + source.Indentation++; |
| 241 | + source.WriteLine($"int __len = {ReaderVarName}.{DecodeVarIntMethodName}<int>();"); |
| 242 | + source.WriteLine("if (__len == 0)"); |
| 243 | + source.WriteLine("{"); |
| 244 | + source.Indentation++; |
| 245 | + source.WriteLine($"{memberName} = global::System.Memory<byte>.Empty;"); |
| 246 | + source.Indentation--; |
| 247 | + source.WriteLine("}"); |
| 248 | + source.WriteLine("else"); |
| 249 | + source.WriteLine("{"); |
| 250 | + source.Indentation++; |
| 251 | + source.WriteLine("var __buffer = global::System.GC.AllocateUninitializedArray<byte>(__len);"); |
| 252 | + source.WriteLine($"var __span = {ReaderVarName}.{CreateSpanMethodName}(__len);"); |
| 253 | + source.WriteLine("__span.CopyTo(__buffer);"); |
| 254 | + source.WriteLine($"{memberName} = __buffer;"); |
| 255 | + source.Indentation--; |
| 256 | + source.WriteLine("}"); |
| 257 | + source.Indentation--; |
| 258 | + source.WriteLine("}"); |
| 259 | + return; |
| 260 | + } |
| 261 | + |
| 262 | + // ReadOnlyMemory<byte> |
| 263 | + if (typeSymbol is INamedTypeSymbol { Name: "ReadOnlyMemory", IsGenericType: true } readOnlyMemoryType && |
| 264 | + readOnlyMemoryType.TypeArguments[0].SpecialType == SpecialType.System_Byte) |
| 265 | + { |
| 266 | + source.WriteLine("{"); |
| 267 | + source.Indentation++; |
| 268 | + source.WriteLine($"int __len = {ReaderVarName}.{DecodeVarIntMethodName}<int>();"); |
| 269 | + source.WriteLine("if (__len == 0)"); |
| 270 | + source.WriteLine("{"); |
| 271 | + source.Indentation++; |
| 272 | + source.WriteLine($"{memberName} = global::System.ReadOnlyMemory<byte>.Empty;"); |
| 273 | + source.Indentation--; |
| 274 | + source.WriteLine("}"); |
| 275 | + source.WriteLine("else"); |
| 276 | + source.WriteLine("{"); |
| 277 | + source.Indentation++; |
| 278 | + source.WriteLine("var __buffer = global::System.GC.AllocateUninitializedArray<byte>(__len);"); |
| 279 | + source.WriteLine($"var __span = {ReaderVarName}.{CreateSpanMethodName}(__len);"); |
| 280 | + source.WriteLine("__span.CopyTo(__buffer);"); |
| 281 | + source.WriteLine($"{memberName} = __buffer;"); |
| 282 | + source.Indentation--; |
| 283 | + source.WriteLine("}"); |
| 284 | + source.Indentation--; |
| 285 | + source.WriteLine("}"); |
| 286 | + return; |
| 287 | + } |
| 288 | + |
| 289 | + // Nested IProtoSerializable type |
| 290 | + if (SymbolResolver.IsProtoPackable(typeSymbol)) |
| 291 | + { |
| 292 | + string typeName = typeSymbol.GetFullName(); |
| 293 | + source.WriteLine("{"); |
| 294 | + source.Indentation++; |
| 295 | + source.WriteLine($"int __len = {ReaderVarName}.{DecodeVarIntMethodName}<int>();"); |
| 296 | + source.WriteLine($"var __span = {ReaderVarName}.{CreateSpanMethodName}(__len);"); |
| 297 | + source.WriteLine($"{memberName} = {ProtoSerializerTypeRef}.DeserializeProtoPackable<{typeName}>(__span);"); |
| 298 | + source.Indentation--; |
| 299 | + source.WriteLine("}"); |
| 300 | + return; |
| 301 | + } |
| 302 | + |
| 303 | + // Fall back to TypeInfo.Fields for other complex types |
| 304 | + uint tag = (uint)field << 3 | (byte)info.WireType; |
| 305 | + source.WriteLine($"{TypeInfoPropertyName}.Fields[{tag}].Read(ref {ReaderVarName}, {ObjectVarName});"); |
| 306 | + } |
| 307 | + |
| 308 | + private static string GetDecodeTypeName(ITypeSymbol typeSymbol) |
| 309 | + { |
| 310 | + return typeSymbol.SpecialType switch |
| 311 | + { |
| 312 | + SpecialType.System_Byte => "byte", |
| 313 | + SpecialType.System_SByte => "sbyte", |
| 314 | + SpecialType.System_Int16 => "short", |
| 315 | + SpecialType.System_UInt16 => "ushort", |
| 316 | + SpecialType.System_Int32 => "int", |
| 317 | + SpecialType.System_UInt32 => "uint", |
| 318 | + SpecialType.System_Int64 => "long", |
| 319 | + SpecialType.System_UInt64 => "ulong", |
| 320 | + SpecialType.System_Single => "float", |
| 321 | + SpecialType.System_Double => "double", |
| 322 | + _ => typeSymbol.GetFullName() |
| 323 | + }; |
| 324 | + } |
| 325 | + |
| 326 | + private static string GetUnsignedTypeName(ITypeSymbol typeSymbol) |
| 327 | + { |
| 328 | + return typeSymbol.SpecialType switch |
| 329 | + { |
| 330 | + SpecialType.System_SByte => "byte", |
| 331 | + SpecialType.System_Int16 => "ushort", |
| 332 | + SpecialType.System_Int32 => "uint", |
| 333 | + SpecialType.System_Int64 => "ulong", |
| 334 | + SpecialType.System_Byte => "byte", |
| 335 | + SpecialType.System_UInt16 => "ushort", |
| 336 | + SpecialType.System_UInt32 => "uint", |
| 337 | + SpecialType.System_UInt64 => "ulong", |
| 338 | + _ => "uint" |
| 339 | + }; |
| 340 | + } |
| 341 | + } |
| 342 | +} |
0 commit comments