diff --git a/Makefile b/Makefile index d1d85445..2f668da3 100644 --- a/Makefile +++ b/Makefile @@ -43,14 +43,14 @@ help: @echo "" @echo "Current binaries to build: $(BINARIES)" -# Build all binaries +# Build all binaries. This always rebuilds because generator implementations live +# under internal/*, not just cmd/*, and stale binaries break golden tests. .PHONY: build -build: $(BINARY_PATHS) - -# Pattern rule to build each binary -$(BIN_DIR)/%: $(CMD_DIR)/%/*.go | $(BIN_DIR) - @echo "Building $*..." - @go build -o $@ ./$(CMD_DIR)/$* +build: | $(BIN_DIR) + @for binary in $(BINARIES); do \ + echo "Building $$binary..."; \ + go build -o $(BIN_DIR)/$$binary ./$(CMD_DIR)/$$binary; \ + done # Create bin directory $(BIN_DIR): @@ -285,4 +285,4 @@ ci-validate: else \ echo "actionlint not found. Install with: go install github.com/rhysd/actionlint/cmd/actionlint@latest"; \ fi; \ - done \ No newline at end of file + done diff --git a/README.md b/README.md index 135a9dcf..e101e6b4 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ This starts a working HTTP API with JSON endpoints and OpenAPI docs - all genera |-----------|--------| | `protoc-gen-go-http` | Go HTTP servers with routing, request binding, validation, and error handling | | `protoc-gen-go-client` | Go HTTP clients with type safety, header helpers, and per-call options | +| `protoc-gen-csharp-http` | C# contracts and `HttpClient` service clients for typed SDKs and integrations | | `protoc-gen-ts-client` | TypeScript HTTP clients with type safety, header helpers, and per-call options | | `protoc-gen-ts-server` | TypeScript HTTP servers with routing, request binding, validation, and error handling — runs on Node, Deno, Bun, Cloudflare Workers | | `protoc-gen-openapiv3` | OpenAPI v3.1 specs that stay in sync with your code, one file per service | @@ -135,6 +136,7 @@ UserService.openapi.yaml # Install the tools go install github.com/SebastienMelki/sebuf/cmd/protoc-gen-go-http@latest go install github.com/SebastienMelki/sebuf/cmd/protoc-gen-go-client@latest +go install github.com/SebastienMelki/sebuf/cmd/protoc-gen-csharp-http@latest go install github.com/SebastienMelki/sebuf/cmd/protoc-gen-openapiv3@latest go install github.com/SebastienMelki/sebuf/cmd/protoc-gen-ts-client@latest go install github.com/SebastienMelki/sebuf/cmd/protoc-gen-ts-server@latest @@ -180,6 +182,7 @@ sebuf is used at [Sarwa](https://www.sarwa.co/), the fastest-growing investment - **[Complete Tutorial](./examples/simple-api/)** - Full walkthrough with working code - **[Documentation](./docs/)** - Comprehensive guides and API reference +- **[C# Contract Generation](./docs/csharp-generation.md)** - C# plugin options, supported annotations, and examples - **[More Examples](./docs/examples/)** - Additional patterns and use cases ## Built on Great Tools diff --git a/cmd/protoc-gen-csharp-http/main.go b/cmd/protoc-gen-csharp-http/main.go new file mode 100644 index 00000000..edbe805e --- /dev/null +++ b/cmd/protoc-gen-csharp-http/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/types/pluginpb" + + "github.com/SebastienMelki/sebuf/internal/csharpgen" +) + +func main() { + options, cfg := csharpgen.NewOptions() + options.Run(func(plugin *protogen.Plugin) error { + plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) + gen := csharpgen.New(plugin, *cfg) + return gen.Generate() + }) +} diff --git a/docs/csharp-generation.md b/docs/csharp-generation.md new file mode 100644 index 00000000..ed26329f --- /dev/null +++ b/docs/csharp-generation.md @@ -0,0 +1,143 @@ +# C# HTTP Client Generation + +> **Generate C# contracts and `HttpClient` service clients from protobuf services** + +`protoc-gen-csharp-http` generates C# contract types and `HttpClient`-based service clients from annotated protobuf packages. It is designed for SDKs, typed API integrations, and shared contracts where C# needs the same JSON-facing shape and HTTP calling surface as other sebuf generators. + +## Quick Start + +### Installation + +```bash +go install github.com/SebastienMelki/sebuf/cmd/protoc-gen-csharp-http@latest +``` + +### Buf Configuration + +Add the plugin to `buf.gen.yaml`: + +```yaml +version: v2 +plugins: + - local: protoc-gen-csharp-http + out: gen/csharp + opt: + - namespace=Acme.Contracts + - json_lib=system_text_json +``` + +### Protoc Usage + +```bash +protoc \ + --plugin=protoc-gen-csharp-http="$(go env GOPATH)/bin/protoc-gen-csharp-http" \ + --csharp-http_out=gen/csharp \ + --csharp-http_opt=namespace=Acme.Contracts,json_lib=newtonsoft \ + --proto_path=. \ + --proto_path=./proto \ + proto/example/v1/service.proto +``` + +## Generated Output + +For each generated package, the plugin emits one `Contracts.g.cs` file containing: + +- C# `enum` types for protobuf enums +- C# classes for protobuf messages +- `I{Service}Client` and `{Service}Client` types built on `HttpClient` +- `{Service}ClientOptions` and `{Service}CallOptions` for headers and transport configuration +- `ApiException` for non-success responses +- `ServiceContracts` metadata with service name, base path, HTTP method, route, request type, and response type per RPC + +Nested protobuf messages and enums are flattened into idiomatic C# names such as `WidgetProfile` and `WidgetState`. + +## Supported Options + +### Generator Options + +- `namespace` + Sets the C# namespace. Default: `Sebuf.Generated` +- `json_lib` + Chooses JSON attributes and converters. Supported values: + - `newtonsoft` + - `system_text_json` + +## JSON Contract Behavior + +The generator reflects the JSON-facing contract shape for the supported annotations below. + +### Field and Message Shape + +- `flatten` + Flattens child message fields into the parent contract, honoring `flatten_prefix` +- `oneof_config` + Emits discriminator properties and flattened discriminated-union fields when configured +- `unwrap` + Root unwrap messages generate collection-shaped contracts such as `List`, and map-value unwrap is preserved during client request/response serialization +- `nullable` + Uses nullable C# reference/value types where the JSON contract can be `null` +- `empty_behavior` + Uses nullable contract fields for `NULL` and `OMIT` empty-message behavior + +### Value Encoding + +- `int64_encoding` + Maps `int64` JSON number encoding to `long`; otherwise uses `string` +- `enum_encoding` + Supports numeric enums or string enums with JSON converters +- `enum_value` + Applies custom string values via `[EnumMember(Value = "...")]` +- `timestamp_format` + Maps timestamp fields to `string` or `long` depending on configured format +- `bytes_encoding` + Represents bytes fields as `byte[]` and re-encodes on the wire for `hex`, `base64_raw`, `base64url`, and `base64url_raw` + +## Example + +Proto: + +```proto +message Widget { + optional string display_name = 1 [(sebuf.http.nullable) = true]; + Profile profile = 2 [(sebuf.http.flatten) = true, (sebuf.http.flatten_prefix) = "meta_"]; + + message Profile { + string note = 1; + } +} +``` + +Generated C#: + +```csharp +public sealed class Widget +{ + [JsonProperty("display_name")] + public string? DisplayName { get; set; } + + [JsonProperty("meta_note")] + public string? MetaNote { get; set; } +} +``` + +## Client Runtime + +For each protobuf service, the generator emits: + +- `IWidgetServiceClient` +- `WidgetServiceClient` +- `WidgetServiceClientOptions` +- `WidgetServiceCallOptions` + +Generated clients: + +- use `HttpClient` +- build paths from annotated route params +- add annotated query params for `GET` / `DELETE` +- apply service-level and method-level headers +- serialize request bodies as JSON +- deserialize JSON responses into generated contracts +- preserve `unwrap` and `bytes_encoding` wire behavior +- throw `ApiException` for non-2xx responses + +See [examples/csharp-contracts-demo](../examples/csharp-contracts-demo/) for a working generation example. diff --git a/docs/examples/README.md b/docs/examples/README.md index cde6d20d..632a49d1 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -28,6 +28,7 @@ This starts a working HTTP API with user management, authentication, and OpenAPI | **[nested-resources](../../examples/nested-resources/)** | Organization hierarchy API | Deep path nesting (3 levels), multiple path params per endpoint | | **[multi-service-api](../../examples/multi-service-api/)** | Multi-tenant platform | Multiple services, different auth levels, service/method headers | | **[market-data-unwrap](../../examples/market-data-unwrap/)** | Financial market data API | Unwrap annotation for map values, JSON/protobuf compatibility | +| **[csharp-contracts-demo](../../examples/csharp-contracts-demo/)** | C# contract generation demo | C# contracts, flattened fields, oneof discriminator metadata, root unwrap | | **[ts-client-demo](../../examples/ts-client-demo/)** | TypeScript client demo | TypeScript HTTP client, CRUD API, query params, headers, error handling | | **[ts-fullstack-demo](../../examples/ts-fullstack-demo/)** | TypeScript full-stack demo | TS client + TS server from same proto, CRUD, unwrap, custom errors | @@ -139,6 +140,16 @@ cd examples/ts-client-demo && make demo **Prerequisites**: Node.js (for the TypeScript client) +### csharp-contracts-demo +HTTP client generation example for `protoc-gen-csharp-http`. +- Generates C# contracts plus `HttpClient` service clients +- Shows `flatten`, `nullable`, `oneof_config`, `unwrap`, `bytes_encoding`, and service route metadata +- Supports both `newtonsoft` and `System.Text.Json` output + +```bash +cd examples/csharp-contracts-demo && make generate +``` + ### ts-fullstack-demo Full TypeScript stack: both client and server generated from the same proto. - Generated TypeScript server from `protoc-gen-ts-server` (Web Fetch API) diff --git a/examples/csharp-contracts-demo/.gitignore b/examples/csharp-contracts-demo/.gitignore new file mode 100644 index 00000000..e8e450be --- /dev/null +++ b/examples/csharp-contracts-demo/.gitignore @@ -0,0 +1 @@ +gen/ diff --git a/examples/csharp-contracts-demo/Makefile b/examples/csharp-contracts-demo/Makefile new file mode 100644 index 00000000..03cb3e4d --- /dev/null +++ b/examples/csharp-contracts-demo/Makefile @@ -0,0 +1,25 @@ +PROTO_DIR := proto +OUT_DIR := gen + +.PHONY: generate +generate: + @mkdir -p $(OUT_DIR)/newtonsoft $(OUT_DIR)/system-text-json + @protoc \ + --plugin=protoc-gen-csharp-http=../../bin/protoc-gen-csharp-http \ + --proto_path=$(PROTO_DIR) \ + --proto_path=../../proto \ + --csharp-http_out=$(OUT_DIR)/newtonsoft \ + --csharp-http_opt=namespace=Demo.Contracts,json_lib=newtonsoft \ + $(PROTO_DIR)/contracts.proto + @protoc \ + --plugin=protoc-gen-csharp-http=../../bin/protoc-gen-csharp-http \ + --proto_path=$(PROTO_DIR) \ + --proto_path=../../proto \ + --csharp-http_out=$(OUT_DIR)/system-text-json \ + --csharp-http_opt=namespace=Demo.Contracts,json_lib=system_text_json \ + $(PROTO_DIR)/contracts.proto + @echo "Generated C# contracts into $(OUT_DIR)/" + +.PHONY: clean +clean: + @rm -rf $(OUT_DIR) diff --git a/examples/csharp-contracts-demo/README.md b/examples/csharp-contracts-demo/README.md new file mode 100644 index 00000000..909ee464 --- /dev/null +++ b/examples/csharp-contracts-demo/README.md @@ -0,0 +1,37 @@ +# C# HTTP Client Demo + +This example shows how to generate C# contracts and `HttpClient` service clients with `protoc-gen-csharp-http`. + +## What It Covers + +- flattened message fields with `flatten` and `flatten_prefix` +- discriminated oneofs with `oneof_config` +- nullable contract fields +- root unwrap collection contracts +- generated `I{Service}Client` / `{Service}Client` types +- request/response JSON handling for `unwrap` and `bytes_encoding` +- service route metadata +- both `newtonsoft` and `System.Text.Json` output modes + +## Generate + +```bash +cd examples/csharp-contracts-demo +make generate +``` + +Generated files: + +- `gen/newtonsoft/demo/contracts/v1/Contracts.g.cs` +- `gen/system-text-json/demo/contracts/v1/Contracts.g.cs` + +Each generated file includes: + +- message and enum contracts +- service clients and per-call options +- `ApiException` +- `ServiceContracts` metadata + +## Proto + +The example proto lives at [proto/contracts.proto](./proto/contracts.proto). diff --git a/examples/csharp-contracts-demo/proto/contracts.proto b/examples/csharp-contracts-demo/proto/contracts.proto new file mode 100644 index 00000000..09150646 --- /dev/null +++ b/examples/csharp-contracts-demo/proto/contracts.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package demo.contracts.v1; + +option go_package = "github.com/SebastienMelki/sebuf/examples/csharp-contracts-demo/gen/go;contracts"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; +import "sebuf/http/annotations.proto"; + +message Product { + string id = 1; + optional string display_name = 2 [(sebuf.http.nullable) = true]; + Metadata metadata = 3 [(sebuf.http.flatten) = true, (sebuf.http.flatten_prefix) = "meta_"]; + google.protobuf.Timestamp updated_at = 4 [(sebuf.http.timestamp_format) = TIMESTAMP_FORMAT_UNIX_MILLIS]; + + message Metadata { + string owner = 1; + } +} + +message ProductEvent { + oneof payload { + option (sebuf.http.oneof_config) = { + discriminator: "kind" + flatten: true + }; + + Created created = 1; + Deleted deleted = 2 [(sebuf.http.oneof_value) = "removed"]; + } + + message Created { + string product_id = 1; + } + + message Deleted { + string product_id = 1; + string reason = 2; + } +} + +message ProductIds { + repeated string values = 1 [(sebuf.http.unwrap) = true]; +} + +message ArchiveProductRequest { + string id = 1; +} + +service ProductService { + option (sebuf.http.service_config) = { + base_path: "/api/v1" + }; + + rpc ArchiveProduct(ArchiveProductRequest) returns (google.protobuf.Empty) { + option (sebuf.http.config) = { + method: HTTP_METHOD_POST + path: "/products/{id}:archive" + }; + } +} diff --git a/internal/clientgen/golden_test.go b/internal/clientgen/golden_test.go index f9dff7bb..a999a979 100644 --- a/internal/clientgen/golden_test.go +++ b/internal/clientgen/golden_test.go @@ -5,8 +5,9 @@ import ( "os" "os/exec" "path/filepath" - "strings" "testing" + + "github.com/SebastienMelki/sebuf/internal/testutil" ) // TestClientGenGoldenFiles tests HTTP client generation against golden files. @@ -239,53 +240,8 @@ func compareGoldenFile(t *testing.T, expectedFile, goldenPath string, generatedC "Run with UPDATE_GOLDEN=1 to update golden files after reviewing changes.\n"+ "Diff:\n%s", expectedFile, - diffStrings(string(goldenContent), string(generatedContent))) - } -} - -// diffStrings returns a simple diff between two strings. -func diffStrings(expected, actual string) string { - expectedLines := strings.Split(expected, "\n") - actualLines := strings.Split(actual, "\n") - - var diff strings.Builder - maxLines := len(expectedLines) - if len(actualLines) > maxLines { - maxLines = len(actualLines) + testutil.DiffStrings(string(goldenContent), string(generatedContent))) } - - diffCount := 0 - const maxDiffs = 20 - - for i := 0; i < maxLines && diffCount < maxDiffs; i++ { - var expLine, actLine string - if i < len(expectedLines) { - expLine = expectedLines[i] - } - if i < len(actualLines) { - actLine = actualLines[i] - } - - if expLine != actLine { - diff.WriteString("Line ") - diff.WriteRune(rune('0' + i/100)) - diff.WriteRune(rune('0' + (i/10)%10)) - diff.WriteRune(rune('0' + i%10)) - diff.WriteString(":\n") - diff.WriteString(" expected: ") - diff.WriteString(expLine) - diff.WriteString("\n actual: ") - diff.WriteString(actLine) - diff.WriteString("\n") - diffCount++ - } - } - - if diffCount >= maxDiffs { - diff.WriteString("... (more differences truncated)\n") - } - - return diff.String() } // TestGeneratedClientCodeCompiles verifies that generated code compiles correctly. diff --git a/internal/contractmodel/model.go b/internal/contractmodel/model.go new file mode 100644 index 00000000..07b8dec4 --- /dev/null +++ b/internal/contractmodel/model.go @@ -0,0 +1,633 @@ +package contractmodel + +import ( + "sort" + "strings" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/reflect/protoreflect" + + sebufhttp "github.com/SebastienMelki/sebuf/http" + "github.com/SebastienMelki/sebuf/internal/annotations" +) + +const ( + AnyFullName protoreflect.FullName = "google.protobuf.Any" + DurationFullName protoreflect.FullName = "google.protobuf.Duration" + EmptyFullName protoreflect.FullName = "google.protobuf.Empty" + FieldMaskFullName protoreflect.FullName = "google.protobuf.FieldMask" + ListValueFullName protoreflect.FullName = "google.protobuf.ListValue" + StructFullName protoreflect.FullName = "google.protobuf.Struct" + TimestampFullName protoreflect.FullName = "google.protobuf.Timestamp" + ValueFullName protoreflect.FullName = "google.protobuf.Value" + DoubleValueName protoreflect.FullName = "google.protobuf.DoubleValue" + FloatValueName protoreflect.FullName = "google.protobuf.FloatValue" + Int64ValueName protoreflect.FullName = "google.protobuf.Int64Value" + UInt64ValueName protoreflect.FullName = "google.protobuf.UInt64Value" + Int32ValueName protoreflect.FullName = "google.protobuf.Int32Value" + UInt32ValueName protoreflect.FullName = "google.protobuf.UInt32Value" + BoolValueName protoreflect.FullName = "google.protobuf.BoolValue" + StringValueName protoreflect.FullName = "google.protobuf.StringValue" + BytesValueName protoreflect.FullName = "google.protobuf.BytesValue" +) + +type Kind int + +const ( + KindScalar Kind = iota + KindEnum + KindMessage + KindWellKnown + KindMap +) + +type WellKnownType string + +const ( + WellKnownAny WellKnownType = "any" + WellKnownDuration WellKnownType = "duration" + WellKnownEmpty WellKnownType = "empty" + WellKnownFieldMask WellKnownType = "field_mask" + WellKnownListValue WellKnownType = "list_value" + WellKnownStruct WellKnownType = "struct" + WellKnownTimestamp WellKnownType = "timestamp" + WellKnownValue WellKnownType = "value" + WellKnownDoubleWrap WellKnownType = "double_wrapper" + WellKnownFloatWrap WellKnownType = "float_wrapper" + WellKnownInt64Wrap WellKnownType = "int64_wrapper" + WellKnownUInt64Wrap WellKnownType = "uint64_wrapper" + WellKnownInt32Wrap WellKnownType = "int32_wrapper" + WellKnownUInt32Wrap WellKnownType = "uint32_wrapper" + WellKnownBoolWrap WellKnownType = "bool_wrapper" + WellKnownStringWrap WellKnownType = "string_wrapper" + WellKnownBytesWrap WellKnownType = "bytes_wrapper" +) + +type TypeRef struct { + Kind Kind + Name string + WellKnown WellKnownType + MapKey *TypeRef + MapValue *TypeRef +} + +type Query struct { + Name string + Required bool +} + +type Header struct { + Name string + Description string + Required bool +} + +type FieldAnnotations struct { + Query *Query + Unwrap bool + Int64Encoding sebufhttp.Int64Encoding + EnumEncoding sebufhttp.EnumEncoding + Nullable bool + EmptyBehavior sebufhttp.EmptyBehavior + TimestampFormat sebufhttp.TimestampFormat + BytesEncoding sebufhttp.BytesEncoding + Flatten bool + FlattenPrefix string + OneofValue string +} + +type Field struct { + Name string + JSONName string + Type *TypeRef + Repeated bool + Optional bool + HasPresence bool + IsMap bool + IsOneofVariant bool + OneofName string + Annotations FieldAnnotations +} + +type EnumValue struct { + Name string + JSONValue string + Number int32 +} + +type Enum struct { + Name string + ProtoName string + Values []*EnumValue +} + +type OneofVariant struct { + FieldName string + DiscriminatorValue string + Type *TypeRef + IsMessage bool +} + +type Oneof struct { + Name string + Discriminator string + Flatten bool + Variants []*OneofVariant +} + +type Unwrap struct { + FieldName string + IsRoot bool + IsMapField bool + ElementType *TypeRef +} + +type Message struct { + Name string + ProtoName string + Fields []*Field + Oneofs []*Oneof + Unwrap *Unwrap +} + +type Method struct { + Name string + InputType string + ResponseType string + HTTPMethod string + Path string + PathParams []string + Headers []*Header +} + +type Service struct { + Name string + BasePath string + Headers []*Header + Methods []*Method +} + +type Package struct { + Name string + SourceFiles []string + Enums []*Enum + Messages []*Message + Services []*Service +} + +type symbols struct { + messages map[protoreflect.FullName]string + enums map[protoreflect.FullName]string +} + +func Packages(files []*protogen.File) []*Package { + byPackage := make(map[string][]*protogen.File) + for _, file := range files { + if !file.Generate { + continue + } + pkg := string(file.Desc.Package()) + if pkg == "" { + pkg = "default" + } + byPackage[pkg] = append(byPackage[pkg], file) + } + + names := make([]string, 0, len(byPackage)) + for name := range byPackage { + names = append(names, name) + } + sort.Strings(names) + + packages := make([]*Package, 0, len(names)) + for _, name := range names { + filesForPackage := byPackage[name] + sort.Slice(filesForPackage, func(i, j int) bool { + return filesForPackage[i].Desc.Path() < filesForPackage[j].Desc.Path() + }) + + table := buildSymbols(filesForPackage) + pkg := &Package{ + Name: name, + SourceFiles: sourceFiles(filesForPackage), + } + pkg.Enums = collectEnums(filesForPackage, table) + pkg.Messages = collectMessages(filesForPackage, table) + pkg.Services = collectServices(filesForPackage, table) + packages = append(packages, pkg) + } + + return packages +} + +func sourceFiles(files []*protogen.File) []string { + result := make([]string, 0, len(files)) + for _, file := range files { + result = append(result, file.Desc.Path()) + } + return result +} + +func buildSymbols(files []*protogen.File) *symbols { + table := &symbols{ + messages: make(map[protoreflect.FullName]string), + enums: make(map[protoreflect.FullName]string), + } + for _, file := range files { + for _, enum := range file.Enums { + table.enums[enum.Desc.FullName()] = string(enum.Desc.Name()) + } + for _, msg := range file.Messages { + walkMessageSymbols(msg, nil, table) + } + } + return table +} + +func walkMessageSymbols(msg *protogen.Message, parents []string, table *symbols) { + name := append(append([]string{}, parents...), string(msg.Desc.Name())) + symbol := strings.Join(name, "") + if !msg.Desc.IsMapEntry() { + table.messages[msg.Desc.FullName()] = symbol + } + for _, enum := range msg.Enums { + table.enums[enum.Desc.FullName()] = symbol + string(enum.Desc.Name()) + } + for _, nested := range msg.Messages { + walkMessageSymbols(nested, name, table) + } +} + +func collectEnums(files []*protogen.File, table *symbols) []*Enum { + var result []*Enum + for _, file := range files { + for _, enum := range file.Enums { + result = append(result, buildEnum(enum, table)) + } + for _, msg := range file.Messages { + collectNestedEnums(msg, table, &result) + } + } + return result +} + +func collectNestedEnums(msg *protogen.Message, table *symbols, out *[]*Enum) { + for _, enum := range msg.Enums { + *out = append(*out, buildEnum(enum, table)) + } + for _, nested := range msg.Messages { + collectNestedEnums(nested, table, out) + } +} + +func buildEnum(enum *protogen.Enum, table *symbols) *Enum { + values := make([]*EnumValue, 0, len(enum.Values)) + for _, value := range enum.Values { + jsonValue := annotations.GetEnumValueMapping(value) + if jsonValue == "" { + jsonValue = string(value.Desc.Name()) + } + values = append(values, &EnumValue{ + Name: string(value.Desc.Name()), + JSONValue: jsonValue, + Number: int32(value.Desc.Number()), + }) + } + return &Enum{ + Name: table.enums[enum.Desc.FullName()], + ProtoName: string(enum.Desc.Name()), + Values: values, + } +} + +func collectMessages(files []*protogen.File, table *symbols) []*Message { + var result []*Message + for _, file := range files { + for _, msg := range file.Messages { + collectMessage(msg, table, &result) + } + } + return result +} + +func collectMessage(msg *protogen.Message, table *symbols, out *[]*Message) { + if !msg.Desc.IsMapEntry() { + queryByField := make(map[string]annotations.QueryParam) + for _, param := range annotations.GetQueryParams(msg) { + queryByField[param.FieldName] = param + } + + fields := make([]*Field, 0, len(msg.Fields)) + for _, field := range msg.Fields { + fields = append(fields, &Field{ + Name: string(field.Desc.Name()), + JSONName: field.Desc.JSONName(), + Type: resolveType(field, table), + Repeated: field.Desc.IsList() && !field.Desc.IsMap(), + Optional: field.Desc.HasOptionalKeyword(), + HasPresence: field.Desc.HasPresence(), + IsMap: field.Desc.IsMap(), + IsOneofVariant: field.Oneof != nil && !field.Oneof.Desc.IsSynthetic(), + OneofName: oneofName(field), + Annotations: fieldAnnotations(field, queryByField[string(field.Desc.Name())]), + }) + } + + message := &Message{ + Name: table.messages[msg.Desc.FullName()], + ProtoName: string(msg.Desc.Name()), + Fields: fields, + Oneofs: collectOneofs(msg, table), + Unwrap: collectUnwrap(msg, table), + } + *out = append(*out, message) + } + for _, nested := range msg.Messages { + collectMessage(nested, table, out) + } +} + +func oneofName(field *protogen.Field) string { + if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() { + return "" + } + return string(field.Oneof.Desc.Name()) +} + +func fieldAnnotations(field *protogen.Field, query annotations.QueryParam) FieldAnnotations { + result := FieldAnnotations{ + Unwrap: annotations.HasUnwrapAnnotation(field), + Int64Encoding: annotations.GetInt64Encoding(field), + EnumEncoding: annotations.GetEnumEncoding(field), + Nullable: annotations.IsNullableField(field), + EmptyBehavior: annotations.GetEmptyBehavior(field), + TimestampFormat: annotations.GetTimestampFormat(field), + BytesEncoding: annotations.GetBytesEncoding(field), + Flatten: annotations.IsFlattenField(field), + FlattenPrefix: annotations.GetFlattenPrefix(field), + OneofValue: annotations.GetOneofVariantValue(field), + } + if query.FieldName != "" { + result.Query = &Query{Name: query.ParamName, Required: query.Required} + } + return result +} + +func collectUnwrap(msg *protogen.Message, table *symbols) *Unwrap { + info, err := annotations.GetUnwrapField(msg) + if err != nil || info == nil { + return nil + } + + var elementType *TypeRef + if info.ElementType != nil { + elementType = resolveMessageType(info.ElementType, table) + } + + return &Unwrap{ + FieldName: string(info.Field.Desc.Name()), + IsRoot: info.IsRootUnwrap, + IsMapField: info.IsMapField, + ElementType: elementType, + } +} + +func collectOneofs(msg *protogen.Message, table *symbols) []*Oneof { + var result []*Oneof + for _, oneof := range msg.Oneofs { + if oneof.Desc.IsSynthetic() { + continue + } + + model := &Oneof{Name: string(oneof.Desc.Name())} + if info := annotations.GetOneofDiscriminatorInfo(oneof); info != nil { + model.Discriminator = info.Discriminator + model.Flatten = info.Flatten + for _, variant := range info.Variants { + model.Variants = append(model.Variants, &OneofVariant{ + FieldName: string(variant.Field.Desc.Name()), + DiscriminatorValue: variant.DiscriminatorVal, + Type: resolveType(variant.Field, table), + IsMessage: variant.IsMessage, + }) + } + } else { + for _, field := range oneof.Fields { + model.Variants = append(model.Variants, &OneofVariant{ + FieldName: string(field.Desc.Name()), + DiscriminatorValue: string(field.Desc.Name()), + Type: resolveType(field, table), + IsMessage: field.Message != nil, + }) + } + } + + result = append(result, model) + } + return result +} + +func collectServices(files []*protogen.File, table *symbols) []*Service { + var result []*Service + for _, file := range files { + for _, service := range file.Services { + basePath := annotations.GetServiceBasePath(service) + serviceHeaders := headersFromAnnotation(annotations.GetServiceHeaders(service)) + methods := make([]*Method, 0, len(service.Methods)) + for _, method := range service.Methods { + httpConfig := annotations.GetMethodHTTPConfig(method) + httpMethod := "POST" + fullPath := annotations.BuildHTTPPath(basePath, "") + var pathParams []string + if httpConfig != nil { + httpMethod = httpConfig.Method + fullPath = annotations.BuildHTTPPath(basePath, httpConfig.Path) + pathParams = append(pathParams, httpConfig.PathParams...) + } + + methods = append(methods, &Method{ + Name: method.GoName, + InputType: resolveMessageName(method.Input, table), + ResponseType: resolveMessageName(method.Output, table), + HTTPMethod: httpMethod, + Path: fullPath, + PathParams: pathParams, + Headers: headersFromAnnotation(annotations.GetMethodHeaders(method)), + }) + } + result = append(result, &Service{ + Name: service.GoName, + BasePath: basePath, + Headers: serviceHeaders, + Methods: methods, + }) + } + } + return result +} + +func headersFromAnnotation(headers []*sebufhttp.Header) []*Header { + result := make([]*Header, 0, len(headers)) + for _, header := range headers { + if header == nil || header.GetName() == "" { + continue + } + result = append(result, &Header{ + Name: header.GetName(), + Description: header.GetDescription(), + Required: header.GetRequired(), + }) + } + return result +} + +func resolveMessageName(message *protogen.Message, table *symbols) string { + if message == nil { + return "" + } + if name, ok := table.messages[message.Desc.FullName()]; ok { + return name + } + return string(message.Desc.Name()) +} + +func resolveMessageType(message *protogen.Message, table *symbols) *TypeRef { + return &TypeRef{Kind: KindMessage, Name: resolveMessageName(message, table)} +} + +func resolveType(field *protogen.Field, table *symbols) *TypeRef { + if field.Desc.IsMap() { + keyField := field.Message.Fields[0] + valueField := field.Message.Fields[1] + return &TypeRef{ + Kind: KindMap, + MapKey: scalarTypeRef(keyField.Desc.Kind()), + MapValue: resolveNonMapType(valueField, table), + } + } + return resolveNonMapType(field, table) +} + +func resolveNonMapType(field *protogen.Field, table *symbols) *TypeRef { + //nolint:exhaustive // scalar kinds intentionally fall through to scalarTypeRef in the default case + switch field.Desc.Kind() { + case protoreflect.EnumKind: + if name, ok := table.enums[field.Enum.Desc.FullName()]; ok { + return &TypeRef{Kind: KindEnum, Name: name} + } + return &TypeRef{Kind: KindEnum, Name: string(field.Enum.Desc.Name())} + case protoreflect.MessageKind, protoreflect.GroupKind: + if ref := wellKnownTypeRef(field.Message.Desc.FullName()); ref != nil { + return ref + } + return resolveMessageType(field.Message, table) + default: + return scalarTypeRef(field.Desc.Kind()) + } +} + +func wellKnownTypeRef(fullName protoreflect.FullName) *TypeRef { + switch fullName { + case AnyFullName: + return &TypeRef{Kind: KindWellKnown, Name: csharpFriendlyName(WellKnownAny), WellKnown: WellKnownAny} + case DurationFullName: + return &TypeRef{Kind: KindWellKnown, Name: csharpFriendlyName(WellKnownDuration), WellKnown: WellKnownDuration} + case EmptyFullName: + return &TypeRef{Kind: KindWellKnown, Name: csharpFriendlyName(WellKnownEmpty), WellKnown: WellKnownEmpty} + case FieldMaskFullName: + return &TypeRef{ + Kind: KindWellKnown, + Name: csharpFriendlyName(WellKnownFieldMask), + WellKnown: WellKnownFieldMask, + } + case ListValueFullName: + return &TypeRef{ + Kind: KindWellKnown, + Name: csharpFriendlyName(WellKnownListValue), + WellKnown: WellKnownListValue, + } + case StructFullName: + return &TypeRef{ + Kind: KindWellKnown, + Name: csharpFriendlyName(WellKnownStruct), + WellKnown: WellKnownStruct, + } + case TimestampFullName: + return &TypeRef{ + Kind: KindWellKnown, + Name: csharpFriendlyName(WellKnownTimestamp), + WellKnown: WellKnownTimestamp, + } + case ValueFullName: + return &TypeRef{ + Kind: KindWellKnown, + Name: csharpFriendlyName(WellKnownValue), + WellKnown: WellKnownValue, + } + case DoubleValueName: + return &TypeRef{Kind: KindWellKnown, Name: "double", WellKnown: WellKnownDoubleWrap} + case FloatValueName: + return &TypeRef{Kind: KindWellKnown, Name: "float", WellKnown: WellKnownFloatWrap} + case Int64ValueName: + return &TypeRef{Kind: KindWellKnown, Name: "int64", WellKnown: WellKnownInt64Wrap} + case UInt64ValueName: + return &TypeRef{Kind: KindWellKnown, Name: "uint64", WellKnown: WellKnownUInt64Wrap} + case Int32ValueName: + return &TypeRef{Kind: KindWellKnown, Name: "int32", WellKnown: WellKnownInt32Wrap} + case UInt32ValueName: + return &TypeRef{Kind: KindWellKnown, Name: "uint32", WellKnown: WellKnownUInt32Wrap} + case BoolValueName: + return &TypeRef{Kind: KindWellKnown, Name: "bool", WellKnown: WellKnownBoolWrap} + case StringValueName: + return &TypeRef{Kind: KindWellKnown, Name: "string", WellKnown: WellKnownStringWrap} + case BytesValueName: + return &TypeRef{Kind: KindWellKnown, Name: "bytes", WellKnown: WellKnownBytesWrap} + default: + return nil + } +} + +func csharpFriendlyName(kind WellKnownType) string { + switch kind { + case WellKnownAny: + return "Any" + case WellKnownDuration: + return "Duration" + case WellKnownEmpty: + return "Empty" + case WellKnownFieldMask: + return "FieldMask" + case WellKnownListValue: + return "ListValue" + case WellKnownStruct: + return "Struct" + case WellKnownTimestamp: + return "Timestamp" + case WellKnownValue: + return "Value" + case WellKnownDoubleWrap: + return "double" + case WellKnownFloatWrap: + return "float" + case WellKnownInt64Wrap: + return "int64" + case WellKnownUInt64Wrap: + return "uint64" + case WellKnownInt32Wrap: + return "int32" + case WellKnownUInt32Wrap: + return "uint32" + case WellKnownBoolWrap: + return "bool" + case WellKnownStringWrap: + return "string" + case WellKnownBytesWrap: + return "bytes" + default: + return string(kind) + } +} + +func scalarTypeRef(kind protoreflect.Kind) *TypeRef { + return &TypeRef{Kind: KindScalar, Name: strings.ToLower(kind.String())} +} diff --git a/internal/contractmodel/model_test.go b/internal/contractmodel/model_test.go new file mode 100644 index 00000000..c783ddbd --- /dev/null +++ b/internal/contractmodel/model_test.go @@ -0,0 +1,660 @@ +package contractmodel + +import ( + "slices" + "testing" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/fieldmaskpb" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" + "google.golang.org/protobuf/types/pluginpb" + + sebufhttp "github.com/SebastienMelki/sebuf/http" +) + +func TestPackagesBuildsRichModel(t *testing.T) { + plugin := newContractModelPlugin(t) + + pkgs := Packages(plugin.Files) + if len(pkgs) != 1 { + t.Fatalf("Packages() returned %d packages, want 1", len(pkgs)) + } + + pkg := pkgs[0] + if got, want := pkg.Name, "test.contracts.v1"; got != want { + t.Fatalf("Package.Name = %q, want %q", got, want) + } + if got, want := pkg.SourceFiles, []string{"widget.proto", "widget_service.proto"}; !slices.Equal(got, want) { + t.Fatalf("Package.SourceFiles = %v, want %v", got, want) + } + + assertEnumModel(t, pkg) + assertWidgetModel(t, pkg) + assertOneofAndUnwrapModel(t, pkg) + assertServiceModel(t, pkg) +} + +func assertEnumModel(t *testing.T, pkg *Package) { + t.Helper() + widgetState := findEnum(t, pkg, "WidgetState") + if widgetState.ProtoName != "State" { + t.Fatalf("Enum.ProtoName = %q, want %q", widgetState.ProtoName, "State") + } + if got := widgetState.Values[1].JSONValue; got != "ready" { + t.Fatalf("Enum JSON mapping = %q, want %q", got, "ready") + } +} + +func assertWidgetModel(t *testing.T, pkg *Package) { + t.Helper() + widgetDetails := findMessage(t, pkg, "WidgetDetails") + if widgetDetails.ProtoName != "Details" { + t.Fatalf("Nested message proto name = %q, want %q", widgetDetails.ProtoName, "Details") + } + + widget := findMessage(t, pkg, "Widget") + displayName := findField(t, widget, "display_name") + if !displayName.Optional || !displayName.HasPresence { + t.Fatalf("display_name field should preserve optional presence: %+v", displayName) + } + if !displayName.Annotations.Nullable { + t.Fatalf("display_name field should carry nullable annotation") + } + + ownerID := findField(t, widget, "owner_id") + if ownerID.Annotations.Query == nil || + ownerID.Annotations.Query.Name != "owner" || + !ownerID.Annotations.Query.Required { + t.Fatalf("owner_id query annotation = %+v, want owner/required", ownerID.Annotations.Query) + } + + createdAt := findField(t, widget, "created_at") + if got := createdAt.Type.WellKnown; got != WellKnownTimestamp { + t.Fatalf("created_at WellKnown = %q, want %q", got, WellKnownTimestamp) + } + if got := createdAt.Annotations.TimestampFormat; got != sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_UNIX_MILLIS { + t.Fatalf("created_at TimestampFormat = %v, want UNIX_MILLIS", got) + } + + payload := findField(t, widget, "payload") + if got := payload.Annotations.BytesEncoding; got != sebufhttp.BytesEncoding_BYTES_ENCODING_HEX { + t.Fatalf("payload BytesEncoding = %v, want HEX", got) + } + + version := findField(t, widget, "version") + if got := version.Annotations.Int64Encoding; got != sebufhttp.Int64Encoding_INT64_ENCODING_NUMBER { + t.Fatalf("version Int64Encoding = %v, want NUMBER", got) + } + + state := findField(t, widget, "state") + if got := state.Type.Name; got != "WidgetState" { + t.Fatalf("state type = %q, want %q", got, "WidgetState") + } + if got := state.Annotations.EnumEncoding; got != sebufhttp.EnumEncoding_ENUM_ENCODING_NUMBER { + t.Fatalf("state EnumEncoding = %v, want NUMBER", got) + } + + profile := findField(t, widget, "profile") + if !profile.Annotations.Flatten || profile.Annotations.FlattenPrefix != "meta_" { + t.Fatalf("profile flatten annotations = %+v", profile.Annotations) + } +} + +func assertOneofAndUnwrapModel(t *testing.T, pkg *Package) { + t.Helper() + shapeHolder := findMessage(t, pkg, "ShapeHolder") + if len(shapeHolder.Oneofs) != 1 { + t.Fatalf("ShapeHolder.Oneofs = %d, want 1", len(shapeHolder.Oneofs)) + } + shape := shapeHolder.Oneofs[0] + if shape.Name != "shape" || shape.Discriminator != "kind" || !shape.Flatten { + t.Fatalf("ShapeHolder oneof = %+v, want named discriminated flatten oneof", shape) + } + if len(shape.Variants) != 2 || shape.Variants[0].DiscriminatorValue != "circle_shape" { + t.Fatalf("ShapeHolder variants = %+v", shape.Variants) + } + + tags := findMessage(t, pkg, "Tags") + if tags.Unwrap == nil || !tags.Unwrap.IsRoot || tags.Unwrap.FieldName != "items" { + t.Fatalf("Tags unwrap = %+v, want root unwrap on items", tags.Unwrap) + } +} + +func assertServiceModel(t *testing.T, pkg *Package) { + t.Helper() + service := findService(t, pkg, "WidgetService") + if service.BasePath != "/api/v1" { + t.Fatalf("Service.BasePath = %q, want %q", service.BasePath, "/api/v1") + } + if len(service.Headers) != 1 || service.Headers[0].Name != "X-API-Key" { + t.Fatalf("Service.Headers = %+v, want X-API-Key", service.Headers) + } + getWidget := findMethod(t, service, "GetWidget") + if getWidget.HTTPMethod != "GET" || getWidget.Path != "/api/v1/widgets/{id}" { + t.Fatalf("GetWidget metadata = %+v", getWidget) + } + if got, want := getWidget.PathParams, []string{"id"}; !slices.Equal(got, want) { + t.Fatalf("GetWidget.PathParams = %v, want %v", got, want) + } + if len(getWidget.Headers) != 1 || getWidget.Headers[0].Name != "X-Request-ID" { + t.Fatalf("GetWidget.Headers = %+v, want X-Request-ID", getWidget.Headers) + } +} + +func TestPackagesResolveWellKnownTypesAndCrossFileMessages(t *testing.T) { + plugin := newContractModelPlugin(t) + pkg := Packages(plugin.Files)[0] + + holder := findMessage(t, pkg, "WellKnownHolder") + cases := map[string]WellKnownType{ + "meta": WellKnownStruct, + "any_value": WellKnownAny, + "ttl": WellKnownDuration, + "raw_value": WellKnownValue, + "items": WellKnownListValue, + "mask": WellKnownFieldMask, + "label": WellKnownStringWrap, + } + for fieldName, want := range cases { + if got := findField(t, holder, fieldName).Type.WellKnown; got != want { + t.Fatalf("%s WellKnown = %q, want %q", fieldName, got, want) + } + } + + reset := findMethod(t, findService(t, pkg, "WidgetService"), "ResetWidget") + if reset.ResponseType != "Empty" { + t.Fatalf("ResetWidget.ResponseType = %q, want %q", reset.ResponseType, "Empty") + } + + getWidget := findMethod(t, findService(t, pkg, "WidgetService"), "GetWidget") + if getWidget.ResponseType != "Widget" { + t.Fatalf("cross-file response type = %q, want %q", getWidget.ResponseType, "Widget") + } +} + +func newContractModelPlugin(t *testing.T) *protogen.Plugin { + t.Helper() + + widgetFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("widget.proto"), + Package: proto.String("test.contracts.v1"), + Syntax: proto.String("proto3"), + Dependency: []string{ + "google/protobuf/struct.proto", + "google/protobuf/timestamp.proto", + "proto/sebuf/http/annotations.proto", + }, + Options: &descriptorpb.FileOptions{ + GoPackage: proto.String("github.com/SebastienMelki/sebuf/internal/testcontracts/widget;widgetpb"), + }, + MessageType: []*descriptorpb.DescriptorProto{ + widgetDescriptor(t), + { + Name: proto.String("Tags"), + Field: []*descriptorpb.FieldDescriptorProto{ + repeatedScalarField( + "items", + 1, + descriptorpb.FieldDescriptorProto_TYPE_STRING, + withFieldOption(t, sebufhttp.E_Unwrap, true), + ), + }, + }, + }, + } + + serviceFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("widget_service.proto"), + Package: proto.String("test.contracts.v1"), + Syntax: proto.String("proto3"), + Dependency: []string{ + "widget.proto", + "google/protobuf/any.proto", + "google/protobuf/duration.proto", + "google/protobuf/empty.proto", + "google/protobuf/field_mask.proto", + "google/protobuf/struct.proto", + "google/protobuf/wrappers.proto", + "proto/sebuf/http/annotations.proto", + }, + Options: &descriptorpb.FileOptions{ + GoPackage: proto.String( + "github.com/SebastienMelki/sebuf/internal/testcontracts/widgetservice;widgetservicepb", + ), + }, + MessageType: []*descriptorpb.DescriptorProto{ + shapeHolderDescriptor(t), + wellKnownHolderDescriptor(), + { + Name: proto.String("GetWidgetRequest"), + Field: []*descriptorpb.FieldDescriptorProto{ + scalarField("id", 1, descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("WidgetService"), + Options: withServiceOptions( + withServiceOption(t, sebufhttp.E_ServiceConfig, &sebufhttp.ServiceConfig{BasePath: "/api/v1"}), + withServiceOption(t, sebufhttp.E_ServiceHeaders, &sebufhttp.ServiceHeaders{ + RequiredHeaders: []*sebufhttp.Header{{Name: "X-API-Key", Required: true}}, + }), + ), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("GetWidget"), + InputType: proto.String(".test.contracts.v1.GetWidgetRequest"), + OutputType: proto.String(".test.contracts.v1.Widget"), + Options: withMethodOptions( + withMethodOption(t, sebufhttp.E_Config, &sebufhttp.HttpConfig{ + Path: "/widgets/{id}", + Method: sebufhttp.HttpMethod_HTTP_METHOD_GET, + }), + withMethodOption(t, sebufhttp.E_MethodHeaders, &sebufhttp.MethodHeaders{ + RequiredHeaders: []*sebufhttp.Header{{Name: "X-Request-ID", Required: true}}, + }), + ), + }, + { + Name: proto.String("ResetWidget"), + InputType: proto.String(".test.contracts.v1.GetWidgetRequest"), + OutputType: proto.String(".google.protobuf.Empty"), + Options: withMethodOption(t, sebufhttp.E_Config, &sebufhttp.HttpConfig{ + Path: "/widgets/{id}:reset", + Method: sebufhttp.HttpMethod_HTTP_METHOD_POST, + }), + }, + }, + }, + }, + } + + req := &pluginpb.CodeGeneratorRequest{ + Parameter: proto.String("paths=source_relative"), + FileToGenerate: []string{"widget.proto", "widget_service.proto"}, + ProtoFile: append( + []*descriptorpb.FileDescriptorProto{widgetFile, serviceFile}, + testDependencyProtos()..., + ), + } + + plugin, err := protogen.Options{}.New(req) + if err != nil { + t.Fatalf("protogen.Options.New() error = %v", err) + } + return plugin +} + +func widgetDescriptor(t *testing.T) *descriptorpb.DescriptorProto { + t.Helper() + return &descriptorpb.DescriptorProto{ + Name: proto.String("Widget"), + Field: []*descriptorpb.FieldDescriptorProto{ + scalarField("id", 1, descriptorpb.FieldDescriptorProto_TYPE_STRING), + scalarField( + "owner_id", + 2, + descriptorpb.FieldDescriptorProto_TYPE_STRING, + withFieldOption(t, sebufhttp.E_Query, &sebufhttp.QueryConfig{Name: "owner", Required: true}), + ), + optionalScalarField( + "display_name", + 3, + descriptorpb.FieldDescriptorProto_TYPE_STRING, + withFieldOption(t, sebufhttp.E_Nullable, true), + ), + messageField( + "created_at", + 4, + ".google.protobuf.Timestamp", + withFieldOption( + t, + sebufhttp.E_TimestampFormat, + sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_UNIX_MILLIS, + ), + ), + scalarField( + "payload", + 5, + descriptorpb.FieldDescriptorProto_TYPE_BYTES, + withFieldOption(t, sebufhttp.E_BytesEncoding, sebufhttp.BytesEncoding_BYTES_ENCODING_HEX), + ), + scalarField( + "version", + 6, + descriptorpb.FieldDescriptorProto_TYPE_INT64, + withFieldOption(t, sebufhttp.E_Int64Encoding, sebufhttp.Int64Encoding_INT64_ENCODING_NUMBER), + ), + enumField( + "state", + 7, + ".test.contracts.v1.Widget.State", + withFieldOption(t, sebufhttp.E_EnumEncoding, sebufhttp.EnumEncoding_ENUM_ENCODING_NUMBER), + ), + messageField("details", 8, ".test.contracts.v1.Widget.Details"), + messageField( + "profile", + 9, + ".test.contracts.v1.Widget.Profile", + withFieldOption(t, sebufhttp.E_Flatten, true), + withFieldOption(t, sebufhttp.E_FlattenPrefix, "meta_"), + ), + }, + NestedType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Details"), + Field: []*descriptorpb.FieldDescriptorProto{ + scalarField("note", 1, descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + { + Name: proto.String("Profile"), + Field: []*descriptorpb.FieldDescriptorProto{ + scalarField("label", 1, descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("State"), + Value: []*descriptorpb.EnumValueDescriptorProto{ + {Name: proto.String("STATE_UNSPECIFIED"), Number: proto.Int32(0)}, + { + Name: proto.String("STATE_READY"), + Number: proto.Int32(1), + Options: withEnumValueOption(t, sebufhttp.E_EnumValue, "ready"), + }, + }, + }, + }, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + {Name: proto.String("_display_name")}, + }, + } +} + +func shapeHolderDescriptor(t *testing.T) *descriptorpb.DescriptorProto { + t.Helper() + return &descriptorpb.DescriptorProto{ + Name: proto.String("ShapeHolder"), + Field: []*descriptorpb.FieldDescriptorProto{ + messageFieldWithOneof( + "circle", + 1, + ".test.contracts.v1.ShapeHolder.Circle", + 0, + withFieldOption(t, sebufhttp.E_OneofValue, "circle_shape"), + ), + messageFieldWithOneof("rectangle", 2, ".test.contracts.v1.ShapeHolder.Rectangle", 0), + }, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + { + Name: proto.String("shape"), + Options: withOneofOption( + t, + sebufhttp.E_OneofConfig, + &sebufhttp.OneofConfig{Discriminator: "kind", Flatten: true}, + ), + }, + }, + NestedType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Circle"), + Field: []*descriptorpb.FieldDescriptorProto{ + scalarField("radius", 1, descriptorpb.FieldDescriptorProto_TYPE_DOUBLE), + }, + }, + { + Name: proto.String("Rectangle"), + Field: []*descriptorpb.FieldDescriptorProto{ + scalarField("width", 1, descriptorpb.FieldDescriptorProto_TYPE_DOUBLE), + scalarField("height", 2, descriptorpb.FieldDescriptorProto_TYPE_DOUBLE), + }, + }, + }, + } +} + +func wellKnownHolderDescriptor() *descriptorpb.DescriptorProto { + return &descriptorpb.DescriptorProto{ + Name: proto.String("WellKnownHolder"), + Field: []*descriptorpb.FieldDescriptorProto{ + messageField("meta", 1, ".google.protobuf.Struct"), + messageField("any_value", 2, ".google.protobuf.Any"), + messageField("ttl", 3, ".google.protobuf.Duration"), + messageField("raw_value", 4, ".google.protobuf.Value"), + messageField("items", 5, ".google.protobuf.ListValue"), + messageField("mask", 6, ".google.protobuf.FieldMask"), + messageField("label", 7, ".google.protobuf.StringValue"), + }, + } +} + +func scalarField( + name string, + number int32, + kind descriptorpb.FieldDescriptorProto_Type, + options ...*descriptorpb.FieldOptions, +) *descriptorpb.FieldDescriptorProto { + field := &descriptorpb.FieldDescriptorProto{ + Name: proto.String(name), + Number: proto.Int32(number), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: kind.Enum(), + } + if len(options) > 0 { + field.Options = mergeFieldOptions(options...) + } + return field +} + +func repeatedScalarField( + name string, + number int32, + kind descriptorpb.FieldDescriptorProto_Type, + options ...*descriptorpb.FieldOptions, +) *descriptorpb.FieldDescriptorProto { + field := scalarField(name, number, kind, options...) + field.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() + return field +} + +func messageField( + name string, + number int32, + typeName string, + options ...*descriptorpb.FieldOptions, +) *descriptorpb.FieldDescriptorProto { + field := &descriptorpb.FieldDescriptorProto{ + Name: proto.String(name), + Number: proto.Int32(number), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(typeName), + } + if len(options) > 0 { + field.Options = mergeFieldOptions(options...) + } + return field +} + +func enumField( + name string, + number int32, + typeName string, + options ...*descriptorpb.FieldOptions, +) *descriptorpb.FieldDescriptorProto { + field := &descriptorpb.FieldDescriptorProto{ + Name: proto.String(name), + Number: proto.Int32(number), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_ENUM.Enum(), + TypeName: proto.String(typeName), + } + if len(options) > 0 { + field.Options = mergeFieldOptions(options...) + } + return field +} + +func optionalScalarField( + name string, + number int32, + kind descriptorpb.FieldDescriptorProto_Type, + options ...*descriptorpb.FieldOptions, +) *descriptorpb.FieldDescriptorProto { + field := scalarField(name, number, kind, options...) + field.OneofIndex = proto.Int32(0) + field.Proto3Optional = proto.Bool(true) + return field +} + +func messageFieldWithOneof( + name string, + number int32, + typeName string, + oneofIndex int32, + options ...*descriptorpb.FieldOptions, +) *descriptorpb.FieldDescriptorProto { + field := messageField(name, number, typeName, options...) + field.OneofIndex = proto.Int32(oneofIndex) + return field +} + +func mergeFieldOptions(options ...*descriptorpb.FieldOptions) *descriptorpb.FieldOptions { + merged := &descriptorpb.FieldOptions{} + for _, option := range options { + proto.Merge(merged, option) + } + return merged +} + +func withFieldOption(t *testing.T, ext protoreflect.ExtensionType, value any) *descriptorpb.FieldOptions { + t.Helper() + opts := &descriptorpb.FieldOptions{} + proto.SetExtension(opts, ext, value) + return opts +} + +func withEnumValueOption(t *testing.T, ext protoreflect.ExtensionType, value any) *descriptorpb.EnumValueOptions { + t.Helper() + opts := &descriptorpb.EnumValueOptions{} + proto.SetExtension(opts, ext, value) + return opts +} + +func withMethodOption(t *testing.T, ext protoreflect.ExtensionType, value any) *descriptorpb.MethodOptions { + t.Helper() + opts := &descriptorpb.MethodOptions{} + proto.SetExtension(opts, ext, value) + return opts +} + +func withMethodOptions(options ...*descriptorpb.MethodOptions) *descriptorpb.MethodOptions { + merged := &descriptorpb.MethodOptions{} + for _, option := range options { + proto.Merge(merged, option) + } + return merged +} + +func withServiceOption(t *testing.T, ext protoreflect.ExtensionType, value any) *descriptorpb.ServiceOptions { + t.Helper() + opts := &descriptorpb.ServiceOptions{} + proto.SetExtension(opts, ext, value) + return opts +} + +func withServiceOptions(options ...*descriptorpb.ServiceOptions) *descriptorpb.ServiceOptions { + merged := &descriptorpb.ServiceOptions{} + for _, option := range options { + proto.Merge(merged, option) + } + return merged +} + +func withOneofOption(t *testing.T, ext protoreflect.ExtensionType, value any) *descriptorpb.OneofOptions { + t.Helper() + opts := &descriptorpb.OneofOptions{} + proto.SetExtension(opts, ext, value) + return opts +} + +func testDependencyProtos() []*descriptorpb.FileDescriptorProto { + return []*descriptorpb.FileDescriptorProto{ + protodesc.ToFileDescriptorProto(descriptorpb.File_google_protobuf_descriptor_proto), + protodesc.ToFileDescriptorProto(anypb.File_google_protobuf_any_proto), + protodesc.ToFileDescriptorProto(durationpb.File_google_protobuf_duration_proto), + protodesc.ToFileDescriptorProto(emptypb.File_google_protobuf_empty_proto), + protodesc.ToFileDescriptorProto(fieldmaskpb.File_google_protobuf_field_mask_proto), + protodesc.ToFileDescriptorProto(structpb.File_google_protobuf_struct_proto), + protodesc.ToFileDescriptorProto(timestamppb.File_google_protobuf_timestamp_proto), + protodesc.ToFileDescriptorProto(wrapperspb.File_google_protobuf_wrappers_proto), + protodesc.ToFileDescriptorProto(sebufhttp.File_proto_sebuf_http_annotations_proto), + } +} + +func findEnum(t *testing.T, pkg *Package, name string) *Enum { + t.Helper() + for _, enum := range pkg.Enums { + if enum.Name == name { + return enum + } + } + t.Fatalf("enum %q not found", name) + return nil +} + +func findMessage(t *testing.T, pkg *Package, name string) *Message { + t.Helper() + for _, message := range pkg.Messages { + if message.Name == name { + return message + } + } + t.Fatalf("message %q not found", name) + return nil +} + +func findField(t *testing.T, msg *Message, name string) *Field { + t.Helper() + for _, field := range msg.Fields { + if field.Name == name { + return field + } + } + t.Fatalf("field %q not found", name) + return nil +} + +func findService(t *testing.T, pkg *Package, name string) *Service { + t.Helper() + for _, service := range pkg.Services { + if service.Name == name { + return service + } + } + t.Fatalf("service %q not found", name) + return nil +} + +func findMethod(t *testing.T, service *Service, name string) *Method { + t.Helper() + for _, method := range service.Methods { + if method.Name == name { + return method + } + } + t.Fatalf("method %q not found", name) + return nil +} diff --git a/internal/csharpgen/generator.go b/internal/csharpgen/generator.go new file mode 100644 index 00000000..bf5f049e --- /dev/null +++ b/internal/csharpgen/generator.go @@ -0,0 +1,2188 @@ +package csharpgen + +import ( + "flag" + "fmt" + "net/http" + "path" + "slices" + "sort" + "strings" + + "google.golang.org/protobuf/compiler/protogen" + + sebufhttp "github.com/SebastienMelki/sebuf/http" + "github.com/SebastienMelki/sebuf/internal/contractmodel" + "github.com/SebastienMelki/sebuf/internal/tscommon" +) + +const ( + csharpObjectType = "object" + csharpStringType = "string" + csharpDoubleType = "double" + csharpBoolType = "bool" + csharpLongType = "long" + emptyTypeName = "Empty" + bytesTypeName = "bytes" + base64Encoding = "base64" +) + +type Options struct { + Namespace string + JSONLib string +} + +type Generator struct { + plugin *protogen.Plugin + opts Options +} + +func New(plugin *protogen.Plugin, opts Options) *Generator { + return &Generator{plugin: plugin, opts: opts} +} + +func NewOptions() (protogen.Options, *Options) { + var flags flag.FlagSet + cfg := &Options{} + flags.StringVar(&cfg.Namespace, "namespace", "Sebuf.Generated", "C# namespace") + flags.StringVar(&cfg.JSONLib, "json_lib", "newtonsoft", "JSON annotation library: newtonsoft or system_text_json") + return protogen.Options{ParamFunc: flags.Set}, cfg +} + +func (g *Generator) Generate() error { + for _, pkg := range contractmodel.Packages(g.plugin.Files) { + if err := g.generatePackage(pkg); err != nil { + return err + } + } + return nil +} + +func (g *Generator) generatePackage(pkg *contractmodel.Package) error { + messages := g.packageMessages(pkg) + messageIndex := make(map[string]*contractmodel.Message, len(messages)) + for _, message := range messages { + messageIndex[message.Name] = message + } + + packagePath := strings.ReplaceAll(pkg.Name, ".", "/") + gf := g.plugin.NewGeneratedFile(path.Join(packagePath, "Contracts.g.cs"), "") + gf.P("// Code generated by protoc-gen-csharp-http. DO NOT EDIT.") + gf.P("#nullable enable") + gf.P("using System;") + gf.P("using System.Collections.Generic;") + gf.P("using System.Linq;") + gf.P("using System.Net.Http;") + gf.P("using System.Runtime.Serialization;") + gf.P("using System.Text;") + gf.P("using System.Threading;") + gf.P("using System.Threading.Tasks;") + if g.useNewtonsoft() { + gf.P("using Newtonsoft.Json;") + gf.P("using Newtonsoft.Json.Converters;") + gf.P("using Newtonsoft.Json.Linq;") + } else { + gf.P("using System.Text.Json;") + gf.P("using System.Text.Json.Serialization;") + gf.P("using System.Text.Json.Nodes;") + } + gf.P() + gf.P("namespace ", g.opts.Namespace) + gf.P("{") + g.generateEnums(gf, pkg.Enums) + g.generateAPIException(gf) + g.generateMessages(gf, messages, messageIndex) + g.generateServiceClients(gf, pkg.Services, messageIndex) + g.generateServices(gf, pkg.Services) + gf.P("}") + return nil +} + +type generatedProperty struct { + jsonName string + name string + typ string + converter string +} + +func (g *Generator) messageProperties( + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) []generatedProperty { + var properties []generatedProperty + usedJSONNames := make(map[string]bool) + skippedFields := make(map[string]bool) + + g.markOneofFields(message, skippedFields) + g.appendStandardMessageProperties(&properties, usedJSONNames, skippedFields, message, messageIndex) + g.appendOneofProperties(&properties, usedJSONNames, message, messageIndex) + + return properties +} + +func (g *Generator) generateEnums(gf *protogen.GeneratedFile, enums []*contractmodel.Enum) { + for _, enum := range enums { + gf.P(" public enum ", enum.Name) + gf.P(" {") + for _, value := range enum.Values { + gf.P(` [EnumMember(Value = "`, value.JSONValue, `")]`) + gf.P(" ", pascalCase(value.Name), " = ", value.Number, ",") + } + gf.P(" }") + gf.P() + } +} + +func (g *Generator) generateMessages( + gf *protogen.GeneratedFile, + messages []*contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + for _, message := range messages { + if isRootUnwrapMessage(message) { + gf.P(" public sealed class ", message.Name, " : ", rootUnwrapBaseType(message)) + gf.P(" {") + gf.P(" }") + gf.P() + continue + } + + gf.P(" public sealed class ", message.Name) + gf.P(" {") + for _, property := range g.messageProperties(message, messageIndex) { + gf.P(" ", g.jsonAttribute(property.jsonName)) + if property.converter != "" { + gf.P(" ", property.converter) + } + gf.P(" public ", property.typ, " ", property.name, " { get; set; }") + } + gf.P(" }") + gf.P() + } +} + +func (g *Generator) packageMessages(pkg *contractmodel.Package) []*contractmodel.Message { + messages := slices.Clone(pkg.Messages) + if g.needsSyntheticEmpty(pkg) { + messages = append(messages, &contractmodel.Message{Name: emptyTypeName}) + } + return messages +} + +func (g *Generator) needsSyntheticEmpty(pkg *contractmodel.Package) bool { + if hasMessageNamed(pkg.Messages, emptyTypeName) { + return false + } + for _, service := range pkg.Services { + for _, method := range service.Methods { + if method.InputType == emptyTypeName || method.ResponseType == emptyTypeName { + return true + } + } + } + return false +} + +func hasMessageNamed(messages []*contractmodel.Message, name string) bool { + for _, message := range messages { + if message.Name == name { + return true + } + } + return false +} + +func (g *Generator) generateAPIException(gf *protogen.GeneratedFile) { + gf.P(" public sealed class ApiException : Exception") + gf.P(" {") + gf.P(" public int StatusCode { get; }") + gf.P(" public string ResponseBody { get; }") + gf.P() + gf.P(" public ApiException(int statusCode, string responseBody)") + gf.P(` : base($\"Request failed with status {statusCode}: {responseBody}\")`) + gf.P(" {") + gf.P(" StatusCode = statusCode;") + gf.P(" ResponseBody = responseBody;") + gf.P(" }") + gf.P(" }") + gf.P() +} + +func (g *Generator) generateServices(gf *protogen.GeneratedFile, services []*contractmodel.Service) { + gf.P(" public static class ServiceContracts") + gf.P(" {") + for _, service := range services { + gf.P(" public static class ", service.Name) + gf.P(" {") + gf.P(` public const string Name = "`, service.Name, `";`) + gf.P(` public const string BasePath = "`, service.BasePath, `";`) + for _, method := range service.Methods { + gf.P(" public static class ", method.Name) + gf.P(" {") + gf.P(` public const string HttpMethod = "`, method.HTTPMethod, `";`) + gf.P(` public const string Path = "`, method.Path, `";`) + gf.P(` public const string RequestType = "`, method.InputType, `";`) + gf.P(` public const string ResponseType = "`, method.ResponseType, `";`) + gf.P(" }") + } + gf.P(" }") + } + gf.P(" }") +} + +func (g *Generator) generateServiceClients( + gf *protogen.GeneratedFile, + services []*contractmodel.Service, + messageIndex map[string]*contractmodel.Message, +) { + for _, service := range services { + g.generateServiceClientOptions(gf, service) + g.generateServiceCallOptions(gf, service) + g.generateServiceClientInterface(gf, service) + g.generateServiceClientClass(gf, service, messageIndex) + } +} + +func (g *Generator) generateServiceClientOptions(gf *protogen.GeneratedFile, service *contractmodel.Service) { + gf.P(" public sealed class ", service.Name, "ClientOptions") + gf.P(" {") + gf.P(" public HttpClient? HttpClient { get; set; }") + gf.P(" public Dictionary? DefaultHeaders { get; set; }") + for _, header := range service.Headers { + gf.P(" public string? ", upperFirst(tscommon.HeaderNameToPropertyName(header.Name)), " { get; set; }") + } + gf.P(" }") + gf.P() +} + +func (g *Generator) generateServiceCallOptions(gf *protogen.GeneratedFile, service *contractmodel.Service) { + gf.P(" public sealed class ", service.Name, "CallOptions") + gf.P(" {") + gf.P(" public Dictionary? Headers { get; set; }") + for _, header := range serviceCallHeaders(service) { + gf.P(" public string? ", upperFirst(tscommon.HeaderNameToPropertyName(header.Name)), " { get; set; }") + } + gf.P(" }") + gf.P() +} + +func (g *Generator) generateServiceClientInterface(gf *protogen.GeneratedFile, service *contractmodel.Service) { + gf.P(" public interface I", service.Name, "Client") + gf.P(" {") + for _, method := range service.Methods { + responseType := clientResponseType(method) + gf.P( + " Task<", responseType, "> ", method.Name, + "Async(", method.InputType, " req, ", service.Name, + "CallOptions? options = null, CancellationToken cancellationToken = default);", + ) + } + gf.P(" }") + gf.P() +} + +func (g *Generator) generateServiceClientClass( + gf *protogen.GeneratedFile, + service *contractmodel.Service, + messageIndex map[string]*contractmodel.Message, +) { + gf.P(" public sealed class ", service.Name, "Client : I", service.Name, "Client") + gf.P(" {") + gf.P(" private readonly string _baseUrl;") + gf.P(" private readonly HttpClient _httpClient;") + gf.P(" private readonly Dictionary _defaultHeaders;") + if !g.useNewtonsoft() { + gf.P(" private static readonly JsonSerializerOptions JsonOptions = new()") + gf.P(" {") + gf.P(" PropertyNamingPolicy = null,") + gf.P(" DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull") + gf.P(" };") + } + gf.P() + g.generateServiceClientConstructor(gf, service) + for _, method := range service.Methods { + g.generateServiceClientMethod(gf, service, method, messageIndex) + } + g.generateServiceClientHelpers(gf, service, messageIndex) + gf.P(" }") + gf.P() +} + +func (g *Generator) generateServiceClientConstructor(gf *protogen.GeneratedFile, service *contractmodel.Service) { + gf.P( + " public ", service.Name, "Client(string baseUrl, ", service.Name, "ClientOptions? options = null)", + ) + gf.P(" {") + gf.P(` _baseUrl = baseUrl.TrimEnd('/');`) + gf.P(" _httpClient = options?.HttpClient ?? new HttpClient();") + gf.P(" _defaultHeaders = options?.DefaultHeaders is null") + gf.P(" ? new Dictionary()") + gf.P(" : new Dictionary(options.DefaultHeaders);") + for _, header := range service.Headers { + prop := upperFirst(tscommon.HeaderNameToPropertyName(header.Name)) + gf.P(" if (!string.IsNullOrEmpty(options?.", prop, "))") + gf.P(" {") + gf.P(` _defaultHeaders["`, header.Name, `"] = options.`, prop, "!;") + gf.P(" }") + } + gf.P(" }") + gf.P() +} + +func (g *Generator) generateServiceClientMethod( + gf *protogen.GeneratedFile, + service *contractmodel.Service, + method *contractmodel.Method, + messageIndex map[string]*contractmodel.Message, +) { + requestMessage := messageIndex[method.InputType] + responseType := clientResponseType(method) + gf.P( + " public async Task<", responseType, "> ", method.Name, "Async(", + method.InputType, " req, ", service.Name, + "CallOptions? options = null, CancellationToken cancellationToken = default)", + ) + gf.P(" {") + gf.P(` var path = "`, method.Path, `";`) + for _, pathParam := range method.PathParams { + gf.P( + ` path = path.Replace("{`, pathParam, `}", Uri.EscapeDataString(FormatPathValue(req.`, + pascalCase(pathParam), ")));", + ) + } + if requestMessage != nil { + g.generateQueryString(gf, requestMessage, method) + } else { + gf.P(" var query = new List();") + } + gf.P(` var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query);`) + gf.P(" var headers = BuildHeaders(options);") + for _, header := range method.Headers { + prop := upperFirst(tscommon.HeaderNameToPropertyName(header.Name)) + gf.P(" if (!string.IsNullOrEmpty(options?.", prop, "))") + gf.P(" {") + gf.P(` headers["`, header.Name, `"] = options.`, prop, "!;") + gf.P(" }") + } + bodyExpr := "null" + if methodHasBody(method) { + bodyExpr = "req" + } + gf.P( + " return await SendAsync<", responseType, ">(HttpMethod.", + httpMethodName(method.HTTPMethod), ", requestUri, ", bodyExpr, ", headers, cancellationToken);", + ) + gf.P(" }") + gf.P() +} + +func (g *Generator) generateQueryString( + gf *protogen.GeneratedFile, + requestMessage *contractmodel.Message, + method *contractmodel.Method, +) { + gf.P(" var query = new List();") + if requestMessage == nil || (method.HTTPMethod != http.MethodGet && method.HTTPMethod != http.MethodDelete) { + return + } + + for _, field := range requestMessage.Fields { + if field.Annotations.Query == nil { + continue + } + prop := pascalCase(field.Name) + paramName := field.Annotations.Query.Name + if field.Repeated { + gf.P(" if (req.", prop, " is not null)") + gf.P(" {") + gf.P(" foreach (var item in req.", prop, ")") + gf.P(" {") + gf.P( + ` query.Add(Uri.EscapeDataString("`, + paramName, + `") + "=" + Uri.EscapeDataString(FormatQueryValue(item)));`, + ) + gf.P(" }") + gf.P(" }") + continue + } + gf.P(" if (", csharpQueryCondition(field, "req."+prop), ")") + gf.P(" {") + gf.P( + ` query.Add(Uri.EscapeDataString("`, + paramName, + `") + "=" + Uri.EscapeDataString(FormatQueryValue(req.`, + prop, + ")));", + ) + gf.P(" }") + } +} + +func (g *Generator) generateServiceClientHelpers( + gf *protogen.GeneratedFile, + service *contractmodel.Service, + messageIndex map[string]*contractmodel.Message, +) { + gf.P(" private Dictionary BuildHeaders(", service.Name, "CallOptions? options)") + gf.P(" {") + gf.P(" var headers = new Dictionary(_defaultHeaders);") + gf.P(" if (options?.Headers is not null)") + gf.P(" {") + gf.P(" foreach (var pair in options.Headers)") + gf.P(" {") + gf.P(" headers[pair.Key] = pair.Value;") + gf.P(" }") + gf.P(" }") + for _, header := range service.Headers { + prop := upperFirst(tscommon.HeaderNameToPropertyName(header.Name)) + gf.P(" if (!string.IsNullOrEmpty(options?.", prop, "))") + gf.P(" {") + gf.P(` headers["`, header.Name, `"] = options.`, prop, "!;") + gf.P(" }") + } + gf.P(" return headers;") + gf.P(" }") + gf.P() + g.generateSendAsync(gf) + g.generateSerializeRequest(gf) + g.generateDeserializeResponse(gf) + g.generateJSONNormalizationHelpers(gf, messageIndex) + g.generatePathAndQueryHelpers(gf) +} + +func (g *Generator) generateSendAsync(gf *protogen.GeneratedFile) { + gf.P( + " private async Task SendAsync(", + "HttpMethod method, string requestUri, object? body, Dictionary headers, CancellationToken cancellationToken) where TResponse : new()", + ) + gf.P(" {") + gf.P(" using var request = new HttpRequestMessage(method, _baseUrl + requestUri);") + gf.P(" foreach (var header in headers)") + gf.P(" {") + gf.P(" request.Headers.TryAddWithoutValidation(header.Key, header.Value);") + gf.P(" }") + gf.P(" if (body is not null)") + gf.P(" {") + gf.P( + ` request.Content = new StringContent(SerializeRequest(body), Encoding.UTF8, "application/json");`, + ) + gf.P(" }") + gf.P(" using var response = await _httpClient.SendAsync(request, cancellationToken);") + gf.P( + ` var responseBody = response.Content is null ? string.Empty : await response.Content.ReadAsStringAsync(cancellationToken);`, + ) + gf.P(" if (!response.IsSuccessStatusCode)") + gf.P(" {") + gf.P(" throw new ApiException((int)response.StatusCode, responseBody);") + gf.P(" }") + gf.P(" if (typeof(TResponse) == typeof(", emptyTypeName, ") && string.IsNullOrWhiteSpace(responseBody))") + gf.P(" {") + gf.P(" return new TResponse();") + gf.P(" }") + gf.P(" if (string.IsNullOrWhiteSpace(responseBody))") + gf.P(" {") + gf.P(" return new TResponse();") + gf.P(" }") + gf.P(" var result = DeserializeResponse(responseBody);") + gf.P(" return result is null ? new TResponse() : result;") + gf.P(" }") + gf.P() +} + +func (g *Generator) generateSerializeRequest(gf *protogen.GeneratedFile) { + gf.P(" private string SerializeRequest(object value)") + gf.P(" {") + if g.useNewtonsoft() { + gf.P(" var json = JsonConvert.SerializeObject(value);") + gf.P(" return NormalizeSerializedJson(value, json);") + } else { + gf.P(" var json = JsonSerializer.Serialize(value, JsonOptions);") + gf.P(" return NormalizeSerializedJson(value, json);") + } + gf.P(" }") + gf.P() +} + +func (g *Generator) generateDeserializeResponse(gf *protogen.GeneratedFile) { + gf.P(" private static TResponse? DeserializeResponse(string json)") + gf.P(" {") + gf.P(" json = NormalizeResponseJson(typeof(TResponse), json);") + if g.useNewtonsoft() { + gf.P(" return JsonConvert.DeserializeObject(json);") + } else { + gf.P(" return JsonSerializer.Deserialize(json, JsonOptions);") + } + gf.P(" }") + gf.P() +} + +func (g *Generator) generateJSONNormalizationHelpers( + gf *protogen.GeneratedFile, + messageIndex map[string]*contractmodel.Message, +) { + messages := messagesRequiringJSONNormalization(messageIndex) + if len(messages) == 0 { + gf.P(" private static string NormalizeSerializedJson(object value, string json)") + gf.P(" {") + gf.P(" return json;") + gf.P(" }") + gf.P() + gf.P(" private static string NormalizeResponseJson(Type responseType, string json)") + gf.P(" {") + gf.P(" return json;") + gf.P(" }") + gf.P() + g.generateBytesEncodingHelpers(gf) + return + } + + if g.useNewtonsoft() { + g.generateNewtonsoftJSONNormalizationHelpers(gf, messages, messageIndex) + } else { + g.generateSystemTextJSONNormalizationHelpers(gf, messages, messageIndex) + } + g.generateBytesEncodingHelpers(gf) +} + +//nolint:funlen // Code generation for JSON normalization is intentionally verbose. +func (g *Generator) generateNewtonsoftJSONNormalizationHelpers( + gf *protogen.GeneratedFile, + messages []*contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + gf.P(" private static string NormalizeSerializedJson(object value, string json)") + gf.P(" {") + gf.P(" var token = JToken.Parse(json);") + gf.P(" var normalized = NormalizeSerializedToken(value.GetType(), token);") + gf.P(" return normalized.ToString(Formatting.None);") + gf.P(" }") + gf.P() + gf.P(" private static string NormalizeResponseJson(Type responseType, string json)") + gf.P(" {") + gf.P(" var token = JToken.Parse(json);") + gf.P(" var normalized = NormalizeResponseToken(responseType, token);") + gf.P(" return normalized.ToString(Formatting.None);") + gf.P(" }") + gf.P() + gf.P(" private static JToken NormalizeSerializedToken(Type messageType, JToken token)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + gf.P(` "`, message.Name, `" => NormalizeSerialized`, message.Name, "(token),") + } + gf.P(" _ => token") + gf.P(" };") + gf.P(" }") + gf.P(" private static JToken NormalizeResponseToken(Type messageType, JToken token)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + gf.P(` "`, message.Name, `" => NormalizeResponse`, message.Name, "(token),") + } + gf.P(" _ => token") + gf.P(" };") + gf.P(" }") + gf.P(" private static JToken NormalizeMapValueForSerialization(JToken token, Type messageType)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + if !mapValueUsesUnwrap(message) { + continue + } + field := findField(message, message.Unwrap.FieldName) + if field == nil { + continue + } + jsonName := field.JSONName + if jsonName == "" { + jsonName = field.Name + } + gf.P( + ` "`, message.Name, + `" => token is JObject obj && obj.TryGetValue("`, jsonName, + `", out var value) ? value : token,`, + ) + } + gf.P(" _ => NormalizeSerializedToken(messageType, token)") + gf.P(" };") + gf.P(" }") + gf.P(" private static JToken NormalizeMapValueForResponse(JToken token, Type messageType)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + if !mapValueUsesUnwrap(message) { + continue + } + field := findField(message, message.Unwrap.FieldName) + if field == nil { + continue + } + jsonName := field.JSONName + if jsonName == "" { + jsonName = field.Name + } + gf.P(` "`, message.Name, `" => new JObject { ["`, jsonName, `"] = token },`) + } + gf.P(" _ => NormalizeResponseToken(messageType, token)") + gf.P(" };") + gf.P(" }") + gf.P(" private static bool IsEmptyObject(JToken token)") + gf.P(" {") + gf.P(" return token is JObject obj && !obj.Properties().Any();") + gf.P(" }") + gf.P(" private static bool ShouldOmitEmptyField(JToken token)") + gf.P(" {") + gf.P(" return token.Type == JTokenType.Null || IsEmptyObject(token);") + gf.P(" }") + gf.P() + for _, message := range messages { + g.generateNewtonsoftMessageNormalizers(gf, message, messageIndex) + } +} + +//nolint:funlen // Code generation for JSON normalization is intentionally verbose. +func (g *Generator) generateSystemTextJSONNormalizationHelpers( + gf *protogen.GeneratedFile, + messages []*contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + gf.P(" private static string NormalizeSerializedJson(object value, string json)") + gf.P(" {") + gf.P(" var token = JsonNode.Parse(json);") + gf.P(" if (token is null)") + gf.P(" {") + gf.P(" return json;") + gf.P(" }") + gf.P(" var normalized = NormalizeSerializedNode(value.GetType(), token);") + gf.P(" return normalized.ToJsonString();") + gf.P(" }") + gf.P() + gf.P(" private static string NormalizeResponseJson(Type responseType, string json)") + gf.P(" {") + gf.P(" var token = JsonNode.Parse(json);") + gf.P(" if (token is null)") + gf.P(" {") + gf.P(" return json;") + gf.P(" }") + gf.P(" var normalized = NormalizeResponseNode(responseType, token);") + gf.P(" return normalized.ToJsonString();") + gf.P(" }") + gf.P() + gf.P(" private static JsonNode NormalizeSerializedNode(Type messageType, JsonNode token)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + gf.P(` "`, message.Name, `" => NormalizeSerialized`, message.Name, "(token),") + } + gf.P(" _ => token") + gf.P(" };") + gf.P(" }") + gf.P(" private static JsonNode NormalizeResponseNode(Type messageType, JsonNode token)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + gf.P(` "`, message.Name, `" => NormalizeResponse`, message.Name, "(token),") + } + gf.P(" _ => token") + gf.P(" };") + gf.P(" }") + gf.P(" private static JsonNode NormalizeMapValueForSerialization(JsonNode token, Type messageType)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + if !mapValueUsesUnwrap(message) { + continue + } + field := findField(message, message.Unwrap.FieldName) + if field == nil { + continue + } + jsonName := field.JSONName + if jsonName == "" { + jsonName = field.Name + } + gf.P( + ` "`, message.Name, + `" => token is JsonObject obj && obj["`, jsonName, + `"] is JsonNode value ? value : token,`, + ) + } + gf.P(" _ => NormalizeSerializedNode(messageType, token)") + gf.P(" };") + gf.P(" }") + gf.P(" private static JsonNode NormalizeMapValueForResponse(JsonNode token, Type messageType)") + gf.P(" {") + gf.P(" return messageType.Name switch") + gf.P(" {") + for _, message := range messages { + if !mapValueUsesUnwrap(message) { + continue + } + field := findField(message, message.Unwrap.FieldName) + if field == nil { + continue + } + jsonName := field.JSONName + if jsonName == "" { + jsonName = field.Name + } + gf.P(` "`, message.Name, `" => new JsonObject { ["`, jsonName, `"] = token.DeepClone() },`) + } + gf.P(" _ => NormalizeResponseNode(messageType, token)") + gf.P(" };") + gf.P(" }") + gf.P(" private static bool IsEmptyObject(JsonNode? token)") + gf.P(" {") + gf.P(" return token is JsonObject obj && obj.Count == 0;") + gf.P(" }") + gf.P(" private static bool ShouldOmitEmptyField(JsonNode? token)") + gf.P(" {") + gf.P(" return token is null || IsEmptyObject(token);") + gf.P(" }") + gf.P() + for _, message := range messages { + g.generateSystemTextMessageNormalizers(gf, message, messageIndex) + } +} + +func (g *Generator) generateNewtonsoftMessageNormalizers( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + gf.P(" private static JToken NormalizeSerialized", message.Name, "(JToken token)") + gf.P(" {") + g.generateNewtonsoftMessageNormalizationBody(gf, message, messageIndex, true) + gf.P(" }") + gf.P() + gf.P(" private static JToken NormalizeResponse", message.Name, "(JToken token)") + gf.P(" {") + g.generateNewtonsoftMessageNormalizationBody(gf, message, messageIndex, false) + gf.P(" }") + gf.P() +} + +func (g *Generator) generateSystemTextMessageNormalizers( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + gf.P(" private static JsonNode NormalizeSerialized", message.Name, "(JsonNode token)") + gf.P(" {") + g.generateSystemTextMessageNormalizationBody(gf, message, messageIndex, true) + gf.P(" }") + gf.P() + gf.P(" private static JsonNode NormalizeResponse", message.Name, "(JsonNode token)") + gf.P(" {") + g.generateSystemTextMessageNormalizationBody(gf, message, messageIndex, false) + gf.P(" }") + gf.P() +} + +func (g *Generator) generateBytesEncodingHelpers(gf *protogen.GeneratedFile) { + gf.P(" private static string EncodeBytes(byte[] bytes, string encoding)") + gf.P(" {") + gf.P(" var base64 = Convert.ToBase64String(bytes);") + gf.P(" return encoding switch") + gf.P(" {") + gf.P(` "base64_raw" => base64.TrimEnd('='),`) + gf.P(` "base64url" => base64.Replace('+', '-').Replace('/', '_'),`) + gf.P(` "base64url_raw" => base64.Replace('+', '-').Replace('/', '_').TrimEnd('='),`) + gf.P(` "hex" => Convert.ToHexString(bytes).ToLowerInvariant(),`) + gf.P(" _ => base64") + gf.P(" };") + gf.P(" }") + gf.P() + gf.P(" private static string ReencodeBytes(string encoded, string fromEncoding, string toEncoding)") + gf.P(" {") + gf.P(" return EncodeBytes(DecodeBytes(encoded, fromEncoding), toEncoding);") + gf.P(" }") + gf.P() + gf.P(" private static byte[] DecodeBytes(string encoded, string encoding)") + gf.P(" {") + gf.P(" return encoding switch") + gf.P(" {") + gf.P(` "hex" => Convert.FromHexString(encoded),`) + gf.P( + ` "base64url" => Convert.FromBase64String(`, + `NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))),`, + ) + gf.P( + ` "base64url_raw" => Convert.FromBase64String(`, + `NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))),`, + ) + gf.P(` "base64_raw" => Convert.FromBase64String(NormalizeBase64(encoded)),`) + gf.P(" _ => Convert.FromBase64String(NormalizeBase64(encoded))") + gf.P(" };") + gf.P(" }") + gf.P() + gf.P(" private static string NormalizeBase64(string value)") + gf.P(" {") + gf.P(" var remainder = value.Length % 4;") + gf.P(" if (remainder == 0)") + gf.P(" {") + gf.P(" return value;") + gf.P(" }") + gf.P(` return value + new string('=', 4 - remainder);`) + gf.P(" }") + gf.P() +} + +func (g *Generator) generateNewtonsoftMessageNormalizationBody( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, + serialize bool, +) { + if isRootUnwrapMessage(message) { + g.generateNewtonsoftRootUnwrapNormalizationBody(gf, message, messageIndex, serialize) + return + } + + gf.P(" if (token is not JObject obj)") + gf.P(" {") + gf.P(" return token;") + gf.P(" }") + for _, field := range message.Fields { + g.generateNewtonsoftFieldNormalization(gf, field, messageIndex, serialize) + } + g.generateNewtonsoftOneofNormalization(gf, message, messageIndex) + gf.P(" return obj;") +} + +//nolint:nestif // Branching mirrors generated JSON normalization cases. +func (g *Generator) generateNewtonsoftRootUnwrapNormalizationBody( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, + serialize bool, +) { + if len(message.Fields) == 0 { + gf.P(" return token;") + return + } + field := message.Fields[0] + if field.Repeated { + if field.Type == nil || field.Type.Kind != contractmodel.KindMessage || + !messageNeedsJSONNormalization(messageIndex[field.Type.Name], messageIndex) { + gf.P(" return token;") + return + } + gf.P(" if (token is not JArray array)") + gf.P(" {") + gf.P(" return token;") + gf.P(" }") + gf.P(" for (var i = 0; i < array.Count; i++)") + gf.P(" {") + if serialize { + gf.P(" array[i] = NormalizeSerializedToken(typeof(", field.Type.Name, "), array[i]!);") + } else { + gf.P(" array[i] = NormalizeResponseToken(typeof(", field.Type.Name, "), array[i]!);") + } + gf.P(" }") + gf.P(" return array;") + return + } + if !field.IsMap || field.Type == nil || field.Type.MapValue == nil { + gf.P(" return token;") + return + } + childMessage := field.Type.MapValue.Kind == contractmodel.KindMessage && + messageNeedsJSONNormalization(messageIndex[field.Type.MapValue.Name], messageIndex) + if !childMessage { + gf.P(" return token;") + return + } + gf.P(" if (token is not JObject obj)") + gf.P(" {") + gf.P(" return token;") + gf.P(" }") + gf.P(" foreach (var property in obj.Properties().ToList())") + gf.P(" {") + if serialize { + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " property.Value = NormalizeMapValueForSerialization(", + "property.Value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " property.Value = NormalizeSerializedToken(", + "typeof(", field.Type.MapValue.Name, "), property.Value);", + ) + } + } else { + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " property.Value = NormalizeMapValueForResponse(", + "property.Value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " property.Value = NormalizeResponseToken(", + "typeof(", field.Type.MapValue.Name, "), property.Value);", + ) + } + } + gf.P(" }") + gf.P(" return obj;") +} + +//nolint:funlen,gocognit,nestif,golines // Branching mirrors generated JSON normalization cases. +func (g *Generator) generateNewtonsoftFieldNormalization( + gf *protogen.GeneratedFile, + field *contractmodel.Field, + messageIndex map[string]*contractmodel.Message, + serialize bool, +) { + if field == nil { + return + } + jsonName := field.JSONName + if jsonName == "" { + jsonName = field.Name + } + + if needsBytesEncodingNormalization(field) { + gf.P( + ` if (obj.TryGetValue("`, jsonName, + `", out var `, pascalCase(jsonName), `Token) && `, + pascalCase(jsonName), `Token.Type == JTokenType.String)`, + ) + gf.P(" {") + if serialize { + gf.P( + ` obj["`, jsonName, `"] = ReencodeBytes(`, pascalCase(jsonName), + `Token.Value()!, "`, base64Encoding, `", "`, + bytesEncodingName(field.Annotations.BytesEncoding), `");`, + ) + } else { + gf.P( + ` obj["`, jsonName, `"] = ReencodeBytes(`, pascalCase(jsonName), + `Token.Value()!, "`, bytesEncodingName(field.Annotations.BytesEncoding), + `", "`, base64Encoding, `");`, + ) + } + gf.P(" }") + } + + if emptyBehaviorNeedsNormalization(field) { + g.generateNewtonsoftEmptyBehaviorNormalization(gf, field, jsonName, serialize) + } + + if field.IsMap && field.Type != nil && field.Type.MapValue != nil && field.Type.MapValue.Kind == contractmodel.KindMessage && + messageNeedsJSONNormalization(messageIndex[field.Type.MapValue.Name], messageIndex) { + gf.P(` if (obj.TryGetValue("`, jsonName, `", out var `, pascalCase(jsonName), `Map) && `, pascalCase(jsonName), `Map is JObject `, pascalCase(jsonName), `Object)`) + gf.P(" {") + gf.P(" foreach (var property in ", pascalCase(jsonName), "Object.Properties().ToList())") + gf.P(" {") + if serialize { + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " property.Value = NormalizeMapValueForSerialization(", + "property.Value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " property.Value = NormalizeSerializedToken(", + "typeof(", field.Type.MapValue.Name, "), property.Value);", + ) + } + } else { + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " property.Value = NormalizeMapValueForResponse(", + "property.Value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " property.Value = NormalizeResponseToken(", + "typeof(", field.Type.MapValue.Name, "), property.Value);", + ) + } + } + gf.P(" }") + gf.P(" }") + return + } + + if field.Type != nil && field.Type.Kind == contractmodel.KindMessage && messageNeedsJSONNormalization(messageIndex[field.Type.Name], messageIndex) { + if field.Repeated { + gf.P(` if (obj.TryGetValue("`, jsonName, `", out var `, pascalCase(jsonName), `List) && `, pascalCase(jsonName), `List is JArray `, pascalCase(jsonName), `Array)`) + gf.P(" {") + gf.P(" for (var i = 0; i < ", pascalCase(jsonName), "Array.Count; i++)") + gf.P(" {") + if serialize { + gf.P( + " ", pascalCase(jsonName), + "Array[i] = NormalizeSerializedToken(typeof(", + field.Type.Name, "), ", pascalCase(jsonName), "Array[i]!);", + ) + } else { + gf.P( + " ", pascalCase(jsonName), + "Array[i] = NormalizeResponseToken(typeof(", + field.Type.Name, "), ", pascalCase(jsonName), "Array[i]!);", + ) + } + gf.P(" }") + gf.P(" }") + return + } + gf.P(` if (obj.TryGetValue("`, jsonName, `", out var `, pascalCase(jsonName), `Child) && `, pascalCase(jsonName), `Child.Type == JTokenType.Object)`) + gf.P(" {") + if serialize { + gf.P( + " obj[\"", jsonName, + "\"] = NormalizeSerializedToken(typeof(", field.Type.Name, + "), ", pascalCase(jsonName), "Child);", + ) + } else { + gf.P( + " obj[\"", jsonName, + "\"] = NormalizeResponseToken(typeof(", field.Type.Name, + "), ", pascalCase(jsonName), "Child);", + ) + } + gf.P(" }") + } +} + +func (g *Generator) generateSystemTextMessageNormalizationBody( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, + serialize bool, +) { + if isRootUnwrapMessage(message) { + g.generateSystemTextRootUnwrapNormalizationBody(gf, message, messageIndex, serialize) + return + } + + gf.P(" if (token is not JsonObject obj)") + gf.P(" {") + gf.P(" return token;") + gf.P(" }") + for _, field := range message.Fields { + g.generateSystemTextFieldNormalization(gf, field, messageIndex, serialize) + } + g.generateSystemTextOneofNormalization(gf, message, messageIndex) + gf.P(" return obj;") +} + +func (g *Generator) generateNewtonsoftOneofNormalization( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + for _, oneof := range message.Oneofs { + if oneof.Discriminator == "" { + continue + } + g.generateNewtonsoftDiscriminatorInference(gf, oneof, messageIndex) + g.generateNewtonsoftDiscriminatorCleanup(gf, oneof, messageIndex) + } +} + +func (g *Generator) generateSystemTextOneofNormalization( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + for _, oneof := range message.Oneofs { + if oneof.Discriminator == "" { + continue + } + g.generateSystemTextDiscriminatorInference(gf, oneof, messageIndex) + g.generateSystemTextDiscriminatorCleanup(gf, oneof, messageIndex) + } +} + +func (g *Generator) generateNewtonsoftDiscriminatorInference( + gf *protogen.GeneratedFile, + oneof *contractmodel.Oneof, + messageIndex map[string]*contractmodel.Message, +) { + discriminator := oneof.Discriminator + tokenName := pascalCase(discriminator) + "Discriminator" + gf.P( + ` if (!obj.TryGetValue("`, discriminator, `", out var `, tokenName, + `) || `, tokenName, `.Type == JTokenType.Null || string.IsNullOrEmpty(`, tokenName, `.Value()))`, + ) + gf.P(" {") + first := true + for _, variant := range oneof.Variants { + condition := newtonsoftVariantCondition(oneof, variant, messageIndex) + if condition == "" { + continue + } + if first { + gf.P(" if (", condition, ")") + first = false + } else { + gf.P(" else if (", condition, ")") + } + gf.P(" {") + gf.P(` obj["`, discriminator, `"] = "`, variant.DiscriminatorValue, `";`) + gf.P(" }") + } + gf.P(" }") +} + +func (g *Generator) generateNewtonsoftDiscriminatorCleanup( + gf *protogen.GeneratedFile, + oneof *contractmodel.Oneof, + messageIndex map[string]*contractmodel.Message, +) { + discriminator := oneof.Discriminator + tokenName := pascalCase(discriminator) + "Selected" + gf.P( + ` if (obj.TryGetValue("`, discriminator, `", out var `, tokenName, + `) && `, tokenName, `.Type == JTokenType.String)`, + ) + gf.P(" {") + gf.P(" switch (", tokenName, ".Value())") + gf.P(" {") + for _, variant := range oneof.Variants { + gf.P(` case "`, variant.DiscriminatorValue, `":`) + for _, jsonName := range oneofOtherVariantJSONFields(oneof, variant, messageIndex) { + gf.P(` obj.Remove("`, jsonName, `");`) + } + gf.P(" break;") + } + gf.P(" }") + gf.P(" }") +} + +func (g *Generator) generateSystemTextDiscriminatorInference( + gf *protogen.GeneratedFile, + oneof *contractmodel.Oneof, + messageIndex map[string]*contractmodel.Message, +) { + discriminator := oneof.Discriminator + tokenName := pascalCase(discriminator) + "Discriminator" + gf.P( + ` if (!obj.TryGetPropertyValue("`, discriminator, `", out var `, tokenName, + `) || `, tokenName, ` is null || string.IsNullOrEmpty(`, tokenName, `.GetValue()))`, + ) + gf.P(" {") + first := true + for _, variant := range oneof.Variants { + condition := systemTextVariantCondition(oneof, variant, messageIndex) + if condition == "" { + continue + } + if first { + gf.P(" if (", condition, ")") + first = false + } else { + gf.P(" else if (", condition, ")") + } + gf.P(" {") + gf.P(` obj["`, discriminator, `"] = "`, variant.DiscriminatorValue, `";`) + gf.P(" }") + } + gf.P(" }") +} + +func (g *Generator) generateSystemTextDiscriminatorCleanup( + gf *protogen.GeneratedFile, + oneof *contractmodel.Oneof, + messageIndex map[string]*contractmodel.Message, +) { + discriminator := oneof.Discriminator + tokenName := pascalCase(discriminator) + "Selected" + gf.P( + ` if (obj.TryGetPropertyValue("`, discriminator, `", out var `, tokenName, + `) && `, tokenName, ` is JsonValue)`, + ) + gf.P(" {") + gf.P(" switch (", tokenName, "!.GetValue())") + gf.P(" {") + for _, variant := range oneof.Variants { + gf.P(` case "`, variant.DiscriminatorValue, `":`) + for _, jsonName := range oneofOtherVariantJSONFields(oneof, variant, messageIndex) { + gf.P(` obj.Remove("`, jsonName, `");`) + } + gf.P(" break;") + } + gf.P(" }") + gf.P(" }") +} + +//nolint:nestif // Branching mirrors generated JSON normalization cases. +func (g *Generator) generateSystemTextRootUnwrapNormalizationBody( + gf *protogen.GeneratedFile, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, + serialize bool, +) { + if len(message.Fields) == 0 { + gf.P(" return token;") + return + } + field := message.Fields[0] + if field.Repeated { + if field.Type == nil || field.Type.Kind != contractmodel.KindMessage || + !messageNeedsJSONNormalization(messageIndex[field.Type.Name], messageIndex) { + gf.P(" return token;") + return + } + gf.P(" if (token is not JsonArray array)") + gf.P(" {") + gf.P(" return token;") + gf.P(" }") + gf.P(" for (var i = 0; i < array.Count; i++)") + gf.P(" {") + gf.P(" if (array[i] is not JsonNode item)") + gf.P(" {") + gf.P(" continue;") + gf.P(" }") + if serialize { + gf.P(" array[i] = NormalizeSerializedNode(typeof(", field.Type.Name, "), item);") + } else { + gf.P(" array[i] = NormalizeResponseNode(typeof(", field.Type.Name, "), item);") + } + gf.P(" }") + gf.P(" return array;") + return + } + if !field.IsMap || field.Type == nil || field.Type.MapValue == nil { + gf.P(" return token;") + return + } + childMessage := field.Type.MapValue.Kind == contractmodel.KindMessage && + messageNeedsJSONNormalization(messageIndex[field.Type.MapValue.Name], messageIndex) + if !childMessage { + gf.P(" return token;") + return + } + gf.P(" if (token is not JsonObject obj)") + gf.P(" {") + gf.P(" return token;") + gf.P(" }") + gf.P(" foreach (var key in obj.Select(pair => pair.Key).ToList())") + gf.P(" {") + if serialize { + gf.P(" if (obj[key] is JsonNode value)") + gf.P(" {") + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " obj[key] = NormalizeMapValueForSerialization(", + "value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " obj[key] = NormalizeSerializedNode(", + "typeof(", field.Type.MapValue.Name, "), value);", + ) + } + gf.P(" }") + } else { + gf.P(" if (obj[key] is JsonNode value)") + gf.P(" {") + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " obj[key] = NormalizeMapValueForResponse(", + "value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " obj[key] = NormalizeResponseNode(", + "typeof(", field.Type.MapValue.Name, "), value);", + ) + } + gf.P(" }") + } + gf.P(" }") + gf.P(" return obj;") +} + +//nolint:funlen,gocognit,nestif,golines // Branching mirrors generated JSON normalization cases. +func (g *Generator) generateSystemTextFieldNormalization( + gf *protogen.GeneratedFile, + field *contractmodel.Field, + messageIndex map[string]*contractmodel.Message, + serialize bool, +) { + if field == nil { + return + } + jsonName := field.JSONName + if jsonName == "" { + jsonName = field.Name + } + + if needsBytesEncodingNormalization(field) { + gf.P( + ` if (obj["`, jsonName, `"] is JsonValue `, + pascalCase(jsonName), `Token && `, + pascalCase(jsonName), `Token.TryGetValue(out var `, + pascalCase(jsonName), `Value))`, + ) + gf.P(" {") + if serialize { + gf.P( + ` obj["`, jsonName, `"] = ReencodeBytes(`, pascalCase(jsonName), + `Value, "`, base64Encoding, `", "`, + bytesEncodingName(field.Annotations.BytesEncoding), `");`, + ) + } else { + gf.P( + ` obj["`, jsonName, `"] = ReencodeBytes(`, pascalCase(jsonName), + `Value, "`, bytesEncodingName(field.Annotations.BytesEncoding), + `", "`, base64Encoding, `");`, + ) + } + gf.P(" }") + } + + if emptyBehaviorNeedsNormalization(field) { + g.generateSystemTextEmptyBehaviorNormalization(gf, field, jsonName, serialize) + } + + if field.IsMap && field.Type != nil && field.Type.MapValue != nil && field.Type.MapValue.Kind == contractmodel.KindMessage && + messageNeedsJSONNormalization(messageIndex[field.Type.MapValue.Name], messageIndex) { + gf.P(` if (obj["`, jsonName, `"] is JsonObject `, pascalCase(jsonName), `Object)`) + gf.P(" {") + gf.P(" foreach (var key in ", pascalCase(jsonName), "Object.Select(pair => pair.Key).ToList())") + gf.P(" {") + gf.P(" if (", pascalCase(jsonName), "Object[key] is JsonNode value)") + gf.P(" {") + if serialize { + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " ", pascalCase(jsonName), + "Object[key] = NormalizeMapValueForSerialization(", + "value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " ", pascalCase(jsonName), + "Object[key] = NormalizeSerializedNode(typeof(", + field.Type.MapValue.Name, "), value);", + ) + } + } else { + if mapValueUsesUnwrap(messageIndex[field.Type.MapValue.Name]) { + gf.P( + " ", pascalCase(jsonName), + "Object[key] = NormalizeMapValueForResponse(", + "value, typeof(", field.Type.MapValue.Name, "));", + ) + } else { + gf.P( + " ", pascalCase(jsonName), + "Object[key] = NormalizeResponseNode(typeof(", + field.Type.MapValue.Name, "), value);", + ) + } + } + gf.P(" }") + gf.P(" }") + gf.P(" }") + return + } + + if field.Type != nil && field.Type.Kind == contractmodel.KindMessage && messageNeedsJSONNormalization(messageIndex[field.Type.Name], messageIndex) { + if field.Repeated { + gf.P(` if (obj["`, jsonName, `"] is JsonArray `, pascalCase(jsonName), `Array)`) + gf.P(" {") + gf.P(" for (var i = 0; i < ", pascalCase(jsonName), "Array.Count; i++)") + gf.P(" {") + gf.P(" if (", pascalCase(jsonName), "Array[i] is not JsonNode item)") + gf.P(" {") + gf.P(" continue;") + gf.P(" }") + if serialize { + gf.P( + " ", pascalCase(jsonName), + "Array[i] = NormalizeSerializedNode(typeof(", + field.Type.Name, "), item);", + ) + } else { + gf.P( + " ", pascalCase(jsonName), + "Array[i] = NormalizeResponseNode(typeof(", + field.Type.Name, "), item);", + ) + } + gf.P(" }") + gf.P(" }") + return + } + gf.P(` if (obj["`, jsonName, `"] is JsonNode child && child is JsonObject)`) + gf.P(" {") + if serialize { + gf.P( + " obj[\"", jsonName, + "\"] = NormalizeSerializedNode(typeof(", field.Type.Name, + "), child);", + ) + } else { + gf.P( + " obj[\"", jsonName, + "\"] = NormalizeResponseNode(typeof(", field.Type.Name, + "), child);", + ) + } + gf.P(" }") + } +} + +func messagesRequiringJSONNormalization(messageIndex map[string]*contractmodel.Message) []*contractmodel.Message { + if len(messageIndex) == 0 { + return nil + } + names := make([]string, 0, len(messageIndex)) + for name := range messageIndex { + if messageNeedsJSONNormalization(messageIndex[name], messageIndex) { + names = append(names, name) + } + } + sort.Strings(names) + result := make([]*contractmodel.Message, 0, len(names)) + for _, name := range names { + result = append(result, messageIndex[name]) + } + return result +} + +//nolint:gocognit // This is the central predicate for generated JSON normalization. +func messageNeedsJSONNormalization( + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) bool { + if message == nil { + return false + } + if oneofNeedsNormalization(message, messageIndex) { + return true + } + if isRootUnwrapMessage(message) { + field := message.Fields[0] + if field.Repeated && field.Type != nil && field.Type.Kind == contractmodel.KindMessage { + return messageNeedsJSONNormalization(messageIndex[field.Type.Name], messageIndex) + } + if field.IsMap && field.Type != nil && field.Type.MapValue != nil && + field.Type.MapValue.Kind == contractmodel.KindMessage { + return messageNeedsJSONNormalization(messageIndex[field.Type.MapValue.Name], messageIndex) + } + return false + } + for _, field := range message.Fields { + if emptyBehaviorNeedsNormalization(field) { + return true + } + if needsBytesEncodingNormalization(field) { + return true + } + if field.IsMap && field.Type != nil && field.Type.MapValue != nil && + field.Type.MapValue.Kind == contractmodel.KindMessage && + messageNeedsJSONNormalization(messageIndex[field.Type.MapValue.Name], messageIndex) { + return true + } + if field.Type != nil && field.Type.Kind == contractmodel.KindMessage && + messageNeedsJSONNormalization(messageIndex[field.Type.Name], messageIndex) { + return true + } + } + return false +} + +func oneofNeedsNormalization(message *contractmodel.Message, messageIndex map[string]*contractmodel.Message) bool { + if message == nil { + return false + } + for _, oneof := range message.Oneofs { + if oneof.Discriminator == "" { + continue + } + if len(oneofVariantJSONFields(oneof, messageIndex)) > 0 { + return true + } + } + return false +} + +func newtonsoftVariantCondition( + oneof *contractmodel.Oneof, + variant *contractmodel.OneofVariant, + messageIndex map[string]*contractmodel.Message, +) string { + parts := variantPresenceExpressions(oneof, variant, messageIndex, func(jsonName string) string { + tokenName := pascalCase(variant.FieldName) + pascalCase(jsonName) + "Token" + return `obj.TryGetValue("` + jsonName + `", out var ` + tokenName + `) && ` + tokenName + `.Type != JTokenType.Null` + }) + return strings.Join(parts, " || ") +} + +func systemTextVariantCondition( + oneof *contractmodel.Oneof, + variant *contractmodel.OneofVariant, + messageIndex map[string]*contractmodel.Message, +) string { + parts := variantPresenceExpressions(oneof, variant, messageIndex, func(jsonName string) string { + tokenName := pascalCase(variant.FieldName) + pascalCase(jsonName) + "Token" + return `obj.TryGetPropertyValue("` + jsonName + `", out var ` + tokenName + `) && ` + tokenName + ` is not null` + }) + return strings.Join(parts, " || ") +} + +func variantPresenceExpressions( + oneof *contractmodel.Oneof, + variant *contractmodel.OneofVariant, + messageIndex map[string]*contractmodel.Message, + expr func(string) string, +) []string { + jsonFields := oneofVariantJSONFieldsForVariant(oneof, variant, messageIndex) + parts := make([]string, 0, len(jsonFields)) + for _, jsonName := range jsonFields { + parts = append(parts, expr(jsonName)) + } + return parts +} + +func oneofVariantJSONFields( + oneof *contractmodel.Oneof, + messageIndex map[string]*contractmodel.Message, +) []string { + var fields []string + for _, variant := range oneof.Variants { + fields = append(fields, oneofVariantJSONFieldsForVariant(oneof, variant, messageIndex)...) + } + return fields +} + +func oneofOtherVariantJSONFields( + oneof *contractmodel.Oneof, + selected *contractmodel.OneofVariant, + messageIndex map[string]*contractmodel.Message, +) []string { + var fields []string + for _, variant := range oneof.Variants { + if variant.FieldName == selected.FieldName { + continue + } + fields = append(fields, oneofVariantJSONFieldsForVariant(oneof, variant, messageIndex)...) + } + return fields +} + +func oneofVariantJSONFieldsForVariant( + oneof *contractmodel.Oneof, + variant *contractmodel.OneofVariant, + messageIndex map[string]*contractmodel.Message, +) []string { + if oneof.Flatten && variant.IsMessage { + child := messageIndex[variant.Type.Name] + if child == nil { + return nil + } + fields := make([]string, 0, len(child.Fields)) + for _, field := range child.Fields { + jsonName := field.JSONName + if jsonName == "" { + jsonName = field.Name + } + fields = append(fields, jsonName) + } + return fields + } + return []string{variant.FieldName} +} + +func mapValueUsesUnwrap(message *contractmodel.Message) bool { + return message != nil && message.Unwrap != nil && !message.Unwrap.IsRoot +} + +func needsBytesEncodingNormalization(field *contractmodel.Field) bool { + return field != nil && + field.Type != nil && + field.Type.Name == bytesTypeName && + field.Annotations.BytesEncoding != sebufhttp.BytesEncoding_BYTES_ENCODING_UNSPECIFIED && + field.Annotations.BytesEncoding != sebufhttp.BytesEncoding_BYTES_ENCODING_BASE64 +} + +func emptyBehaviorNeedsNormalization(field *contractmodel.Field) bool { + return field != nil && + field.Type != nil && + field.Type.Kind == contractmodel.KindMessage && + !field.Repeated && + !field.IsMap && + (field.Annotations.EmptyBehavior == sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_NULL || + field.Annotations.EmptyBehavior == sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_OMIT) +} + +func (g *Generator) generateNewtonsoftEmptyBehaviorNormalization( + gf *protogen.GeneratedFile, + field *contractmodel.Field, + jsonName string, + serialize bool, +) { + tokenName := pascalCase(jsonName) + "EmptyBehavior" + gf.P(` if (obj.TryGetValue("`, jsonName, `", out var `, tokenName, `))`) + gf.P(" {") + g.generateEmptyBehaviorNormalizationBody( + gf, field, jsonName, tokenName, serialize, ` obj["`+jsonName+`"] = JValue.CreateNull();`, + ) + gf.P(" }") +} + +func (g *Generator) generateSystemTextEmptyBehaviorNormalization( + gf *protogen.GeneratedFile, + field *contractmodel.Field, + jsonName string, + serialize bool, +) { + tokenName := pascalCase(jsonName) + "EmptyBehavior" + gf.P(` if (obj.TryGetPropertyValue("`, jsonName, `", out var `, tokenName, `))`) + gf.P(" {") + g.generateEmptyBehaviorNormalizationBody( + gf, field, jsonName, tokenName, serialize, ` obj["`+jsonName+`"] = null;`, + ) + gf.P(" }") +} + +func (g *Generator) generateEmptyBehaviorNormalizationBody( + gf *protogen.GeneratedFile, + field *contractmodel.Field, + jsonName string, + tokenName string, + serialize bool, + nullAssignment string, +) { + switch field.Annotations.EmptyBehavior { + case sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_UNSPECIFIED, + sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_PRESERVE: + // No special wire normalization needed. + case sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_NULL: + gf.P(" if (IsEmptyObject(", tokenName, "))") + gf.P(" {") + gf.P(nullAssignment) + gf.P(" }") + case sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_OMIT: + check := "IsEmptyObject(" + tokenName + ")" + if serialize { + check = "ShouldOmitEmptyField(" + tokenName + ")" + } + gf.P(" if (", check, ")") + gf.P(" {") + gf.P(` obj.Remove("`, jsonName, `");`) + gf.P(" }") + } +} + +func bytesEncodingName(encoding sebufhttp.BytesEncoding) string { + switch encoding { + case sebufhttp.BytesEncoding_BYTES_ENCODING_UNSPECIFIED: + return base64Encoding + case sebufhttp.BytesEncoding_BYTES_ENCODING_BASE64: + return base64Encoding + case sebufhttp.BytesEncoding_BYTES_ENCODING_BASE64_RAW: + return "base64_raw" + case sebufhttp.BytesEncoding_BYTES_ENCODING_BASE64URL: + return "base64url" + case sebufhttp.BytesEncoding_BYTES_ENCODING_BASE64URL_RAW: + return "base64url_raw" + case sebufhttp.BytesEncoding_BYTES_ENCODING_HEX: + return "hex" + default: + return base64Encoding + } +} + +func (g *Generator) generatePathAndQueryHelpers(gf *protogen.GeneratedFile) { + gf.P(" private static string FormatPathValue(object? value)") + gf.P(" {") + gf.P(` return value?.ToString() ?? string.Empty;`) + gf.P(" }") + gf.P() + gf.P(" private static string FormatQueryValue(object? value)") + gf.P(" {") + gf.P(` return value?.ToString() ?? string.Empty;`) + gf.P(" }") + gf.P() +} + +func serviceCallHeaders(service *contractmodel.Service) []*contractmodel.Header { + headers := make([]*contractmodel.Header, 0, len(service.Headers)) + seen := make(map[string]bool) + for _, header := range service.Headers { + if header == nil || seen[header.Name] { + continue + } + seen[header.Name] = true + headers = append(headers, header) + } + for _, method := range service.Methods { + for _, header := range method.Headers { + if header == nil || seen[header.Name] { + continue + } + seen[header.Name] = true + headers = append(headers, header) + } + } + return headers +} + +func methodHasBody(method *contractmodel.Method) bool { + switch method.HTTPMethod { + case http.MethodPost, http.MethodPut, http.MethodPatch: + return true + default: + return false + } +} + +func httpMethodName(method string) string { + switch method { + case http.MethodGet: + return "Get" + case http.MethodPost: + return "Post" + case http.MethodPut: + return "Put" + case http.MethodDelete: + return "Delete" + case http.MethodPatch: + return "Patch" + default: + return "Post" + } +} + +func clientResponseType(method *contractmodel.Method) string { + if method.ResponseType == "" { + return emptyTypeName + } + return method.ResponseType +} + +func csharpQueryCondition(field *contractmodel.Field, expr string) string { + if field == nil || field.Type == nil { + return expr + " != null" + } + if field.Optional || field.HasPresence { + return expr + " != null" + } + switch field.Type.Kind { + case contractmodel.KindScalar: + switch field.Type.Name { + case csharpStringType, bytesTypeName: + return "!string.IsNullOrEmpty(" + expr + ")" + case "bool": + return expr + case "int32", "fixed32", "uint32", "sfixed32", "sint32", "int64", "uint64", "fixed64", "sfixed64", "sint64": + return expr + " != 0" + case "float", "double": + return expr + " != 0" + default: + return expr + " != null" + } + case contractmodel.KindEnum: + return expr + " != 0" + case contractmodel.KindMessage, contractmodel.KindWellKnown, contractmodel.KindMap: + return expr + " != null" + default: + return expr + " != null" + } +} + +func (g *Generator) markOneofFields(message *contractmodel.Message, skippedFields map[string]bool) { + for _, oneof := range message.Oneofs { + for _, variant := range oneof.Variants { + skippedFields[variant.FieldName] = true + } + } +} + +func (g *Generator) appendStandardMessageProperties( + properties *[]generatedProperty, + usedJSONNames map[string]bool, + skippedFields map[string]bool, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + for _, field := range message.Fields { + if skippedFields[field.Name] { + continue + } + if field.Annotations.Flatten && g.appendFlattenedProperties(properties, usedJSONNames, field, messageIndex) { + continue + } + g.appendProperty( + properties, + usedJSONNames, + field.Name, + pascalCase(field.Name), + csharpPropertyType(field, false), + g.enumConverterAttribute(field), + ) + } +} + +func (g *Generator) appendOneofProperties( + properties *[]generatedProperty, + usedJSONNames map[string]bool, + message *contractmodel.Message, + messageIndex map[string]*contractmodel.Message, +) { + for _, oneof := range message.Oneofs { + if oneof.Discriminator != "" { + g.appendProperty( + properties, + usedJSONNames, + oneof.Discriminator, + pascalCase(oneof.Discriminator), + "string?", + "", + ) + } + for _, variant := range oneof.Variants { + if oneof.Flatten && variant.IsMessage && + g.appendFlattenedVariantProperties(properties, usedJSONNames, variant, messageIndex) { + continue + } + + variantField := findField(message, variant.FieldName) + if variantField == nil { + continue + } + g.appendProperty( + properties, + usedJSONNames, + variant.FieldName, + pascalCase(variant.FieldName), + csharpPropertyType(variantField, true), + g.enumConverterAttribute(variantField), + ) + } + } +} + +func (g *Generator) appendFlattenedProperties( + properties *[]generatedProperty, + usedJSONNames map[string]bool, + field *contractmodel.Field, + messageIndex map[string]*contractmodel.Message, +) bool { + child := childMessage(messageIndex, field) + if child == nil { + return false + } + + for _, childField := range child.Fields { + jsonName := field.Annotations.FlattenPrefix + childField.Name + g.appendProperty( + properties, + usedJSONNames, + jsonName, + pascalCase(jsonName), + csharpPropertyType(childField, true), + g.enumConverterAttribute(childField), + ) + } + return true +} + +func (g *Generator) appendFlattenedVariantProperties( + properties *[]generatedProperty, + usedJSONNames map[string]bool, + variant *contractmodel.OneofVariant, + messageIndex map[string]*contractmodel.Message, +) bool { + child := messageIndex[variant.Type.Name] + if child == nil { + return false + } + + for _, childField := range child.Fields { + jsonName := childField.Name + if usedJSONNames[jsonName] { + return false + } + } + + for _, childField := range child.Fields { + jsonName := childField.Name + g.appendProperty( + properties, + usedJSONNames, + jsonName, + pascalCase(jsonName), + csharpPropertyType(childField, true), + g.enumConverterAttribute(childField), + ) + } + return true +} + +func (g *Generator) appendProperty( + properties *[]generatedProperty, + usedJSONNames map[string]bool, + jsonName string, + name string, + typ string, + converter string, +) { + if usedJSONNames[jsonName] { + return + } + usedJSONNames[jsonName] = true + *properties = append(*properties, generatedProperty{ + jsonName: jsonName, + name: name, + typ: typ, + converter: converter, + }) +} + +func childMessage(messageIndex map[string]*contractmodel.Message, field *contractmodel.Field) *contractmodel.Message { + if field == nil || field.Type == nil || field.Type.Kind != contractmodel.KindMessage { + return nil + } + return messageIndex[field.Type.Name] +} + +func findField(message *contractmodel.Message, name string) *contractmodel.Field { + for _, field := range message.Fields { + if field.Name == name { + return field + } + } + return nil +} + +func isRootUnwrapMessage(message *contractmodel.Message) bool { + return message != nil && message.Unwrap != nil && message.Unwrap.IsRoot && len(message.Fields) == 1 +} + +func rootUnwrapBaseType(message *contractmodel.Message) string { + if message == nil || len(message.Fields) == 0 { + return csharpObjectType + } + field := message.Fields[0] + if field.IsMap { + return csharpBaseType(field) + } + if field.Repeated && field.Type != nil { + return "List<" + csharpBaseType(&contractmodel.Field{Type: field.Type, Annotations: field.Annotations}) + ">" + } + return csharpPropertyType(field, false) +} + +func (g *Generator) useNewtonsoft() bool { + return strings.EqualFold(g.opts.JSONLib, "newtonsoft") +} + +func (g *Generator) jsonAttribute(name string) string { + if g.useNewtonsoft() { + return `[JsonProperty("` + name + `")]` + } + return `[JsonPropertyName("` + name + `")]` +} + +func csharpType(field *contractmodel.Field) string { + return csharpPropertyType(field, false) +} + +func csharpPropertyType(field *contractmodel.Field, forceNullable bool) string { + ref := field.Type + if ref == nil { + return csharpObjectType + } + base := csharpBaseType(field) + if field.Repeated { + if forceNullable { + return "List<" + base + ">?" + } + return "List<" + base + ">" + } + if field.IsMap { + if forceNullable { + return base + "?" + } + return base + } + if shouldUseNullableType(field, ref, forceNullable) { + return base + "?" + } + return base +} + +func csharpBaseType(field *contractmodel.Field) string { + ref := field.Type + if ref == nil { + return csharpObjectType + } + + switch ref.Kind { + case contractmodel.KindMessage: + return ref.Name + case contractmodel.KindEnum: + return ref.Name + case contractmodel.KindWellKnown: + return csharpWellKnownType(field) + case contractmodel.KindMap: + return fmt.Sprintf("Dictionary<%s, %s>", csharpBaseTypeForRef(ref.MapKey), csharpBaseTypeForRef(ref.MapValue)) + case contractmodel.KindScalar: + return csharpScalar(field) + default: + return csharpObjectType + } +} + +func csharpBaseTypeForRef(ref *contractmodel.TypeRef) string { + return csharpBaseType(&contractmodel.Field{Type: ref}) +} + +func csharpWellKnownType(field *contractmodel.Field) string { + ref := field.Type + switch ref.WellKnown { + case contractmodel.WellKnownStruct: + return "Dictionary" + case contractmodel.WellKnownTimestamp: + switch field.Annotations.TimestampFormat { + case sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_UNIX_SECONDS, + sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_UNIX_MILLIS: + return csharpLongType + case sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_UNSPECIFIED, + sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_RFC3339, + sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_DATE: + return csharpStringType + default: + return csharpStringType + } + case contractmodel.WellKnownDuration, contractmodel.WellKnownFieldMask: + return csharpStringType + case contractmodel.WellKnownListValue: + return "List" + case contractmodel.WellKnownAny, contractmodel.WellKnownValue: + return csharpObjectType + case contractmodel.WellKnownEmpty: + return "object" + case contractmodel.WellKnownDoubleWrap, + contractmodel.WellKnownFloatWrap, + contractmodel.WellKnownInt64Wrap, + contractmodel.WellKnownUInt64Wrap, + contractmodel.WellKnownInt32Wrap, + contractmodel.WellKnownUInt32Wrap, + contractmodel.WellKnownBoolWrap, + contractmodel.WellKnownStringWrap, + contractmodel.WellKnownBytesWrap: + return csharpScalar(&contractmodel.Field{Type: ref, Annotations: field.Annotations}) + default: + return csharpObjectType + } +} + +func csharpScalar(field *contractmodel.Field) string { + kind := field.Type.Name + switch kind { + case csharpDoubleType, "float": + return csharpDoubleType + case "int64", "uint64", "fixed64", "sfixed64", "sint64": + if field.Annotations.Int64Encoding != sebufhttp.Int64Encoding_INT64_ENCODING_NUMBER { + return csharpStringType + } + return csharpLongType + case "int32", "fixed32", "uint32", "sfixed32", "sint32": + return "int" + case csharpBoolType: + return csharpBoolType + case csharpStringType: + return csharpStringType + case bytesTypeName: + return "byte[]" + default: + return csharpObjectType + } +} + +func (g *Generator) enumConverterAttribute(field *contractmodel.Field) string { + if field.Type == nil || field.Type.Kind != contractmodel.KindEnum { + return "" + } + if field.Annotations.EnumEncoding == sebufhttp.EnumEncoding_ENUM_ENCODING_NUMBER { + return "" + } + if g.useNewtonsoft() { + return `[JsonConverter(typeof(StringEnumConverter))]` + } + return `[JsonConverter(typeof(JsonStringEnumConverter))]` +} + +func shouldUseNullableType(field *contractmodel.Field, ref *contractmodel.TypeRef, forceNullable bool) bool { + if ref == nil { + return false + } + if forceNullable { + return true + } + if ref.Kind == contractmodel.KindWellKnown && isWrapper(ref.WellKnown) { + return true + } + if field.Annotations.Nullable || field.Annotations.EmptyBehavior == sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_NULL || + field.Annotations.EmptyBehavior == sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_OMIT { + return true + } + if !field.Optional && !field.HasPresence && !field.IsOneofVariant { + return false + } + return true +} + +func isWrapper(kind contractmodel.WellKnownType) bool { + switch kind { + case contractmodel.WellKnownDoubleWrap, + contractmodel.WellKnownFloatWrap, + contractmodel.WellKnownInt64Wrap, + contractmodel.WellKnownUInt64Wrap, + contractmodel.WellKnownInt32Wrap, + contractmodel.WellKnownUInt32Wrap, + contractmodel.WellKnownBoolWrap, + contractmodel.WellKnownStringWrap, + contractmodel.WellKnownBytesWrap: + return true + case contractmodel.WellKnownAny, + contractmodel.WellKnownDuration, + contractmodel.WellKnownEmpty, + contractmodel.WellKnownFieldMask, + contractmodel.WellKnownListValue, + contractmodel.WellKnownStruct, + contractmodel.WellKnownTimestamp, + contractmodel.WellKnownValue: + return false + default: + return false + } +} + +func pascalCase(name string) string { + parts := strings.Split(name, "_") + for i, part := range parts { + if part == "" { + continue + } + part = strings.ToLower(part) + parts[i] = strings.ToUpper(part[:1]) + part[1:] + } + return strings.Join(parts, "") +} + +func upperFirst(value string) string { + if value == "" { + return "" + } + return strings.ToUpper(value[:1]) + value[1:] +} diff --git a/internal/csharpgen/generator_test.go b/internal/csharpgen/generator_test.go new file mode 100644 index 00000000..1108a779 --- /dev/null +++ b/internal/csharpgen/generator_test.go @@ -0,0 +1,964 @@ +package csharpgen + +import ( + "strings" + "testing" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/pluginpb" + + sebufhttp "github.com/SebastienMelki/sebuf/http" + "github.com/SebastienMelki/sebuf/internal/contractmodel" +) + +func TestPascalCase(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "STATE_UNSPECIFIED", want: "StateUnspecified"}, + {input: "item_state", want: "ItemState"}, + {input: "already", want: "Already"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := pascalCase(tt.input); got != tt.want { + t.Fatalf("pascalCase(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestJSONAttribute(t *testing.T) { + newtonsoft := &Generator{opts: Options{JSONLib: "newtonsoft"}} + if got := newtonsoft.jsonAttribute("owner_id"); got != `[JsonProperty("owner_id")]` { + t.Fatalf("Newtonsoft jsonAttribute = %q", got) + } + + systemText := &Generator{opts: Options{JSONLib: "system_text_json"}} + if got := systemText.jsonAttribute("owner_id"); got != `[JsonPropertyName("owner_id")]` { + t.Fatalf("System.Text.Json jsonAttribute = %q", got) + } +} + +func TestCSharpTypeMappings(t *testing.T) { + enumType := &contractmodel.TypeRef{Kind: contractmodel.KindEnum, Name: "WidgetState"} + if got := csharpType( + &contractmodel.Field{Name: "state", Type: enumType, HasPresence: true}, + ); got != "WidgetState?" { + t.Fatalf("enum csharpType = %q, want %q", got, "WidgetState?") + } + + mapType := &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{ + Kind: contractmodel.KindScalar, + Name: "string", + }, + MapValue: &contractmodel.TypeRef{ + Kind: contractmodel.KindScalar, + Name: "int32", + }, + } + if got := csharpType(&contractmodel.Field{Name: "scores", Type: mapType}); got != "Dictionary" { + t.Fatalf("map csharpType = %q, want %q", got, "Dictionary") + } + + enumMapType := &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{ + Kind: contractmodel.KindScalar, + Name: "string", + }, + MapValue: &contractmodel.TypeRef{ + Kind: contractmodel.KindEnum, + Name: "WidgetState", + }, + } + if got := csharpType( + &contractmodel.Field{Name: "states", Type: enumMapType}, + ); got != "Dictionary" { + t.Fatalf("enum map csharpType = %q, want %q", got, "Dictionary") + } + + messageMapType := &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{ + Kind: contractmodel.KindScalar, + Name: "string", + }, + MapValue: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "WidgetProfile", + }, + } + if got := csharpType( + &contractmodel.Field{Name: "profiles", Type: messageMapType}, + ); got != "Dictionary" { + t.Fatalf("message map csharpType = %q, want %q", got, "Dictionary") + } + + int64String := &contractmodel.Field{ + Name: "version", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "int64"}, + } + if got := csharpType(int64String); got != "string" { + t.Fatalf("default int64 csharpType = %q, want %q", got, "string") + } + + int64Number := &contractmodel.Field{ + Name: "version", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "int64"}, + Annotations: contractmodel.FieldAnnotations{ + Int64Encoding: sebufhttp.Int64Encoding_INT64_ENCODING_NUMBER, + }, + } + if got := csharpType(int64Number); got != "long" { + t.Fatalf("number int64 csharpType = %q, want %q", got, "long") + } + + timestampString := &contractmodel.Field{ + Name: "created_at", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindWellKnown, WellKnown: contractmodel.WellKnownTimestamp}, + } + if got := csharpType(timestampString); got != "string" { + t.Fatalf("default timestamp csharpType = %q, want %q", got, "string") + } + + timestampNumber := &contractmodel.Field{ + Name: "created_at", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindWellKnown, WellKnown: contractmodel.WellKnownTimestamp}, + Annotations: contractmodel.FieldAnnotations{ + TimestampFormat: sebufhttp.TimestampFormat_TIMESTAMP_FORMAT_UNIX_MILLIS, + }, + } + if got := csharpType(timestampNumber); got != "long" { + t.Fatalf("unix millis timestamp csharpType = %q, want %q", got, "long") + } + + durationType := &contractmodel.Field{ + Name: "ttl", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindWellKnown, WellKnown: contractmodel.WellKnownDuration}, + } + if got := csharpType(durationType); got != "string" { + t.Fatalf("duration csharpType = %q, want %q", got, "string") + } + + fieldMaskType := &contractmodel.Field{ + Name: "mask", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindWellKnown, WellKnown: contractmodel.WellKnownFieldMask}, + } + if got := csharpType(fieldMaskType); got != "string" { + t.Fatalf("field mask csharpType = %q, want %q", got, "string") + } + + listValueType := &contractmodel.Field{ + Name: "items", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindWellKnown, WellKnown: contractmodel.WellKnownListValue}, + } + if got := csharpType(listValueType); got != "List" { + t.Fatalf("list value csharpType = %q, want %q", got, "List") + } + + anyType := &contractmodel.Field{ + Name: "raw", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindWellKnown, WellKnown: contractmodel.WellKnownAny}, + } + if got := csharpType(anyType); got != "object" { + t.Fatalf("any csharpType = %q, want %q", got, "object") + } + + bytesWrapperType := &contractmodel.TypeRef{ + Kind: contractmodel.KindWellKnown, + Name: "bytes", + WellKnown: contractmodel.WellKnownBytesWrap, + } + if got := csharpType(&contractmodel.Field{Name: "wrapped_payload", Type: bytesWrapperType}); got != "byte[]?" { + t.Fatalf("bytes wrapper csharpType = %q, want %q", got, "byte[]?") + } + + wrapperType := &contractmodel.TypeRef{ + Kind: contractmodel.KindWellKnown, + Name: "int32", + WellKnown: contractmodel.WellKnownInt32Wrap, + } + if got := csharpType(&contractmodel.Field{Name: "count", Type: wrapperType}); got != "int?" { + t.Fatalf("wrapper csharpType = %q, want %q", got, "int?") + } + + repeatedMessage := &contractmodel.Field{ + Name: "items", + Repeated: true, + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "WidgetDetails", + }, + } + if got := csharpType(repeatedMessage); got != "List" { + t.Fatalf("repeated message csharpType = %q, want %q", got, "List") + } + + nullableString := &contractmodel.Field{ + Name: "display_name", + Optional: true, + HasPresence: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + Annotations: contractmodel.FieldAnnotations{Nullable: true}, + } + if got := csharpType(nullableString); got != "string?" { + t.Fatalf("nullable string csharpType = %q, want %q", got, "string?") + } + + emptyBehaviorNull := &contractmodel.Field{ + Name: "meta", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Metadata"}, + Annotations: contractmodel.FieldAnnotations{ + EmptyBehavior: sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_NULL, + }, + } + if got := csharpType(emptyBehaviorNull); got != "Metadata?" { + t.Fatalf("empty_behavior null csharpType = %q, want %q", got, "Metadata?") + } + + bytesHex := &contractmodel.Field{ + Name: "payload", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "bytes"}, + Annotations: contractmodel.FieldAnnotations{ + BytesEncoding: sebufhttp.BytesEncoding_BYTES_ENCODING_HEX, + }, + } + if got := csharpType(bytesHex); got != "byte[]" { + t.Fatalf("bytes csharpType = %q, want %q", got, "byte[]") + } +} + +func TestMessageProperties(t *testing.T) { + gen := &Generator{opts: Options{JSONLib: "newtonsoft"}} + profile := &contractmodel.Message{ + Name: "WidgetProfile", + Fields: []*contractmodel.Field{ + {Name: "note", Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}}, + }, + } + circle := &contractmodel.Message{ + Name: "ShapeEnvelopeCircle", + Fields: []*contractmodel.Field{ + {Name: "radius", Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "double"}}, + }, + } + rectangle := &contractmodel.Message{ + Name: "ShapeEnvelopeRectangle", + Fields: []*contractmodel.Field{ + {Name: "width", Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "double"}}, + {Name: "height", Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "double"}}, + }, + } + index := map[string]*contractmodel.Message{ + profile.Name: profile, + circle.Name: circle, + rectangle.Name: rectangle, + } + + message := &contractmodel.Message{ + Name: "Widget", + Fields: []*contractmodel.Field{ + {Name: "id", Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}}, + { + Name: "profile", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "WidgetProfile"}, + Annotations: contractmodel.FieldAnnotations{ + Flatten: true, + FlattenPrefix: "meta_", + }, + }, + }, + Oneofs: []*contractmodel.Oneof{ + { + Name: "shape", + Discriminator: "kind", + Flatten: true, + Variants: []*contractmodel.OneofVariant{ + { + FieldName: "circle", + DiscriminatorValue: "circle_shape", + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "ShapeEnvelopeCircle", + }, + IsMessage: true, + }, + { + FieldName: "rectangle", + DiscriminatorValue: "rectangle", + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "ShapeEnvelopeRectangle", + }, + IsMessage: true, + }, + }, + }, + }, + } + + properties := gen.messageProperties(message, index) + got := make(map[string]string, len(properties)) + for _, property := range properties { + got[property.jsonName] = property.typ + } + + for jsonName, wantType := range map[string]string{ + "id": "string", + "meta_note": "string?", + "kind": "string?", + "radius": "double?", + "width": "double?", + "height": "double?", + } { + if got[jsonName] != wantType { + t.Fatalf("property %q = %q, want %q (all: %#v)", jsonName, got[jsonName], wantType, got) + } + } +} + +func TestRootUnwrapBaseType(t *testing.T) { + message := &contractmodel.Message{ + Name: "TagList", + Fields: []*contractmodel.Field{ + { + Name: "values", + Repeated: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + }, + Unwrap: &contractmodel.Unwrap{ + FieldName: "values", + IsRoot: true, + }, + } + + if !isRootUnwrapMessage(message) { + t.Fatalf("expected root unwrap message") + } + if got := rootUnwrapBaseType(message); got != "List" { + t.Fatalf("rootUnwrapBaseType() = %q, want %q", got, "List") + } +} + +func TestMessageNeedsJSONNormalizationForRootUnwrap(t *testing.T) { + messageIndex := map[string]*contractmodel.Message{ + "Widget": { + Name: "Widget", + Fields: []*contractmodel.Field{ + { + Name: "payload", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "bytes"}, + Annotations: contractmodel.FieldAnnotations{ + BytesEncoding: sebufhttp.BytesEncoding_BYTES_ENCODING_HEX, + }, + }, + }, + }, + "OptionBarsList": { + Name: "OptionBarsList", + Fields: []*contractmodel.Field{ + { + Name: "bars", + Repeated: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Widget"}, + }, + }, + Unwrap: &contractmodel.Unwrap{FieldName: "bars"}, + }, + } + + rootRepeated := &contractmodel.Message{ + Name: "RootRepeatedResponse", + Fields: []*contractmodel.Field{ + { + Name: "items", + Repeated: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Widget"}, + }, + }, + Unwrap: &contractmodel.Unwrap{FieldName: "items", IsRoot: true}, + } + rootMap := &contractmodel.Message{ + Name: "RootMapResponse", + Fields: []*contractmodel.Field{ + { + Name: "items", + IsMap: true, + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + MapValue: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Widget"}, + }, + }, + }, + Unwrap: &contractmodel.Unwrap{FieldName: "items", IsRoot: true, IsMapField: true}, + } + rootMapValueUnwrap := &contractmodel.Message{ + Name: "RootMapWithValueUnwrapResponse", + Fields: []*contractmodel.Field{ + { + Name: "items", + IsMap: true, + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + MapValue: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "OptionBarsList"}, + }, + }, + }, + Unwrap: &contractmodel.Unwrap{FieldName: "items", IsRoot: true, IsMapField: true}, + } + + for _, tt := range []struct { + name string + message *contractmodel.Message + }{ + {name: "root repeated", message: rootRepeated}, + {name: "root map", message: rootMap}, + {name: "root map value unwrap", message: rootMapValueUnwrap}, + } { + t.Run(tt.name, func(t *testing.T) { + if !messageNeedsJSONNormalization(tt.message, messageIndex) { + t.Fatalf("expected %s to require normalization", tt.name) + } + }) + } +} + +func TestMessageNeedsJSONNormalizationForNestedAnnotatedChildren(t *testing.T) { + messageIndex := map[string]*contractmodel.Message{ + "Inner": { + Name: "Inner", + Fields: []*contractmodel.Field{ + { + Name: "metadata_null", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "EmptyMessage"}, + Annotations: contractmodel.FieldAnnotations{ + EmptyBehavior: sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_NULL, + }, + }, + }, + }, + "EmptyMessage": {Name: "EmptyMessage"}, + } + + outer := &contractmodel.Message{ + Name: "Outer", + Fields: []*contractmodel.Field{ + { + Name: "inner", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Inner"}, + }, + }, + } + outerMap := &contractmodel.Message{ + Name: "OuterMap", + Fields: []*contractmodel.Field{ + { + Name: "entries", + IsMap: true, + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + MapValue: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Inner"}, + }, + }, + }, + } + + for _, tt := range []struct { + name string + message *contractmodel.Message + }{ + {name: "nested child", message: outer}, + {name: "map child", message: outerMap}, + } { + t.Run(tt.name, func(t *testing.T) { + if !messageNeedsJSONNormalization(tt.message, messageIndex) { + t.Fatalf("expected %s to require normalization", tt.name) + } + }) + } +} + +func TestGeneratePackage(t *testing.T) { + plugin := newCSharpTestPlugin(t) + gen := New(plugin, Options{Namespace: "Test.Contracts", JSONLib: "newtonsoft"}) + + pkg := &contractmodel.Package{ + Name: "test.contracts.v1", + Enums: []*contractmodel.Enum{ + { + Name: "WidgetState", + Values: []*contractmodel.EnumValue{ + {Name: "STATE_UNSPECIFIED", JSONValue: "STATE_UNSPECIFIED", Number: 0}, + {Name: "STATE_READY", JSONValue: "ready", Number: 1}, + }, + }, + }, + Messages: []*contractmodel.Message{ + { + Name: "WidgetProfile", + Fields: []*contractmodel.Field{ + { + Name: "note", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + }, + }, + { + Name: "Widget", + Fields: []*contractmodel.Field{ + { + Name: "id", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + { + Name: "display_name", + Optional: true, + HasPresence: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + Annotations: contractmodel.FieldAnnotations{ + Nullable: true, + }, + }, + { + Name: "state", + HasPresence: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindEnum, Name: "WidgetState"}, + Annotations: contractmodel.FieldAnnotations{ + EnumEncoding: sebufhttp.EnumEncoding_ENUM_ENCODING_STRING, + }, + }, + { + Name: "meta", + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindWellKnown, + WellKnown: contractmodel.WellKnownStruct, + }, + Repeated: false, + }, + { + Name: "profile", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "WidgetProfile"}, + Annotations: contractmodel.FieldAnnotations{ + Flatten: true, + FlattenPrefix: "meta_", + }, + }, + { + Name: "payload", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "bytes"}, + Annotations: contractmodel.FieldAnnotations{ + BytesEncoding: sebufhttp.BytesEncoding_BYTES_ENCODING_HEX, + }, + }, + }, + Oneofs: []*contractmodel.Oneof{ + { + Name: "shape", + Discriminator: "kind", + Flatten: true, + Variants: []*contractmodel.OneofVariant{ + { + FieldName: "circle", + DiscriminatorValue: "circle_shape", + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "ShapeEnvelopeCircle", + }, + IsMessage: true, + }, + { + FieldName: "rectangle", + DiscriminatorValue: "rectangle", + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "ShapeEnvelopeRectangle", + }, + IsMessage: true, + }, + }, + }, + }, + }, + { + Name: "ShapeEnvelopeCircle", + Fields: []*contractmodel.Field{ + { + Name: "radius", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "double"}, + }, + }, + }, + { + Name: "ShapeEnvelopeRectangle", + Fields: []*contractmodel.Field{ + { + Name: "width", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "double"}, + }, + { + Name: "height", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "double"}, + }, + }, + }, + { + Name: "NestedTextContent", + Fields: []*contractmodel.Field{ + { + Name: "body", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + }, + }, + { + Name: "NestedImageContent", + Fields: []*contractmodel.Field{ + { + Name: "url", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + }, + }, + { + Name: "NestedEvent", + Fields: []*contractmodel.Field{ + { + Name: "id", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + { + Name: "text", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "NestedTextContent"}, + }, + { + Name: "image", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "NestedImageContent"}, + }, + }, + Oneofs: []*contractmodel.Oneof{ + { + Name: "content", + Discriminator: "kind", + Variants: []*contractmodel.OneofVariant{ + { + FieldName: "text", + DiscriminatorValue: "text", + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "NestedTextContent", + }, + IsMessage: true, + }, + { + FieldName: "image", + DiscriminatorValue: "img", + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "NestedImageContent", + }, + IsMessage: true, + }, + }, + }, + }, + }, + { + Name: "TagList", + Fields: []*contractmodel.Field{ + { + Name: "values", + Repeated: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + }, + Unwrap: &contractmodel.Unwrap{ + FieldName: "values", + IsRoot: true, + }, + }, + { + Name: "EmptyMessage", + }, + { + Name: "EmptyBehaviorHolder", + Fields: []*contractmodel.Field{ + { + Name: "metadata_null", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "EmptyMessage"}, + Annotations: contractmodel.FieldAnnotations{ + EmptyBehavior: sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_NULL, + }, + }, + { + Name: "metadata_omit", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "EmptyMessage"}, + Annotations: contractmodel.FieldAnnotations{ + EmptyBehavior: sebufhttp.EmptyBehavior_EMPTY_BEHAVIOR_OMIT, + }, + }, + }, + }, + { + Name: "OptionBarsList", + Fields: []*contractmodel.Field{ + { + Name: "bars", + Repeated: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Widget"}, + }, + }, + Unwrap: &contractmodel.Unwrap{ + FieldName: "bars", + }, + }, + { + Name: "RootRepeatedResponse", + Fields: []*contractmodel.Field{ + { + Name: "items", + Repeated: true, + Type: &contractmodel.TypeRef{Kind: contractmodel.KindMessage, Name: "Widget"}, + }, + }, + Unwrap: &contractmodel.Unwrap{ + FieldName: "items", + IsRoot: true, + }, + }, + { + Name: "RootMapResponse", + Fields: []*contractmodel.Field{ + { + Name: "items", + IsMap: true, + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{ + Kind: contractmodel.KindScalar, + Name: "string", + }, + MapValue: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "Widget", + }, + }, + }, + }, + Unwrap: &contractmodel.Unwrap{ + FieldName: "items", + IsRoot: true, + IsMapField: true, + }, + }, + { + Name: "RootMapWithValueUnwrapResponse", + Fields: []*contractmodel.Field{ + { + Name: "items", + IsMap: true, + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{ + Kind: contractmodel.KindScalar, + Name: "string", + }, + MapValue: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "OptionBarsList", + }, + }, + }, + }, + Unwrap: &contractmodel.Unwrap{ + FieldName: "items", + IsRoot: true, + IsMapField: true, + }, + }, + { + Name: "GetOptionBarsResponse", + Fields: []*contractmodel.Field{ + { + Name: "bars", + IsMap: true, + Type: &contractmodel.TypeRef{ + Kind: contractmodel.KindMap, + MapKey: &contractmodel.TypeRef{ + Kind: contractmodel.KindScalar, + Name: "string", + }, + MapValue: &contractmodel.TypeRef{ + Kind: contractmodel.KindMessage, + Name: "OptionBarsList", + }, + }, + }, + }, + }, + { + Name: "GetWidgetRequest", + Fields: []*contractmodel.Field{ + { + Name: "id", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + }, + { + Name: "owner_id", + Type: &contractmodel.TypeRef{Kind: contractmodel.KindScalar, Name: "string"}, + Annotations: contractmodel.FieldAnnotations{ + Query: &contractmodel.Query{Name: "owner"}, + }, + }, + }, + }, + }, + Services: []*contractmodel.Service{ + { + Name: "WidgetService", + BasePath: "/api/v1", + Headers: []*contractmodel.Header{ + {Name: "X-API-Key", Required: true}, + }, + Methods: []*contractmodel.Method{ + { + Name: "GetWidget", + HTTPMethod: "GET", + Path: "/api/v1/widgets/{id}", + InputType: "GetWidgetRequest", + ResponseType: "Widget", + PathParams: []string{"id"}, + Headers: []*contractmodel.Header{ + {Name: "X-Request-ID", Required: true}, + }, + }, + { + Name: "GetOptionBars", + HTTPMethod: "POST", + Path: "/api/v1/options/bars", + InputType: "GetWidgetRequest", + ResponseType: "GetOptionBarsResponse", + }, + }, + }, + }, + } + + if err := gen.generatePackage(pkg); err != nil { + t.Fatalf("generatePackage() error = %v", err) + } + + output := generatedCSharpContent(t, plugin, "test/contracts/v1/Contracts.g.cs") + for _, want := range []string{ + "public enum WidgetState", + `[EnumMember(Value = "ready")]`, + "StateUnspecified = 0", + `[JsonConverter(typeof(StringEnumConverter))]`, + "public WidgetState? State { get; set; }", + "public string? DisplayName { get; set; }", + `[JsonProperty("meta")]`, + "public Dictionary Meta { get; set; }", + `[JsonProperty("meta_note")]`, + "public string? MetaNote { get; set; }", + `[JsonProperty("payload")]`, + "public byte[] Payload { get; set; }", + `[JsonProperty("kind")]`, + "public string? Kind { get; set; }", + `[JsonProperty("radius")]`, + "public double? Radius { get; set; }", + "public sealed class TagList : List", + "public sealed class RootRepeatedResponse : List", + "public sealed class RootMapResponse : Dictionary", + "public sealed class RootMapWithValueUnwrapResponse : Dictionary", + "public sealed class ApiException : Exception", + "public sealed class WidgetServiceClientOptions", + "public sealed class WidgetServiceCallOptions", + "public interface IWidgetServiceClient", + "public sealed class WidgetServiceClient : IWidgetServiceClient", + "private static string NormalizeSerializedJson(object value, string json)", + "private static string NormalizeResponseJson(Type responseType, string json)", + "private static JToken NormalizeSerializedWidget(JToken token)", + "private static JToken NormalizeResponseWidget(JToken token)", + `obj["kind"] = "circle_shape";`, + `obj.Remove("width");`, + `obj.Remove("height");`, + "private static JToken NormalizeSerializedNestedEvent(JToken token)", + `obj["kind"] = "text";`, + `obj.Remove("image");`, + "private static JToken NormalizeSerializedRootRepeatedResponse(JToken token)", + "array[i] = NormalizeSerializedToken(typeof(Widget), array[i]!);", + "private static JToken NormalizeSerializedRootMapResponse(JToken token)", + "property.Value = NormalizeSerializedToken(typeof(Widget), property.Value);", + "private static JToken NormalizeSerializedRootMapWithValueUnwrapResponse(JToken token)", + "property.Value = NormalizeMapValueForSerialization(property.Value, typeof(OptionBarsList));", + "private static JToken NormalizeSerializedEmptyBehaviorHolder(JToken token)", + "private static JToken NormalizeResponseEmptyBehaviorHolder(JToken token)", + "private static bool IsEmptyObject(JToken token)", + "private static bool ShouldOmitEmptyField(JToken token)", + `obj["metadata_null"] = JValue.CreateNull();`, + `obj.Remove("metadata_omit");`, + "private static JToken NormalizeMapValueForSerialization(JToken token, Type messageType)", + `"hex" => Convert.ToHexString(bytes).ToLowerInvariant(),`, + "public string? ApiKey { get; set; }", + "public string? RequestId { get; set; }", + "public async Task GetWidgetAsync(GetWidgetRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default)", + `path = path.Replace("{id}", Uri.EscapeDataString(FormatPathValue(req.Id)));`, + `query.Add(Uri.EscapeDataString("owner") + "=" + Uri.EscapeDataString(FormatQueryValue(req.OwnerId)));`, + `headers["X-API-Key"] = options.ApiKey!;`, + `headers["X-Request-ID"] = options.RequestId!;`, + "return await SendAsync(HttpMethod.Get, requestUri, null, headers, cancellationToken);", + "public static class WidgetService", + `public const string Path = "/api/v1/widgets/{id}";`, + `public const string RequestType = "GetWidgetRequest";`, + } { + if !strings.Contains(output, want) { + t.Fatalf("generated output missing %q:\n%s", want, output) + } + } +} + +func newCSharpTestPlugin(t *testing.T) *protogen.Plugin { + t.Helper() + req := &pluginpb.CodeGeneratorRequest{ + Parameter: proto.String("paths=source_relative"), + FileToGenerate: []string{"placeholder.proto"}, + ProtoFile: []*descriptorpb.FileDescriptorProto{ + { + Name: proto.String("placeholder.proto"), + Package: proto.String("test.contracts.v1"), + Syntax: proto.String("proto3"), + Options: &descriptorpb.FileOptions{ + GoPackage: proto.String("github.com/SebastienMelki/sebuf/internal/testdata/csharp;csharptest"), + }, + }, + }, + } + + plugin, err := protogen.Options{}.New(req) + if err != nil { + t.Fatalf("protogen.Options.New() error = %v", err) + } + return plugin +} + +func generatedCSharpContent(t *testing.T, plugin *protogen.Plugin, filename string) string { + t.Helper() + resp := plugin.Response() + for _, file := range resp.GetFile() { + if file.GetName() == filename { + return file.GetContent() + } + } + t.Fatalf("generated file %q not found", filename) + return "" +} diff --git a/internal/csharpgen/golden_test.go b/internal/csharpgen/golden_test.go new file mode 100644 index 00000000..0e5672eb --- /dev/null +++ b/internal/csharpgen/golden_test.go @@ -0,0 +1,109 @@ +package csharpgen + +import ( + "bytes" + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/SebastienMelki/sebuf/internal/testutil" +) + +func TestCSharpGenGoldenFiles(t *testing.T) { + if _, err := exec.LookPath("protoc"); err != nil { + t.Skip("protoc not found, skipping golden file tests") + } + + testCases := []struct { + name string + protoFiles []string + opt string + expectedFile string + }{ + { + name: "simple contracts newtonsoft", + protoFiles: []string{"contracts.proto"}, + opt: "namespace=Test.Contracts,json_lib=newtonsoft", + expectedFile: "Contracts.g.cs", + }, + { + name: "comprehensive contracts newtonsoft", + protoFiles: []string{"comprehensive_models.proto", "comprehensive_services.proto"}, + opt: "namespace=Test.Contracts,json_lib=newtonsoft", + expectedFile: "Comprehensive.Newtonsoft.g.cs", + }, + { + name: "comprehensive contracts system text json", + protoFiles: []string{"comprehensive_models.proto", "comprehensive_services.proto"}, + opt: "namespace=Test.Contracts,json_lib=system_text_json", + expectedFile: "Comprehensive.SystemTextJson.g.cs", + }, + } + + baseDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get working directory: %v", err) + } + + projectRoot := filepath.Join(baseDir, "..", "..") + protoDir := filepath.Join(baseDir, "testdata", "proto") + goldenDir := filepath.Join(baseDir, "testdata", "golden") + pluginPath := filepath.Join(projectRoot, "bin", "protoc-gen-csharp-http") + + buildCmd := exec.Command("make", "build") + buildCmd.Dir = projectRoot + if buildErr := buildCmd.Run(); buildErr != nil { + t.Fatalf("Failed to build plugin: %v", buildErr) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tempDir := t.TempDir() + + args := []string{ + "--plugin=protoc-gen-csharp-http=" + pluginPath, + "--csharp-http_out=" + tempDir, + "--csharp-http_opt=" + tc.opt, + "--proto_path=" + protoDir, + "--proto_path=" + filepath.Join(projectRoot, "proto"), + } + args = append(args, tc.protoFiles...) + + cmd := exec.Command("protoc", args...) + cmd.Dir = protoDir + + var stderr bytes.Buffer + cmd.Stderr = &stderr + if runErr := cmd.Run(); runErr != nil { + t.Fatalf("protoc failed: %v\nstderr: %s", runErr, stderr.String()) + } + + generatedPath := filepath.Join(tempDir, "test/contracts/v1/Contracts.g.cs") + goldenPath := filepath.Join(goldenDir, tc.expectedFile) + generatedContent, readErr := os.ReadFile(generatedPath) + if readErr != nil { + t.Fatalf("Failed to read generated file %s: %v", generatedPath, readErr) + } + + if os.Getenv("UPDATE_GOLDEN") == "1" { + if writeErr := os.WriteFile(goldenPath, generatedContent, 0o644); writeErr != nil { + t.Fatalf("Failed to write golden file %s: %v", goldenPath, writeErr) + } + return + } + + goldenContent, goldenErr := os.ReadFile(goldenPath) + if goldenErr != nil { + t.Fatalf("Failed to read golden file %s: %v", goldenPath, goldenErr) + } + if !bytes.Equal(generatedContent, goldenContent) { + t.Fatalf( + "Generated file %s does not match golden file.\nDiff:\n%s", + tc.expectedFile, + testutil.DiffStrings(string(goldenContent), string(generatedContent)), + ) + } + }) + } +} diff --git a/internal/csharpgen/testdata/golden/Comprehensive.Newtonsoft.g.cs b/internal/csharpgen/testdata/golden/Comprehensive.Newtonsoft.g.cs new file mode 100644 index 00000000..f6f7da4c --- /dev/null +++ b/internal/csharpgen/testdata/golden/Comprehensive.Newtonsoft.g.cs @@ -0,0 +1,1069 @@ +// Code generated by protoc-gen-csharp-http. DO NOT EDIT. +#nullable enable +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Runtime.Serialization; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; + +namespace Test.Contracts +{ + public enum WidgetState + { + [EnumMember(Value = "STATE_UNSPECIFIED")] + StateUnspecified = 0, + [EnumMember(Value = "ready")] + StateReady = 1, + } + + public sealed class ApiException : Exception + { + public int StatusCode { get; } + public string ResponseBody { get; } + + public ApiException(int statusCode, string responseBody) + : base($\"Request failed with status {statusCode}: {responseBody}\") + { + StatusCode = statusCode; + ResponseBody = responseBody; + } + } + + public sealed class Widget + { + [JsonProperty("id")] + public string Id { get; set; } + [JsonProperty("display_name")] + public string? DisplayName { get; set; } + [JsonProperty("scores")] + public Dictionary Scores { get; set; } + [JsonProperty("meta")] + public Dictionary? Meta { get; set; } + [JsonProperty("created_at")] + public long? CreatedAt { get; set; } + [JsonProperty("alias")] + public string? Alias { get; set; } + [JsonProperty("state")] + public WidgetState State { get; set; } + [JsonProperty("meta_note")] + public string? MetaNote { get; set; } + [JsonProperty("tags")] + public List Tags { get; set; } + [JsonProperty("owner_id")] + public string OwnerId { get; set; } + [JsonProperty("payload")] + public byte[] Payload { get; set; } + [JsonProperty("version")] + public long Version { get; set; } + [JsonProperty("state_labels")] + public Dictionary StateLabels { get; set; } + [JsonProperty("profiles_by_id")] + public Dictionary ProfilesById { get; set; } + } + + public sealed class WidgetProfile + { + [JsonProperty("note")] + public string Note { get; set; } + } + + public sealed class ShapeEnvelope + { + [JsonProperty("kind")] + public string? Kind { get; set; } + [JsonProperty("radius")] + public double? Radius { get; set; } + [JsonProperty("width")] + public double? Width { get; set; } + [JsonProperty("height")] + public double? Height { get; set; } + } + + public sealed class ShapeEnvelopeCircle + { + [JsonProperty("radius")] + public double Radius { get; set; } + } + + public sealed class ShapeEnvelopeRectangle + { + [JsonProperty("width")] + public double Width { get; set; } + [JsonProperty("height")] + public double Height { get; set; } + } + + public sealed class NestedShapeEnvelope + { + [JsonProperty("kind")] + public string? Kind { get; set; } + [JsonProperty("circle")] + public NestedShapeEnvelopeNestedCircle? Circle { get; set; } + [JsonProperty("rectangle")] + public NestedShapeEnvelopeNestedRectangle? Rectangle { get; set; } + } + + public sealed class NestedShapeEnvelopeNestedCircle + { + [JsonProperty("radius")] + public double Radius { get; set; } + } + + public sealed class NestedShapeEnvelopeNestedRectangle + { + [JsonProperty("width")] + public double Width { get; set; } + [JsonProperty("height")] + public double Height { get; set; } + } + + public sealed class DeepNest + { + [JsonProperty("level1")] + public DeepNestLevel1? Level1 { get; set; } + } + + public sealed class DeepNestLevel1 + { + [JsonProperty("level2")] + public DeepNestLevel1Level2? Level2 { get; set; } + } + + public sealed class DeepNestLevel1Level2 + { + [JsonProperty("code")] + public string Code { get; set; } + } + + public sealed class WellKnownHolder + { + [JsonProperty("any_value")] + public object? AnyValue { get; set; } + [JsonProperty("ttl")] + public string? Ttl { get; set; } + [JsonProperty("mask")] + public string? Mask { get; set; } + [JsonProperty("items")] + public List? Items { get; set; } + [JsonProperty("raw_value")] + public object? RawValue { get; set; } + } + + public sealed class EmptyMessage + { + } + + public sealed class EmptyBehaviorHolder + { + [JsonProperty("metadata_preserve")] + public EmptyMessage? MetadataPreserve { get; set; } + [JsonProperty("metadata_null")] + public EmptyMessage? MetadataNull { get; set; } + [JsonProperty("metadata_omit")] + public EmptyMessage? MetadataOmit { get; set; } + } + + public sealed class TagList : List + { + } + + public sealed class GetWidgetRequest + { + [JsonProperty("id")] + public string Id { get; set; } + } + + public sealed class SearchWidgetsRequest + { + [JsonProperty("owner_id")] + public string OwnerId { get; set; } + [JsonProperty("tag_ids")] + public List TagIds { get; set; } + } + + public sealed class Empty + { + } + + public sealed class WidgetServiceClientOptions + { + public HttpClient? HttpClient { get; set; } + public Dictionary? DefaultHeaders { get; set; } + } + + public sealed class WidgetServiceCallOptions + { + public Dictionary? Headers { get; set; } + } + + public interface IWidgetServiceClient + { + Task GetWidgetAsync(GetWidgetRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default); + Task SearchWidgetsAsync(SearchWidgetsRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default); + } + + public sealed class WidgetServiceClient : IWidgetServiceClient + { + private readonly string _baseUrl; + private readonly HttpClient _httpClient; + private readonly Dictionary _defaultHeaders; + + public WidgetServiceClient(string baseUrl, WidgetServiceClientOptions? options = null) + { + _baseUrl = baseUrl.TrimEnd('/'); + _httpClient = options?.HttpClient ?? new HttpClient(); + _defaultHeaders = options?.DefaultHeaders is null + ? new Dictionary() + : new Dictionary(options.DefaultHeaders); + } + + public async Task GetWidgetAsync(GetWidgetRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default) + { + var path = "/api/v1/widgets/{id}"; + path = path.Replace("{id}", Uri.EscapeDataString(FormatPathValue(req.Id))); + var query = new List(); + var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query); + var headers = BuildHeaders(options); + return await SendAsync(HttpMethod.Get, requestUri, null, headers, cancellationToken); + } + + public async Task SearchWidgetsAsync(SearchWidgetsRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default) + { + var path = "/api/v1/widgets"; + var query = new List(); + if (!string.IsNullOrEmpty(req.OwnerId)) + { + query.Add(Uri.EscapeDataString("owner") + "=" + Uri.EscapeDataString(FormatQueryValue(req.OwnerId))); + } + if (req.TagIds is not null) + { + foreach (var item in req.TagIds) + { + query.Add(Uri.EscapeDataString("tag_id") + "=" + Uri.EscapeDataString(FormatQueryValue(item))); + } + } + var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query); + var headers = BuildHeaders(options); + return await SendAsync(HttpMethod.Get, requestUri, null, headers, cancellationToken); + } + + private Dictionary BuildHeaders(WidgetServiceCallOptions? options) + { + var headers = new Dictionary(_defaultHeaders); + if (options?.Headers is not null) + { + foreach (var pair in options.Headers) + { + headers[pair.Key] = pair.Value; + } + } + return headers; + } + + private async Task SendAsync(HttpMethod method, string requestUri, object? body, Dictionary headers, CancellationToken cancellationToken) where TResponse : new() + { + using var request = new HttpRequestMessage(method, _baseUrl + requestUri); + foreach (var header in headers) + { + request.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + if (body is not null) + { + request.Content = new StringContent(SerializeRequest(body), Encoding.UTF8, "application/json"); + } + using var response = await _httpClient.SendAsync(request, cancellationToken); + var responseBody = response.Content is null ? string.Empty : await response.Content.ReadAsStringAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new ApiException((int)response.StatusCode, responseBody); + } + if (typeof(TResponse) == typeof(Empty) && string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + if (string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + var result = DeserializeResponse(responseBody); + return result is null ? new TResponse() : result; + } + + private string SerializeRequest(object value) + { + var json = JsonConvert.SerializeObject(value); + return NormalizeSerializedJson(value, json); + } + + private static TResponse? DeserializeResponse(string json) + { + json = NormalizeResponseJson(typeof(TResponse), json); + return JsonConvert.DeserializeObject(json); + } + + private static string NormalizeSerializedJson(object value, string json) + { + var token = JToken.Parse(json); + var normalized = NormalizeSerializedToken(value.GetType(), token); + return normalized.ToString(Formatting.None); + } + + private static string NormalizeResponseJson(Type responseType, string json) + { + var token = JToken.Parse(json); + var normalized = NormalizeResponseToken(responseType, token); + return normalized.ToString(Formatting.None); + } + + private static JToken NormalizeSerializedToken(Type messageType, JToken token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeSerializedEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeSerializedNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeSerializedShapeEnvelope(token), + "Widget" => NormalizeSerializedWidget(token), + _ => token + }; + } + private static JToken NormalizeResponseToken(Type messageType, JToken token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeResponseEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeResponseNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeResponseShapeEnvelope(token), + "Widget" => NormalizeResponseWidget(token), + _ => token + }; + } + private static JToken NormalizeMapValueForSerialization(JToken token, Type messageType) + { + return messageType.Name switch + { + "Widget" => token is JObject obj && obj.TryGetValue("tags", out var value) ? value : token, + _ => NormalizeSerializedToken(messageType, token) + }; + } + private static JToken NormalizeMapValueForResponse(JToken token, Type messageType) + { + return messageType.Name switch + { + "Widget" => new JObject { ["tags"] = token }, + _ => NormalizeResponseToken(messageType, token) + }; + } + private static bool IsEmptyObject(JToken token) + { + return token is JObject obj && !obj.Properties().Any(); + } + private static bool ShouldOmitEmptyField(JToken token) + { + return token.Type == JTokenType.Null || IsEmptyObject(token); + } + + private static JToken NormalizeSerializedEmptyBehaviorHolder(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = JValue.CreateNull(); + } + } + if (obj.TryGetValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (ShouldOmitEmptyField(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JToken NormalizeResponseEmptyBehaviorHolder(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = JValue.CreateNull(); + } + } + if (obj.TryGetValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (IsEmptyObject(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JToken NormalizeSerializedNestedShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("circle", out var CircleCircleToken) && CircleCircleToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JToken NormalizeResponseNestedShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("circle", out var CircleCircleToken) && CircleCircleToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JToken NormalizeSerializedShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("radius", out var CircleRadiusToken) && CircleRadiusToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("width", out var RectangleWidthToken) && RectangleWidthToken.Type != JTokenType.Null || obj.TryGetValue("height", out var RectangleHeightToken) && RectangleHeightToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JToken NormalizeResponseShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("radius", out var CircleRadiusToken) && CircleRadiusToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("width", out var RectangleWidthToken) && RectangleWidthToken.Type != JTokenType.Null || obj.TryGetValue("height", out var RectangleHeightToken) && RectangleHeightToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JToken NormalizeSerializedWidget(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("payload", out var PayloadToken) && PayloadToken.Type == JTokenType.String) + { + obj["payload"] = ReencodeBytes(PayloadToken.Value()!, "base64", "hex"); + } + return obj; + } + + private static JToken NormalizeResponseWidget(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("payload", out var PayloadToken) && PayloadToken.Type == JTokenType.String) + { + obj["payload"] = ReencodeBytes(PayloadToken.Value()!, "hex", "base64"); + } + return obj; + } + + private static string EncodeBytes(byte[] bytes, string encoding) + { + var base64 = Convert.ToBase64String(bytes); + return encoding switch + { + "base64_raw" => base64.TrimEnd('='), + "base64url" => base64.Replace('+', '-').Replace('/', '_'), + "base64url_raw" => base64.Replace('+', '-').Replace('/', '_').TrimEnd('='), + "hex" => Convert.ToHexString(bytes).ToLowerInvariant(), + _ => base64 + }; + } + + private static string ReencodeBytes(string encoded, string fromEncoding, string toEncoding) + { + return EncodeBytes(DecodeBytes(encoded, fromEncoding), toEncoding); + } + + private static byte[] DecodeBytes(string encoded, string encoding) + { + return encoding switch + { + "hex" => Convert.FromHexString(encoded), + "base64url" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64url_raw" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64_raw" => Convert.FromBase64String(NormalizeBase64(encoded)), + _ => Convert.FromBase64String(NormalizeBase64(encoded)) + }; + } + + private static string NormalizeBase64(string value) + { + var remainder = value.Length % 4; + if (remainder == 0) + { + return value; + } + return value + new string('=', 4 - remainder); + } + + private static string FormatPathValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + private static string FormatQueryValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + } + + public sealed class AdminServiceClientOptions + { + public HttpClient? HttpClient { get; set; } + public Dictionary? DefaultHeaders { get; set; } + } + + public sealed class AdminServiceCallOptions + { + public Dictionary? Headers { get; set; } + } + + public interface IAdminServiceClient + { + Task ResetWidgetAsync(GetWidgetRequest req, AdminServiceCallOptions? options = null, CancellationToken cancellationToken = default); + } + + public sealed class AdminServiceClient : IAdminServiceClient + { + private readonly string _baseUrl; + private readonly HttpClient _httpClient; + private readonly Dictionary _defaultHeaders; + + public AdminServiceClient(string baseUrl, AdminServiceClientOptions? options = null) + { + _baseUrl = baseUrl.TrimEnd('/'); + _httpClient = options?.HttpClient ?? new HttpClient(); + _defaultHeaders = options?.DefaultHeaders is null + ? new Dictionary() + : new Dictionary(options.DefaultHeaders); + } + + public async Task ResetWidgetAsync(GetWidgetRequest req, AdminServiceCallOptions? options = null, CancellationToken cancellationToken = default) + { + var path = "/api/v1/admin/widgets/{id}:reset"; + path = path.Replace("{id}", Uri.EscapeDataString(FormatPathValue(req.Id))); + var query = new List(); + var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query); + var headers = BuildHeaders(options); + return await SendAsync(HttpMethod.Post, requestUri, req, headers, cancellationToken); + } + + private Dictionary BuildHeaders(AdminServiceCallOptions? options) + { + var headers = new Dictionary(_defaultHeaders); + if (options?.Headers is not null) + { + foreach (var pair in options.Headers) + { + headers[pair.Key] = pair.Value; + } + } + return headers; + } + + private async Task SendAsync(HttpMethod method, string requestUri, object? body, Dictionary headers, CancellationToken cancellationToken) where TResponse : new() + { + using var request = new HttpRequestMessage(method, _baseUrl + requestUri); + foreach (var header in headers) + { + request.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + if (body is not null) + { + request.Content = new StringContent(SerializeRequest(body), Encoding.UTF8, "application/json"); + } + using var response = await _httpClient.SendAsync(request, cancellationToken); + var responseBody = response.Content is null ? string.Empty : await response.Content.ReadAsStringAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new ApiException((int)response.StatusCode, responseBody); + } + if (typeof(TResponse) == typeof(Empty) && string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + if (string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + var result = DeserializeResponse(responseBody); + return result is null ? new TResponse() : result; + } + + private string SerializeRequest(object value) + { + var json = JsonConvert.SerializeObject(value); + return NormalizeSerializedJson(value, json); + } + + private static TResponse? DeserializeResponse(string json) + { + json = NormalizeResponseJson(typeof(TResponse), json); + return JsonConvert.DeserializeObject(json); + } + + private static string NormalizeSerializedJson(object value, string json) + { + var token = JToken.Parse(json); + var normalized = NormalizeSerializedToken(value.GetType(), token); + return normalized.ToString(Formatting.None); + } + + private static string NormalizeResponseJson(Type responseType, string json) + { + var token = JToken.Parse(json); + var normalized = NormalizeResponseToken(responseType, token); + return normalized.ToString(Formatting.None); + } + + private static JToken NormalizeSerializedToken(Type messageType, JToken token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeSerializedEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeSerializedNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeSerializedShapeEnvelope(token), + "Widget" => NormalizeSerializedWidget(token), + _ => token + }; + } + private static JToken NormalizeResponseToken(Type messageType, JToken token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeResponseEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeResponseNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeResponseShapeEnvelope(token), + "Widget" => NormalizeResponseWidget(token), + _ => token + }; + } + private static JToken NormalizeMapValueForSerialization(JToken token, Type messageType) + { + return messageType.Name switch + { + "Widget" => token is JObject obj && obj.TryGetValue("tags", out var value) ? value : token, + _ => NormalizeSerializedToken(messageType, token) + }; + } + private static JToken NormalizeMapValueForResponse(JToken token, Type messageType) + { + return messageType.Name switch + { + "Widget" => new JObject { ["tags"] = token }, + _ => NormalizeResponseToken(messageType, token) + }; + } + private static bool IsEmptyObject(JToken token) + { + return token is JObject obj && !obj.Properties().Any(); + } + private static bool ShouldOmitEmptyField(JToken token) + { + return token.Type == JTokenType.Null || IsEmptyObject(token); + } + + private static JToken NormalizeSerializedEmptyBehaviorHolder(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = JValue.CreateNull(); + } + } + if (obj.TryGetValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (ShouldOmitEmptyField(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JToken NormalizeResponseEmptyBehaviorHolder(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = JValue.CreateNull(); + } + } + if (obj.TryGetValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (IsEmptyObject(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JToken NormalizeSerializedNestedShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("circle", out var CircleCircleToken) && CircleCircleToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JToken NormalizeResponseNestedShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("circle", out var CircleCircleToken) && CircleCircleToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JToken NormalizeSerializedShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("radius", out var CircleRadiusToken) && CircleRadiusToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("width", out var RectangleWidthToken) && RectangleWidthToken.Type != JTokenType.Null || obj.TryGetValue("height", out var RectangleHeightToken) && RectangleHeightToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JToken NormalizeResponseShapeEnvelope(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (!obj.TryGetValue("kind", out var KindDiscriminator) || KindDiscriminator.Type == JTokenType.Null || string.IsNullOrEmpty(KindDiscriminator.Value())) + { + if (obj.TryGetValue("radius", out var CircleRadiusToken) && CircleRadiusToken.Type != JTokenType.Null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetValue("width", out var RectangleWidthToken) && RectangleWidthToken.Type != JTokenType.Null || obj.TryGetValue("height", out var RectangleHeightToken) && RectangleHeightToken.Type != JTokenType.Null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetValue("kind", out var KindSelected) && KindSelected.Type == JTokenType.String) + { + switch (KindSelected.Value()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JToken NormalizeSerializedWidget(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("payload", out var PayloadToken) && PayloadToken.Type == JTokenType.String) + { + obj["payload"] = ReencodeBytes(PayloadToken.Value()!, "base64", "hex"); + } + return obj; + } + + private static JToken NormalizeResponseWidget(JToken token) + { + if (token is not JObject obj) + { + return token; + } + if (obj.TryGetValue("payload", out var PayloadToken) && PayloadToken.Type == JTokenType.String) + { + obj["payload"] = ReencodeBytes(PayloadToken.Value()!, "hex", "base64"); + } + return obj; + } + + private static string EncodeBytes(byte[] bytes, string encoding) + { + var base64 = Convert.ToBase64String(bytes); + return encoding switch + { + "base64_raw" => base64.TrimEnd('='), + "base64url" => base64.Replace('+', '-').Replace('/', '_'), + "base64url_raw" => base64.Replace('+', '-').Replace('/', '_').TrimEnd('='), + "hex" => Convert.ToHexString(bytes).ToLowerInvariant(), + _ => base64 + }; + } + + private static string ReencodeBytes(string encoded, string fromEncoding, string toEncoding) + { + return EncodeBytes(DecodeBytes(encoded, fromEncoding), toEncoding); + } + + private static byte[] DecodeBytes(string encoded, string encoding) + { + return encoding switch + { + "hex" => Convert.FromHexString(encoded), + "base64url" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64url_raw" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64_raw" => Convert.FromBase64String(NormalizeBase64(encoded)), + _ => Convert.FromBase64String(NormalizeBase64(encoded)) + }; + } + + private static string NormalizeBase64(string value) + { + var remainder = value.Length % 4; + if (remainder == 0) + { + return value; + } + return value + new string('=', 4 - remainder); + } + + private static string FormatPathValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + private static string FormatQueryValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + } + + public static class ServiceContracts + { + public static class WidgetService + { + public const string Name = "WidgetService"; + public const string BasePath = "/api/v1"; + public static class GetWidget + { + public const string HttpMethod = "GET"; + public const string Path = "/api/v1/widgets/{id}"; + public const string RequestType = "GetWidgetRequest"; + public const string ResponseType = "Widget"; + } + public static class SearchWidgets + { + public const string HttpMethod = "GET"; + public const string Path = "/api/v1/widgets"; + public const string RequestType = "SearchWidgetsRequest"; + public const string ResponseType = "Widget"; + } + } + public static class AdminService + { + public const string Name = "AdminService"; + public const string BasePath = "/api/v1/admin"; + public static class ResetWidget + { + public const string HttpMethod = "POST"; + public const string Path = "/api/v1/admin/widgets/{id}:reset"; + public const string RequestType = "GetWidgetRequest"; + public const string ResponseType = "Empty"; + } + } + } +} diff --git a/internal/csharpgen/testdata/golden/Comprehensive.SystemTextJson.g.cs b/internal/csharpgen/testdata/golden/Comprehensive.SystemTextJson.g.cs new file mode 100644 index 00000000..cfe5b84b --- /dev/null +++ b/internal/csharpgen/testdata/golden/Comprehensive.SystemTextJson.g.cs @@ -0,0 +1,1095 @@ +// Code generated by protoc-gen-csharp-http. DO NOT EDIT. +#nullable enable +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Runtime.Serialization; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Nodes; + +namespace Test.Contracts +{ + public enum WidgetState + { + [EnumMember(Value = "STATE_UNSPECIFIED")] + StateUnspecified = 0, + [EnumMember(Value = "ready")] + StateReady = 1, + } + + public sealed class ApiException : Exception + { + public int StatusCode { get; } + public string ResponseBody { get; } + + public ApiException(int statusCode, string responseBody) + : base($\"Request failed with status {statusCode}: {responseBody}\") + { + StatusCode = statusCode; + ResponseBody = responseBody; + } + } + + public sealed class Widget + { + [JsonPropertyName("id")] + public string Id { get; set; } + [JsonPropertyName("display_name")] + public string? DisplayName { get; set; } + [JsonPropertyName("scores")] + public Dictionary Scores { get; set; } + [JsonPropertyName("meta")] + public Dictionary? Meta { get; set; } + [JsonPropertyName("created_at")] + public long? CreatedAt { get; set; } + [JsonPropertyName("alias")] + public string? Alias { get; set; } + [JsonPropertyName("state")] + public WidgetState State { get; set; } + [JsonPropertyName("meta_note")] + public string? MetaNote { get; set; } + [JsonPropertyName("tags")] + public List Tags { get; set; } + [JsonPropertyName("owner_id")] + public string OwnerId { get; set; } + [JsonPropertyName("payload")] + public byte[] Payload { get; set; } + [JsonPropertyName("version")] + public long Version { get; set; } + [JsonPropertyName("state_labels")] + public Dictionary StateLabels { get; set; } + [JsonPropertyName("profiles_by_id")] + public Dictionary ProfilesById { get; set; } + } + + public sealed class WidgetProfile + { + [JsonPropertyName("note")] + public string Note { get; set; } + } + + public sealed class ShapeEnvelope + { + [JsonPropertyName("kind")] + public string? Kind { get; set; } + [JsonPropertyName("radius")] + public double? Radius { get; set; } + [JsonPropertyName("width")] + public double? Width { get; set; } + [JsonPropertyName("height")] + public double? Height { get; set; } + } + + public sealed class ShapeEnvelopeCircle + { + [JsonPropertyName("radius")] + public double Radius { get; set; } + } + + public sealed class ShapeEnvelopeRectangle + { + [JsonPropertyName("width")] + public double Width { get; set; } + [JsonPropertyName("height")] + public double Height { get; set; } + } + + public sealed class NestedShapeEnvelope + { + [JsonPropertyName("kind")] + public string? Kind { get; set; } + [JsonPropertyName("circle")] + public NestedShapeEnvelopeNestedCircle? Circle { get; set; } + [JsonPropertyName("rectangle")] + public NestedShapeEnvelopeNestedRectangle? Rectangle { get; set; } + } + + public sealed class NestedShapeEnvelopeNestedCircle + { + [JsonPropertyName("radius")] + public double Radius { get; set; } + } + + public sealed class NestedShapeEnvelopeNestedRectangle + { + [JsonPropertyName("width")] + public double Width { get; set; } + [JsonPropertyName("height")] + public double Height { get; set; } + } + + public sealed class DeepNest + { + [JsonPropertyName("level1")] + public DeepNestLevel1? Level1 { get; set; } + } + + public sealed class DeepNestLevel1 + { + [JsonPropertyName("level2")] + public DeepNestLevel1Level2? Level2 { get; set; } + } + + public sealed class DeepNestLevel1Level2 + { + [JsonPropertyName("code")] + public string Code { get; set; } + } + + public sealed class WellKnownHolder + { + [JsonPropertyName("any_value")] + public object? AnyValue { get; set; } + [JsonPropertyName("ttl")] + public string? Ttl { get; set; } + [JsonPropertyName("mask")] + public string? Mask { get; set; } + [JsonPropertyName("items")] + public List? Items { get; set; } + [JsonPropertyName("raw_value")] + public object? RawValue { get; set; } + } + + public sealed class EmptyMessage + { + } + + public sealed class EmptyBehaviorHolder + { + [JsonPropertyName("metadata_preserve")] + public EmptyMessage? MetadataPreserve { get; set; } + [JsonPropertyName("metadata_null")] + public EmptyMessage? MetadataNull { get; set; } + [JsonPropertyName("metadata_omit")] + public EmptyMessage? MetadataOmit { get; set; } + } + + public sealed class TagList : List + { + } + + public sealed class GetWidgetRequest + { + [JsonPropertyName("id")] + public string Id { get; set; } + } + + public sealed class SearchWidgetsRequest + { + [JsonPropertyName("owner_id")] + public string OwnerId { get; set; } + [JsonPropertyName("tag_ids")] + public List TagIds { get; set; } + } + + public sealed class Empty + { + } + + public sealed class WidgetServiceClientOptions + { + public HttpClient? HttpClient { get; set; } + public Dictionary? DefaultHeaders { get; set; } + } + + public sealed class WidgetServiceCallOptions + { + public Dictionary? Headers { get; set; } + } + + public interface IWidgetServiceClient + { + Task GetWidgetAsync(GetWidgetRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default); + Task SearchWidgetsAsync(SearchWidgetsRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default); + } + + public sealed class WidgetServiceClient : IWidgetServiceClient + { + private readonly string _baseUrl; + private readonly HttpClient _httpClient; + private readonly Dictionary _defaultHeaders; + private static readonly JsonSerializerOptions JsonOptions = new() + { + PropertyNamingPolicy = null, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + public WidgetServiceClient(string baseUrl, WidgetServiceClientOptions? options = null) + { + _baseUrl = baseUrl.TrimEnd('/'); + _httpClient = options?.HttpClient ?? new HttpClient(); + _defaultHeaders = options?.DefaultHeaders is null + ? new Dictionary() + : new Dictionary(options.DefaultHeaders); + } + + public async Task GetWidgetAsync(GetWidgetRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default) + { + var path = "/api/v1/widgets/{id}"; + path = path.Replace("{id}", Uri.EscapeDataString(FormatPathValue(req.Id))); + var query = new List(); + var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query); + var headers = BuildHeaders(options); + return await SendAsync(HttpMethod.Get, requestUri, null, headers, cancellationToken); + } + + public async Task SearchWidgetsAsync(SearchWidgetsRequest req, WidgetServiceCallOptions? options = null, CancellationToken cancellationToken = default) + { + var path = "/api/v1/widgets"; + var query = new List(); + if (!string.IsNullOrEmpty(req.OwnerId)) + { + query.Add(Uri.EscapeDataString("owner") + "=" + Uri.EscapeDataString(FormatQueryValue(req.OwnerId))); + } + if (req.TagIds is not null) + { + foreach (var item in req.TagIds) + { + query.Add(Uri.EscapeDataString("tag_id") + "=" + Uri.EscapeDataString(FormatQueryValue(item))); + } + } + var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query); + var headers = BuildHeaders(options); + return await SendAsync(HttpMethod.Get, requestUri, null, headers, cancellationToken); + } + + private Dictionary BuildHeaders(WidgetServiceCallOptions? options) + { + var headers = new Dictionary(_defaultHeaders); + if (options?.Headers is not null) + { + foreach (var pair in options.Headers) + { + headers[pair.Key] = pair.Value; + } + } + return headers; + } + + private async Task SendAsync(HttpMethod method, string requestUri, object? body, Dictionary headers, CancellationToken cancellationToken) where TResponse : new() + { + using var request = new HttpRequestMessage(method, _baseUrl + requestUri); + foreach (var header in headers) + { + request.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + if (body is not null) + { + request.Content = new StringContent(SerializeRequest(body), Encoding.UTF8, "application/json"); + } + using var response = await _httpClient.SendAsync(request, cancellationToken); + var responseBody = response.Content is null ? string.Empty : await response.Content.ReadAsStringAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new ApiException((int)response.StatusCode, responseBody); + } + if (typeof(TResponse) == typeof(Empty) && string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + if (string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + var result = DeserializeResponse(responseBody); + return result is null ? new TResponse() : result; + } + + private string SerializeRequest(object value) + { + var json = JsonSerializer.Serialize(value, JsonOptions); + return NormalizeSerializedJson(value, json); + } + + private static TResponse? DeserializeResponse(string json) + { + json = NormalizeResponseJson(typeof(TResponse), json); + return JsonSerializer.Deserialize(json, JsonOptions); + } + + private static string NormalizeSerializedJson(object value, string json) + { + var token = JsonNode.Parse(json); + if (token is null) + { + return json; + } + var normalized = NormalizeSerializedNode(value.GetType(), token); + return normalized.ToJsonString(); + } + + private static string NormalizeResponseJson(Type responseType, string json) + { + var token = JsonNode.Parse(json); + if (token is null) + { + return json; + } + var normalized = NormalizeResponseNode(responseType, token); + return normalized.ToJsonString(); + } + + private static JsonNode NormalizeSerializedNode(Type messageType, JsonNode token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeSerializedEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeSerializedNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeSerializedShapeEnvelope(token), + "Widget" => NormalizeSerializedWidget(token), + _ => token + }; + } + private static JsonNode NormalizeResponseNode(Type messageType, JsonNode token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeResponseEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeResponseNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeResponseShapeEnvelope(token), + "Widget" => NormalizeResponseWidget(token), + _ => token + }; + } + private static JsonNode NormalizeMapValueForSerialization(JsonNode token, Type messageType) + { + return messageType.Name switch + { + "Widget" => token is JsonObject obj && obj["tags"] is JsonNode value ? value : token, + _ => NormalizeSerializedNode(messageType, token) + }; + } + private static JsonNode NormalizeMapValueForResponse(JsonNode token, Type messageType) + { + return messageType.Name switch + { + "Widget" => new JsonObject { ["tags"] = token.DeepClone() }, + _ => NormalizeResponseNode(messageType, token) + }; + } + private static bool IsEmptyObject(JsonNode? token) + { + return token is JsonObject obj && obj.Count == 0; + } + private static bool ShouldOmitEmptyField(JsonNode? token) + { + return token is null || IsEmptyObject(token); + } + + private static JsonNode NormalizeSerializedEmptyBehaviorHolder(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj.TryGetPropertyValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = null; + } + } + if (obj.TryGetPropertyValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (ShouldOmitEmptyField(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JsonNode NormalizeResponseEmptyBehaviorHolder(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj.TryGetPropertyValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = null; + } + } + if (obj.TryGetPropertyValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (IsEmptyObject(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JsonNode NormalizeSerializedNestedShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("circle", out var CircleCircleToken) && CircleCircleToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeResponseNestedShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("circle", out var CircleCircleToken) && CircleCircleToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeSerializedShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("radius", out var CircleRadiusToken) && CircleRadiusToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("width", out var RectangleWidthToken) && RectangleWidthToken is not null || obj.TryGetPropertyValue("height", out var RectangleHeightToken) && RectangleHeightToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeResponseShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("radius", out var CircleRadiusToken) && CircleRadiusToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("width", out var RectangleWidthToken) && RectangleWidthToken is not null || obj.TryGetPropertyValue("height", out var RectangleHeightToken) && RectangleHeightToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeSerializedWidget(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj["payload"] is JsonValue PayloadToken && PayloadToken.TryGetValue(out var PayloadValue)) + { + obj["payload"] = ReencodeBytes(PayloadValue, "base64", "hex"); + } + return obj; + } + + private static JsonNode NormalizeResponseWidget(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj["payload"] is JsonValue PayloadToken && PayloadToken.TryGetValue(out var PayloadValue)) + { + obj["payload"] = ReencodeBytes(PayloadValue, "hex", "base64"); + } + return obj; + } + + private static string EncodeBytes(byte[] bytes, string encoding) + { + var base64 = Convert.ToBase64String(bytes); + return encoding switch + { + "base64_raw" => base64.TrimEnd('='), + "base64url" => base64.Replace('+', '-').Replace('/', '_'), + "base64url_raw" => base64.Replace('+', '-').Replace('/', '_').TrimEnd('='), + "hex" => Convert.ToHexString(bytes).ToLowerInvariant(), + _ => base64 + }; + } + + private static string ReencodeBytes(string encoded, string fromEncoding, string toEncoding) + { + return EncodeBytes(DecodeBytes(encoded, fromEncoding), toEncoding); + } + + private static byte[] DecodeBytes(string encoded, string encoding) + { + return encoding switch + { + "hex" => Convert.FromHexString(encoded), + "base64url" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64url_raw" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64_raw" => Convert.FromBase64String(NormalizeBase64(encoded)), + _ => Convert.FromBase64String(NormalizeBase64(encoded)) + }; + } + + private static string NormalizeBase64(string value) + { + var remainder = value.Length % 4; + if (remainder == 0) + { + return value; + } + return value + new string('=', 4 - remainder); + } + + private static string FormatPathValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + private static string FormatQueryValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + } + + public sealed class AdminServiceClientOptions + { + public HttpClient? HttpClient { get; set; } + public Dictionary? DefaultHeaders { get; set; } + } + + public sealed class AdminServiceCallOptions + { + public Dictionary? Headers { get; set; } + } + + public interface IAdminServiceClient + { + Task ResetWidgetAsync(GetWidgetRequest req, AdminServiceCallOptions? options = null, CancellationToken cancellationToken = default); + } + + public sealed class AdminServiceClient : IAdminServiceClient + { + private readonly string _baseUrl; + private readonly HttpClient _httpClient; + private readonly Dictionary _defaultHeaders; + private static readonly JsonSerializerOptions JsonOptions = new() + { + PropertyNamingPolicy = null, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + public AdminServiceClient(string baseUrl, AdminServiceClientOptions? options = null) + { + _baseUrl = baseUrl.TrimEnd('/'); + _httpClient = options?.HttpClient ?? new HttpClient(); + _defaultHeaders = options?.DefaultHeaders is null + ? new Dictionary() + : new Dictionary(options.DefaultHeaders); + } + + public async Task ResetWidgetAsync(GetWidgetRequest req, AdminServiceCallOptions? options = null, CancellationToken cancellationToken = default) + { + var path = "/api/v1/admin/widgets/{id}:reset"; + path = path.Replace("{id}", Uri.EscapeDataString(FormatPathValue(req.Id))); + var query = new List(); + var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query); + var headers = BuildHeaders(options); + return await SendAsync(HttpMethod.Post, requestUri, req, headers, cancellationToken); + } + + private Dictionary BuildHeaders(AdminServiceCallOptions? options) + { + var headers = new Dictionary(_defaultHeaders); + if (options?.Headers is not null) + { + foreach (var pair in options.Headers) + { + headers[pair.Key] = pair.Value; + } + } + return headers; + } + + private async Task SendAsync(HttpMethod method, string requestUri, object? body, Dictionary headers, CancellationToken cancellationToken) where TResponse : new() + { + using var request = new HttpRequestMessage(method, _baseUrl + requestUri); + foreach (var header in headers) + { + request.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + if (body is not null) + { + request.Content = new StringContent(SerializeRequest(body), Encoding.UTF8, "application/json"); + } + using var response = await _httpClient.SendAsync(request, cancellationToken); + var responseBody = response.Content is null ? string.Empty : await response.Content.ReadAsStringAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new ApiException((int)response.StatusCode, responseBody); + } + if (typeof(TResponse) == typeof(Empty) && string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + if (string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + var result = DeserializeResponse(responseBody); + return result is null ? new TResponse() : result; + } + + private string SerializeRequest(object value) + { + var json = JsonSerializer.Serialize(value, JsonOptions); + return NormalizeSerializedJson(value, json); + } + + private static TResponse? DeserializeResponse(string json) + { + json = NormalizeResponseJson(typeof(TResponse), json); + return JsonSerializer.Deserialize(json, JsonOptions); + } + + private static string NormalizeSerializedJson(object value, string json) + { + var token = JsonNode.Parse(json); + if (token is null) + { + return json; + } + var normalized = NormalizeSerializedNode(value.GetType(), token); + return normalized.ToJsonString(); + } + + private static string NormalizeResponseJson(Type responseType, string json) + { + var token = JsonNode.Parse(json); + if (token is null) + { + return json; + } + var normalized = NormalizeResponseNode(responseType, token); + return normalized.ToJsonString(); + } + + private static JsonNode NormalizeSerializedNode(Type messageType, JsonNode token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeSerializedEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeSerializedNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeSerializedShapeEnvelope(token), + "Widget" => NormalizeSerializedWidget(token), + _ => token + }; + } + private static JsonNode NormalizeResponseNode(Type messageType, JsonNode token) + { + return messageType.Name switch + { + "EmptyBehaviorHolder" => NormalizeResponseEmptyBehaviorHolder(token), + "NestedShapeEnvelope" => NormalizeResponseNestedShapeEnvelope(token), + "ShapeEnvelope" => NormalizeResponseShapeEnvelope(token), + "Widget" => NormalizeResponseWidget(token), + _ => token + }; + } + private static JsonNode NormalizeMapValueForSerialization(JsonNode token, Type messageType) + { + return messageType.Name switch + { + "Widget" => token is JsonObject obj && obj["tags"] is JsonNode value ? value : token, + _ => NormalizeSerializedNode(messageType, token) + }; + } + private static JsonNode NormalizeMapValueForResponse(JsonNode token, Type messageType) + { + return messageType.Name switch + { + "Widget" => new JsonObject { ["tags"] = token.DeepClone() }, + _ => NormalizeResponseNode(messageType, token) + }; + } + private static bool IsEmptyObject(JsonNode? token) + { + return token is JsonObject obj && obj.Count == 0; + } + private static bool ShouldOmitEmptyField(JsonNode? token) + { + return token is null || IsEmptyObject(token); + } + + private static JsonNode NormalizeSerializedEmptyBehaviorHolder(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj.TryGetPropertyValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = null; + } + } + if (obj.TryGetPropertyValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (ShouldOmitEmptyField(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JsonNode NormalizeResponseEmptyBehaviorHolder(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj.TryGetPropertyValue("metadataNull", out var MetadatanullEmptyBehavior)) + { + if (IsEmptyObject(MetadatanullEmptyBehavior)) + { + obj["metadataNull"] = null; + } + } + if (obj.TryGetPropertyValue("metadataOmit", out var MetadataomitEmptyBehavior)) + { + if (IsEmptyObject(MetadataomitEmptyBehavior)) + { + obj.Remove("metadataOmit"); + } + } + return obj; + } + + private static JsonNode NormalizeSerializedNestedShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("circle", out var CircleCircleToken) && CircleCircleToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeResponseNestedShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("circle", out var CircleCircleToken) && CircleCircleToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("rectangle", out var RectangleRectangleToken) && RectangleRectangleToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("rectangle"); + break; + case "rectangle": + obj.Remove("circle"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeSerializedShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("radius", out var CircleRadiusToken) && CircleRadiusToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("width", out var RectangleWidthToken) && RectangleWidthToken is not null || obj.TryGetPropertyValue("height", out var RectangleHeightToken) && RectangleHeightToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeResponseShapeEnvelope(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (!obj.TryGetPropertyValue("kind", out var KindDiscriminator) || KindDiscriminator is null || string.IsNullOrEmpty(KindDiscriminator.GetValue())) + { + if (obj.TryGetPropertyValue("radius", out var CircleRadiusToken) && CircleRadiusToken is not null) + { + obj["kind"] = "circle_shape"; + } + else if (obj.TryGetPropertyValue("width", out var RectangleWidthToken) && RectangleWidthToken is not null || obj.TryGetPropertyValue("height", out var RectangleHeightToken) && RectangleHeightToken is not null) + { + obj["kind"] = "rectangle"; + } + } + if (obj.TryGetPropertyValue("kind", out var KindSelected) && KindSelected is JsonValue) + { + switch (KindSelected!.GetValue()) + { + case "circle_shape": + obj.Remove("width"); + obj.Remove("height"); + break; + case "rectangle": + obj.Remove("radius"); + break; + } + } + return obj; + } + + private static JsonNode NormalizeSerializedWidget(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj["payload"] is JsonValue PayloadToken && PayloadToken.TryGetValue(out var PayloadValue)) + { + obj["payload"] = ReencodeBytes(PayloadValue, "base64", "hex"); + } + return obj; + } + + private static JsonNode NormalizeResponseWidget(JsonNode token) + { + if (token is not JsonObject obj) + { + return token; + } + if (obj["payload"] is JsonValue PayloadToken && PayloadToken.TryGetValue(out var PayloadValue)) + { + obj["payload"] = ReencodeBytes(PayloadValue, "hex", "base64"); + } + return obj; + } + + private static string EncodeBytes(byte[] bytes, string encoding) + { + var base64 = Convert.ToBase64String(bytes); + return encoding switch + { + "base64_raw" => base64.TrimEnd('='), + "base64url" => base64.Replace('+', '-').Replace('/', '_'), + "base64url_raw" => base64.Replace('+', '-').Replace('/', '_').TrimEnd('='), + "hex" => Convert.ToHexString(bytes).ToLowerInvariant(), + _ => base64 + }; + } + + private static string ReencodeBytes(string encoded, string fromEncoding, string toEncoding) + { + return EncodeBytes(DecodeBytes(encoded, fromEncoding), toEncoding); + } + + private static byte[] DecodeBytes(string encoded, string encoding) + { + return encoding switch + { + "hex" => Convert.FromHexString(encoded), + "base64url" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64url_raw" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64_raw" => Convert.FromBase64String(NormalizeBase64(encoded)), + _ => Convert.FromBase64String(NormalizeBase64(encoded)) + }; + } + + private static string NormalizeBase64(string value) + { + var remainder = value.Length % 4; + if (remainder == 0) + { + return value; + } + return value + new string('=', 4 - remainder); + } + + private static string FormatPathValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + private static string FormatQueryValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + } + + public static class ServiceContracts + { + public static class WidgetService + { + public const string Name = "WidgetService"; + public const string BasePath = "/api/v1"; + public static class GetWidget + { + public const string HttpMethod = "GET"; + public const string Path = "/api/v1/widgets/{id}"; + public const string RequestType = "GetWidgetRequest"; + public const string ResponseType = "Widget"; + } + public static class SearchWidgets + { + public const string HttpMethod = "GET"; + public const string Path = "/api/v1/widgets"; + public const string RequestType = "SearchWidgetsRequest"; + public const string ResponseType = "Widget"; + } + } + public static class AdminService + { + public const string Name = "AdminService"; + public const string BasePath = "/api/v1/admin"; + public static class ResetWidget + { + public const string HttpMethod = "POST"; + public const string Path = "/api/v1/admin/widgets/{id}:reset"; + public const string RequestType = "GetWidgetRequest"; + public const string ResponseType = "Empty"; + } + } + } +} diff --git a/internal/csharpgen/testdata/golden/Contracts.g.cs b/internal/csharpgen/testdata/golden/Contracts.g.cs new file mode 100644 index 00000000..ad156424 --- /dev/null +++ b/internal/csharpgen/testdata/golden/Contracts.g.cs @@ -0,0 +1,238 @@ +// Code generated by protoc-gen-csharp-http. DO NOT EDIT. +#nullable enable +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Runtime.Serialization; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; + +namespace Test.Contracts +{ + public enum ItemState + { + [EnumMember(Value = "STATE_UNSPECIFIED")] + StateUnspecified = 0, + [EnumMember(Value = "STATE_READY")] + StateReady = 1, + } + + public sealed class ApiException : Exception + { + public int StatusCode { get; } + public string ResponseBody { get; } + + public ApiException(int statusCode, string responseBody) + : base($\"Request failed with status {statusCode}: {responseBody}\") + { + StatusCode = statusCode; + ResponseBody = responseBody; + } + } + + public sealed class Item + { + [JsonProperty("id")] + public string Id { get; set; } + [JsonProperty("tags")] + public List Tags { get; set; } + [JsonProperty("meta")] + public Dictionary? Meta { get; set; } + [JsonProperty("state")] + [JsonConverter(typeof(StringEnumConverter))] + public ItemState State { get; set; } + [JsonProperty("details")] + public ItemDetails? Details { get; set; } + } + + public sealed class ItemDetails + { + [JsonProperty("note")] + public string Note { get; set; } + [JsonProperty("scores")] + public List Scores { get; set; } + } + + public sealed class FetchItemRequest + { + [JsonProperty("id")] + public string Id { get; set; } + } + + public sealed class ContractServiceClientOptions + { + public HttpClient? HttpClient { get; set; } + public Dictionary? DefaultHeaders { get; set; } + } + + public sealed class ContractServiceCallOptions + { + public Dictionary? Headers { get; set; } + } + + public interface IContractServiceClient + { + Task FetchItemAsync(FetchItemRequest req, ContractServiceCallOptions? options = null, CancellationToken cancellationToken = default); + } + + public sealed class ContractServiceClient : IContractServiceClient + { + private readonly string _baseUrl; + private readonly HttpClient _httpClient; + private readonly Dictionary _defaultHeaders; + + public ContractServiceClient(string baseUrl, ContractServiceClientOptions? options = null) + { + _baseUrl = baseUrl.TrimEnd('/'); + _httpClient = options?.HttpClient ?? new HttpClient(); + _defaultHeaders = options?.DefaultHeaders is null + ? new Dictionary() + : new Dictionary(options.DefaultHeaders); + } + + public async Task FetchItemAsync(FetchItemRequest req, ContractServiceCallOptions? options = null, CancellationToken cancellationToken = default) + { + var path = "/"; + var query = new List(); + var requestUri = query.Count == 0 ? path : path + "?" + string.Join("&", query); + var headers = BuildHeaders(options); + return await SendAsync(HttpMethod.Post, requestUri, req, headers, cancellationToken); + } + + private Dictionary BuildHeaders(ContractServiceCallOptions? options) + { + var headers = new Dictionary(_defaultHeaders); + if (options?.Headers is not null) + { + foreach (var pair in options.Headers) + { + headers[pair.Key] = pair.Value; + } + } + return headers; + } + + private async Task SendAsync(HttpMethod method, string requestUri, object? body, Dictionary headers, CancellationToken cancellationToken) where TResponse : new() + { + using var request = new HttpRequestMessage(method, _baseUrl + requestUri); + foreach (var header in headers) + { + request.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + if (body is not null) + { + request.Content = new StringContent(SerializeRequest(body), Encoding.UTF8, "application/json"); + } + using var response = await _httpClient.SendAsync(request, cancellationToken); + var responseBody = response.Content is null ? string.Empty : await response.Content.ReadAsStringAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new ApiException((int)response.StatusCode, responseBody); + } + if (typeof(TResponse) == typeof(Empty) && string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + if (string.IsNullOrWhiteSpace(responseBody)) + { + return new TResponse(); + } + var result = DeserializeResponse(responseBody); + return result is null ? new TResponse() : result; + } + + private string SerializeRequest(object value) + { + var json = JsonConvert.SerializeObject(value); + return NormalizeSerializedJson(value, json); + } + + private static TResponse? DeserializeResponse(string json) + { + json = NormalizeResponseJson(typeof(TResponse), json); + return JsonConvert.DeserializeObject(json); + } + + private static string NormalizeSerializedJson(object value, string json) + { + return json; + } + + private static string NormalizeResponseJson(Type responseType, string json) + { + return json; + } + + private static string EncodeBytes(byte[] bytes, string encoding) + { + var base64 = Convert.ToBase64String(bytes); + return encoding switch + { + "base64_raw" => base64.TrimEnd('='), + "base64url" => base64.Replace('+', '-').Replace('/', '_'), + "base64url_raw" => base64.Replace('+', '-').Replace('/', '_').TrimEnd('='), + "hex" => Convert.ToHexString(bytes).ToLowerInvariant(), + _ => base64 + }; + } + + private static string ReencodeBytes(string encoded, string fromEncoding, string toEncoding) + { + return EncodeBytes(DecodeBytes(encoded, fromEncoding), toEncoding); + } + + private static byte[] DecodeBytes(string encoded, string encoding) + { + return encoding switch + { + "hex" => Convert.FromHexString(encoded), + "base64url" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64url_raw" => Convert.FromBase64String(NormalizeBase64(encoded.Replace('-', '+').Replace('_', '/'))), + "base64_raw" => Convert.FromBase64String(NormalizeBase64(encoded)), + _ => Convert.FromBase64String(NormalizeBase64(encoded)) + }; + } + + private static string NormalizeBase64(string value) + { + var remainder = value.Length % 4; + if (remainder == 0) + { + return value; + } + return value + new string('=', 4 - remainder); + } + + private static string FormatPathValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + private static string FormatQueryValue(object? value) + { + return value?.ToString() ?? string.Empty; + } + + } + + public static class ServiceContracts + { + public static class ContractService + { + public const string Name = "ContractService"; + public const string BasePath = ""; + public static class FetchItem + { + public const string HttpMethod = "POST"; + public const string Path = "/"; + public const string RequestType = "FetchItemRequest"; + public const string ResponseType = "Item"; + } + } + } +} diff --git a/internal/csharpgen/testdata/proto/comprehensive_models.proto b/internal/csharpgen/testdata/proto/comprehensive_models.proto new file mode 100644 index 00000000..1b3fcbbd --- /dev/null +++ b/internal/csharpgen/testdata/proto/comprehensive_models.proto @@ -0,0 +1,113 @@ +syntax = "proto3"; + +package test.contracts.v1; + +option go_package = "github.com/SebastienMelki/sebuf/internal/testcontracts;testcontracts"; + +import "google/protobuf/any.proto"; +import "google/protobuf/duration.proto"; +import "google/protobuf/field_mask.proto"; +import "google/protobuf/struct.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; +import "sebuf/http/annotations.proto"; + +message Widget { + string id = 1; + optional string display_name = 2 [(sebuf.http.nullable) = true]; + map scores = 3; + google.protobuf.Struct meta = 4; + google.protobuf.Timestamp created_at = 5 [(sebuf.http.timestamp_format) = TIMESTAMP_FORMAT_UNIX_MILLIS]; + google.protobuf.StringValue alias = 6; + State state = 7 [(sebuf.http.enum_encoding) = ENUM_ENCODING_NUMBER]; + Profile profile = 8 [(sebuf.http.flatten) = true, (sebuf.http.flatten_prefix) = "meta_"]; + repeated string tags = 9 [(sebuf.http.unwrap) = true]; + string owner_id = 10 [(sebuf.http.query) = { name: "owner", required: true }]; + bytes payload = 11 [(sebuf.http.bytes_encoding) = BYTES_ENCODING_HEX]; + int64 version = 12 [(sebuf.http.int64_encoding) = INT64_ENCODING_NUMBER]; + map state_labels = 13; + map profiles_by_id = 14; + + enum State { + STATE_UNSPECIFIED = 0; + STATE_READY = 1 [(sebuf.http.enum_value) = "ready"]; + } + + message Profile { + string note = 1; + } +} + +message ShapeEnvelope { + oneof shape { + option (sebuf.http.oneof_config) = { + discriminator: "kind" + flatten: true + }; + + Circle circle = 1 [(sebuf.http.oneof_value) = "circle_shape"]; + Rectangle rectangle = 2; + } + + message Circle { + double radius = 1; + } + + message Rectangle { + double width = 1; + double height = 2; + } +} + +message NestedShapeEnvelope { + oneof shape { + option (sebuf.http.oneof_config) = { + discriminator: "kind" + flatten: false + }; + + NestedCircle circle = 1 [(sebuf.http.oneof_value) = "circle_shape"]; + NestedRectangle rectangle = 2; + } + + message NestedCircle { + double radius = 1; + } + + message NestedRectangle { + double width = 1; + double height = 2; + } +} + +message DeepNest { + Level1 level1 = 1; + + message Level1 { + Level2 level2 = 1; + + message Level2 { + string code = 1; + } + } +} + +message WellKnownHolder { + google.protobuf.Any any_value = 1; + google.protobuf.Duration ttl = 2; + google.protobuf.FieldMask mask = 3; + google.protobuf.ListValue items = 4; + google.protobuf.Value raw_value = 5; +} + +message EmptyMessage {} + +message EmptyBehaviorHolder { + EmptyMessage metadata_preserve = 1; + EmptyMessage metadata_null = 2 [(sebuf.http.empty_behavior) = EMPTY_BEHAVIOR_NULL]; + EmptyMessage metadata_omit = 3 [(sebuf.http.empty_behavior) = EMPTY_BEHAVIOR_OMIT]; +} + +message TagList { + repeated string values = 1 [(sebuf.http.unwrap) = true]; +} diff --git a/internal/csharpgen/testdata/proto/comprehensive_services.proto b/internal/csharpgen/testdata/proto/comprehensive_services.proto new file mode 100644 index 00000000..cdf52c76 --- /dev/null +++ b/internal/csharpgen/testdata/proto/comprehensive_services.proto @@ -0,0 +1,51 @@ +syntax = "proto3"; + +package test.contracts.v1; + +option go_package = "github.com/SebastienMelki/sebuf/internal/testcontracts;testcontracts"; + +import "google/protobuf/empty.proto"; +import "sebuf/http/annotations.proto"; +import "comprehensive_models.proto"; + +message GetWidgetRequest { + string id = 1; +} + +message SearchWidgetsRequest { + string owner_id = 1 [(sebuf.http.query) = { name: "owner" }]; + repeated string tag_ids = 2 [(sebuf.http.query) = { name: "tag_id" }]; +} + +service WidgetService { + option (sebuf.http.service_config) = { + base_path: "/api/v1" + }; + + rpc GetWidget(GetWidgetRequest) returns (Widget) { + option (sebuf.http.config) = { + method: HTTP_METHOD_GET + path: "/widgets/{id}" + }; + } + + rpc SearchWidgets(SearchWidgetsRequest) returns (Widget) { + option (sebuf.http.config) = { + method: HTTP_METHOD_GET + path: "/widgets" + }; + } +} + +service AdminService { + option (sebuf.http.service_config) = { + base_path: "/api/v1/admin" + }; + + rpc ResetWidget(GetWidgetRequest) returns (google.protobuf.Empty) { + option (sebuf.http.config) = { + method: HTTP_METHOD_POST + path: "/widgets/{id}:reset" + }; + } +} diff --git a/internal/csharpgen/testdata/proto/contracts.proto b/internal/csharpgen/testdata/proto/contracts.proto new file mode 100644 index 00000000..bc78d73d --- /dev/null +++ b/internal/csharpgen/testdata/proto/contracts.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package test.contracts.v1; + +option go_package = "github.com/SebastienMelki/sebuf/internal/testcontracts;testcontracts"; + +import "google/protobuf/struct.proto"; + +message Item { + string id = 1; + repeated string tags = 2; + google.protobuf.Struct meta = 3; + State state = 4; + Details details = 5; + + enum State { + STATE_UNSPECIFIED = 0; + STATE_READY = 1; + } + + message Details { + string note = 1; + repeated int32 scores = 2; + } +} + +message FetchItemRequest { + string id = 1; +} + +service ContractService { + rpc FetchItem(FetchItemRequest) returns (Item); +} diff --git a/internal/httpgen/generator.go b/internal/httpgen/generator.go index 0ba1611d..28e22fc3 100644 --- a/internal/httpgen/generator.go +++ b/internal/httpgen/generator.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" "strings" + "unicode" "google.golang.org/protobuf/compiler/protogen" @@ -887,18 +888,21 @@ func (g *Generator) getPathParams(method *protogen.Method) []string { } func camelToSnake(s string) string { - var result []byte + const snakeGrowthFactor = 2 + + var result strings.Builder + result.Grow(len(s) * snakeGrowthFactor) for i, r := range s { - if r >= 'A' && r <= 'Z' { + if unicode.IsUpper(r) { if i > 0 { - result = append(result, '_') + result.WriteByte('_') } - result = append(result, byte(r+'a'-'A')) + result.WriteRune(unicode.ToLower(r)) } else { - result = append(result, byte(r)) + result.WriteRune(r) } } - return string(result) + return result.String() } // generateErrorResponseFunctions generates error response helper functions. diff --git a/internal/httpgen/golden_test.go b/internal/httpgen/golden_test.go index a3794224..55233ec3 100644 --- a/internal/httpgen/golden_test.go +++ b/internal/httpgen/golden_test.go @@ -6,10 +6,10 @@ import ( "os" "os/exec" "path/filepath" - "strings" "testing" "github.com/SebastienMelki/sebuf/internal/annotations" + "github.com/SebastienMelki/sebuf/internal/testutil" ) // TestHTTPGenGoldenFiles tests HTTP handler generation against golden files. @@ -260,55 +260,10 @@ func compareGoldenFile(t *testing.T, expectedFile, goldenPath string, generatedC "Run with UPDATE_GOLDEN=1 to update golden files after reviewing changes.\n"+ "Diff:\n%s", expectedFile, - diffStrings(string(goldenContent), string(generatedContent))) + testutil.DiffStrings(string(goldenContent), string(generatedContent))) } } -// diffStrings returns a simple diff between two strings. -func diffStrings(expected, actual string) string { - expectedLines := strings.Split(expected, "\n") - actualLines := strings.Split(actual, "\n") - - var diff strings.Builder - maxLines := len(expectedLines) - if len(actualLines) > maxLines { - maxLines = len(actualLines) - } - - diffCount := 0 - const maxDiffs = 20 - - for i := 0; i < maxLines && diffCount < maxDiffs; i++ { - var expLine, actLine string - if i < len(expectedLines) { - expLine = expectedLines[i] - } - if i < len(actualLines) { - actLine = actualLines[i] - } - - if expLine != actLine { - diff.WriteString("Line ") - diff.WriteRune(rune('0' + i/100)) - diff.WriteRune(rune('0' + (i/10)%10)) - diff.WriteRune(rune('0' + i%10)) - diff.WriteString(":\n") - diff.WriteString(" expected: ") - diff.WriteString(expLine) - diff.WriteString("\n actual: ") - diff.WriteString(actLine) - diff.WriteString("\n") - diffCount++ - } - } - - if diffCount >= maxDiffs { - diff.WriteString("... (more differences truncated)\n") - } - - return diff.String() -} - // TestHTTPGenValidation tests that invalid configurations produce expected errors. func TestHTTPGenValidation(t *testing.T) { // These tests verify validation error messages are clear and actionable diff --git a/internal/testutil/diff.go b/internal/testutil/diff.go new file mode 100644 index 00000000..bbc518fd --- /dev/null +++ b/internal/testutil/diff.go @@ -0,0 +1,49 @@ +package testutil + +import ( + "strconv" + "strings" +) + +// DiffStrings returns a simple line-by-line diff capped at a small number of entries. +func DiffStrings(expected, actual string) string { + expectedLines := strings.Split(expected, "\n") + actualLines := strings.Split(actual, "\n") + + var diff strings.Builder + maxLines := len(expectedLines) + if len(actualLines) > maxLines { + maxLines = len(actualLines) + } + + diffCount := 0 + const maxDiffs = 20 + + for i := 0; i < maxLines && diffCount < maxDiffs; i++ { + var expLine, actLine string + if i < len(expectedLines) { + expLine = expectedLines[i] + } + if i < len(actualLines) { + actLine = actualLines[i] + } + + if expLine != actLine { + diff.WriteString("Line ") + diff.WriteString(strconv.Itoa(i + 1)) + diff.WriteString(":\n") + diff.WriteString(" expected: ") + diff.WriteString(expLine) + diff.WriteString("\n actual: ") + diff.WriteString(actLine) + diff.WriteString("\n") + diffCount++ + } + } + + if diffCount >= maxDiffs { + diff.WriteString("... (more differences truncated)\n") + } + + return diff.String() +} diff --git a/internal/tsclientgen/golden_test.go b/internal/tsclientgen/golden_test.go index f5a6249b..c664b461 100644 --- a/internal/tsclientgen/golden_test.go +++ b/internal/tsclientgen/golden_test.go @@ -5,8 +5,9 @@ import ( "os" "os/exec" "path/filepath" - "strings" "testing" + + "github.com/SebastienMelki/sebuf/internal/testutil" ) // TestTSClientGenGoldenFiles tests TypeScript client generation against golden files. @@ -218,50 +219,6 @@ func compareGoldenFile(t *testing.T, expectedFile, goldenPath string, generatedC "Run with UPDATE_GOLDEN=1 to update golden files after reviewing changes.\n"+ "Diff:\n%s", expectedFile, - diffStrings(string(goldenContent), string(generatedContent))) - } -} - -func diffStrings(expected, actual string) string { - expectedLines := strings.Split(expected, "\n") - actualLines := strings.Split(actual, "\n") - - var diff strings.Builder - maxLines := len(expectedLines) - if len(actualLines) > maxLines { - maxLines = len(actualLines) + testutil.DiffStrings(string(goldenContent), string(generatedContent))) } - - diffCount := 0 - const maxDiffs = 20 - - for i := 0; i < maxLines && diffCount < maxDiffs; i++ { - var expLine, actLine string - if i < len(expectedLines) { - expLine = expectedLines[i] - } - if i < len(actualLines) { - actLine = actualLines[i] - } - - if expLine != actLine { - diff.WriteString("Line ") - diff.WriteRune(rune('0' + i/100)) - diff.WriteRune(rune('0' + (i/10)%10)) - diff.WriteRune(rune('0' + i%10)) - diff.WriteString(":\n") - diff.WriteString(" expected: ") - diff.WriteString(expLine) - diff.WriteString("\n actual: ") - diff.WriteString(actLine) - diff.WriteString("\n") - diffCount++ - } - } - - if diffCount >= maxDiffs { - diff.WriteString("... (more differences truncated)\n") - } - - return diff.String() } diff --git a/internal/tsservergen/golden_test.go b/internal/tsservergen/golden_test.go index 58f05697..5d8240b7 100644 --- a/internal/tsservergen/golden_test.go +++ b/internal/tsservergen/golden_test.go @@ -7,6 +7,8 @@ import ( "path/filepath" "strings" "testing" + + "github.com/SebastienMelki/sebuf/internal/testutil" ) // TestTSServerGenGoldenFiles tests TypeScript server generation against golden files. @@ -218,7 +220,7 @@ func compareGoldenFile(t *testing.T, expectedFile, goldenPath string, generatedC "Run with UPDATE_GOLDEN=1 to update golden files after reviewing changes.\n"+ "Diff:\n%s", expectedFile, - diffStrings(string(goldenContent), string(generatedContent))) + testutil.DiffStrings(string(goldenContent), string(generatedContent))) } } @@ -293,47 +295,3 @@ func TestTSServerGenValidationErrors(t *testing.T) { }) } } - -func diffStrings(expected, actual string) string { - expectedLines := strings.Split(expected, "\n") - actualLines := strings.Split(actual, "\n") - - var diff strings.Builder - maxLines := len(expectedLines) - if len(actualLines) > maxLines { - maxLines = len(actualLines) - } - - diffCount := 0 - const maxDiffs = 20 - - for i := 0; i < maxLines && diffCount < maxDiffs; i++ { - var expLine, actLine string - if i < len(expectedLines) { - expLine = expectedLines[i] - } - if i < len(actualLines) { - actLine = actualLines[i] - } - - if expLine != actLine { - diff.WriteString("Line ") - diff.WriteRune(rune('0' + i/100)) - diff.WriteRune(rune('0' + (i/10)%10)) - diff.WriteRune(rune('0' + i%10)) - diff.WriteString(":\n") - diff.WriteString(" expected: ") - diff.WriteString(expLine) - diff.WriteString("\n actual: ") - diff.WriteString(actLine) - diff.WriteString("\n") - diffCount++ - } - } - - if diffCount >= maxDiffs { - diff.WriteString("... (more differences truncated)\n") - } - - return diff.String() -}