diff --git a/README.md b/README.md index e7c24778..b32d63ec 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ on the real value-add for your organization. ### Configuration & Filtering - **YAML-based configuration** with JSON schema validation -- **Flexible filtering** - Include/exclude by paths, tags, operation IDs, schema properties, or extensions +- **Flexible filtering** - Include/exclude by paths, tags, operation IDs, webhooks, schema properties, or extensions - **Transitive pruning** - Automatically remove schemas that are only referenced by filtered-out properties - **[OpenAPI Overlays](https://doordash-oss.github.io/oapi-codegen-dd/overlays/)** - Modify specs without editing originals (add extensions, remove paths) - **External file `$ref` resolution** - Split specs across multiple files with relative `$ref` references (auto-detected from spec path, or set `base-path` in config) @@ -85,7 +85,8 @@ Tested against [2,137 real-world OpenAPI 3.0 specs](https://github.com/mockzilla | **Name Conflicts** | Manual fix required | Automatic resolution | | **Validation** | None | `Validate()` methods | | **Server Scaffold** | Interface only, manual boilerplate | Full typed solution (service, middleware, main.go) | -| **Filtering** | Tags, operation IDs | + Paths, extensions, schema properties | +| **Webhooks** | Not supported | Full type generation with filtering | +| **Filtering** | Tags, operation IDs | + Paths, webhooks, extensions, schema properties | | **Overlays** | Single | Multiple, applied in order | | **Output** | Single file | Single or multiple files | | **Templates** | Monolithic | Composable with `{{define}}` blocks | diff --git a/configuration-schema.json b/configuration-schema.json index 3bf544b3..cabc1fa5 100644 --- a/configuration-schema.json +++ b/configuration-schema.json @@ -139,11 +139,11 @@ "properties": { "include": { "$ref": "#/definitions/FilterParamsConfig", - "description": "Paths, tags, operation IDs, and schema properties to include." + "description": "Paths, tags, operation IDs, webhooks, and schema properties to include." }, "exclude": { "$ref": "#/definitions/FilterParamsConfig", - "description": "Paths, tags, operation IDs, and schema properties to exclude." + "description": "Paths, tags, operation IDs, webhooks, and schema properties to exclude." } }, "required": [] @@ -173,6 +173,13 @@ }, "description": "List of operation IDs to include or exclude." }, + "webhooks": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of webhook names to include or exclude." + }, "schema-properties": { "type": "object", "description": "Mapping of schema names to property names to include or exclude.", diff --git a/docs/configuration.md b/docs/configuration.md index 9de03a7b..e38f634c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -427,6 +427,25 @@ filter: - deleteUser ``` +### Filter by Webhooks + +Filter webhook entries by name. This is useful for OpenAPI 3.1 specs that define [webhooks](https://spec.openapis.org/oas/v3.1.0#fixed-fields) - +schemas referenced by included webhooks are preserved during pruning. + +```yaml +filter: + include: + webhooks: + - order.created + - payment.completed + exclude: + webhooks: + - internal.debug +``` + +Webhook operations also respect tag and operation ID filters - if a webhook's operation doesn't match the tag/operation ID filter, +it will be removed just like path operations. + ### Filter by Schema Properties Filter which properties are included in generated types. This triggers **transitive pruning** - schemas that are only referenced by filtered-out properties will also be pruned. diff --git a/examples/webhooks/ex1/api.yaml b/examples/webhooks/ex1/api.yaml new file mode 100644 index 00000000..84a90080 --- /dev/null +++ b/examples/webhooks/ex1/api.yaml @@ -0,0 +1,140 @@ +openapi: 3.1.0 +info: + version: '1.0' + title: Webhook Events API + description: > + Example spec demonstrating webhook support. + Defines inbound webhook callbacks for payment and order events. +webhooks: + payment.authorized: + post: + summary: Payment has been authorized + operationId: paymentAuthorized + tags: + - payments + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PaymentEvent' + responses: + '200': + description: Acknowledged + content: + application/json: + schema: + $ref: '#/components/schemas/WebhookAck' + order.shipped: + post: + summary: Order has been shipped + operationId: orderShipped + tags: + - orders + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/OrderEvent' + responses: + '200': + description: Acknowledged + content: + application/json: + schema: + $ref: '#/components/schemas/WebhookAck' + inventory.low: + post: + summary: Inventory level is low + operationId: inventoryLow + tags: + - inventory + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - sku + - remaining + properties: + sku: + type: string + remaining: + type: integer + responses: + '200': + description: Acknowledged + internal.debug: + post: + summary: Internal debug event (excluded from generation) + operationId: internalDebug + tags: + - internal + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DebugEvent' + responses: + '200': + description: Acknowledged +components: + schemas: + PaymentEvent: + type: object + required: + - id + - amount + - currency + properties: + id: + type: string + format: uuid + amount: + type: integer + description: Amount in minor units (e.g. cents) + currency: + type: string + minLength: 3 + maxLength: 3 + metadata: + $ref: '#/components/schemas/EventMetadata' + OrderEvent: + type: object + required: + - id + - trackingNumber + properties: + id: + type: string + format: uuid + trackingNumber: + type: string + carrier: + type: string + metadata: + $ref: '#/components/schemas/EventMetadata' + EventMetadata: + type: object + properties: + timestamp: + type: string + format: date-time + source: + type: string + WebhookAck: + type: object + properties: + received: + type: boolean + DebugEvent: + type: object + properties: + level: + type: string + message: + type: string diff --git a/examples/webhooks/ex1/cfg.yaml b/examples/webhooks/ex1/cfg.yaml new file mode 100644 index 00000000..3746b9f6 --- /dev/null +++ b/examples/webhooks/ex1/cfg.yaml @@ -0,0 +1,8 @@ +# yaml-language-server: $schema=../../../configuration-schema.json +package: gen +output: + use-single-file: false +filter: + exclude: + webhooks: + - internal.debug diff --git a/examples/webhooks/ex1/gen/common.go b/examples/webhooks/ex1/gen/common.go new file mode 100644 index 00000000..a8814d2b --- /dev/null +++ b/examples/webhooks/ex1/gen/common.go @@ -0,0 +1,15 @@ +// Code generated by oapi-codegen. DO NOT EDIT. + +package gen + +import ( + "github.com/doordash-oss/oapi-codegen-dd/v3/pkg/runtime" + "github.com/go-playground/validator/v10" +) + +var typesValidator *validator.Validate + +func init() { + typesValidator = validator.New(validator.WithRequiredStructEnabled()) + runtime.RegisterCustomTypeFunc(typesValidator) +} diff --git a/examples/webhooks/ex1/gen/types.go b/examples/webhooks/ex1/gen/types.go new file mode 100644 index 00000000..ec42f9be --- /dev/null +++ b/examples/webhooks/ex1/gen/types.go @@ -0,0 +1,84 @@ +// Code generated by oapi-codegen. DO NOT EDIT. + +package gen + +import ( + "time" + + "github.com/doordash-oss/oapi-codegen-dd/v3/pkg/runtime" + "github.com/google/uuid" +) + +type PaymentEvent struct { + ID uuid.UUID `json:"id" validate:"required"` + + // Amount Amount in minor units (e.g. cents) + Amount int `json:"amount" validate:"required"` + Currency string `json:"currency" validate:"required,max=3,min=3"` + Metadata *EventMetadata `json:"metadata,omitempty"` +} + +func (p PaymentEvent) Validate() error { + var errors runtime.ValidationErrors + if v, ok := any(p.ID).(runtime.Validator); ok { + if err := v.Validate(); err != nil { + errors = errors.Append("ID", err) + } + } + if err := typesValidator.Var(p.Amount, "required"); err != nil { + errors = errors.Append("Amount", err) + } + if err := typesValidator.Var(p.Currency, "required,max=3,min=3"); err != nil { + errors = errors.Append("Currency", err) + } + if p.Metadata != nil { + if v, ok := any(p.Metadata).(runtime.Validator); ok { + if err := v.Validate(); err != nil { + errors = errors.Append("Metadata", err) + } + } + } + if len(errors) == 0 { + return nil + } + return errors +} + +type OrderEvent struct { + ID uuid.UUID `json:"id" validate:"required"` + TrackingNumber string `json:"trackingNumber" validate:"required"` + Carrier *string `json:"carrier,omitempty"` + Metadata *EventMetadata `json:"metadata,omitempty"` +} + +func (o OrderEvent) Validate() error { + var errors runtime.ValidationErrors + if v, ok := any(o.ID).(runtime.Validator); ok { + if err := v.Validate(); err != nil { + errors = errors.Append("ID", err) + } + } + if err := typesValidator.Var(o.TrackingNumber, "required"); err != nil { + errors = errors.Append("TrackingNumber", err) + } + if o.Metadata != nil { + if v, ok := any(o.Metadata).(runtime.Validator); ok { + if err := v.Validate(); err != nil { + errors = errors.Append("Metadata", err) + } + } + } + if len(errors) == 0 { + return nil + } + return errors +} + +type EventMetadata struct { + Timestamp *time.Time `json:"timestamp,omitempty"` + Source *string `json:"source,omitempty"` +} + +type WebhookAck struct { + Received *bool `json:"received,omitempty"` +} diff --git a/examples/webhooks/ex1/gen/webhooks.go b/examples/webhooks/ex1/gen/webhooks.go new file mode 100644 index 00000000..9b21b1bb --- /dev/null +++ b/examples/webhooks/ex1/gen/webhooks.go @@ -0,0 +1,16 @@ +// Code generated by oapi-codegen. DO NOT EDIT. + +package gen + +type PaymentAuthorizedBody = PaymentEvent + +type PaymentAuthorizedResponse = WebhookAck + +type OrderShippedBody = OrderEvent + +type OrderShippedResponse = WebhookAck + +type InventoryLowBody struct { + Sku string `json:"sku" validate:"required"` + Remaining int `json:"remaining" validate:"required"` +} diff --git a/examples/webhooks/ex1/generate.go b/examples/webhooks/ex1/generate.go new file mode 100644 index 00000000..ede7b3d6 --- /dev/null +++ b/examples/webhooks/ex1/generate.go @@ -0,0 +1,3 @@ +package gen + +//go:generate go run github.com/doordash-oss/oapi-codegen-dd/v3/cmd/oapi-codegen -config cfg.yaml api.yaml diff --git a/pkg/codegen/codegen.go b/pkg/codegen/codegen.go index 860ff635..8311c2f0 100644 --- a/pkg/codegen/codegen.go +++ b/pkg/codegen/codegen.go @@ -130,6 +130,14 @@ func CreateParseContextFromModel(model *v3high.Document, cfg Configuration) (*Pa responseErrors = opColl.responseErrors } + // Collect webhook type definitions (request body and response types) + webhookTypeDefs, webhookSchemas, err := collectWebhookDefinitions(model, parseOptions) + if err != nil { + return nil, fmt.Errorf("error collecting webhook definitions: %w", err) + } + typeDefs = append(typeDefs, webhookTypeDefs...) + importSchemas = append(importSchemas, webhookSchemas...) + // Collect Schemas from components for _, componentDef := range typeDefs { importSchemas = append(importSchemas, componentDef.Schema) @@ -527,3 +535,51 @@ func collectResponseErrors(errNames []string, tracker *TypeTracker) ([]string, e return res, nil } + +// collectWebhookDefinitions walks webhook operations and collects type definitions +// for request bodies and responses. Types are tagged with SpecLocationWebhook +// so they appear in a separate "webhooks" output file in multi-file mode. +func collectWebhookDefinitions(model *v3high.Document, options ParseOptions) ([]TypeDefinition, []GoSchema, error) { + if model.Webhooks == nil { + return nil, nil, nil + } + + var ( + typeDefs []TypeDefinition + importSchemas []GoSchema + ) + + for webhookName, pathItem := range model.Webhooks.FromOldest() { + for method, operation := range pathItem.GetOperations().FromOldest() { + operationID, err := createOperationID(method, "/webhooks/"+webhookName, operation.OperationId) + if err != nil { + return nil, nil, fmt.Errorf("error creating webhook operation ID: %w", err) + } + + // Process Request Body + _, bodyTypeDef, err := createBodyDefinition(operationID, operation.RequestBody, options) + if err != nil { + return nil, nil, fmt.Errorf("error generating webhook body definitions for %s: %w", webhookName, err) + } + if bodyTypeDef != nil { + bodyTypeDef.SpecLocation = SpecLocationWebhook + typeDefs = append(typeDefs, *bodyTypeDef) + importSchemas = append(importSchemas, bodyTypeDef.Schema) + } + + // Process Responses + _, responseTypes, err := getOperationResponses(operationID, operation.Responses, options) + if err != nil { + return nil, nil, fmt.Errorf("error getting webhook response types for %s: %w", webhookName, err) + } + for i := range responseTypes { + responseTypes[i].SpecLocation = SpecLocationWebhook + importSchemas = append(importSchemas, responseTypes[i].Schema) + } + typeDefs = append(typeDefs, responseTypes...) + } + } + + allTypeDefs := extractAllTypeDefinitions(typeDefs) + return allTypeDefs, importSchemas, nil +} diff --git a/pkg/codegen/codegen_test.go b/pkg/codegen/codegen_test.go index db442f15..c2a7742b 100644 --- a/pkg/codegen/codegen_test.go +++ b/pkg/codegen/codegen_test.go @@ -555,3 +555,73 @@ func TestExternalFileRefResolution(t *testing.T) { assert.Contains(t, combined, "City") }) } + +func TestCollectWebhookDefinitions(t *testing.T) { + t.Run("nil webhooks returns nothing", func(t *testing.T) { + contents, err := os.ReadFile("testdata/prune-cat-dog.yml") + require.NoError(t, err) + + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + opts := ParseOptions{typeTracker: newTypeTracker(), visited: map[string]bool{}, model: &model.Model} + typeDefs, schemas, err := collectWebhookDefinitions(&model.Model, opts) + require.NoError(t, err) + assert.Empty(t, typeDefs) + assert.Empty(t, schemas) + }) + + t.Run("ref webhooks produce alias types", func(t *testing.T) { + contents, err := os.ReadFile("testdata/webhooks-with-examples.yml") + require.NoError(t, err) + + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + opts := ParseOptions{typeTracker: newTypeTracker(), visited: map[string]bool{}, model: &model.Model} + typeDefs, schemas, err := collectWebhookDefinitions(&model.Model, opts) + require.NoError(t, err) + + assert.NotEmpty(t, typeDefs) + assert.NotEmpty(t, schemas) + + // All types should be tagged as webhook + for _, td := range typeDefs { + assert.Equal(t, SpecLocationWebhook, td.SpecLocation, "type %s should have webhook SpecLocation", td.Name) + } + + // Should have body and response aliases + names := make(map[string]bool) + for _, td := range typeDefs { + names[td.Name] = true + } + assert.True(t, names["PaymentCreatedBody"], "should have PaymentCreatedBody") + assert.True(t, names["PaymentCreatedResponse"], "should have PaymentCreatedResponse") + }) + + t.Run("inline webhook body produces struct type", func(t *testing.T) { + contents, err := os.ReadFile("testdata/filter-webhooks.yml") + require.NoError(t, err) + + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + opts := ParseOptions{typeTracker: newTypeTracker(), visited: map[string]bool{}, model: &model.Model} + typeDefs, _, err := collectWebhookDefinitions(&model.Model, opts) + require.NoError(t, err) + + // All types should be tagged as webhook + for _, td := range typeDefs { + assert.Equal(t, SpecLocationWebhook, td.SpecLocation, "type %s should have webhook SpecLocation", td.Name) + } + }) +} diff --git a/pkg/codegen/configuration.go b/pkg/codegen/configuration.go index 3d220f2b..d631b355 100644 --- a/pkg/codegen/configuration.go +++ b/pkg/codegen/configuration.go @@ -271,6 +271,7 @@ type FilterParamsConfig struct { Paths []string `yaml:"paths"` Tags []string `yaml:"tags"` OperationIDs []string `yaml:"operation-ids"` + Webhooks []string `yaml:"webhooks"` SchemaProperties map[string][]string `yaml:"schema-properties"` Extensions []string `yaml:"extensions"` } @@ -280,6 +281,7 @@ func (o FilterParamsConfig) IsEmpty() bool { return len(o.Paths) == 0 && len(o.Tags) == 0 && len(o.OperationIDs) == 0 && + len(o.Webhooks) == 0 && len(o.SchemaProperties) == 0 && len(o.Extensions) == 0 } diff --git a/pkg/codegen/filter.go b/pkg/codegen/filter.go index 3f0ea11b..c9ebf5fb 100644 --- a/pkg/codegen/filter.go +++ b/pkg/codegen/filter.go @@ -63,59 +63,91 @@ func filterOperations(model *v3high.Document, cfg FilterConfig) bool { continue } - for method, op := range pathItem.GetOperations().FromOldest() { - remove := false + removed = filterPathItemOperations(pathItem, cfg, removed) + } - // Tags - for _, tag := range op.Tags { - if slices.Contains(cfg.Exclude.Tags, tag) { - remove = true - break - } + // Filter webhooks + if model.Webhooks != nil { + webhooks := map[string]*v3high.PathItem{} + for name, pathItem := range model.Webhooks.FromOldest() { + webhooks[name] = pathItem + } + + for name, pathItem := range webhooks { + if len(cfg.Include.Webhooks) > 0 && !slices.Contains(cfg.Include.Webhooks, name) { + model.Webhooks.Delete(name) + removed = true + continue } - if !remove && len(cfg.Include.Tags) > 0 { - // Only include if it matches Include.Tags - includeMatch := false - for _, tag := range op.Tags { - if slices.Contains(cfg.Include.Tags, tag) { - includeMatch = true - break - } - } - if !includeMatch { - remove = true - } + if len(cfg.Exclude.Webhooks) > 0 && slices.Contains(cfg.Exclude.Webhooks, name) { + model.Webhooks.Delete(name) + removed = true + continue } - // OperationIDs - if len(cfg.Exclude.OperationIDs) > 0 && slices.Contains(cfg.Exclude.OperationIDs, op.OperationId) { + removed = filterPathItemOperations(pathItem, cfg, removed) + } + } + + return removed +} + +// filterPathItemOperations filters operations within a PathItem by tags and operation IDs. +// Returns the updated removed flag. +func filterPathItemOperations(pathItem *v3high.PathItem, cfg FilterConfig, removed bool) bool { + for method, op := range pathItem.GetOperations().FromOldest() { + remove := false + + // Tags + for _, tag := range op.Tags { + if slices.Contains(cfg.Exclude.Tags, tag) { remove = true + break + } + } + + if !remove && len(cfg.Include.Tags) > 0 { + // Only include if it matches Include.Tags + includeMatch := false + for _, tag := range op.Tags { + if slices.Contains(cfg.Include.Tags, tag) { + includeMatch = true + break + } } - if len(cfg.Include.OperationIDs) > 0 && !slices.Contains(cfg.Include.OperationIDs, op.OperationId) { + if !includeMatch { remove = true } + } - if remove { - removed = true - switch strings.ToLower(method) { - case "get": - pathItem.Get = nil - case "post": - pathItem.Post = nil - case "put": - pathItem.Put = nil - case "delete": - pathItem.Delete = nil - case "patch": - pathItem.Patch = nil - case "head": - pathItem.Head = nil - case "options": - pathItem.Options = nil - case "trace": - pathItem.Trace = nil - } + // OperationIDs + if len(cfg.Exclude.OperationIDs) > 0 && slices.Contains(cfg.Exclude.OperationIDs, op.OperationId) { + remove = true + } + if len(cfg.Include.OperationIDs) > 0 && !slices.Contains(cfg.Include.OperationIDs, op.OperationId) { + remove = true + } + + if remove { + removed = true + switch strings.ToLower(method) { + case "get": + pathItem.Get = nil + case "post": + pathItem.Post = nil + case "put": + pathItem.Put = nil + case "delete": + pathItem.Delete = nil + case "patch": + pathItem.Patch = nil + case "head": + pathItem.Head = nil + case "options": + pathItem.Options = nil + case "trace": + pathItem.Trace = nil } } } diff --git a/pkg/codegen/filter_test.go b/pkg/codegen/filter_test.go index c156e24b..8f8850dc 100644 --- a/pkg/codegen/filter_test.go +++ b/pkg/codegen/filter_test.go @@ -11,6 +11,7 @@ package codegen import ( + "os" "testing" "github.com/stretchr/testify/assert" @@ -344,3 +345,175 @@ func TestFilterOperationsByPath(t *testing.T) { assert.Contains(t, combined, `"/enum"`) }) } + +func TestFilterWebhooks(t *testing.T) { + contents, err := os.ReadFile("testdata/filter-webhooks.yml") + require.NoError(t, err) + + t.Run("include webhooks", func(t *testing.T) { + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + cfg := FilterConfig{ + Include: FilterParamsConfig{ + Webhooks: []string{"order.created"}, + }, + } + + removed := filterOperations(&model.Model, cfg) + assert.True(t, removed) + + // order.created should remain + _, hasOrder := model.Model.Webhooks.Get("order.created") + assert.True(t, hasOrder, "order.created webhook should be included") + + // payment.completed should be removed + _, hasPayment := model.Model.Webhooks.Get("payment.completed") + assert.False(t, hasPayment, "payment.completed webhook should be excluded") + }) + + t.Run("exclude webhooks", func(t *testing.T) { + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + cfg := FilterConfig{ + Exclude: FilterParamsConfig{ + Webhooks: []string{"payment.completed"}, + }, + } + + removed := filterOperations(&model.Model, cfg) + assert.True(t, removed) + + // order.created should remain + _, hasOrder := model.Model.Webhooks.Get("order.created") + assert.True(t, hasOrder, "order.created webhook should remain") + + // payment.completed should be removed + _, hasPayment := model.Model.Webhooks.Get("payment.completed") + assert.False(t, hasPayment, "payment.completed webhook should be excluded") + }) + + t.Run("empty webhook filter does not remove webhooks", func(t *testing.T) { + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + cfg := FilterConfig{} + removed := filterOperations(&model.Model, cfg) + assert.False(t, removed) + + // Both webhooks should remain + _, hasOrder := model.Model.Webhooks.Get("order.created") + assert.True(t, hasOrder) + _, hasPayment := model.Model.Webhooks.Get("payment.completed") + assert.True(t, hasPayment) + }) + + t.Run("filter webhook operations by tag", func(t *testing.T) { + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + cfg := FilterConfig{ + Include: FilterParamsConfig{ + Tags: []string{"orders"}, + }, + } + + removed := filterOperations(&model.Model, cfg) + assert.True(t, removed) + + // order.created should still have its post operation + orderItem, hasOrder := model.Model.Webhooks.Get("order.created") + assert.True(t, hasOrder) + assert.NotNil(t, orderItem.Post) + + // payment.completed should have its post operation nil'd out + paymentItem, hasPayment := model.Model.Webhooks.Get("payment.completed") + assert.True(t, hasPayment) + assert.Nil(t, paymentItem.Post) + }) + + t.Run("filter webhook operations by operationID", func(t *testing.T) { + doc, err := LoadDocumentFromContents(contents) + require.NoError(t, err) + + model, err := doc.BuildV3Model() + require.NoError(t, err) + + cfg := FilterConfig{ + Exclude: FilterParamsConfig{ + OperationIDs: []string{"paymentCompleted"}, + }, + } + + removed := filterOperations(&model.Model, cfg) + assert.True(t, removed) + + // order.created post should remain + orderItem, _ := model.Model.Webhooks.Get("order.created") + assert.NotNil(t, orderItem.Post) + + // payment.completed post should be nil'd out + paymentItem, _ := model.Model.Webhooks.Get("payment.completed") + assert.Nil(t, paymentItem.Post) + }) +} + +func TestGenerateWebhookTypes(t *testing.T) { + contents, err := os.ReadFile("testdata/webhooks-with-examples.yml") + require.NoError(t, err) + + cfg := Configuration{ + PackageName: "webhooktest", + Output: &Output{ + UseSingleFile: true, + }, + } + + code, err := Generate(contents, cfg) + require.NoError(t, err) + assert.NotEmpty(t, code) + + combined := code.GetCombined() + assert.Contains(t, combined, "PaymentRequest") + assert.Contains(t, combined, "Response") +} + +func TestGenerateWebhookTypesMultiFile(t *testing.T) { + contents, err := os.ReadFile("testdata/webhooks-with-examples.yml") + require.NoError(t, err) + + cfg := Configuration{ + PackageName: "webhooktest", + Output: &Output{ + UseSingleFile: false, + }, + } + + code, err := Generate(contents, cfg) + require.NoError(t, err) + assert.NotEmpty(t, code) + + // Webhook types should be in the "webhooks" output file + webhooksCode, hasWebhooks := code["webhooks"] + assert.True(t, hasWebhooks, "should have a 'webhooks' output file for webhook types") + assert.Contains(t, webhooksCode, "PaymentCreatedBody") + assert.Contains(t, webhooksCode, "PaymentCreatedResponse") + + // Component schemas should be in the "types" output file + typesCode, hasTypes := code["types"] + assert.True(t, hasTypes, "should have a 'types' output file for component schemas") + assert.Contains(t, typesCode, "PaymentRequest") +} diff --git a/pkg/codegen/parser.go b/pkg/codegen/parser.go index cbb2ca12..ab3b7150 100644 --- a/pkg/codegen/parser.go +++ b/pkg/codegen/parser.go @@ -681,6 +681,8 @@ func getSpecLocationOutName(specLocation SpecLocation) string { return "types" case SpecLocationUnion: return "unions" + case SpecLocationWebhook: + return "webhooks" default: return string(specLocation) } diff --git a/pkg/codegen/prune.go b/pkg/codegen/prune.go index ef708ac4..72dae6e7 100644 --- a/pkg/codegen/prune.go +++ b/pkg/codegen/prune.go @@ -22,8 +22,7 @@ import ( func pruneSchema(model *v3high.Document) error { // Aggressively remove everything we don't generate code for - slog.Debug("Pruning: removing webhooks, security schemes, callbacks, component examples, links") - model.Webhooks = nil + slog.Debug("Pruning: removing security schemes, callbacks, component examples, links") if model.Components != nil { // Set to nil - we don't generate code for these model.Components.SecuritySchemes = nil @@ -121,38 +120,17 @@ func removeOrphanedComponents(model *v3high.Document, refs map[string]bool) int func findOperationRefs(model *v3high.Document) map[string]bool { refSet := make(map[string]bool) - if model.Paths == nil || model.Paths.PathItems == nil { - return refSet - } - - // Walk all operations and collect refs - for _, pathItem := range model.Paths.PathItems.FromOldest() { - // Collect path-level parameters - for _, param := range pathItem.Parameters { - collectRefFromProxy(param, refSet, model) + if model.Paths != nil && model.Paths.PathItems != nil { + // Walk all operations and collect refs + for _, pathItem := range model.Paths.PathItems.FromOldest() { + collectPathItemRefs(pathItem, refSet, model) } + } - // Collect operation-level refs - for _, op := range pathItem.GetOperations().FromOldest() { - // Request body - if op.RequestBody != nil { - collectRefFromProxy(op.RequestBody, refSet, model) - } - - // Parameters - for _, param := range op.Parameters { - collectRefFromProxy(param, refSet, model) - } - - // Responses - if op.Responses != nil { - if op.Responses.Default != nil { - collectRefFromProxy(op.Responses.Default, refSet, model) - } - for _, resp := range op.Responses.Codes.FromOldest() { - collectRefFromProxy(resp, refSet, model) - } - } + // Walk webhook path items and collect refs (structurally identical to paths) + if model.Webhooks != nil { + for _, pathItem := range model.Webhooks.FromOldest() { + collectPathItemRefs(pathItem, refSet, model) } } @@ -215,6 +193,37 @@ func findOperationRefs(model *v3high.Document) map[string]bool { return refSet } +// collectPathItemRefs collects all refs from a PathItem's parameters and operations. +func collectPathItemRefs(pathItem *v3high.PathItem, refSet map[string]bool, model *v3high.Document) { + // Collect path-level parameters + for _, param := range pathItem.Parameters { + collectRefFromProxy(param, refSet, model) + } + + // Collect operation-level refs + for _, op := range pathItem.GetOperations().FromOldest() { + // Request body + if op.RequestBody != nil { + collectRefFromProxy(op.RequestBody, refSet, model) + } + + // Parameters + for _, param := range op.Parameters { + collectRefFromProxy(param, refSet, model) + } + + // Responses + if op.Responses != nil { + if op.Responses.Default != nil { + collectRefFromProxy(op.Responses.Default, refSet, model) + } + for _, resp := range op.Responses.Codes.FromOldest() { + collectRefFromProxy(resp, refSet, model) + } + } + } +} + // addParentSchemaRef adds the parent schema reference if the given ref is a property reference // e.g., if ref is "#/components/schemas/Foo/properties/bar", also add "#/components/schemas/Foo" func addParentSchemaRef(ref string, refSet map[string]bool) { diff --git a/pkg/codegen/prune_test.go b/pkg/codegen/prune_test.go index 28f32a9f..141e5179 100644 --- a/pkg/codegen/prune_test.go +++ b/pkg/codegen/prune_test.go @@ -472,7 +472,7 @@ func TestPruneExamples(t *testing.T) { assert.NotNil(t, header.Example) }) - t.Run("webhooks removed during pruning", func(t *testing.T) { + t.Run("webhook schema refs collected by findOperationRefs", func(t *testing.T) { contents, err := os.ReadFile("testdata/webhooks-with-examples.yml") assert.NoError(t, err) @@ -482,6 +482,25 @@ func TestPruneExamples(t *testing.T) { model, err := doc.BuildV3Model() assert.NoError(t, err) + refs := findOperationRefs(&model.Model) + assert.True(t, refs["#/components/schemas/PaymentRequest"], "PaymentRequest should be in refs from webhook") + assert.True(t, refs["#/components/schemas/Response"], "Response should be in refs from webhook") + }) + + t.Run("webhooks preserved during pruning", func(t *testing.T) { + contents, err := os.ReadFile("testdata/webhooks-with-examples.yml") + assert.NoError(t, err) + + doc, err := LoadDocumentFromContents(contents) + assert.NoError(t, err) + + model, err := doc.BuildV3Model() + assert.NoError(t, err) + + // Before pruning: should have webhooks and schemas + assert.NotNil(t, model.Model.Webhooks) + assert.Equal(t, 2, model.Model.Components.Schemas.Len()) + // Prune the document err = pruneSchema(&model.Model) assert.NoError(t, err) @@ -489,7 +508,16 @@ func TestPruneExamples(t *testing.T) { // components/examples should be removed (set to nil) assert.Nil(t, model.Model.Components.Examples) - // webhooks should be removed - assert.Nil(t, model.Model.Webhooks) + // webhooks should be preserved + assert.NotNil(t, model.Model.Webhooks) + + // schemas referenced by webhooks should be preserved + assert.Equal(t, 2, model.Model.Components.Schemas.Len()) + + _, hasPaymentRequest := model.Model.Components.Schemas.Get("PaymentRequest") + assert.True(t, hasPaymentRequest, "PaymentRequest schema should be preserved - referenced by webhook") + + _, hasResponse := model.Model.Components.Schemas.Get("Response") + assert.True(t, hasResponse, "Response schema should be preserved - referenced by webhook") }) } diff --git a/pkg/codegen/testdata/filter-webhooks.yml b/pkg/codegen/testdata/filter-webhooks.yml new file mode 100644 index 00000000..5828d99f --- /dev/null +++ b/pkg/codegen/testdata/filter-webhooks.yml @@ -0,0 +1,54 @@ +openapi: 3.1.0 +info: + version: '1' + title: Filter webhooks test +webhooks: + order.created: + post: + summary: Order created + operationId: orderCreated + tags: + - orders + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OrderEvent' + responses: + '200': + description: OK + payment.completed: + post: + summary: Payment completed + operationId: paymentCompleted + tags: + - payments + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PaymentEvent' + responses: + '200': + description: OK +components: + schemas: + OrderEvent: + type: object + properties: + orderId: + type: string + amount: + $ref: '#/components/schemas/Money' + PaymentEvent: + type: object + properties: + paymentId: + type: string + Money: + type: object + properties: + currency: + type: string + value: + type: integer diff --git a/pkg/codegen/typedef.go b/pkg/codegen/typedef.go index 96e10acd..e8f9c772 100644 --- a/pkg/codegen/typedef.go +++ b/pkg/codegen/typedef.go @@ -25,6 +25,7 @@ const ( SpecLocationResponse SpecLocation = "response" SpecLocationSchema SpecLocation = "schema" SpecLocationUnion SpecLocation = "union" + SpecLocationWebhook SpecLocation = "webhook" ) // TypeDefinition describes a Go type definition in generated code.