diff --git a/utils/component/openapi_generator.go b/utils/component/openapi_generator.go index ce546d52..7a9e2fdf 100644 --- a/utils/component/openapi_generator.go +++ b/utils/component/openapi_generator.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "reflect" "strings" "cuelang.org/go/cue" @@ -196,10 +197,7 @@ func getResolvedManifest(manifest string) (string, error) { if doc.Components == nil || len(doc.Components.Schemas) == 0 { return "", ErrNoSchemasFound } - stack := make(map[*openapi3.Schema]bool) - for _, schemaRef := range doc.Components.Schemas { - clearSchemaRefs(schemaRef, stack) - } + clearDocRefs(doc) resolved, err := json.Marshal(doc) if err != nil { return "", err @@ -207,44 +205,92 @@ func getResolvedManifest(manifest string) (string, error) { return string(resolved), nil } -// clearSchemaRefs recursively clears $ref strings on all nested SchemaRefs -// so that json.Marshal outputs fully inlined schemas. The stack set tracks -// Schema values (not SchemaRef pointers) on the current recursion path to -// detect circular references. kin-openapi resolves $refs by creating -// different SchemaRef objects that share the same underlying Schema pointer, -// so tracking by *Schema is necessary to catch all cycles. -func clearSchemaRefs(sr *openapi3.SchemaRef, stack map[*openapi3.Schema]bool) { - if sr == nil { - return - } - sr.Ref = "" - s := sr.Value - if s == nil { - return - } - if stack[s] { - sr.Value = &openapi3.Schema{} - return - } - stack[s] = true - for _, child := range s.AllOf { - clearSchemaRefs(child, stack) - } - for _, child := range s.AnyOf { - clearSchemaRefs(child, stack) - } - for _, child := range s.OneOf { - clearSchemaRefs(child, stack) - } - clearSchemaRefs(s.Not, stack) - if s.Items != nil { - clearSchemaRefs(s.Items, stack) - } - for _, prop := range s.Properties { - clearSchemaRefs(prop, stack) - } - if s.AdditionalProperties.Schema != nil { - clearSchemaRefs(s.AdditionalProperties.Schema, stack) +// clearDocRefs uses reflection to walk the entire OpenAPI document and clear +// all $ref strings so that json.Marshal outputs fully inlined schemas. +// It uses two tracking mechanisms: +// - visited: permanent set for general pointers to avoid re-processing +// - schemaStack: path-based set for *Schema pointers to detect circular +// schema references (add on enter, remove on exit), allowing the same +// schema to appear in multiple non-circular positions +func clearDocRefs(doc *openapi3.T) { + visited := make(map[uintptr]bool) + schemaStack := make(map[uintptr]bool) + walkAndClearRefs(reflect.ValueOf(doc), visited, schemaStack) +} + +var schemaRefType = reflect.TypeOf((*openapi3.SchemaRef)(nil)) + +func walkAndClearRefs(v reflect.Value, visited map[uintptr]bool, schemaStack map[uintptr]bool) { + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return + } + + // SchemaRef needs path-based cycle detection so shared (non-circular) + // schemas are fully expanded while true cycles are broken. + if v.Type() == schemaRefType { + sr := v.Interface().(*openapi3.SchemaRef) + sr.Ref = "" + if sr.Value == nil { + return + } + schemaPtr := reflect.ValueOf(sr.Value).Pointer() + if schemaStack[schemaPtr] { + sr.Value = &openapi3.Schema{} + return + } + schemaStack[schemaPtr] = true + walkAndClearRefs(reflect.ValueOf(sr.Value), visited, schemaStack) + delete(schemaStack, schemaPtr) + return + } + + ptr := v.Pointer() + if visited[ptr] { + return + } + visited[ptr] = true + + elem := v.Elem() + if elem.Kind() == reflect.Struct { + if refField := elem.FieldByName("Ref"); refField.IsValid() && refField.Kind() == reflect.String { + refField.SetString("") + } + } + walkAndClearRefs(elem, visited, schemaStack) + + case reflect.Struct: + // Handle types with unexported map fields (Paths, Callback, Responses) + // accessed via a Map() method. + if v.CanAddr() { + if mapMethod := v.Addr().MethodByName("Map"); mapMethod.IsValid() { + results := mapMethod.Call(nil) + if len(results) == 1 && results[0].Kind() == reflect.Map { + walkAndClearRefs(results[0], visited, schemaStack) + } + } + } + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.CanInterface() { + walkAndClearRefs(field, visited, schemaStack) + } + } + + case reflect.Map: + for _, key := range v.MapKeys() { + walkAndClearRefs(v.MapIndex(key), visited, schemaStack) + } + + case reflect.Slice: + for i := 0; i < v.Len(); i++ { + walkAndClearRefs(v.Index(i), visited, schemaStack) + } + + case reflect.Interface: + if !v.IsNil() { + walkAndClearRefs(v.Elem(), visited, schemaStack) + } } - delete(stack, s) } diff --git a/utils/component/openapi_generator_test.go b/utils/component/openapi_generator_test.go index a5afae36..b9d0c220 100644 --- a/utils/component/openapi_generator_test.go +++ b/utils/component/openapi_generator_test.go @@ -2,6 +2,7 @@ package component import ( "encoding/json" + "reflect" "strings" "testing" @@ -305,7 +306,7 @@ func TestGetResolvedManifest_AllOf(t *testing.T) { } } -func TestClearSchemaRefs(t *testing.T) { +func TestWalkAndClearRefs(t *testing.T) { tests := []struct { name string sr *openapi3.SchemaRef @@ -322,8 +323,9 @@ func TestClearSchemaRefs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - visited := make(map[*openapi3.Schema]bool) - clearSchemaRefs(tt.sr, visited) + visited := make(map[uintptr]bool) + schemaStack := make(map[uintptr]bool) + walkAndClearRefs(reflect.ValueOf(tt.sr), visited, schemaStack) if tt.sr != nil && tt.sr.Ref != "" { t.Errorf("Ref = %q, want empty", tt.sr.Ref) } @@ -331,15 +333,16 @@ func TestClearSchemaRefs(t *testing.T) { } } -func TestClearSchemaRefs_Circular(t *testing.T) { +func TestWalkAndClearRefs_Circular(t *testing.T) { // Build a circular reference: A -> B -> A a := &openapi3.SchemaRef{Ref: "#/components/schemas/A", Value: &openapi3.Schema{}} b := &openapi3.SchemaRef{Ref: "#/components/schemas/B", Value: &openapi3.Schema{}} a.Value.Properties = openapi3.Schemas{"b": b} b.Value.Properties = openapi3.Schemas{"a": a} - stack := make(map[*openapi3.Schema]bool) - clearSchemaRefs(a, stack) // must not hang or panic + visited := make(map[uintptr]bool) + schemaStack := make(map[uintptr]bool) + walkAndClearRefs(reflect.ValueOf(a), visited, schemaStack) // must not hang or panic if a.Ref != "" { t.Errorf("a.Ref = %q, want empty", a.Ref) @@ -354,15 +357,16 @@ func TestClearSchemaRefs_Circular(t *testing.T) { } } -func TestClearSchemaRefs_SelfReference(t *testing.T) { +func TestWalkAndClearRefs_SelfReference(t *testing.T) { // Schema that references itself (like JSONSchemaProps). self := &openapi3.SchemaRef{Value: &openapi3.Schema{ Type: &openapi3.Types{"object"}, }} self.Value.Properties = openapi3.Schemas{"nested": self} - stack := make(map[*openapi3.Schema]bool) - clearSchemaRefs(self, stack) + visited := make(map[uintptr]bool) + schemaStack := make(map[uintptr]bool) + walkAndClearRefs(reflect.ValueOf(self), visited, schemaStack) // The self-referencing property should be replaced, breaking the cycle. if _, err := json.Marshal(self); err != nil { @@ -370,6 +374,185 @@ func TestClearSchemaRefs_SelfReference(t *testing.T) { } } +func TestGetResolvedManifest_PathsAndComponents(t *testing.T) { + input := `{ + "openapi": "3.0.0", + "info": {"title": "test", "version": "1.0"}, + "paths": { + "/pets": { + "get": { + "parameters": [ + {"$ref": "#/components/parameters/LimitParam"} + ], + "responses": { + "200": { + "$ref": "#/components/responses/PetList" + } + } + }, + "post": { + "requestBody": { + "$ref": "#/components/requestBodies/PetBody" + }, + "responses": { + "201": { + "description": "created" + } + } + } + } + }, + "components": { + "schemas": { + "Pet": { + "type": "object", + "properties": { + "name": {"type": "string"} + } + } + }, + "parameters": { + "LimitParam": { + "name": "limit", + "in": "query", + "schema": {"type": "integer"} + } + }, + "requestBodies": { + "PetBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Pet"} + } + } + } + }, + "responses": { + "PetList": { + "description": "A list of pets", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": {"$ref": "#/components/schemas/Pet"} + } + } + } + } + } + } + }` + + out, err := getResolvedManifest(input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var parsed map[string]any + if err := json.Unmarshal([]byte(out), &parsed); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + // Verify parameter ref in path is resolved. + param := navigatePath(t, parsed, "paths./pets.get.parameters") + params, ok := param.([]any) + if !ok || len(params) == 0 { + t.Fatal("expected parameters array") + } + pm := params[0].(map[string]any) + if _, hasRef := pm["$ref"]; hasRef { + t.Error("parameter $ref still present in path") + } + if pm["name"] != "limit" { + t.Errorf("parameter name = %v, want limit", pm["name"]) + } + + // Verify response ref in path is resolved. + resp := navigatePath(t, parsed, "paths./pets.get.responses.200").(map[string]any) + if _, hasRef := resp["$ref"]; hasRef { + t.Error("response $ref still present in path") + } + if resp["description"] != "A list of pets" { + t.Errorf("response description = %v, want 'A list of pets'", resp["description"]) + } + + // Verify schema ref inside response content is resolved. + items := navigatePath(t, parsed, "paths./pets.get.responses.200.content.application/json.schema.items").(map[string]any) + if _, hasRef := items["$ref"]; hasRef { + t.Error("schema $ref in response items still present") + } + if items["type"] != "object" { + t.Errorf("items type = %v, want object", items["type"]) + } + + // Verify requestBody ref in path is resolved. + rb := navigatePath(t, parsed, "paths./pets.post.requestBody").(map[string]any) + if _, hasRef := rb["$ref"]; hasRef { + t.Error("requestBody $ref still present in path") + } + + // Verify schema ref inside requestBody content is resolved. + rbSchema := navigatePath(t, parsed, "paths./pets.post.requestBody.content.application/json.schema").(map[string]any) + if _, hasRef := rbSchema["$ref"]; hasRef { + t.Error("schema $ref in requestBody content still present") + } + if rbSchema["type"] != "object" { + t.Errorf("requestBody schema type = %v, want object", rbSchema["type"]) + } +} + +func TestGetResolvedManifest_HeaderRefs(t *testing.T) { + input := `{ + "openapi": "3.0.0", + "info": {"title": "test", "version": "1.0"}, + "paths": { + "/items": { + "get": { + "responses": { + "200": { + "description": "OK", + "headers": { + "X-Rate-Limit": { + "$ref": "#/components/headers/RateLimit" + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Placeholder": {"type": "string"} + }, + "headers": { + "RateLimit": { + "schema": {"type": "integer"} + } + } + } + }` + + out, err := getResolvedManifest(input) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var parsed map[string]any + if err := json.Unmarshal([]byte(out), &parsed); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + header := navigatePath(t, parsed, "paths./items.get.responses.200.headers.X-Rate-Limit").(map[string]any) + if _, hasRef := header["$ref"]; hasRef { + t.Error("header $ref still present") + } + schema := header["schema"].(map[string]any) + if schema["type"] != "integer" { + t.Errorf("header schema type = %v, want integer", schema["type"]) + } +} + // navigatePath walks a dot-separated path through nested maps. func navigatePath(t *testing.T, data map[string]any, path string) any { t.Helper()