Skip to content

Commit 22ff22f

Browse files
authored
codegen: support mixed unary + SSE results (#3883)
Support methods that define both Result and StreamingResult with distinct types. Generate HTTP content negotiation (JSON vs SSE), endpoint wrappers that accept a unified input struct, and a split client API for unary vs streaming. Ensure SSE codegen works in mixed results mode and add targeted tests.
1 parent 83a4b9f commit 22ff22f

22 files changed

Lines changed: 487 additions & 5 deletions

codegen/service/client_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func TestClient(t *testing.T) {
2525
{"client-no-payload", testdata.NoPayloadEndpointDSL, testdata.NoPayloadMethodsClient},
2626
{"client-with-result", testdata.WithResultEndpointDSL, testdata.WithResultMethodClient},
2727
{"client-streaming-result", testdata.StreamingResultMethodDSL, testdata.StreamingResultMethodClient},
28+
{"client-mixed-results", testdata.MixedResultsEndpointDSL, testdata.MixedResultsMethodClient},
2829
{"client-streaming-result-no-payload", testdata.StreamingResultNoPayloadMethodDSL, testdata.StreamingResultNoPayloadMethodClient},
2930
{"client-streaming-payload", testdata.StreamingPayloadMethodDSL, testdata.StreamingPayloadMethodClient},
3031
{"client-streaming-payload-no-payload", testdata.StreamingPayloadNoPayloadMethodDSL, testdata.StreamingPayloadNoPayloadMethodClient},

codegen/service/endpoint.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ type (
4444
// ArgName is the name of the argument used to initialize the client
4545
// struct method field.
4646
ArgName string
47+
// StreamArgName is the name of the argument used to initialize the client
48+
// struct stream endpoint field when the method defines mixed results.
49+
//
50+
// It is only set when HasMixedResults is true.
51+
StreamArgName string
4752
// ClientVarName is the corresponding client struct field name.
4853
ClientVarName string
4954
// ServiceName is the name of the owner service.
@@ -142,16 +147,24 @@ func EndpointFile(genpkg string, service *expr.ServiceExpr, services *ServicesDa
142147

143148
func endpointData(svc *Data) *EndpointsData {
144149
methods := make([]*EndpointMethodData, len(svc.Methods))
145-
names := make([]string, len(svc.Methods))
150+
argScope := codegen.NewNameScope()
151+
var names []string
146152
for i, m := range svc.Methods {
153+
argName := argScope.Unique(codegen.Goify(m.VarName, false), "")
154+
names = append(names, argName)
155+
streamArgName := ""
156+
if m.HasMixedResults {
157+
streamArgName = argScope.Unique(argName+"Stream", "")
158+
names = append(names, streamArgName)
159+
}
147160
methods[i] = &EndpointMethodData{
148161
MethodData: m,
149-
ArgName: codegen.Goify(m.VarName, false),
162+
ArgName: argName,
163+
StreamArgName: streamArgName,
150164
ServiceName: svc.Name,
151165
ServiceVarName: serviceInterfaceName,
152166
ClientVarName: clientStructName,
153167
}
154-
names[i] = codegen.Goify(m.VarName, false)
155168
}
156169
desc := fmt.Sprintf("%s wraps the %q service endpoints.", endpointsStructName, svc.Name)
157170
return &EndpointsData{

codegen/service/endpoint_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func TestEndpoint(t *testing.T) {
2626
{"endpoint-with-result", testdata.WithResultEndpointDSL, testdata.WithResultEndpoint},
2727
{"endpoint-with-result-multiple-views", testdata.WithResultMultipleViewsEndpointDSL, testdata.WithResultMultipleViewsEndpoint},
2828
{"endpoint-streaming-result", testdata.StreamingResultEndpointDSL, testdata.StreamingResultMethodEndpoint},
29+
{"endpoint-mixed-results", testdata.MixedResultsEndpointDSL, testdata.MixedResultsMethodEndpoint},
2930
{"endpoint-streaming-result-no-payload", testdata.StreamingResultNoPayloadEndpointDSL, testdata.StreamingResultNoPayloadMethodEndpoint},
3031
{"endpoint-streaming-result-with-views", testdata.StreamingResultWithViewsMethodDSL, testdata.StreamingResultWithViewsMethodEndpoint},
3132
{"endpoint-streaming-payload", testdata.StreamingPayloadEndpointDSL, testdata.StreamingPayloadMethodEndpoint},

codegen/service/service_data.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ type (
188188
// StreamKind is the kind of the stream (payload or result or
189189
// bidirectional).
190190
StreamKind expr.StreamKind
191+
// HasMixedResults indicates whether the method defines both Result and
192+
// StreamingResult with different types, enabling content negotiation at
193+
// the transport layer (e.g. JSON vs SSE over HTTP).
194+
HasMixedResults bool
191195
// SkipRequestBodyEncodeDecode is true if the method payload includes
192196
// the raw HTTP request body reader.
193197
SkipRequestBodyEncodeDecode bool
@@ -206,6 +210,13 @@ type (
206210
// struct to store the goa.Endpoint for this method. It is computed with a
207211
// scope that includes method names to avoid field/method name collisions.
208212
EndpointField string
213+
// StreamEndpointField is the unique field name used in the generated client
214+
// struct to store the "streaming mode" goa.Endpoint for mixed results. The
215+
// transport endpoint forces server streaming (e.g. sets "Accept:
216+
// text/event-stream") and returns the client stream interface.
217+
//
218+
// It is only set when HasMixedResults is true.
219+
StreamEndpointField string
209220
}
210221

211222
// StreamData is the data used to generate client and server interfaces that
@@ -871,6 +882,9 @@ func (d *ServicesData) analyze(service *expr.ServiceExpr) *Data {
871882
// existing method names.
872883
for _, m := range methods {
873884
m.EndpointField = scope.Unique(m.VarName+"Endpoint", "")
885+
if m.HasMixedResults {
886+
m.StreamEndpointField = scope.Unique(m.VarName+"StreamEndpoint", "")
887+
}
874888
}
875889

876890
// Collect union sum-type definitions for the service.
@@ -1263,6 +1277,7 @@ func (d *ServicesData) buildMethodData(m *expr.MethodExpr, scope *codegen.NameSc
12631277
Requirements: reqs,
12641278
Schemes: schemes,
12651279
StreamKind: m.Stream,
1280+
HasMixedResults: m.HasMixedResults(),
12661281
SkipRequestBodyEncodeDecode: skipRequestBodyEncodeDecode,
12671282
SkipResponseBodyEncodeDecode: skipResponseBodyEncodeDecode,
12681283
RequestStruct: vname + "RequestData",

codegen/service/service_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ func TestService(t *testing.T) {
4343
{"service-force-generate-type", testdata.ForceGenerateTypeDSL},
4444
{"service-force-generate-type-explicit", testdata.ForceGenerateTypeExplicitDSL},
4545
{"service-streaming-result", testdata.StreamingResultMethodDSL},
46+
{"service-mixed-results", testdata.MixedResultsEndpointDSL},
4647
{"service-streaming-result-with-views", testdata.StreamingResultWithViewsMethodDSL},
4748
{"service-streaming-result-with-explicit-view", testdata.StreamingResultWithExplicitViewMethodDSL},
4849
{"service-streaming-result-no-payload", testdata.StreamingResultNoPayloadMethodDSL},

codegen/service/templates/service.go.tpl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ type Service interface {
2525
{{- if .ServerStream }}
2626
{{- if and .IsJSONRPC (not .IsJSONRPCSSE) (eq .ServerStream.Kind 2) }}
2727
{{ .VarName }}(context.Context{{ if .Payload }}, {{ .PayloadRef }}{{ end }}) ({{ if .Result }}res {{ .ResultRef }}, {{ end }}err error)
28+
{{- else if .HasMixedResults }}
29+
{{- /* Mixed results: the method may be invoked in a unary (JSON) or streaming (SSE) mode.
30+
The server stream is non-nil only when the transport negotiates streaming. */}}
31+
{{ .VarName }}(context.Context{{ if .Payload }}, {{ .PayloadRef }}{{ end }}, {{ .ServerStream.Interface }}) ({{ if .Result }}res {{ .ResultRef }}, {{ end }}{{ if .ViewedResult }}{{ if not .ViewedResult.ViewName }}view string, {{ end }}{{ end }}err error)
2832
{{- else }}
2933
{{- if and .IsJSONRPC (not .IsJSONRPCSSE) (eq .ServerStream.Kind 3) .PayloadRef }}
3034
{{- /* JSON-RPC WebSocket server streaming with non-streaming payload */ -}}

codegen/service/templates/service_client.go.tpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,8 @@
22
type {{ .ClientVarName }} struct {
33
{{- range .Methods}}
44
{{ .EndpointField }} goa.Endpoint
5+
{{- if .HasMixedResults }}
6+
{{ .StreamEndpointField }} goa.Endpoint
7+
{{- end }}
58
{{- end }}
69
}

codegen/service/templates/service_client_init.go.tpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ func New{{ .ClientVarName }}({{ .ClientInitArgs }} goa.Endpoint{{ if .HasClientI
33
return &{{ .ClientVarName }}{
44
{{- range .Methods }}
55
{{ .EndpointField }}: {{ if .ClientInterceptors }}Wrap{{ .VarName }}ClientEndpoint({{ end }}{{ .ArgName }}{{ if .ClientInterceptors }}, ci){{ end }},
6+
{{- if .HasMixedResults }}
7+
{{ .StreamEndpointField }}: {{ if .ClientInterceptors }}Wrap{{ .VarName }}ClientEndpoint({{ end }}{{ .StreamArgName }}{{ if .ClientInterceptors }}, ci){{ end }},
8+
{{- end }}
69
{{- end }}
710
}
811
}

codegen/service/templates/service_client_method.go.tpl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,38 @@
77
{{- end }}
88
// - error: internal error
99
{{- end }}
10+
{{- if .HasMixedResults }}
11+
{{- $unaryResultType := .ResultRef }}
12+
func (c *{{ .ClientVarName }}) {{ .VarName }}(ctx context.Context{{ if .PayloadRef }}, p {{ .PayloadRef }}{{ end }}{{ if .MethodData.SkipRequestBodyEncodeDecode}}, req io.ReadCloser{{ end }}) ({{ if $unaryResultType }}res {{ $unaryResultType }}, {{ end }}{{ if .MethodData.SkipResponseBodyEncodeDecode }}resp io.ReadCloser, {{ end }}err error) {
13+
{{- if or $unaryResultType .MethodData.SkipResponseBodyEncodeDecode }}
14+
var ires any
15+
{{- end }}
16+
{{ if or $unaryResultType .MethodData.SkipResponseBodyEncodeDecode }}ires{{ else }}_{{ end }}, err = c.{{ .EndpointField }}(ctx, {{ if .MethodData.SkipRequestBodyEncodeDecode }}&{{ .RequestStruct }}{ {{ if .PayloadRef }}Payload: p, {{ end }}Body: req }{{ else if .PayloadRef }}p{{ else }}nil{{ end }})
17+
{{- if not (or $unaryResultType .MethodData.SkipResponseBodyEncodeDecode) }}
18+
return
19+
{{- else }}
20+
if err != nil {
21+
return
22+
}
23+
{{- if .MethodData.SkipResponseBodyEncodeDecode }}
24+
o := ires.(*{{ .MethodData.ResponseStruct }})
25+
return {{ if .ResultRef }}o.Result, {{ end }}o.Body, nil
26+
{{- else }}
27+
return ires.({{ $unaryResultType }}), nil
28+
{{- end }}
29+
{{- end }}
30+
}
31+
32+
{{ printf "%sStream calls the %q endpoint of the %q service with server streaming enabled." .VarName .Name .ServiceName | comment }}
33+
func (c *{{ .ClientVarName }}) {{ .VarName }}Stream(ctx context.Context{{ if .PayloadRef }}, p {{ .PayloadRef }}{{ end }}{{ if .MethodData.SkipRequestBodyEncodeDecode}}, req io.ReadCloser{{ end }}) (res {{ .ClientStream.Interface }}, err error) {
34+
var ires any
35+
ires, err = c.{{ .StreamEndpointField }}(ctx, {{ if .MethodData.SkipRequestBodyEncodeDecode }}&{{ .RequestStruct }}{ {{ if .PayloadRef }}Payload: p, {{ end }}Body: req }{{ else if .PayloadRef }}p{{ else }}nil{{ end }})
36+
if err != nil {
37+
return
38+
}
39+
return ires.({{ .ClientStream.Interface }}), nil
40+
}
41+
{{- else }}
1042
{{- $resultType := .ResultRef }}
1143
{{- if .ClientStream }}
1244
{{- /* When a client stream exists, always return it from the client method. */ -}}
@@ -31,3 +63,4 @@ func (c *{{ .ClientVarName }}) {{ .VarName }}(ctx context.Context{{ if .PayloadR
3163
{{- end }}
3264
{{- end }}
3365
}
66+
{{- end }}

codegen/service/templates/service_endpoint_method.go.tpl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,24 @@ func New{{ .VarName }}Endpoint(s {{ .ServiceVarName }}{{ range .Schemes.DedupeBy
120120

121121
{{- if .ServerStream }}
122122
{{- if .ServerStream.EndpointStruct }}
123+
{{- if .HasMixedResults }}
124+
res, {{ if .ViewedResult }}{{ if not .ViewedResult.ViewName }}view, {{ end }}{{ end }}err := s.{{ .VarName }}(ctx, {{ if .PayloadRef }}{{ $payload }}, {{ end }}ep.Stream)
125+
if err != nil {
126+
return nil, err
127+
}
128+
{{- if .ViewedResult }}
129+
{{- if .ViewedResult.ViewName }}
130+
vres := {{ $.ViewedResult.Init.Name }}(res, {{ printf "%q" .ViewedResult.ViewName }})
131+
{{- else }}
132+
vres := {{ $.ViewedResult.Init.Name }}(res, view)
133+
{{- end }}
134+
return vres, nil
135+
{{- else }}
136+
return res, nil
137+
{{- end }}
138+
{{- else }}
123139
return nil, s.{{ .VarName }}(ctx, {{ if .PayloadRef }}{{ $payload }}, {{ end }}ep.Stream)
140+
{{- end }}
124141
{{- else }}
125142
{{- /* JSON-RPC WebSocket client streaming: no stream parameter, just payload */ -}}
126143
{{- if .PayloadRef }}

0 commit comments

Comments
 (0)