Skip to content

Commit 36883b3

Browse files
authored
codegen: carry mixed gRPC streaming requests in typed envelopes (#3918)
Stop rewriting ordinary payload into metadata for methods that also stream payload items. Generate a typed initial-payload/stream-item envelope and update the gRPC tests and docs around the new transport contract.
1 parent 3bc54e5 commit 36883b3

22 files changed

Lines changed: 419 additions & 85 deletions

dsl/grpc.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ func Message(fn func()) {
246246
// request metadata unless specified explicitly in request message using
247247
// Message function. All other attributes in method payload are added to the
248248
// request message unless specified explicitly using Metadata (in which case
249-
// will be added to the metadata).
249+
// will be added to the metadata). For methods that also define
250+
// StreamingPayload, the ordinary request message is carried as the initial
251+
// typed stream frame rather than being rewritten into metadata.
250252
//
251253
// Metadata takes one argument of function type which lists the attributes
252254
// that must be set in the request metadata instead of the message.

dsl/payload.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ func Payload(val any, args ...any) {
9191
// StreamingPayload requires a transport that supports client-to-server streaming
9292
// such as gRPC or WebSockets. When using HTTP or JSON-RPC transports, methods
9393
// with StreamingPayload must use WebSockets (via GET endpoints).
94+
// For gRPC methods that define both Payload and StreamingPayload, the ordinary
95+
// method payload is sent once as the initial typed stream frame and the
96+
// StreamingPayload values are sent as subsequent stream item frames.
9497
//
9598
// Examples:
9699
//

expr/grpc_endpoint.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,6 @@ func (e *GRPCEndpointExpr) Finalize() {
327327
}
328328
}
329329

330-
// If endpoint defines streaming payload, then add the attributes in method
331-
// payload type to request metadata.
332-
if e.MethodExpr.StreamingPayload.Type != Empty {
333-
for _, nat := range *pobj {
334-
addToMetadata(nat.Name, "")
335-
}
336-
}
337-
338330
// msgObj contains only the attributes in the method payload that must
339331
// be added to the request message type after removing attributes
340332
// specified in the request metadata.
@@ -387,14 +379,7 @@ func (e *GRPCEndpointExpr) Finalize() {
387379
}
388380
} else {
389381
// method payload is not an object type.
390-
if e.MethodExpr.StreamingPayload.Type != Empty {
391-
// endpoint defines streaming payload. So add the method payload to
392-
// request metadata under "goa-payload" field
393-
e.Metadata.Type.(*Object).Set("goa_payload", e.MethodExpr.Payload)
394-
e.Metadata.Validation.AddRequired("goa_payload")
395-
} else {
396-
initAttrFromDesign(e.Request, e.MethodExpr.Payload)
397-
}
382+
initAttrFromDesign(e.Request, e.MethodExpr.Payload)
398383
}
399384

400385
// Finalize streaming payload type if defined

expr/grpc_endpoint_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"errors"
55
"testing"
66

7+
"github.com/stretchr/testify/require"
8+
79
"goa.design/goa/v3/eval"
810
"goa.design/goa/v3/expr"
911
"goa.design/goa/v3/expr/testdata"
@@ -84,3 +86,17 @@ service "Service" method "MethodUnion": union type choice has map elements, not
8486
})
8587
}
8688
}
89+
90+
func TestGRPCEndpointStreamingPayloadKeepsInitialRequest(t *testing.T) {
91+
root := expr.RunDSL(t, testdata.GRPCEndpointWithStreamingPayloadInitialRequest)
92+
grpcSvc := root.API.GRPC.Service("Service")
93+
require.NotNil(t, grpcSvc)
94+
require.Len(t, grpcSvc.GRPCEndpoints, 1)
95+
96+
endpoint := grpcSvc.GRPCEndpoints[0]
97+
req := expr.AsObject(endpoint.Request.Type)
98+
require.NotNil(t, req)
99+
require.NotNil(t, req.Attribute("repository_id"))
100+
require.NotNil(t, req.Attribute("version_ref"))
101+
require.True(t, endpoint.Metadata.IsEmpty())
102+
}

expr/testdata/endpoint_dsls.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,3 +751,28 @@ var GRPCEndpointWithUnionContainingAny = func() {
751751
})
752752
})
753753
}
754+
755+
var GRPCEndpointWithStreamingPayloadInitialRequest = func() {
756+
var VersionRef = Type("VersionRef", func() {
757+
OneOf("ref_type", func() {
758+
Field(1, "version_id", String)
759+
Field(2, "ref_name", String)
760+
})
761+
Required("ref_type")
762+
})
763+
var UploadChunk = Type("UploadChunk", func() {
764+
Field(1, "chunk", Bytes)
765+
Required("chunk")
766+
})
767+
Service("Service", func() {
768+
Method("Method", func() {
769+
Payload(func() {
770+
Field(1, "repository_id", String)
771+
Field(2, "version_ref", VersionRef)
772+
Required("repository_id", "version_ref")
773+
})
774+
StreamingPayload(UploadChunk)
775+
GRPC(func() {})
776+
})
777+
})
778+
}

grpc/codegen/server.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ func serverFile(genpkg string, svc *expr.GRPCServiceExpr, services *ServicesData
4646
{Path: path.Join(genpkg, svcName, "views"), Name: data.Service.ViewsPkg},
4747
{Path: path.Join(genpkg, "grpc", svcName, pbPkgName), Name: data.PkgName},
4848
}
49+
for _, e := range data.Endpoints {
50+
if e.Request.StreamEnvelope != nil {
51+
imports = append(imports, &codegen.ImportSpec{Path: "io"})
52+
break
53+
}
54+
}
4955
sections = []*codegen.SectionTemplate{
5056
codegen.Header(svc.Name()+" gRPC server", "server", imports),
5157
{

grpc/codegen/service_data.go

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,17 @@ type (
178178
RequestData struct {
179179
// Description is the request description.
180180
Description string
181-
// Message is the gRPC request message.
181+
// Message is the gRPC request message used by the transport. For
182+
// streaming payload methods with an initial payload frame, this is the
183+
// synthesized stream envelope.
182184
Message *service.UserTypeData
185+
// PayloadMessage is the gRPC message that carries the one-shot method
186+
// payload fields before any stream envelope wrapping.
187+
PayloadMessage *service.UserTypeData
188+
// StreamEnvelope describes the synthesized stream envelope when the
189+
// transport must carry both the one-shot payload and streaming payload
190+
// items through the same streamed protobuf message.
191+
StreamEnvelope *StreamEnvelopeData
183192
// Metadata is the request metadata.
184193
Metadata []*MetadataData
185194
// ServerConvert is the request data with constructor function to
@@ -195,6 +204,23 @@ type (
195204
CLIArgs []*InitArgData
196205
}
197206

207+
// StreamEnvelopeData describes a synthesized streamed protobuf envelope.
208+
StreamEnvelopeData struct {
209+
// FieldName is the protobuf oneof field name on the envelope message.
210+
FieldName string
211+
// InitialFieldName is the name of the initial payload branch field.
212+
InitialFieldName string
213+
// InitialWrapperRef is the fully qualified protobuf wrapper type for the
214+
// initial payload branch.
215+
InitialWrapperRef string
216+
// StreamItemFieldName is the name of the streaming payload item branch
217+
// field.
218+
StreamItemFieldName string
219+
// StreamItemWrapperRef is the fully qualified protobuf wrapper type for
220+
// the streaming payload item branch.
221+
StreamItemWrapperRef string
222+
}
223+
198224
// ResponseData describes a gRPC success or error response.
199225
ResponseData struct {
200226
// StatusCode is the return code of the response.
@@ -462,10 +488,26 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
462488
}
463489
seen, imported := make(map[string]struct{}), make(map[string]struct{})
464490
for _, e := range gs.GRPCEndpoints {
491+
hasRequestMessage := !isEmpty(e.Request.Type)
492+
useStreamEnvelope := usesStreamEnvelope(e)
493+
465494
// convert request and response types to protocol buffer message types
466495
e.Request = makeProtoBufMessage(e.Request, protoBufify(e.Name()+"_request", true, true), sd)
467496
if e.MethodExpr.StreamingPayload.Type != expr.Empty {
468-
e.StreamingRequest = makeProtoBufMessage(e.StreamingRequest, protoBufify(e.Name()+"_streaming_request", true, true), sd)
497+
streamMessageName := protoBufify(e.Name()+"_streaming_request", true, true)
498+
if useStreamEnvelope {
499+
streamMessageName = protoBufify(e.Name()+"_stream_item", true, true)
500+
}
501+
e.StreamingRequest = makeProtoBufMessage(e.StreamingRequest, streamMessageName, sd)
502+
}
503+
var requestEnvelope *expr.AttributeExpr
504+
if useStreamEnvelope {
505+
requestEnvelope = makeProtoBufStreamEnvelope(
506+
e.Request,
507+
e.StreamingRequest,
508+
protoBufify(e.Name()+"_streaming_request", true, true),
509+
sd,
510+
)
469511
}
470512
e.Response.Message = makeProtoBufMessage(e.Response.Message, protoBufify(e.Name()+"_response", true, true), sd)
471513
for _, er := range e.GRPCErrors {
@@ -540,6 +582,9 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
540582
ServerConvert: d.buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, true),
541583
ClientConvert: d.buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, false),
542584
}
585+
if hasRequestMessage {
586+
request.PayloadMessage = collect(e.Request)
587+
}
543588
if obj := expr.AsObject(e.Request.Type); (obj != nil && len(*obj) > 0) || expr.IsUnion(e.Request.Type) {
544589
// add the request message as the first argument to the CLI
545590
request.CLIArgs = append(request.CLIArgs, &InitArgData{
@@ -567,9 +612,13 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
567612
DefaultValue: m.DefaultValue,
568613
})
569614
}
570-
if e.StreamingRequest.Type != expr.Empty {
615+
switch {
616+
case requestEnvelope != nil:
617+
request.Message = collect(requestEnvelope)
618+
request.StreamEnvelope = buildStreamEnvelopeData(requestEnvelope, request.Message, sd)
619+
case e.StreamingRequest.Type != expr.Empty:
571620
request.Message = collect(e.StreamingRequest)
572-
} else {
621+
default:
573622
request.Message = collect(e.Request)
574623
}
575624

@@ -872,20 +921,20 @@ func userTypeAttribute(ut expr.UserType) *expr.AttributeExpr {
872921

873922
// buildRequestConvertData builds the convert data for the server and client
874923
// requests.
875-
// - server side - converts generated gRPC request type in *.pb.go and the
876-
// gRPC metadata to method payload type.
877-
// - client side - converts method payload type to generated gRPC request
878-
// type in *.pb.go.
924+
// - server side - converts the one-shot gRPC request message (if any) and
925+
// gRPC metadata to the method payload type.
926+
// - client side - converts the method payload type to the one-shot gRPC
927+
// request message sent before any stream items.
879928
//
880929
// svr param indicates that the convert data is generated for server side.
881930
func (d *ServicesData) buildRequestConvertData(request, payload *expr.AttributeExpr, md []*MetadataData, e *expr.GRPCEndpointExpr, sd *ServiceData, svr bool) *ConvertData {
882-
// Server-side: No need to build convert data if payload is empty or payload
883-
// is not an object type and endpoint streams payload (the payload is
884-
// encoded in metadata under "goa-payload" in this case).
885-
if (svr && (isEmpty(payload.Type) || !expr.IsObject(payload.Type) && e.MethodExpr.IsPayloadStreaming())) ||
886-
// Client-side: No need to build convert data if streaming payload since
887-
// all attributes in method payload is encoded into request metadata.
888-
(!svr && e.MethodExpr.IsPayloadStreaming()) {
931+
if svr && isEmpty(payload.Type) {
932+
return nil
933+
}
934+
if !svr && e.MethodExpr.IsPayloadStreaming() && isEmpty(request.Type) {
935+
return nil
936+
}
937+
if svr && e.MethodExpr.IsPayloadStreaming() && isEmpty(request.Type) && !expr.IsObject(payload.Type) {
889938
return nil
890939
}
891940

@@ -1297,6 +1346,60 @@ func extractMetadata(a *expr.MappedAttributeExpr, service *expr.AttributeExpr, s
12971346
return metadata
12981347
}
12991348

1349+
// usesStreamEnvelope reports whether the transport needs a typed stream
1350+
// envelope to carry both the one-shot method payload and streaming payload
1351+
// items.
1352+
func usesStreamEnvelope(e *expr.GRPCEndpointExpr) bool {
1353+
return e.MethodExpr.IsPayloadStreaming() && !isEmpty(e.Request.Type)
1354+
}
1355+
1356+
// makeProtoBufStreamEnvelope builds the protobuf stream envelope that carries
1357+
// the initial request payload frame and subsequent stream item frames.
1358+
func makeProtoBufStreamEnvelope(request, stream *expr.AttributeExpr, tname string, sd *ServiceData) *expr.AttributeExpr {
1359+
initial := expr.DupAtt(request)
1360+
initial.Meta = initial.Meta.Dup()
1361+
initial.Meta["rpc:tag"] = []string{"1"}
1362+
streamItem := expr.DupAtt(stream)
1363+
streamItem.Meta = streamItem.Meta.Dup()
1364+
streamItem.Meta["rpc:tag"] = []string{"2"}
1365+
envelope := &expr.AttributeExpr{
1366+
Type: &expr.Object{
1367+
&expr.NamedAttributeExpr{
1368+
Name: "body",
1369+
Attribute: &expr.AttributeExpr{
1370+
Type: &expr.Union{
1371+
TypeName: "body",
1372+
Values: []*expr.NamedAttributeExpr{
1373+
{Name: "initial_payload", Attribute: initial},
1374+
{Name: "stream_item", Attribute: streamItem},
1375+
},
1376+
},
1377+
},
1378+
},
1379+
},
1380+
Validation: &expr.ValidationExpr{Required: []string{"body"}},
1381+
}
1382+
return makeProtoBufMessage(envelope, tname, sd)
1383+
}
1384+
1385+
// buildStreamEnvelopeData computes the generated Go names for the protobuf
1386+
// oneof field and wrapper types of the synthesized stream envelope.
1387+
func buildStreamEnvelopeData(envelope *expr.AttributeExpr, message *service.UserTypeData, sd *ServiceData) *StreamEnvelopeData {
1388+
body := envelope.Find("body")
1389+
union := expr.AsUnion(body.Type)
1390+
scope := &protoBufScope{scope: sd.Scope}
1391+
fieldName := scope.Field(body, union.TypeName, true)
1392+
initialFieldName := scope.Field(union.Values[0].Attribute, union.Values[0].Name, true)
1393+
streamItemFieldName := scope.Field(union.Values[1].Attribute, union.Values[1].Name, true)
1394+
return &StreamEnvelopeData{
1395+
FieldName: fieldName,
1396+
InitialFieldName: initialFieldName,
1397+
InitialWrapperRef: fmt.Sprintf("%s.%s_%s", sd.PkgName, message.VarName, initialFieldName),
1398+
StreamItemFieldName: streamItemFieldName,
1399+
StreamItemWrapperRef: fmt.Sprintf("%s.%s_%s", sd.PkgName, message.VarName, streamItemFieldName),
1400+
}
1401+
}
1402+
13001403
func unalias(att *expr.AttributeExpr) *expr.AttributeExpr {
13011404
if ut, ok := att.Type.(expr.UserType); ok {
13021405
if _, ok := ut.Attribute().Type.(expr.Primitive); ok {

grpc/codegen/streaming_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
89

910
"goa.design/goa/v3/codegen"
1011
"goa.design/goa/v3/grpc/codegen/testdata"
@@ -148,3 +149,41 @@ func TestStreaming(t *testing.T) {
148149
})
149150
}
150151
}
152+
153+
func TestStreamingPayloadEnvelopeWithUnionPayload(t *testing.T) {
154+
root := RunGRPCDSL(t, testdata.ClientStreamingRPCWithUnionPayloadDSL)
155+
services := CreateGRPCServices(root)
156+
157+
clientfs := ClientFiles("", services)
158+
require.Len(t, clientfs, 2)
159+
serverfs := ServerFiles("", services)
160+
require.Len(t, serverfs, 2)
161+
protofs := ProtoFiles("", services)
162+
require.Len(t, protofs, 1)
163+
164+
requestEncoder := codegen.SectionsCode(t, clientfs[1].Section("request-encoder"))
165+
assert.Contains(t, requestEncoder, "InitialPayload")
166+
assert.Contains(t, requestEncoder, "MethodClientStreamingRPCWithUnionPayloadStreamingRequest")
167+
168+
clientSend := codegen.SectionsCode(t, clientfs[0].Section("client-stream-send"))
169+
assert.Contains(t, clientSend, "StreamItem")
170+
assert.Contains(t, clientSend, "UploadChunk")
171+
172+
serverInterface := codegen.SectionsCode(t, serverfs[0].Section("server-grpc-interface"))
173+
assert.Contains(t, serverInterface, "message, err := stream.Recv()")
174+
assert.Contains(t, serverInterface, "Decode(ctx, reqpb)")
175+
176+
requestDecoder := codegen.SectionsCode(t, serverfs[1].Section("request-decoder"))
177+
assert.Contains(t, requestDecoder, "InitialPayload")
178+
assert.Contains(t, requestDecoder, "stream_item")
179+
assert.Contains(t, requestDecoder, "NewMethodClientStreamingRPCWithUnionPayloadPayload(message)")
180+
181+
proto := sectionCode(t, protofs[0].SectionTemplates[1:]...)
182+
assert.Contains(t, proto, "message MethodClientStreamingRPCWithUnionPayloadStreamingRequest")
183+
assert.Contains(t, proto, "oneof body")
184+
assert.Contains(t, proto, "MethodClientStreamingRPCWithUnionPayloadRequest initial_payload")
185+
assert.Contains(t, proto, "MethodClientStreamingRPCWithUnionPayloadStreamItem stream_item")
186+
187+
fpath := codegen.CreateTempFile(t, proto)
188+
assert.NoError(t, protoc(defaultProtocCmd, fpath, nil))
189+
}

grpc/codegen/templates/remote_method_builder.go.tpl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,22 @@ func Build{{ .Method.VarName }}Func(grpccli {{ .PkgName }}.{{ .ClientInterface }
44
for _, opt := range cliopts {
55
opts = append(opts, opt)
66
}
7-
if reqpb != nil {
8-
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, reqpb.({{ .Request.ClientConvert.TgtRef }}){{ end }}, opts...)
9-
}
10-
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, &{{ .Request.ClientConvert.TgtName }}{}{{ end }}, opts...)
7+
{{- if .Request.StreamEnvelope }}
8+
stream, err := grpccli.{{ .ClientMethodName }}(ctx, opts...)
9+
if err != nil {
10+
return nil, err
11+
}
12+
if reqpb != nil {
13+
if err := stream.Send(reqpb.({{ .Request.Message.Ref }})); err != nil {
14+
return nil, err
15+
}
16+
}
17+
return stream, nil
18+
{{- else }}
19+
if reqpb != nil {
20+
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, reqpb.({{ .Request.ClientConvert.TgtRef }}){{ end }}, opts...)
21+
}
22+
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, &{{ .Request.ClientConvert.TgtName }}{}{{ end }}, opts...)
23+
{{- end }}
1124
}
1225
}

0 commit comments

Comments
 (0)