Skip to content

Commit 96d1a62

Browse files
committed
Fix client validators for constructor union bodies
1 parent d1d6b27 commit 96d1a62

3 files changed

Lines changed: 113 additions & 3 deletions

File tree

http/codegen/client_cli_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package codegen
33
import (
44
"bytes"
55
"goa.design/goa/v3/codegen/testutil"
6+
"regexp"
67
"strings"
78
"testing"
89

@@ -94,3 +95,32 @@ func TestConstructorUnionClientCLIFiles(t *testing.T) {
9495
})
9596
}
9697
}
98+
99+
func TestConstructorUnionClientCLIPayloadValidatorsExistInClientTypes(t *testing.T) {
100+
root := RunHTTPDSL(t, testdata.ConstructorUnionClientValidatorReferenceHTTPDSL)
101+
services := CreateHTTPServices(root)
102+
103+
cliFiles := ClientCLIFiles("", services)
104+
require.GreaterOrEqual(t, len(cliFiles), 2, "expected parser and payload builder files")
105+
var builder bytes.Buffer
106+
for _, s := range cliFiles[1].SectionTemplates {
107+
require.NoError(t, s.Write(&builder))
108+
}
109+
builderCode := codegen.FormatTestCode(t, builder.String())
110+
111+
typeFiles := ClientTypeFiles("", services)
112+
require.NotEmpty(t, typeFiles, "expected client type files")
113+
var types bytes.Buffer
114+
for _, s := range typeFiles[0].SectionTemplates {
115+
require.NoError(t, s.Write(&types))
116+
}
117+
typesCode := codegen.FormatTestCode(t, types.String())
118+
119+
re := regexp.MustCompile(`Validate([A-Za-z0-9]+RequestBody)\(`)
120+
matches := re.FindAllStringSubmatch(builderCode, -1)
121+
require.NotEmpty(t, matches, "expected constructor-union builder to reference request-body validators")
122+
for _, match := range matches {
123+
name := match[1]
124+
require.Contains(t, typesCode, "func Validate"+name+"(", "missing client validator for %s", name)
125+
}
126+
}

http/codegen/service_data.go

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,9 +2116,11 @@ func (sds *ServicesData) buildRequestBodyType(body, att *expr.AttributeExpr, e *
21162116
def = goTypeDef(sd.Scope, ut.Attribute(), svr, !svr)
21172117
desc = fmt.Sprintf("%s is the type of the %q service %q endpoint HTTP request body.",
21182118
varname, svc.Name, e.Name())
2119-
if svr {
2120-
// generate validation code for unmarshaled type (server-side).
2121-
validateDef = codegen.ValidationCode(ut.Attribute(), ut, httpctx, true, expr.IsAlias(ut), false, "body")
2119+
// Generate validation code for unmarshaled request bodies on the server,
2120+
// and for client request bodies only when constructor unions require the
2121+
// corresponding validator helper during CLI payload validation.
2122+
if svr || containsUnionType(body.Type) {
2123+
validateDef = codegen.ValidationCode(body, ut, httpctx, true, expr.IsAlias(ut), false, "body")
21222124
if validateDef != "" {
21232125
validateRef = fmt.Sprintf("err = Validate%s(&body)", varname)
21242126
}
@@ -2834,6 +2836,9 @@ func (sds *ServicesData) attributeTypeData(ut expr.UserType, req, ptr, server bo
28342836
// requests server-side and CLI.
28352837
// Alias types are validated inline in the parent type
28362838
validate = codegen.ValidationCode(ut.Attribute(), ut, hctx, true, expr.IsAlias(ut), false, "body")
2839+
if validate == "" && req && !server && needsClientRequestBodyValidatorStub(ut) {
2840+
validate = "// no validations"
2841+
}
28372842
}
28382843
if validate != "" {
28392844
validateRef = fmt.Sprintf("err = Validate%s(v)", name)
@@ -2850,6 +2855,44 @@ func (sds *ServicesData) attributeTypeData(ut expr.UserType, req, ptr, server bo
28502855
}
28512856
}
28522857

2858+
func needsClientRequestBodyValidatorStub(ut expr.UserType) bool {
2859+
if ut == nil || ut.Attribute() == nil || ut.Attribute().Meta == nil {
2860+
return false
2861+
}
2862+
_, ok := ut.Attribute().Meta.Last("oneof:type:tag")
2863+
return ok
2864+
}
2865+
2866+
func containsUnionType(dt expr.DataType) bool {
2867+
return containsUnionTypeRecursive(dt, make(map[string]struct{}))
2868+
}
2869+
2870+
func containsUnionTypeRecursive(dt expr.DataType, seen map[string]struct{}) bool {
2871+
switch actual := dt.(type) {
2872+
case nil:
2873+
return false
2874+
case *expr.Union:
2875+
return true
2876+
case expr.UserType:
2877+
if _, ok := seen[actual.ID()]; ok {
2878+
return false
2879+
}
2880+
seen[actual.ID()] = struct{}{}
2881+
return containsUnionTypeRecursive(actual.Attribute().Type, seen)
2882+
case *expr.Object:
2883+
for _, nat := range *actual {
2884+
if containsUnionTypeRecursive(nat.Attribute.Type, seen) {
2885+
return true
2886+
}
2887+
}
2888+
case *expr.Array:
2889+
return containsUnionTypeRecursive(actual.ElemType.Type, seen)
2890+
case *expr.Map:
2891+
return containsUnionTypeRecursive(actual.KeyType.Type, seen) || containsUnionTypeRecursive(actual.ElemType.Type, seen)
2892+
}
2893+
return false
2894+
}
2895+
28532896
// httpContext returns a context for attributes of types used to marshal and
28542897
// unmarshal HTTP requests and responses.
28552898
//

http/codegen/testdata/constructor_union_dsls.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,43 @@ var NestedTopLevelConstructorUnionCustomKeysHTTPDSL = func() {
472472
})
473473
}
474474

475+
var ConstructorUnionClientValidatorReferenceHTTPDSL = func() {
476+
var All = Type("ClientValidatorReferenceAll", func() {
477+
Meta("name:original", "All")
478+
Meta("oneof:type:tag", "all")
479+
})
480+
var Single = Type("ClientValidatorReferenceSingle", func() {
481+
Meta("name:original", "Single")
482+
Meta("oneof:type:tag", "single")
483+
Attribute("task_id", String)
484+
Required("task_id")
485+
})
486+
var Batch = Type("ClientValidatorReferenceBatch", func() {
487+
Meta("name:original", "Batch")
488+
Meta("oneof:type:tag", "batch")
489+
Attribute("task_ids", ArrayOf(String), func() {
490+
MinLength(1)
491+
})
492+
Required("task_ids")
493+
})
494+
var PayloadType = Type("ClientValidatorReferencePayload", func() {
495+
Attribute("value", OneOf(All, Single, Batch), func() {
496+
Meta("oneof:typename", "ClientValidatorReferenceMode")
497+
Meta("oneof:type:field", "mode")
498+
Meta("oneof:value:field", "value")
499+
})
500+
Required("value")
501+
})
502+
Service("ClientValidatorReference", func() {
503+
Method("Show", func() {
504+
Payload(PayloadType)
505+
HTTP(func() {
506+
POST("/")
507+
})
508+
})
509+
})
510+
}
511+
475512
var RecursiveConstructorUnionHTTPDSL = func() {
476513
var Leaf = Type("Leaf", func() {
477514
Attribute("value", String)

0 commit comments

Comments
 (0)