Skip to content

Commit 4021283

Browse files
authored
Support cookie-backed API key security inference (#3910)
* Fix cookie-backed API key security inference * Deduplicate security-bound OpenAPI parameters * Address OpenAPI review feedback
1 parent 8df277a commit 4021283

13 files changed

Lines changed: 236 additions & 183 deletions

File tree

AGENTS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ No commented-out code—delete dead code.
5757

5858
- **Use NameScope helpers** for type references: `GoTypeRef`, `GoFullTypeRef`, `GoTypeName`. Never concatenate strings for types.
5959
- Let Goa decide pointer/value semantics. Do not force `pointer=true` except in transport validation.
60+
- **Keep helper visibility minimal**: If logic is shared only inside one codegen area, keep it package-private or move it under an `internal` package. Do not export helpers from a parent package just to share them across sibling generators.
61+
- **Avoid pass-through wrappers**: When two helper functions differ only by forwarding arguments or hard-coding `nil`, collapse them into a single implementation instead of adding an extra layer.
6062

6163
### Documentation
6264

codegen/sections_test.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,23 @@ import (
7171
package testpackage
7272
7373
`
74-
imprt = []*ImportSpec{{Path: "test"}}
75-
imports = append(imprt, &ImportSpec{Path: "other"})
76-
pathImport = []*ImportSpec{{Path: "import/with/slashes"}}
77-
pathImports = append(pathImport, &ImportSpec{Path: "other/import/with/slashes"})
78-
pathNamedImport = []*ImportSpec{{Name: "myname", Path: "import/with/slashes"}}
74+
imprt = func() []*ImportSpec {
75+
imports := make([]*ImportSpec, 0, 2)
76+
imports = append(imports, &ImportSpec{Path: "test"})
77+
return imports
78+
}()
79+
imports = append(imprt, &ImportSpec{Path: "other"})
80+
pathImport = func() []*ImportSpec {
81+
imports := make([]*ImportSpec, 0, 2)
82+
imports = append(imports, &ImportSpec{Path: "import/with/slashes"})
83+
return imports
84+
}()
85+
pathImports = append(pathImport, &ImportSpec{Path: "other/import/with/slashes"})
86+
pathNamedImport = func() []*ImportSpec {
87+
imports := make([]*ImportSpec, 0, 2)
88+
imports = append(imports, &ImportSpec{Name: "myname", Path: "import/with/slashes"})
89+
return imports
90+
}()
7991
pathNamedImports = append(pathNamedImport, &ImportSpec{Name: "myothername", Path: "other/import/with/slashes"})
8092
)
8193
cases := map[string]struct {

expr/helpers.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@ func Title(s string) string {
1212
}
1313

1414
// findKey finds the given key in the endpoint expression and returns the
15-
// transport element name and the position (header, query, or body for HTTP or
16-
// message, metadata for gRPC endpoint).
15+
// transport element name and the position (header, query, cookie, or body for
16+
// HTTP or message, metadata for gRPC endpoint).
1717
func findKey(exp eval.Expression, keyAtt string) (string, string) {
1818
switch e := exp.(type) {
1919
case *HTTPEndpointExpr:
2020
if n, exists := e.Params.FindKey(keyAtt); exists {
2121
return n, "query"
2222
} else if n, exists := e.Headers.FindKey(keyAtt); exists {
2323
return n, "header"
24+
} else if n, exists := e.Cookies.FindKey(keyAtt); exists {
25+
return n, "cookie"
2426
} else if e.Body == nil {
2527
return "", "header"
2628
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package codegen
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"path/filepath"
7+
"testing"
8+
"text/template"
9+
10+
"github.com/getkin/kin-openapi/openapi2"
11+
"github.com/getkin/kin-openapi/openapi3"
12+
"github.com/stretchr/testify/require"
13+
14+
"goa.design/goa/v3/codegen"
15+
"goa.design/goa/v3/dsl"
16+
"goa.design/goa/v3/expr"
17+
"goa.design/goa/v3/http/codegen/openapi"
18+
openapiv2 "goa.design/goa/v3/http/codegen/openapi/v2"
19+
openapiv3 "goa.design/goa/v3/http/codegen/openapi/v3"
20+
)
21+
22+
func TestCookieAPIKeySecurity(t *testing.T) {
23+
t.Run("endpoint requirement uses cookie transport", func(t *testing.T) {
24+
root := RunHTTPDSL(t, cookieAPIKeySecurityDSL)
25+
endpoint := root.API.HTTP.Services[0].HTTPEndpoints[0]
26+
require.Len(t, endpoint.Requirements, 1)
27+
require.Len(t, endpoint.Requirements[0].Schemes, 1)
28+
29+
scheme := endpoint.Requirements[0].Schemes[0]
30+
require.Equal(t, "cookie", scheme.In)
31+
require.Equal(t, "__Host-ak_session", scheme.Name)
32+
33+
headers := expr.AsObject(endpoint.Headers.Type)
34+
require.Zero(t, len(*headers), "cookie-backed api key must not synthesize an Authorization header")
35+
})
36+
37+
t.Run("openapi uses cookie security scheme", func(t *testing.T) {
38+
root := RunHTTPDSL(t, cookieAPIKeySecurityDSL)
39+
openapi.Definitions = make(map[string]*openapi.Schema)
40+
41+
v2JSON := renderOpenAPIJSON(t, openapiv2.Files, root)
42+
var swagger openapi2.T
43+
require.NoError(t, swagger.UnmarshalJSON(v2JSON))
44+
require.Len(t, swagger.SecurityDefinitions, 1)
45+
require.Len(t, swagger.Paths, 1)
46+
require.Contains(t, swagger.Paths, "/auth/profile")
47+
require.NotNil(t, swagger.Paths["/auth/profile"].Get.Security)
48+
require.Len(t, *swagger.Paths["/auth/profile"].Get.Security, 1)
49+
for name, def := range swagger.SecurityDefinitions {
50+
require.Equal(t, "apiKey", def.Type, name)
51+
require.Equal(t, "cookie", def.In, name)
52+
require.Equal(t, "__Host-ak_session", def.Name, name)
53+
require.Contains(t, (*swagger.Paths["/auth/profile"].Get.Security)[0], name)
54+
}
55+
56+
openapi.Definitions = make(map[string]*openapi.Schema)
57+
v3JSON := renderOpenAPIJSON(t, openapiv3.Files, root)
58+
loader := openapi3.NewLoader()
59+
doc, err := loader.LoadFromData(v3JSON)
60+
require.NoError(t, err)
61+
require.NoError(t, doc.Validate(context.Background()))
62+
require.Len(t, doc.Components.SecuritySchemes, 1)
63+
require.NotNil(t, doc.Paths.Find("/auth/profile"))
64+
require.NotNil(t, doc.Paths.Find("/auth/profile").Get.Security)
65+
require.Len(t, *doc.Paths.Find("/auth/profile").Get.Security, 1)
66+
for name, ref := range doc.Components.SecuritySchemes {
67+
require.NotNil(t, ref.Value, name)
68+
require.Equal(t, "apiKey", ref.Value.Type, name)
69+
require.Equal(t, "cookie", ref.Value.In, name)
70+
require.Equal(t, "__Host-ak_session", ref.Value.Name, name)
71+
require.Contains(t, (*doc.Paths.Find("/auth/profile").Get.Security)[0], name)
72+
}
73+
})
74+
75+
t.Run("http codegen does not duplicate cookie-backed auth fields", func(t *testing.T) {
76+
root := RunHTTPDSL(t, cookieAPIKeySecurityDSL)
77+
services := CreateHTTPServices(root)
78+
79+
serverTypes := serverType("gen", root.API.HTTP.Services[0], services)
80+
var serverTypesBuf bytes.Buffer
81+
for _, section := range serverTypes.SectionTemplates[1:] {
82+
require.NoError(t, section.Write(&serverTypesBuf))
83+
}
84+
serverTypesCode := codegen.FormatTestCode(t, "package foo\n"+serverTypesBuf.String())
85+
require.Contains(t, serverTypesCode, "func NewProfilePayload(browserSession string)")
86+
require.NotContains(t, serverTypesCode, "browserSession *string, browserSession *string")
87+
require.NotContains(t, serverTypesCode, "browserSession string, browserSession string")
88+
89+
serverFiles := ServerFiles("", services)
90+
require.Len(t, serverFiles, 2)
91+
serverDecode := codegen.SectionCode(t, serverFiles[1].SectionTemplates[2])
92+
require.Contains(t, serverDecode, `r.Cookie("__Host-ak_session")`)
93+
require.NotContains(t, serverDecode, "Authorization")
94+
require.NotContains(t, serverDecode, "browserSession *string, browserSession *string")
95+
require.NotContains(t, serverDecode, "browserSession string, browserSession string")
96+
97+
clientFiles := ClientFiles("", services)
98+
require.Len(t, clientFiles, 2)
99+
clientEncode := codegen.SectionCode(t, clientFiles[1].SectionTemplates[2])
100+
require.Contains(t, clientEncode, `req.AddCookie(&http.Cookie{`)
101+
require.Contains(t, clientEncode, `Name: "__Host-ak_session"`)
102+
require.NotContains(t, clientEncode, "Authorization")
103+
})
104+
}
105+
106+
func renderOpenAPIJSON(
107+
t *testing.T,
108+
build func(*expr.RootExpr) ([]*codegen.File, error),
109+
root *expr.RootExpr,
110+
) []byte {
111+
t.Helper()
112+
113+
files, err := build(root)
114+
require.NoError(t, err)
115+
for _, f := range files {
116+
if filepath.Ext(f.Path) != ".json" {
117+
continue
118+
}
119+
require.Len(t, f.SectionTemplates, 1)
120+
section := f.SectionTemplates[0]
121+
require.NotEmpty(t, section.Source)
122+
require.NotNil(t, section.Data)
123+
124+
var buf bytes.Buffer
125+
tmpl := template.Must(template.New("openapi").Funcs(section.FuncMap).Parse(section.Source))
126+
require.NoError(t, tmpl.Execute(&buf, section.Data))
127+
return buf.Bytes()
128+
}
129+
130+
t.Fatalf("no JSON OpenAPI file generated")
131+
return nil
132+
}
133+
134+
var cookieAPIKeySecurityDSL = func() {
135+
var browserSessionCookie = dsl.APIKeySecurity("browser_session_cookie", func() {
136+
dsl.Description("Browser session cookie")
137+
})
138+
139+
dsl.Service("cookieSecurity", func() {
140+
dsl.Method("profile", func() {
141+
dsl.Security(browserSessionCookie)
142+
dsl.Payload(func() {
143+
dsl.APIKey("browser_session_cookie", "browser_session", dsl.String, func() {
144+
dsl.Description("Opaque browser session cookie")
145+
})
146+
dsl.Required("browser_session")
147+
})
148+
dsl.Result(dsl.Empty)
149+
dsl.HTTP(func() {
150+
dsl.GET("/auth/profile")
151+
dsl.Cookie("browser_session:__Host-ak_session")
152+
dsl.Response(dsl.StatusOK)
153+
})
154+
})
155+
})
156+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package internal
2+
3+
import (
4+
"strings"
5+
6+
"goa.design/goa/v3/expr"
7+
)
8+
9+
// IsSecurityParameter returns true if the given HTTP transport element is used
10+
// by one of the endpoint security schemes and should therefore not be emitted
11+
// again as a regular OpenAPI parameter.
12+
func IsSecurityParameter(endpoint *expr.HTTPEndpointExpr, in, name string) bool {
13+
if endpoint == nil {
14+
return false
15+
}
16+
for _, req := range endpoint.Requirements {
17+
for _, scheme := range req.Schemes {
18+
if scheme.In != in {
19+
continue
20+
}
21+
if in == "header" {
22+
if strings.EqualFold(scheme.Name, name) {
23+
return true
24+
}
25+
continue
26+
}
27+
if scheme.Name == name {
28+
return true
29+
}
30+
}
31+
}
32+
return false
33+
}

http/codegen/openapi/v2/builder.go

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"goa.design/goa/v3/codegen"
1313
"goa.design/goa/v3/expr"
1414
"goa.design/goa/v3/http/codegen/openapi"
15+
openapiinternal "goa.design/goa/v3/http/codegen/openapi/internal"
1516
)
1617

1718
// NewV2 returns the OpenAPI v2 specification for the given API.
@@ -35,7 +36,7 @@ func NewV2(root *expr.RootExpr, h *expr.HostExpr) (*V2, error) {
3536
if hasAbsoluteRoutes(root) {
3637
basePath = ""
3738
}
38-
params := paramsFromExpr(root.API.HTTP.Params, basePath)
39+
params := paramsFromExpr(nil, root.API.HTTP.Params, basePath)
3940
var paramMap map[string]*Parameter
4041
if len(params) > 0 {
4142
paramMap = make(map[string]*Parameter, len(params))
@@ -269,7 +270,7 @@ func summaryFromMeta(name string, meta expr.MetaExpr) string {
269270
return name
270271
}
271272

272-
func paramsFromExpr(params *expr.MappedAttributeExpr, path string) []*Parameter {
273+
func paramsFromExpr(endpoint *expr.HTTPEndpointExpr, params *expr.MappedAttributeExpr, path string) []*Parameter {
273274
if params == nil {
274275
return nil
275276
}
@@ -283,6 +284,9 @@ func paramsFromExpr(params *expr.MappedAttributeExpr, path string) []*Parameter
283284
in = "path"
284285
required = true
285286
}
287+
if endpoint != nil && in != "path" && openapiinternal.IsSecurityParameter(endpoint, in, pn) {
288+
return nil
289+
}
286290
param := paramFor(at, pn, in, required)
287291
res = append(res, param)
288292
return nil
@@ -294,24 +298,14 @@ func paramsFromHeaders(endpoint *expr.HTTPEndpointExpr) []*Parameter {
294298
var params []*Parameter
295299

296300
expr.WalkMappedAttr(endpoint.Headers, func(name, elem string, att *expr.AttributeExpr) error { // nolint: errcheck
301+
if openapiinternal.IsSecurityParameter(endpoint, "header", elem) {
302+
return nil
303+
}
297304
required := endpoint.Headers.IsRequiredNoDefault(name)
298305
params = append(params, paramFor(att, elem, "header", required))
299306
return nil
300307
})
301308

302-
// Add basic auth to headers
303-
if att := expr.TaggedAttribute(endpoint.MethodExpr.Payload, "security:username"); att != "" {
304-
// Basic Auth is always encoded in the Authorization header
305-
// https://golang.org/pkg/net/http/#Request.SetBasicAuth
306-
params = append(params, &Parameter{
307-
In: "header",
308-
Name: "Authorization",
309-
Required: endpoint.MethodExpr.Payload.IsRequired(att),
310-
Description: "Basic Auth security using Basic scheme (https://tools.ietf.org/html/rfc7617)",
311-
Type: "string",
312-
})
313-
}
314-
315309
return params
316310
}
317311

@@ -502,7 +496,7 @@ func buildPathFromExpr(s *V2, root *expr.RootExpr, h *expr.HostExpr, route *expr
502496
// Remove any wildcards that is defined in path as a workaround to
503497
// https://github.com/OAI/OpenAPI-Specification/issues/291
504498
key = expr.HTTPWildcardRegex.ReplaceAllString(key, "/{$1}")
505-
params := paramsFromExpr(endpoint.Params, key)
499+
params := paramsFromExpr(endpoint, endpoint.Params, key)
506500
params = append(params, paramsFromHeaders(endpoint)...)
507501
var produces []string
508502

http/codegen/openapi/v2/testdata/TestSections/security_file0.golden

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,6 @@
1414
"get": {
1515
"description": "\n**Required security scopes for basic**:\n * `api:read`\n\n**Required security scopes for jwt**:\n * `api:read`\n\n**Required security scopes for api_key**:\n * `api:read`",
1616
"operationId": "testService#testEndpointA",
17-
"parameters": [
18-
{
19-
"in": "query",
20-
"name": "k",
21-
"required": true,
22-
"type": "string"
23-
},
24-
{
25-
"in": "header",
26-
"name": "Token",
27-
"required": true,
28-
"type": "string"
29-
},
30-
{
31-
"in": "header",
32-
"name": "X-Authorization",
33-
"required": true,
34-
"type": "string"
35-
},
36-
{
37-
"description": "Basic Auth security using Basic scheme (https://tools.ietf.org/html/rfc7617)",
38-
"in": "header",
39-
"name": "Authorization",
40-
"required": true,
41-
"type": "string"
42-
}
43-
],
4417
"responses": {
4518
"204": {
4619
"description": "No Content response."
@@ -66,20 +39,6 @@
6639
},
6740
"post": {
6841
"operationId": "testService#testEndpointB",
69-
"parameters": [
70-
{
71-
"in": "query",
72-
"name": "auth",
73-
"required": true,
74-
"type": "string"
75-
},
76-
{
77-
"in": "header",
78-
"name": "Authorization",
79-
"required": true,
80-
"type": "string"
81-
}
82-
],
8342
"responses": {
8443
"204": {
8544
"description": "No Content response."

0 commit comments

Comments
 (0)