Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dsl/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ func Message(fn func()) {
// request metadata unless specified explicitly in request message using
// Message function. All other attributes in method payload are added to the
// request message unless specified explicitly using Metadata (in which case
// will be added to the metadata).
// will be added to the metadata). For methods that also define
// StreamingPayload, the ordinary request message is carried as the initial
// typed stream frame rather than being rewritten into metadata.
//
// Metadata takes one argument of function type which lists the attributes
// that must be set in the request metadata instead of the message.
Expand Down
3 changes: 3 additions & 0 deletions dsl/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func Payload(val any, args ...any) {
// StreamingPayload requires a transport that supports client-to-server streaming
// such as gRPC or WebSockets. When using HTTP or JSON-RPC transports, methods
// with StreamingPayload must use WebSockets (via GET endpoints).
// For gRPC methods that define both Payload and StreamingPayload, the ordinary
// method payload is sent once as the initial typed stream frame and the
// StreamingPayload values are sent as subsequent stream item frames.
//
// Examples:
//
Expand Down
17 changes: 1 addition & 16 deletions expr/grpc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,6 @@ func (e *GRPCEndpointExpr) Finalize() {
}
}

// If endpoint defines streaming payload, then add the attributes in method
// payload type to request metadata.
if e.MethodExpr.StreamingPayload.Type != Empty {
for _, nat := range *pobj {
addToMetadata(nat.Name, "")
}
}

// msgObj contains only the attributes in the method payload that must
// be added to the request message type after removing attributes
// specified in the request metadata.
Expand Down Expand Up @@ -387,14 +379,7 @@ func (e *GRPCEndpointExpr) Finalize() {
}
} else {
// method payload is not an object type.
if e.MethodExpr.StreamingPayload.Type != Empty {
// endpoint defines streaming payload. So add the method payload to
// request metadata under "goa-payload" field
e.Metadata.Type.(*Object).Set("goa_payload", e.MethodExpr.Payload)
e.Metadata.Validation.AddRequired("goa_payload")
} else {
initAttrFromDesign(e.Request, e.MethodExpr.Payload)
}
initAttrFromDesign(e.Request, e.MethodExpr.Payload)
}

// Finalize streaming payload type if defined
Expand Down
16 changes: 16 additions & 0 deletions expr/grpc_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"errors"
"testing"

"github.com/stretchr/testify/require"

"goa.design/goa/v3/eval"
"goa.design/goa/v3/expr"
"goa.design/goa/v3/expr/testdata"
Expand Down Expand Up @@ -84,3 +86,17 @@ service "Service" method "MethodUnion": union type choice has map elements, not
})
}
}

func TestGRPCEndpointStreamingPayloadKeepsInitialRequest(t *testing.T) {
root := expr.RunDSL(t, testdata.GRPCEndpointWithStreamingPayloadInitialRequest)
grpcSvc := root.API.GRPC.Service("Service")
require.NotNil(t, grpcSvc)
require.Len(t, grpcSvc.GRPCEndpoints, 1)

endpoint := grpcSvc.GRPCEndpoints[0]
req := expr.AsObject(endpoint.Request.Type)
require.NotNil(t, req)
require.NotNil(t, req.Attribute("repository_id"))
require.NotNil(t, req.Attribute("version_ref"))
require.True(t, endpoint.Metadata.IsEmpty())
}
25 changes: 25 additions & 0 deletions expr/testdata/endpoint_dsls.go
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,28 @@ var GRPCEndpointWithUnionContainingAny = func() {
})
})
}

var GRPCEndpointWithStreamingPayloadInitialRequest = func() {
var VersionRef = Type("VersionRef", func() {
OneOf("ref_type", func() {
Field(1, "version_id", String)
Field(2, "ref_name", String)
})
Required("ref_type")
})
var UploadChunk = Type("UploadChunk", func() {
Field(1, "chunk", Bytes)
Required("chunk")
})
Service("Service", func() {
Method("Method", func() {
Payload(func() {
Field(1, "repository_id", String)
Field(2, "version_ref", VersionRef)
Required("repository_id", "version_ref")
})
StreamingPayload(UploadChunk)
GRPC(func() {})
})
})
}
6 changes: 6 additions & 0 deletions grpc/codegen/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ func serverFile(genpkg string, svc *expr.GRPCServiceExpr, services *ServicesData
{Path: path.Join(genpkg, svcName, "views"), Name: data.Service.ViewsPkg},
{Path: path.Join(genpkg, "grpc", svcName, pbPkgName), Name: data.PkgName},
}
for _, e := range data.Endpoints {
if e.Request.StreamEnvelope != nil {
imports = append(imports, &codegen.ImportSpec{Path: "io"})
break
}
}
sections = []*codegen.SectionTemplate{
codegen.Header(svc.Name()+" gRPC server", "server", imports),
{
Expand Down
133 changes: 118 additions & 15 deletions grpc/codegen/service_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,17 @@ type (
RequestData struct {
// Description is the request description.
Description string
// Message is the gRPC request message.
// Message is the gRPC request message used by the transport. For
// streaming payload methods with an initial payload frame, this is the
// synthesized stream envelope.
Message *service.UserTypeData
// PayloadMessage is the gRPC message that carries the one-shot method
// payload fields before any stream envelope wrapping.
PayloadMessage *service.UserTypeData
// StreamEnvelope describes the synthesized stream envelope when the
// transport must carry both the one-shot payload and streaming payload
// items through the same streamed protobuf message.
StreamEnvelope *StreamEnvelopeData
// Metadata is the request metadata.
Metadata []*MetadataData
// ServerConvert is the request data with constructor function to
Expand All @@ -195,6 +204,23 @@ type (
CLIArgs []*InitArgData
}

// StreamEnvelopeData describes a synthesized streamed protobuf envelope.
StreamEnvelopeData struct {
// FieldName is the protobuf oneof field name on the envelope message.
FieldName string
// InitialFieldName is the name of the initial payload branch field.
InitialFieldName string
// InitialWrapperRef is the fully qualified protobuf wrapper type for the
// initial payload branch.
InitialWrapperRef string
// StreamItemFieldName is the name of the streaming payload item branch
// field.
StreamItemFieldName string
// StreamItemWrapperRef is the fully qualified protobuf wrapper type for
// the streaming payload item branch.
StreamItemWrapperRef string
}

// ResponseData describes a gRPC success or error response.
ResponseData struct {
// StatusCode is the return code of the response.
Expand Down Expand Up @@ -462,10 +488,26 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
}
seen, imported := make(map[string]struct{}), make(map[string]struct{})
for _, e := range gs.GRPCEndpoints {
hasRequestMessage := !isEmpty(e.Request.Type)
useStreamEnvelope := usesStreamEnvelope(e)

// convert request and response types to protocol buffer message types
e.Request = makeProtoBufMessage(e.Request, protoBufify(e.Name()+"_request", true, true), sd)
if e.MethodExpr.StreamingPayload.Type != expr.Empty {
e.StreamingRequest = makeProtoBufMessage(e.StreamingRequest, protoBufify(e.Name()+"_streaming_request", true, true), sd)
streamMessageName := protoBufify(e.Name()+"_streaming_request", true, true)
if useStreamEnvelope {
streamMessageName = protoBufify(e.Name()+"_stream_item", true, true)
}
e.StreamingRequest = makeProtoBufMessage(e.StreamingRequest, streamMessageName, sd)
}
var requestEnvelope *expr.AttributeExpr
if useStreamEnvelope {
requestEnvelope = makeProtoBufStreamEnvelope(
e.Request,
e.StreamingRequest,
protoBufify(e.Name()+"_streaming_request", true, true),
sd,
)
}
e.Response.Message = makeProtoBufMessage(e.Response.Message, protoBufify(e.Name()+"_response", true, true), sd)
for _, er := range e.GRPCErrors {
Expand Down Expand Up @@ -540,6 +582,9 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
ServerConvert: d.buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, true),
ClientConvert: d.buildRequestConvertData(e.Request, e.MethodExpr.Payload, reqMD, e, sd, false),
}
if hasRequestMessage {
request.PayloadMessage = collect(e.Request)
}
if obj := expr.AsObject(e.Request.Type); (obj != nil && len(*obj) > 0) || expr.IsUnion(e.Request.Type) {
// add the request message as the first argument to the CLI
request.CLIArgs = append(request.CLIArgs, &InitArgData{
Expand Down Expand Up @@ -567,9 +612,13 @@ func (d *ServicesData) analyze(gs *expr.GRPCServiceExpr) *ServiceData {
DefaultValue: m.DefaultValue,
})
}
if e.StreamingRequest.Type != expr.Empty {
switch {
case requestEnvelope != nil:
request.Message = collect(requestEnvelope)
request.StreamEnvelope = buildStreamEnvelopeData(requestEnvelope, request.Message, sd)
case e.StreamingRequest.Type != expr.Empty:
request.Message = collect(e.StreamingRequest)
} else {
default:
request.Message = collect(e.Request)
}

Expand Down Expand Up @@ -872,20 +921,20 @@ func userTypeAttribute(ut expr.UserType) *expr.AttributeExpr {

// buildRequestConvertData builds the convert data for the server and client
// requests.
// - server side - converts generated gRPC request type in *.pb.go and the
// gRPC metadata to method payload type.
// - client side - converts method payload type to generated gRPC request
// type in *.pb.go.
// - server side - converts the one-shot gRPC request message (if any) and
// gRPC metadata to the method payload type.
// - client side - converts the method payload type to the one-shot gRPC
// request message sent before any stream items.
//
// svr param indicates that the convert data is generated for server side.
func (d *ServicesData) buildRequestConvertData(request, payload *expr.AttributeExpr, md []*MetadataData, e *expr.GRPCEndpointExpr, sd *ServiceData, svr bool) *ConvertData {
// Server-side: No need to build convert data if payload is empty or payload
// is not an object type and endpoint streams payload (the payload is
// encoded in metadata under "goa-payload" in this case).
if (svr && (isEmpty(payload.Type) || !expr.IsObject(payload.Type) && e.MethodExpr.IsPayloadStreaming())) ||
// Client-side: No need to build convert data if streaming payload since
// all attributes in method payload is encoded into request metadata.
(!svr && e.MethodExpr.IsPayloadStreaming()) {
if svr && isEmpty(payload.Type) {
return nil
}
if !svr && e.MethodExpr.IsPayloadStreaming() && isEmpty(request.Type) {
return nil
}
if svr && e.MethodExpr.IsPayloadStreaming() && isEmpty(request.Type) && !expr.IsObject(payload.Type) {
return nil
}

Expand Down Expand Up @@ -1297,6 +1346,60 @@ func extractMetadata(a *expr.MappedAttributeExpr, service *expr.AttributeExpr, s
return metadata
}

// usesStreamEnvelope reports whether the transport needs a typed stream
// envelope to carry both the one-shot method payload and streaming payload
// items.
func usesStreamEnvelope(e *expr.GRPCEndpointExpr) bool {
return e.MethodExpr.IsPayloadStreaming() && !isEmpty(e.Request.Type)
}

// makeProtoBufStreamEnvelope builds the protobuf stream envelope that carries
// the initial request payload frame and subsequent stream item frames.
func makeProtoBufStreamEnvelope(request, stream *expr.AttributeExpr, tname string, sd *ServiceData) *expr.AttributeExpr {
initial := expr.DupAtt(request)
initial.Meta = initial.Meta.Dup()
initial.Meta["rpc:tag"] = []string{"1"}
streamItem := expr.DupAtt(stream)
streamItem.Meta = streamItem.Meta.Dup()
streamItem.Meta["rpc:tag"] = []string{"2"}
envelope := &expr.AttributeExpr{
Type: &expr.Object{
&expr.NamedAttributeExpr{
Name: "body",
Attribute: &expr.AttributeExpr{
Type: &expr.Union{
TypeName: "body",
Values: []*expr.NamedAttributeExpr{
{Name: "initial_payload", Attribute: initial},
{Name: "stream_item", Attribute: streamItem},
},
},
},
},
},
Validation: &expr.ValidationExpr{Required: []string{"body"}},
}
return makeProtoBufMessage(envelope, tname, sd)
}

// buildStreamEnvelopeData computes the generated Go names for the protobuf
// oneof field and wrapper types of the synthesized stream envelope.
func buildStreamEnvelopeData(envelope *expr.AttributeExpr, message *service.UserTypeData, sd *ServiceData) *StreamEnvelopeData {
body := envelope.Find("body")
union := expr.AsUnion(body.Type)
scope := &protoBufScope{scope: sd.Scope}
fieldName := scope.Field(body, union.TypeName, true)
initialFieldName := scope.Field(union.Values[0].Attribute, union.Values[0].Name, true)
streamItemFieldName := scope.Field(union.Values[1].Attribute, union.Values[1].Name, true)
return &StreamEnvelopeData{
FieldName: fieldName,
InitialFieldName: initialFieldName,
InitialWrapperRef: fmt.Sprintf("%s.%s_%s", sd.PkgName, message.VarName, initialFieldName),
StreamItemFieldName: streamItemFieldName,
StreamItemWrapperRef: fmt.Sprintf("%s.%s_%s", sd.PkgName, message.VarName, streamItemFieldName),
}
}

func unalias(att *expr.AttributeExpr) *expr.AttributeExpr {
if ut, ok := att.Type.(expr.UserType); ok {
if _, ok := ut.Attribute().Type.(expr.Primitive); ok {
Expand Down
39 changes: 39 additions & 0 deletions grpc/codegen/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"goa.design/goa/v3/codegen"
"goa.design/goa/v3/grpc/codegen/testdata"
Expand Down Expand Up @@ -148,3 +149,41 @@ func TestStreaming(t *testing.T) {
})
}
}

func TestStreamingPayloadEnvelopeWithUnionPayload(t *testing.T) {
root := RunGRPCDSL(t, testdata.ClientStreamingRPCWithUnionPayloadDSL)
services := CreateGRPCServices(root)

clientfs := ClientFiles("", services)
require.Len(t, clientfs, 2)
serverfs := ServerFiles("", services)
require.Len(t, serverfs, 2)
protofs := ProtoFiles("", services)
require.Len(t, protofs, 1)

requestEncoder := codegen.SectionsCode(t, clientfs[1].Section("request-encoder"))
assert.Contains(t, requestEncoder, "InitialPayload")
assert.Contains(t, requestEncoder, "MethodClientStreamingRPCWithUnionPayloadStreamingRequest")

clientSend := codegen.SectionsCode(t, clientfs[0].Section("client-stream-send"))
assert.Contains(t, clientSend, "StreamItem")
assert.Contains(t, clientSend, "UploadChunk")

serverInterface := codegen.SectionsCode(t, serverfs[0].Section("server-grpc-interface"))
assert.Contains(t, serverInterface, "message, err := stream.Recv()")
assert.Contains(t, serverInterface, "Decode(ctx, reqpb)")

requestDecoder := codegen.SectionsCode(t, serverfs[1].Section("request-decoder"))
assert.Contains(t, requestDecoder, "InitialPayload")
assert.Contains(t, requestDecoder, "stream_item")
assert.Contains(t, requestDecoder, "NewMethodClientStreamingRPCWithUnionPayloadPayload(message)")

proto := sectionCode(t, protofs[0].SectionTemplates[1:]...)
assert.Contains(t, proto, "message MethodClientStreamingRPCWithUnionPayloadStreamingRequest")
assert.Contains(t, proto, "oneof body")
assert.Contains(t, proto, "MethodClientStreamingRPCWithUnionPayloadRequest initial_payload")
assert.Contains(t, proto, "MethodClientStreamingRPCWithUnionPayloadStreamItem stream_item")

fpath := codegen.CreateTempFile(t, proto)
assert.NoError(t, protoc(defaultProtocCmd, fpath, nil))
}
21 changes: 17 additions & 4 deletions grpc/codegen/templates/remote_method_builder.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,22 @@ func Build{{ .Method.VarName }}Func(grpccli {{ .PkgName }}.{{ .ClientInterface }
for _, opt := range cliopts {
opts = append(opts, opt)
}
if reqpb != nil {
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, reqpb.({{ .Request.ClientConvert.TgtRef }}){{ end }}, opts...)
}
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, &{{ .Request.ClientConvert.TgtName }}{}{{ end }}, opts...)
{{- if .Request.StreamEnvelope }}
stream, err := grpccli.{{ .ClientMethodName }}(ctx, opts...)
if err != nil {
return nil, err
}
if reqpb != nil {
if err := stream.Send(reqpb.({{ .Request.Message.Ref }})); err != nil {
return nil, err
}
}
return stream, nil
{{- else }}
if reqpb != nil {
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, reqpb.({{ .Request.ClientConvert.TgtRef }}){{ end }}, opts...)
}
return grpccli.{{ .ClientMethodName }}(ctx{{ if not .Method.StreamingPayload }}, &{{ .Request.ClientConvert.TgtName }}{}{{ end }}, opts...)
{{- end }}
}
}
Loading
Loading