diff --git a/dsl/grpc.go b/dsl/grpc.go index b20f4ee940..e494b4458d 100644 --- a/dsl/grpc.go +++ b/dsl/grpc.go @@ -246,7 +246,9 @@ func Message(fn func()) { // request metadata unless specified explicitly in request message using // Message function. All other attributes in method payload are added to the // request message unless specified explicitly using Metadata (in which case -// will be added to the metadata). +// will be added to the metadata). For methods that also define +// StreamingPayload, the ordinary request message is carried as the initial +// typed stream frame rather than being rewritten into metadata. // // Metadata takes one argument of function type which lists the attributes // that must be set in the request metadata instead of the message. diff --git a/dsl/payload.go b/dsl/payload.go index 46e0e9b1b5..2e96fd8a37 100644 --- a/dsl/payload.go +++ b/dsl/payload.go @@ -91,6 +91,9 @@ func Payload(val any, args ...any) { // StreamingPayload requires a transport that supports client-to-server streaming // such as gRPC or WebSockets. When using HTTP or JSON-RPC transports, methods // with StreamingPayload must use WebSockets (via GET endpoints). +// For gRPC methods that define both Payload and StreamingPayload, the ordinary +// method payload is sent once as the initial typed stream frame and the +// StreamingPayload values are sent as subsequent stream item frames. // // Examples: // diff --git a/expr/grpc_endpoint.go b/expr/grpc_endpoint.go index bd26f70e9b..a808fdb9a3 100644 --- a/expr/grpc_endpoint.go +++ b/expr/grpc_endpoint.go @@ -327,14 +327,6 @@ func (e *GRPCEndpointExpr) Finalize() { } } - // If endpoint defines streaming payload, then add the attributes in method - // payload type to request metadata. - if e.MethodExpr.StreamingPayload.Type != Empty { - for _, nat := range *pobj { - addToMetadata(nat.Name, "") - } - } - // msgObj contains only the attributes in the method payload that must // be added to the request message type after removing attributes // specified in the request metadata. @@ -387,14 +379,7 @@ func (e *GRPCEndpointExpr) Finalize() { } } else { // method payload is not an object type. - if e.MethodExpr.StreamingPayload.Type != Empty { - // endpoint defines streaming payload. So add the method payload to - // request metadata under "goa-payload" field - e.Metadata.Type.(*Object).Set("goa_payload", e.MethodExpr.Payload) - e.Metadata.Validation.AddRequired("goa_payload") - } else { - initAttrFromDesign(e.Request, e.MethodExpr.Payload) - } + initAttrFromDesign(e.Request, e.MethodExpr.Payload) } // Finalize streaming payload type if defined diff --git a/expr/grpc_endpoint_test.go b/expr/grpc_endpoint_test.go index 5e965ce1c9..8b6ffb08b4 100644 --- a/expr/grpc_endpoint_test.go +++ b/expr/grpc_endpoint_test.go @@ -4,6 +4,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/require" + "goa.design/goa/v3/eval" "goa.design/goa/v3/expr" "goa.design/goa/v3/expr/testdata" @@ -84,3 +86,17 @@ service "Service" method "MethodUnion": union type choice has map elements, not }) } } + +func TestGRPCEndpointStreamingPayloadKeepsInitialRequest(t *testing.T) { + root := expr.RunDSL(t, testdata.GRPCEndpointWithStreamingPayloadInitialRequest) + grpcSvc := root.API.GRPC.Service("Service") + require.NotNil(t, grpcSvc) + require.Len(t, grpcSvc.GRPCEndpoints, 1) + + endpoint := grpcSvc.GRPCEndpoints[0] + req := expr.AsObject(endpoint.Request.Type) + require.NotNil(t, req) + require.NotNil(t, req.Attribute("repository_id")) + require.NotNil(t, req.Attribute("version_ref")) + require.True(t, endpoint.Metadata.IsEmpty()) +} diff --git a/expr/testdata/endpoint_dsls.go b/expr/testdata/endpoint_dsls.go index 31fdab168a..74e6562ce3 100644 --- a/expr/testdata/endpoint_dsls.go +++ b/expr/testdata/endpoint_dsls.go @@ -751,3 +751,28 @@ var GRPCEndpointWithUnionContainingAny = func() { }) }) } + +var GRPCEndpointWithStreamingPayloadInitialRequest = func() { + var VersionRef = Type("VersionRef", func() { + OneOf("ref_type", func() { + Field(1, "version_id", String) + Field(2, "ref_name", String) + }) + Required("ref_type") + }) + var UploadChunk = Type("UploadChunk", func() { + Field(1, "chunk", Bytes) + Required("chunk") + }) + Service("Service", func() { + Method("Method", func() { + Payload(func() { + Field(1, "repository_id", String) + Field(2, "version_ref", VersionRef) + Required("repository_id", "version_ref") + }) + StreamingPayload(UploadChunk) + GRPC(func() {}) + }) + }) +} diff --git a/grpc/codegen/server.go b/grpc/codegen/server.go index e3886445fd..d7f73592da 100644 --- a/grpc/codegen/server.go +++ b/grpc/codegen/server.go @@ -46,6 +46,12 @@ func serverFile(genpkg string, svc *expr.GRPCServiceExpr, services *ServicesData {Path: path.Join(genpkg, svcName, "views"), Name: data.Service.ViewsPkg}, {Path: path.Join(genpkg, "grpc", svcName, pbPkgName), Name: data.PkgName}, } + for _, e := range data.Endpoints { + if e.Request.StreamEnvelope != nil { + imports = append(imports, &codegen.ImportSpec{Path: "io"}) + break + } + } sections = []*codegen.SectionTemplate{ codegen.Header(svc.Name()+" gRPC server", "server", imports), { diff --git a/grpc/codegen/service_data.go b/grpc/codegen/service_data.go index bfb7837984..d7492281b4 100644 --- a/grpc/codegen/service_data.go +++ b/grpc/codegen/service_data.go @@ -178,8 +178,17 @@ type ( RequestData struct { // Description is the request description. Description string - // Message is the gRPC request message. + // Message is the gRPC request message used by the transport. For + // streaming payload methods with an initial payload frame, this is the + // synthesized stream envelope. Message *service.UserTypeData + // PayloadMessage is the gRPC message that carries the one-shot method + // payload fields before any stream envelope wrapping. + PayloadMessage *service.UserTypeData + // StreamEnvelope describes the synthesized stream envelope when the + // transport must carry both the one-shot payload and streaming payload + // items through the same streamed protobuf message. + StreamEnvelope *StreamEnvelopeData // Metadata is the request metadata. Metadata []*MetadataData // ServerConvert is the request data with constructor function to @@ -195,6 +204,23 @@ type ( CLIArgs []*InitArgData } + // StreamEnvelopeData describes a synthesized streamed protobuf envelope. + StreamEnvelopeData struct { + // FieldName is the protobuf oneof field name on the envelope message. + FieldName string + // InitialFieldName is the name of the initial payload branch field. + InitialFieldName string + // InitialWrapperRef is the fully qualified protobuf wrapper type for the + // initial payload branch. + InitialWrapperRef string + // StreamItemFieldName is the name of the streaming payload item branch + // field. + StreamItemFieldName string + // StreamItemWrapperRef is the fully qualified protobuf wrapper type for + // the streaming payload item branch. + StreamItemWrapperRef string + } + // ResponseData describes a gRPC success or error response. ResponseData struct { // StatusCode is the return code of the response. @@ -462,10 +488,26 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData { } seen, imported := make(map[string]struct{}), make(map[string]struct{}) for _, e := range gs.GRPCEndpoints { + hasRequestMessage := !isEmpty(e.Request.Type) + useStreamEnvelope := usesStreamEnvelope(e) + // convert request and response types to protocol buffer message types e.Request = makeProtoBufMessage(e.Request, protoBufify(e.Name()+"_request", true, true), sd) if e.MethodExpr.StreamingPayload.Type != expr.Empty { - e.StreamingRequest = makeProtoBufMessage(e.StreamingRequest, protoBufify(e.Name()+"_streaming_request", true, true), sd) + streamMessageName := protoBufify(e.Name()+"_streaming_request", true, true) + if useStreamEnvelope { + streamMessageName = protoBufify(e.Name()+"_stream_item", true, true) + } + e.StreamingRequest = makeProtoBufMessage(e.StreamingRequest, streamMessageName, sd) + } + var requestEnvelope *expr.AttributeExpr + if useStreamEnvelope { + requestEnvelope = makeProtoBufStreamEnvelope( + e.Request, + e.StreamingRequest, + protoBufify(e.Name()+"_streaming_request", true, true), + sd, + ) } e.Response.Message = makeProtoBufMessage(e.Response.Message, protoBufify(e.Name()+"_response", true, true), sd) for _, er := range e.GRPCErrors { @@ -540,6 +582,9 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData { ServerConvert: d.buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, true), ClientConvert: d.buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, false), } + if hasRequestMessage { + request.PayloadMessage = collect(e.Request) + } if obj := expr.AsObject(e.Request.Type); (obj != nil && len(*obj) > 0) || expr.IsUnion(e.Request.Type) { // add the request message as the first argument to the CLI request.CLIArgs = append(request.CLIArgs, &InitArgData{ @@ -567,9 +612,13 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData { DefaultValue: m.DefaultValue, }) } - if e.StreamingRequest.Type != expr.Empty { + switch { + case requestEnvelope != nil: + request.Message = collect(requestEnvelope) + request.StreamEnvelope = buildStreamEnvelopeData(requestEnvelope, request.Message, sd) + case e.StreamingRequest.Type != expr.Empty: request.Message = collect(e.StreamingRequest) - } else { + default: request.Message = collect(e.Request) } @@ -872,20 +921,20 @@ func userTypeAttribute(ut expr.UserType) *expr.AttributeExpr { // buildRequestConvertData builds the convert data for the server and client // requests. -// - server side - converts generated gRPC request type in *.pb.go and the -// gRPC metadata to method payload type. -// - client side - converts method payload type to generated gRPC request -// type in *.pb.go. +// - server side - converts the one-shot gRPC request message (if any) and +// gRPC metadata to the method payload type. +// - client side - converts the method payload type to the one-shot gRPC +// request message sent before any stream items. // // svr param indicates that the convert data is generated for server side. func (d *ServicesData) buildRequestConvertData(request, payload *expr.AttributeExpr, md []*MetadataData, e *expr.GRPCEndpointExpr, sd *ServiceData, svr bool) *ConvertData { - // Server-side: No need to build convert data if payload is empty or payload - // is not an object type and endpoint streams payload (the payload is - // encoded in metadata under "goa-payload" in this case). - if (svr && (isEmpty(payload.Type) || !expr.IsObject(payload.Type) && e.MethodExpr.IsPayloadStreaming())) || - // Client-side: No need to build convert data if streaming payload since - // all attributes in method payload is encoded into request metadata. - (!svr && e.MethodExpr.IsPayloadStreaming()) { + if svr && isEmpty(payload.Type) { + return nil + } + if !svr && e.MethodExpr.IsPayloadStreaming() && isEmpty(request.Type) { + return nil + } + if svr && e.MethodExpr.IsPayloadStreaming() && isEmpty(request.Type) && !expr.IsObject(payload.Type) { return nil } @@ -1297,6 +1346,60 @@ func extractMetadata(a *expr.MappedAttributeExpr, service *expr.AttributeExpr, s return metadata } +// usesStreamEnvelope reports whether the transport needs a typed stream +// envelope to carry both the one-shot method payload and streaming payload +// items. +func usesStreamEnvelope(e *expr.GRPCEndpointExpr) bool { + return e.MethodExpr.IsPayloadStreaming() && !isEmpty(e.Request.Type) +} + +// makeProtoBufStreamEnvelope builds the protobuf stream envelope that carries +// the initial request payload frame and subsequent stream item frames. +func makeProtoBufStreamEnvelope(request, stream *expr.AttributeExpr, tname string, sd *ServiceData) *expr.AttributeExpr { + initial := expr.DupAtt(request) + initial.Meta = initial.Meta.Dup() + initial.Meta["rpc:tag"] = []string{"1"} + streamItem := expr.DupAtt(stream) + streamItem.Meta = streamItem.Meta.Dup() + streamItem.Meta["rpc:tag"] = []string{"2"} + envelope := &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{ + Name: "body", + Attribute: &expr.AttributeExpr{ + Type: &expr.Union{ + TypeName: "body", + Values: []*expr.NamedAttributeExpr{ + {Name: "initial_payload", Attribute: initial}, + {Name: "stream_item", Attribute: streamItem}, + }, + }, + }, + }, + }, + Validation: &expr.ValidationExpr{Required: []string{"body"}}, + } + return makeProtoBufMessage(envelope, tname, sd) +} + +// buildStreamEnvelopeData computes the generated Go names for the protobuf +// oneof field and wrapper types of the synthesized stream envelope. +func buildStreamEnvelopeData(envelope *expr.AttributeExpr, message *service.UserTypeData, sd *ServiceData) *StreamEnvelopeData { + body := envelope.Find("body") + union := expr.AsUnion(body.Type) + scope := &protoBufScope{scope: sd.Scope} + fieldName := scope.Field(body, union.TypeName, true) + initialFieldName := scope.Field(union.Values[0].Attribute, union.Values[0].Name, true) + streamItemFieldName := scope.Field(union.Values[1].Attribute, union.Values[1].Name, true) + return &StreamEnvelopeData{ + FieldName: fieldName, + InitialFieldName: initialFieldName, + InitialWrapperRef: fmt.Sprintf("%s.%s_%s", sd.PkgName, message.VarName, initialFieldName), + StreamItemFieldName: streamItemFieldName, + StreamItemWrapperRef: fmt.Sprintf("%s.%s_%s", sd.PkgName, message.VarName, streamItemFieldName), + } +} + func unalias(att *expr.AttributeExpr) *expr.AttributeExpr { if ut, ok := att.Type.(expr.UserType); ok { if _, ok := ut.Attribute().Type.(expr.Primitive); ok { diff --git a/grpc/codegen/streaming_test.go b/grpc/codegen/streaming_test.go index 91b6b8dacf..1bfc3ea67b 100644 --- a/grpc/codegen/streaming_test.go +++ b/grpc/codegen/streaming_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "goa.design/goa/v3/codegen" "goa.design/goa/v3/grpc/codegen/testdata" @@ -148,3 +149,41 @@ func TestStreaming(t *testing.T) { }) } } + +func TestStreamingPayloadEnvelopeWithUnionPayload(t *testing.T) { + root := RunGRPCDSL(t, testdata.ClientStreamingRPCWithUnionPayloadDSL) + services := CreateGRPCServices(root) + + clientfs := ClientFiles("", services) + require.Len(t, clientfs, 2) + serverfs := ServerFiles("", services) + require.Len(t, serverfs, 2) + protofs := ProtoFiles("", services) + require.Len(t, protofs, 1) + + requestEncoder := codegen.SectionsCode(t, clientfs[1].Section("request-encoder")) + assert.Contains(t, requestEncoder, "InitialPayload") + assert.Contains(t, requestEncoder, "MethodClientStreamingRPCWithUnionPayloadStreamingRequest") + + clientSend := codegen.SectionsCode(t, clientfs[0].Section("client-stream-send")) + assert.Contains(t, clientSend, "StreamItem") + assert.Contains(t, clientSend, "UploadChunk") + + serverInterface := codegen.SectionsCode(t, serverfs[0].Section("server-grpc-interface")) + assert.Contains(t, serverInterface, "message, err := stream.Recv()") + assert.Contains(t, serverInterface, "Decode(ctx, reqpb)") + + requestDecoder := codegen.SectionsCode(t, serverfs[1].Section("request-decoder")) + assert.Contains(t, requestDecoder, "InitialPayload") + assert.Contains(t, requestDecoder, "stream_item") + assert.Contains(t, requestDecoder, "NewMethodClientStreamingRPCWithUnionPayloadPayload(message)") + + proto := sectionCode(t, protofs[0].SectionTemplates[1:]...) + assert.Contains(t, proto, "message MethodClientStreamingRPCWithUnionPayloadStreamingRequest") + assert.Contains(t, proto, "oneof body") + assert.Contains(t, proto, "MethodClientStreamingRPCWithUnionPayloadRequest initial_payload") + assert.Contains(t, proto, "MethodClientStreamingRPCWithUnionPayloadStreamItem stream_item") + + fpath := codegen.CreateTempFile(t, proto) + assert.NoError(t, protoc(defaultProtocCmd, fpath, nil)) +} diff --git a/grpc/codegen/templates/remote_method_builder.go.tpl b/grpc/codegen/templates/remote_method_builder.go.tpl index 7eb479018b..c96d079b03 100644 --- a/grpc/codegen/templates/remote_method_builder.go.tpl +++ b/grpc/codegen/templates/remote_method_builder.go.tpl @@ -4,9 +4,22 @@ func Build{{ .Method.VarName }}Func(grpccli {{ .PkgName }}.{{ .ClientInterface } for _, opt := range cliopts { opts = append(opts, opt) } - if reqpb != nil { - return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, reqpb.({{ .Request.ClientConvert.TgtRef }}){{ end }}, opts...) - } - return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, &{{ .Request.ClientConvert.TgtName }}{}{{ end }}, opts...) + {{- if .Request.StreamEnvelope }} + stream, err := grpccli.{{ .ClientMethodName }}(ctx, opts...) + if err != nil { + return nil, err + } + if reqpb != nil { + if err := stream.Send(reqpb.({{ .Request.Message.Ref }})); err != nil { + return nil, err + } + } + return stream, nil + {{- else }} + if reqpb != nil { + return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, reqpb.({{ .Request.ClientConvert.TgtRef }}){{ end }}, opts...) + } + return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, &{{ .Request.ClientConvert.TgtName }}{}{{ end }}, opts...) + {{- end }} } } diff --git a/grpc/codegen/templates/request_decoder.go.tpl b/grpc/codegen/templates/request_decoder.go.tpl index ad38f5105d..b8392a9fea 100644 --- a/grpc/codegen/templates/request_decoder.go.tpl +++ b/grpc/codegen/templates/request_decoder.go.tpl @@ -67,15 +67,36 @@ func Decode{{ .Method.VarName }}Request(ctx context.Context, v any, md metadata. return nil, err } {{- end }} -{{- if and (not .Method.StreamingPayload) (not (isEmpty .Request.Message.Type)) }} +{{- if .Request.PayloadMessage }} var ( - message {{ .Request.ServerConvert.SrcRef }} + message {{ .Request.PayloadMessage.Ref }} ok bool ) { - if message, ok = v.({{ .Request.ServerConvert.SrcRef }}); !ok { - return nil, goagrpc.ErrInvalidType("{{ .ServiceName }}", "{{ .Method.Name }}", "{{ .Request.Message.Ref }}", v) + {{- if .Request.StreamEnvelope }} + if v == nil { + return nil, goa.MissingFieldError("initial_payload", "stream") + } + var envelope {{ .Request.Message.Ref }} + if envelope, ok = v.({{ .Request.Message.Ref }}); !ok { + return nil, goagrpc.ErrInvalidType("{{ .ServiceName }}", "{{ .Method.Name }}", "{{ .Request.Message.Ref }}", v) + } + switch body := envelope.{{ .Request.StreamEnvelope.FieldName }}.(type) { + case *{{ .Request.StreamEnvelope.InitialWrapperRef }}: + if body.{{ .Request.StreamEnvelope.InitialFieldName }} == nil { + return nil, goa.MissingFieldError("initial_payload", "stream") + } + message = body.{{ .Request.StreamEnvelope.InitialFieldName }} + case *{{ .Request.StreamEnvelope.StreamItemWrapperRef }}: + return nil, goa.InvalidFieldTypeError("body", "stream_item", "initial_payload") + default: + return nil, goa.MissingFieldError("initial_payload", "stream") + } + {{- else }} + if message, ok = v.({{ .Request.PayloadMessage.Ref }}); !ok { + return nil, goagrpc.ErrInvalidType("{{ .ServiceName }}", "{{ .Method.Name }}", "{{ .Request.PayloadMessage.Ref }}", v) } + {{- end }} {{- if .Request.ServerConvert.Validation }} if err {{ if .Request.Metadata }}={{ else }}:={{ end }} {{ .Request.ServerConvert.Validation.Name }}(message); err != nil { return nil, err diff --git a/grpc/codegen/templates/request_encoder.go.tpl b/grpc/codegen/templates/request_encoder.go.tpl index bb10d45946..53481c0c4b 100644 --- a/grpc/codegen/templates/request_encoder.go.tpl +++ b/grpc/codegen/templates/request_encoder.go.tpl @@ -39,7 +39,16 @@ func Encode{{ .Method.VarName }}Request(ctx context.Context, v any, md *metadata {{- end }} {{- end }} {{- if .Request.ClientConvert }} + {{- if .Request.StreamEnvelope }} + message := {{ .Request.ClientConvert.Init.Name }}({{ range .Request.ClientConvert.Init.Args }}{{ .Name }}, {{ end }}) + return &{{ .PkgName }}.{{ .Request.Message.VarName }}{ + {{ .Request.StreamEnvelope.FieldName }}: &{{ .Request.StreamEnvelope.InitialWrapperRef }}{ + {{ .Request.StreamEnvelope.InitialFieldName }}: message, + }, + }, nil + {{- else }} return {{ .Request.ClientConvert.Init.Name }}({{ range .Request.ClientConvert.Init.Args }}{{ .Name }}, {{ end }}), nil + {{- end }} {{- else }} return nil, nil {{- end }} diff --git a/grpc/codegen/templates/server_grpc_interface.go.tpl b/grpc/codegen/templates/server_grpc_interface.go.tpl index c4a79402ab..aa4e954cf4 100644 --- a/grpc/codegen/templates/server_grpc_interface.go.tpl +++ b/grpc/codegen/templates/server_grpc_interface.go.tpl @@ -10,7 +10,22 @@ func (s *{{ .ServerStruct }}) {{ .Method.VarName }}( ctx = context.WithValue(ctx, goa.ServiceKey, {{ printf "%q" .ServiceName }}) {{- if .ServerStream }} - {{if .PayloadRef }}p{{ else }}_{{ end }}, err := s.{{ .Method.VarName }}H.Decode(ctx, {{ if .Method.StreamingPayload }}nil{{ else }}message{{ end }}) + {{- if .Request.StreamEnvelope }} + var reqpb any + message, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + reqpb = nil + } else { + return goagrpc.EncodeError(err) + } + } else { + reqpb = message + } + {{if .PayloadRef }}p{{ else }}_{{ end }}, err := s.{{ .Method.VarName }}H.Decode(ctx, reqpb) + {{- else }} + {{if .PayloadRef }}p{{ else }}_{{ end }}, err := s.{{ .Method.VarName }}H.Decode(ctx, {{ if .Method.StreamingPayload }}nil{{ else }}message{{ end }}) + {{- end }} {{- template "handle_error" . }} ep := &{{ .ServicePkgName }}.{{ .Method.VarName }}EndpointInput{ Stream: &{{ .ServerStream.VarName }}{stream: stream}, diff --git a/grpc/codegen/templates/stream_recv.go.tpl b/grpc/codegen/templates/stream_recv.go.tpl index bc2a5534d1..1a0a7e4f98 100644 --- a/grpc/codegen/templates/stream_recv.go.tpl +++ b/grpc/codegen/templates/stream_recv.go.tpl @@ -1,7 +1,11 @@ {{ comment .RecvDesc }} func (s *{{ .VarName }}) {{ .RecvName }}() ({{ .RecvRef }}, error) { var res {{ .RecvRef }} + {{- if and (eq .Type "server") .Endpoint.Request.StreamEnvelope }} + message, err := s.stream.{{ .RecvName }}() + {{- else }} v, err := s.stream.{{ .RecvName }}() + {{- end }} if err != nil { {{- if and .Endpoint .Endpoint.Errors (eq .Type "client") }} resp := goagrpc.DecodeError(err) @@ -26,6 +30,21 @@ func (s *{{ .VarName }}) {{ .RecvName }}() ({{ .RecvRef }}, error) { return res, err {{- end }} } + {{- if and (eq .Type "server") .Endpoint.Request.StreamEnvelope }} + body, ok := message.{{ .Endpoint.Request.StreamEnvelope.FieldName }}.(*{{ .Endpoint.Request.StreamEnvelope.StreamItemWrapperRef }}) + if !ok { + switch message.{{ .Endpoint.Request.StreamEnvelope.FieldName }}.(type) { + case *{{ .Endpoint.Request.StreamEnvelope.InitialWrapperRef }}: + return res, goa.InvalidFieldTypeError("body", "initial_payload", "stream_item") + default: + return res, goa.MissingFieldError("stream_item", "stream") + } + } + if body.{{ .Endpoint.Request.StreamEnvelope.StreamItemFieldName }} == nil { + return res, goa.MissingFieldError("stream_item", "stream") + } + v := body.{{ .Endpoint.Request.StreamEnvelope.StreamItemFieldName }} + {{- end }} {{- if and .Endpoint.Method.ViewedResult (eq .Type "client") }} proj := {{ .RecvConvert.Init.Name }}({{ range .RecvConvert.Init.Args }}{{ .Name }}, {{ end }}) vres := {{ if not .Endpoint.Method.ViewedResult.IsCollection }}&{{ end }}{{ .Endpoint.Method.ViewedResult.FullName }}{Projected: proj, View: {{ if .Endpoint.Method.ViewedResult.ViewName }}"{{ .Endpoint.Method.ViewedResult.ViewName }}"{{ else }}s.view{{ end }} } diff --git a/grpc/codegen/templates/stream_send.go.tpl b/grpc/codegen/templates/stream_send.go.tpl index e873121a22..ef0146a0ed 100644 --- a/grpc/codegen/templates/stream_send.go.tpl +++ b/grpc/codegen/templates/stream_send.go.tpl @@ -8,7 +8,15 @@ func (s *{{ .VarName }}) {{ .SendName }}(res {{ .SendRef }}) error { {{- end }} {{- end }} v := {{ .SendConvert.Init.Name }}({{ if and .Endpoint.Method.ViewedResult (eq .Type "server") }}vres.Projected{{ else }}res{{ end }}) + {{- if and (eq .Type "client") .Endpoint.Request.StreamEnvelope }} + return s.stream.{{ .SendName }}(&{{ .Endpoint.PkgName }}.{{ .Endpoint.Request.Message.VarName }}{ + {{ .Endpoint.Request.StreamEnvelope.FieldName }}: &{{ .Endpoint.Request.StreamEnvelope.StreamItemWrapperRef }}{ + {{ .Endpoint.Request.StreamEnvelope.StreamItemFieldName }}: v, + }, + }) + {{- else }} return s.stream.{{ .SendName }}(v) + {{- end }} } {{ comment .SendWithContextDesc }} diff --git a/grpc/codegen/testdata/dsls.go b/grpc/codegen/testdata/dsls.go index df7f3d7ccf..59dbc16fcc 100644 --- a/grpc/codegen/testdata/dsls.go +++ b/grpc/codegen/testdata/dsls.go @@ -264,6 +264,32 @@ var ClientStreamingRPCWithPayloadDSL = func() { }) } +var ClientStreamingRPCWithUnionPayloadDSL = func() { + var VersionRef = Type("VersionRef", func() { + OneOf("ref_type", func() { + Field(1, "version_id", String) + Field(2, "ref_name", String) + }) + Required("ref_type") + }) + var UploadChunk = Type("UploadChunk", func() { + Field(1, "chunk", Bytes) + Required("chunk") + }) + Service("ServiceClientStreamingRPCWithUnionPayload", func() { + Method("MethodClientStreamingRPCWithUnionPayload", func() { + Payload(func() { + Field(1, "repository_id", String) + Field(2, "version_ref", VersionRef) + Required("repository_id", "version_ref") + }) + StreamingPayload(UploadChunk) + Result(String) + GRPC(func() {}) + }) + }) +} + var ClientStreamingNoResultDSL = func() { Service("ServiceClientStreamingNoResult", func() { Method("MethodClientStreamingNoResult", func() { diff --git a/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-primitive-with-streaming-payload.go.golden b/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-primitive-with-streaming-payload.go.golden index 4223dd8e09..d5b5e567c3 100644 --- a/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-primitive-with-streaming-payload.go.golden +++ b/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-primitive-with-streaming-payload.go.golden @@ -3,28 +3,32 @@ // "MethodClientStreamingRPCWithPayload" endpoint. func DecodeMethodClientStreamingRPCWithPayloadRequest(ctx context.Context, v any, md metadata.MD) (any, error) { var ( - goaPayload int - err error + message *service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadRequest + ok bool ) { - if vals := md.Get("goa_payload"); len(vals) == 0 { - err = goa.MergeErrors(err, goa.MissingFieldError("goa_payload", "metadata")) - } else { - goaPayloadRaw := vals[0] - - v, err2 := strconv.ParseInt(goaPayloadRaw, 10, strconv.IntSize) - if err2 != nil { - err = goa.MergeErrors(err, goa.InvalidFieldTypeError("goaPayload", goaPayloadRaw, "integer")) + if v == nil { + return nil, goa.MissingFieldError("initial_payload", "stream") + } + var envelope *service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadStreamingRequest + if envelope, ok = v.(*service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadStreamingRequest); !ok { + return nil, goagrpc.ErrInvalidType("ServiceClientStreamingRPCWithPayload", "MethodClientStreamingRPCWithPayload", "*service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadStreamingRequest", v) + } + switch body := envelope.Body.(type) { + case *service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadStreamingRequest_InitialPayload: + if body.InitialPayload == nil { + return nil, goa.MissingFieldError("initial_payload", "stream") } - goaPayload = int(v) + message = body.InitialPayload + case *service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadStreamingRequest_StreamItem: + return nil, goa.InvalidFieldTypeError("body", "stream_item", "initial_payload") + default: + return nil, goa.MissingFieldError("initial_payload", "stream") } } - if err != nil { - return nil, err - } var payload int { - payload = goaPayload + payload = NewMethodClientStreamingRPCWithPayloadPayload(message) } return payload, nil } diff --git a/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-user-type-with-streaming-payload.go.golden b/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-user-type-with-streaming-payload.go.golden index 0dba83df8f..76cec748f1 100644 --- a/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-user-type-with-streaming-payload.go.golden +++ b/grpc/codegen/testdata/golden/request_decoder_request-decoder-payload-user-type-with-streaming-payload.go.golden @@ -3,31 +3,32 @@ // "MethodBidirectionalStreamingRPCWithPayload" endpoint. func DecodeMethodBidirectionalStreamingRPCWithPayloadRequest(ctx context.Context, v any, md metadata.MD) (any, error) { var ( - a *int - b *string - err error + message *service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadRequest + ok bool ) { - if vals := md.Get("a"); len(vals) > 0 { - aRaw := vals[0] - - v, err2 := strconv.ParseInt(aRaw, 10, strconv.IntSize) - if err2 != nil { - err = goa.MergeErrors(err, goa.InvalidFieldTypeError("a", aRaw, "integer")) - } - pv := int(v) - a = &pv + if v == nil { + return nil, goa.MissingFieldError("initial_payload", "stream") } - if vals := md.Get("b"); len(vals) > 0 { - b = &vals[0] + var envelope *service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadStreamingRequest + if envelope, ok = v.(*service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadStreamingRequest); !ok { + return nil, goagrpc.ErrInvalidType("ServiceBidirectionalStreamingRPCWithPayload", "MethodBidirectionalStreamingRPCWithPayload", "*service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadStreamingRequest", v) + } + switch body := envelope.Body.(type) { + case *service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadStreamingRequest_InitialPayload: + if body.InitialPayload == nil { + return nil, goa.MissingFieldError("initial_payload", "stream") + } + message = body.InitialPayload + case *service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadStreamingRequest_StreamItem: + return nil, goa.InvalidFieldTypeError("body", "stream_item", "initial_payload") + default: + return nil, goa.MissingFieldError("initial_payload", "stream") } - } - if err != nil { - return nil, err } var payload *servicebidirectionalstreamingrpcwithpayload.Payload { - payload = NewMethodBidirectionalStreamingRPCWithPayloadPayload(a, b) + payload = NewMethodBidirectionalStreamingRPCWithPayloadPayload(message) } return payload, nil } diff --git a/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-primitive-with-streaming-payload.go.golden b/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-primitive-with-streaming-payload.go.golden index 5e69c8b11f..e32ba39174 100644 --- a/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-primitive-with-streaming-payload.go.golden +++ b/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-primitive-with-streaming-payload.go.golden @@ -6,6 +6,10 @@ func EncodeMethodClientStreamingRPCWithPayloadRequest(ctx context.Context, v any if !ok { return nil, goagrpc.ErrInvalidType("ServiceClientStreamingRPCWithPayload", "MethodClientStreamingRPCWithPayload", "int", v) } - (*md).Append("goa_payload", fmt.Sprintf("%v", payload)) - return nil, nil + message := NewProtoMethodClientStreamingRPCWithPayloadRequest(payload) + return &service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadStreamingRequest{ + Body: &service_client_streaming_rpc_with_payloadpb.MethodClientStreamingRPCWithPayloadStreamingRequest_InitialPayload{ + InitialPayload: message, + }, + }, nil } diff --git a/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-user-type-with-streaming-payload.go.golden b/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-user-type-with-streaming-payload.go.golden index 4a754ae7e8..9424da2fbb 100644 --- a/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-user-type-with-streaming-payload.go.golden +++ b/grpc/codegen/testdata/golden/request_encoder_request-encoder-payload-user-type-with-streaming-payload.go.golden @@ -6,11 +6,10 @@ func EncodeMethodBidirectionalStreamingRPCWithPayloadRequest(ctx context.Context if !ok { return nil, goagrpc.ErrInvalidType("ServiceBidirectionalStreamingRPCWithPayload", "MethodBidirectionalStreamingRPCWithPayload", "*servicebidirectionalstreamingrpcwithpayload.Payload", v) } - if payload.A != nil { - (*md).Append("a", fmt.Sprintf("%v", *payload.A)) - } - if payload.B != nil { - (*md).Append("b", *payload.B) - } - return nil, nil + message := NewProtoMethodBidirectionalStreamingRPCWithPayloadRequest(payload) + return &service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadStreamingRequest{ + Body: &service_bidirectional_streaming_rpc_with_payloadpb.MethodBidirectionalStreamingRPCWithPayloadStreamingRequest_InitialPayload{ + InitialPayload: message, + }, + }, nil } diff --git a/grpc/codegen/testdata/golden/server_grpc_interface_bidirectional-streaming-rpc-with-payload.go.golden b/grpc/codegen/testdata/golden/server_grpc_interface_bidirectional-streaming-rpc-with-payload.go.golden index 80edfaf71d..ebfecd1ef8 100644 --- a/grpc/codegen/testdata/golden/server_grpc_interface_bidirectional-streaming-rpc-with-payload.go.golden +++ b/grpc/codegen/testdata/golden/server_grpc_interface_bidirectional-streaming-rpc-with-payload.go.golden @@ -6,7 +6,18 @@ func (s *Server) MethodBidirectionalStreamingRPCWithPayload(stream service_bidir ctx := stream.Context() ctx = context.WithValue(ctx, goa.MethodKey, "MethodBidirectionalStreamingRPCWithPayload") ctx = context.WithValue(ctx, goa.ServiceKey, "ServiceBidirectionalStreamingRPCWithPayload") - p, err := s.MethodBidirectionalStreamingRPCWithPayloadH.Decode(ctx, nil) + var reqpb any + message, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + reqpb = nil + } else { + return goagrpc.EncodeError(err) + } + } else { + reqpb = message + } + p, err := s.MethodBidirectionalStreamingRPCWithPayloadH.Decode(ctx, reqpb) if err != nil { return goagrpc.EncodeError(err) } diff --git a/grpc/codegen/testdata/golden/server_grpc_interface_client-streaming-rpc-with-payload.go.golden b/grpc/codegen/testdata/golden/server_grpc_interface_client-streaming-rpc-with-payload.go.golden index 0067772716..354d4edbf5 100644 --- a/grpc/codegen/testdata/golden/server_grpc_interface_client-streaming-rpc-with-payload.go.golden +++ b/grpc/codegen/testdata/golden/server_grpc_interface_client-streaming-rpc-with-payload.go.golden @@ -6,7 +6,18 @@ func (s *Server) MethodClientStreamingRPCWithPayload(stream service_client_strea ctx := stream.Context() ctx = context.WithValue(ctx, goa.MethodKey, "MethodClientStreamingRPCWithPayload") ctx = context.WithValue(ctx, goa.ServiceKey, "ServiceClientStreamingRPCWithPayload") - p, err := s.MethodClientStreamingRPCWithPayloadH.Decode(ctx, nil) + var reqpb any + message, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + reqpb = nil + } else { + return goagrpc.EncodeError(err) + } + } else { + reqpb = message + } + p, err := s.MethodClientStreamingRPCWithPayloadH.Decode(ctx, reqpb) if err != nil { return goagrpc.EncodeError(err) } diff --git a/grpc/docs/FAQ.md b/grpc/docs/FAQ.md index e7406106be..9442e0651e 100644 --- a/grpc/docs/FAQ.md +++ b/grpc/docs/FAQ.md @@ -50,6 +50,20 @@ type ArrayOfBool struct { } ``` +# How does goa encode methods that define both Payload and StreamingPayload? + +For gRPC, goa keeps ordinary method payload in the typed message channel. +When a method defines both `Payload(...)` and `StreamingPayload(...)`, goa +generates a streamed request envelope with two variants: + +- `initial_payload` carries the one-shot method payload once at stream setup. +- `stream_item` carries each `StreamingPayload` value after that. + +This keeps gRPC metadata reserved for explicit `GRPC.Metadata(...)` fields and +security attributes. It also means rich payload types such as objects, maps, +and unions are encoded with the normal protobuf message machinery instead of +being stringified into headers. + # How does goa handle the Any type in gRPC? Goa supports the `Any` type in gRPC by mapping it to `google.protobuf.Value`, which is specifically designed to represent dynamic JSON-like values. This is simpler and more efficient than using `google.protobuf.Any`.