From 656bde15181e05727f7763f54f8b0901db3f6afc Mon Sep 17 00:00:00 2001 From: Raphael Simon Date: Thu, 10 Apr 2025 17:28:41 -0700 Subject: [PATCH 1/4] Initial work on adding SSE support Added DSL and expressions. --- dsl/sse.go | 276 ++++++++++++++++++++++++++++++++++++++++++ eval/eval_test.go | 118 ++++++++++++++++++ expr/http.go | 3 + expr/http_endpoint.go | 30 +++++ expr/http_service.go | 3 + expr/http_sse.go | 103 ++++++++++++++++ expr/http_sse_test.go | 175 ++++++++++++++++++++++++++ 7 files changed, 708 insertions(+) create mode 100644 dsl/sse.go create mode 100644 expr/http_sse.go create mode 100644 expr/http_sse_test.go diff --git a/dsl/sse.go b/dsl/sse.go new file mode 100644 index 0000000000..66f752e70e --- /dev/null +++ b/dsl/sse.go @@ -0,0 +1,276 @@ +package dsl + +import ( + "goa.design/goa/v3/eval" + "goa.design/goa/v3/expr" +) + +// ServerSentEvents specifies that a streaming endpoint should use the +// Server-Sent Events protocol for streaming instead of WebSockets. It can be +// used in three ways: +// +// 1. ServerSentEvents() - StreamingResult type is used directly as the "data" field +// 2. ServerSentEvents("attributeName") - The specified attribute is used as the "data" field +// 3. ServerSentEvents(func() { ... }) - Custom mapping of attributes to SSE fields +// +// ServerSentEvents can appear in an API HTTP expression (to specify SSE for all streaming +// methods in the API), in a Service HTTP expression (to specify SSE for all streaming +// methods in the service), or in a Method HTTP expression. When specified at the +// API or service level, any method with a StreamingPayload will fall back to using WebSockets +// as SSE only supports server-to-client streaming. +// +// See SSEData, SSEID, SSEType, SSERetry for more details on mapping result attributes +// to SSE fields. +// +// Example: +// +// var Notification = Type("Notification", func() { +// Attribute("message", String)message +// Attribute("timestamp", String) +// Required("message", "timestamp") +// }) +// +// var _ = Service("events", func() { +// HTTP(func() { +// ServerSentEvents() // All streaming methods in this service use SSE by default +// }) +// +// // Simple method with just data field +// Method("stream", func() { +// StreamingResult(Notification) +// HTTP(func() { +// GET("/events") // Messages are sent as {"data": {"message": , "timestamp": }} +// }) +// }) +// }) +// +// var _ = Service("other", func() { +// // Method using WebSockets +// Method("stream", func() { +// StreamingResult(Notification) +// HTTP(func() { +// GET("/websocket") +// }) +// }) +// +// // Method using SSE +// Method("stream", func() { +// Payload(func() { +// Attribute("id", String) +// }) +// StreamingResult(Notification) +// HTTP(func() { +// ServerSentEvents(func() { // Use SSE for this method +// SSERequestID("id") // Use payload "id" field to set "Last-Event-Id" request header +// SSEEventID("timestamp") // Use result "timestamp" attribute for "id" event field +// SSEEventData("message") // Use result "message" attribute for "data" event field +// }) +// GET("/sse") // Messages are sent as {"id": , "data": } +// }) +// }) +// }) +func ServerSentEvents(val any) { + var fn func() + var dataField string + + switch actual := val.(type) { + case func(): + fn = actual + case string: + dataField = actual + case nil: + // Use the entire result as data field + default: + eval.InvalidArgError("function or string", val) + return + } + + sse := &expr.HTTPSSEExpr{ + DataField: dataField, + } + + switch actual := eval.Current().(type) { + case *expr.HTTPExpr: + actual.SSE = sse + case *expr.HTTPServiceExpr: + actual.SSE = sse + case *expr.HTTPEndpointExpr: + actual.SSE = sse + default: + eval.IncompatibleDSL() + } + + if fn != nil { + eval.Execute(fn, sse) + } +} + +// SSERequestID defines the attribute of the Payload type that provides the +// Last-Event-ID request header value. The attribute must exist in the Payload +// type and must be of type String. +// +// SSERequestID must appear in a `ServerSentEvents` expression. +// +// SSERequestID accepts a single argument: the name of the attribute of the +// Payload type that provides the Last-Event-ID request header value. +// +// Example: +// +// Method("stream", func() { +// Payload(func() { +// Attribute("id", String) +// }) +// StreamingResult(Notification) +// HTTP(func() { +// GET("/events") +// ServerSentEvents(func() { // Use SSE for this method +// SSERequestID("id") // Use payload "id" field to set "Last-Event-Id" request header +// SSEEventID("timestamp") // Use result "timestamp" attribute for "id" event field +// SSEEventData("message") // Use result "message" attribute for "data" event field +// }) +// }) +// }) +func SSERequestID(name string) { + if name == "" { + eval.ReportError("request ID field name cannot be empty") + return + } + sse, ok := eval.Current().(*expr.HTTPSSEExpr) + if !ok { + eval.IncompatibleDSL() + return + } + sse.RequestIDField = name +} + +// SSEEventData defines the attribute of the StreamingResult type that provides the +// data field for a Server-Sent Event. The attribute must exist in the +// StreamingResult type. +// +// SSEEventData must appear in a `ServerSentEvents` expression. +// +// SSEEventData accepts a single argument: the name of the attribute of the +// StreamingResult type that provides the data field for a Server-Sent Event. +// +// Example: +// +// Method("stream", func() { +// StreamingResult(Payload) +// HTTP(func() { +// GET("/events") +// ServerSentEvents(func() { +// SSEEventData("message") // Use payload "message" attribute for SSE data field, other attributes are ignored +// }) +// }) +// }) +func SSEEventData(name string) { + if name == "" { + eval.ReportError("data field name cannot be empty") + return + } + sse, ok := eval.Current().(*expr.HTTPSSEExpr) + if !ok { + eval.IncompatibleDSL() + return + } + sse.DataField = name +} + +// SSEEventID defines the attribute of the StreamingResult type that provides the +// id field for a Server-Sent Event. The attribute must exist in the +// StreamingResult type and must be of type String. +// +// SSEEventID must appear in a `ServerSentEvents` expression. +// +// SSEEventID accepts a single argument: the name of the attribute of the +// StreamingResult type that provides the id field for a Server-Sent Event. +// +// Example: +// +// Method("stream", func() { +// StreamingResult(Payload) +// HTTP(func() { +// GET("/events") +// ServerSentEvents(func() { +// SSEEventID("timestamp") // Use "timestamp" attribute for SSE id field +// }) +// }) +// }) +func SSEEventID(name string) { + if name == "" { + eval.ReportError("id field name cannot be empty") + return + } + sse, ok := eval.Current().(*expr.HTTPSSEExpr) + if !ok { + eval.IncompatibleDSL() + return + } + sse.IDField = name +} + +// SSEEventType defines the attribute of the StreamingResult type that provides the +// event field (event type) for a Server-Sent Event. The attribute must exist in the +// StreamingResult type and must be of type String. +// +// SSEEventType must appear in a `ServerSentEvents` expression. +// +// SSEEventType accepts a single argument: the name of the attribute of the +// StreamingResult type that provides the event field for a Server-Sent Event. +// +// Example: +// +// Method("stream", func() { +// StreamingResult(Payload) +// HTTP(func() { +// GET("/events") +// ServerSentEvents(func() { +// SSEEventType("type") // Use payload "type" attribute for SSE event field +// }) +// }) +// }) +func SSEEventType(name string) { + if name == "" { + eval.ReportError("event field name cannot be empty") + return + } + sse, ok := eval.Current().(*expr.HTTPSSEExpr) + if !ok { + eval.IncompatibleDSL() + return + } + sse.EventField = name +} + +// SSEEventRetry defines the attribute of the StreamingResult type that provides +// the retry field for a Server-Sent Event. The attribute must exist in the +// StreamingResult type and must be of type Int or UInt. +// +// SSEEventRetry must appear in a `ServerSentEvents` expression. +// +// SSEEventRetry accepts a single argument: the name of the attribute of the +// StreamingResult type that provides the retry field for a Server-Sent Event. +// +// Example: +// +// Method("stream", func() { +// StreamingResult(Notification) +// HTTP(func() { +// GET("/events") +// ServerSentEvents(func() { +// SSEEventRetry("retry") // Use "retry" attribute for SSE retry field +// }) +// }) +// }) +func SSEEventRetry(name string) { + if name == "" { + eval.ReportError("retry field name cannot be empty") + return + } + sse, ok := eval.Current().(*expr.HTTPSSEExpr) + if !ok { + eval.IncompatibleDSL() + return + } + sse.RetryField = name +} diff --git a/eval/eval_test.go b/eval/eval_test.go index d334fec1da..25ffa2d836 100644 --- a/eval/eval_test.go +++ b/eval/eval_test.go @@ -1,11 +1,14 @@ package eval_test import ( + "errors" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" . "goa.design/goa/v3/dsl" + "goa.design/goa/v3/eval" "goa.design/goa/v3/expr" ) @@ -114,3 +117,118 @@ func TestTooManyArgError(t *testing.T) { }) } } + +// mockExpr is a test implementation of the Expression interface +type mockExpr struct{} + +func (m *mockExpr) EvalName() string { + return "MockExpression" +} + +// TestValidationErrors tests the ValidationErrors type methods +func TestValidationErrors(t *testing.T) { + + t.Run("Error", func(t *testing.T) { + // Test Error() method + expr1 := &mockExpr{} + expr2 := &mockExpr{} + + verr := &eval.ValidationErrors{ + Errors: []error{errors.New("error1"), errors.New("error2")}, + Expressions: []eval.Expression{expr1, expr2}, + } + + errStr := verr.Error() + assert.Contains(t, errStr, "MockExpression: error1") + assert.Contains(t, errStr, "MockExpression: error2") + assert.Contains(t, errStr, "\n") // Should contain a newline between errors + }) + + t.Run("Merge", func(t *testing.T) { + // Test Merge() method + expr1 := &mockExpr{} + expr2 := &mockExpr{} + expr3 := &mockExpr{} + + verr1 := &eval.ValidationErrors{ + Errors: []error{errors.New("error1")}, + Expressions: []eval.Expression{expr1}, + } + + verr2 := &eval.ValidationErrors{ + Errors: []error{errors.New("error2"), errors.New("error3")}, + Expressions: []eval.Expression{expr2, expr3}, + } + + // Test merging with nil + verrCopy := *verr1 + verrCopy.Merge(nil) + assert.Equal(t, 1, len(verrCopy.Errors)) + assert.Equal(t, 1, len(verrCopy.Expressions)) + + // Test merging with empty but non-nil ValidationErrors + verrCopy2 := *verr1 + emptyVerr := &eval.ValidationErrors{} + verrCopy2.Merge(emptyVerr) + assert.Equal(t, 1, len(verrCopy2.Errors), "Merging with empty ValidationErrors should not change the error count") + assert.Equal(t, 1, len(verrCopy2.Expressions), "Merging with empty ValidationErrors should not change the expression count") + + // Test normal merge + verr1.Merge(verr2) + assert.Equal(t, 3, len(verr1.Errors)) + assert.Equal(t, 3, len(verr1.Expressions)) + assert.Equal(t, expr1, verr1.Expressions[0]) + assert.Equal(t, expr2, verr1.Expressions[1]) + assert.Equal(t, expr3, verr1.Expressions[2]) + }) + + t.Run("Add", func(t *testing.T) { + // Test Add() method + expr1 := &mockExpr{} + + verr := &eval.ValidationErrors{} + verr.Add(expr1, "test error %s", "message") + + require.Equal(t, 1, len(verr.Errors)) + require.Equal(t, 1, len(verr.Expressions)) + assert.Equal(t, "test error message", verr.Errors[0].Error()) + assert.Equal(t, expr1, verr.Expressions[0]) + }) + + t.Run("AddError_Simple", func(t *testing.T) { + // Test AddError() with a simple error + expr1 := &mockExpr{} + simpleErr := errors.New("simple error") + + verr := &eval.ValidationErrors{} + verr.AddError(expr1, simpleErr) + + require.Equal(t, 1, len(verr.Errors)) + require.Equal(t, 1, len(verr.Expressions)) + assert.Equal(t, simpleErr, verr.Errors[0]) + assert.Equal(t, expr1, verr.Expressions[0]) + }) + + t.Run("AddError_ValidationErrors", func(t *testing.T) { + // Test AddError() with another ValidationErrors + expr1 := &mockExpr{} + expr2 := &mockExpr{} + expr3 := &mockExpr{} + + nestedVerr := &eval.ValidationErrors{ + Errors: []error{errors.New("nested error 1"), errors.New("nested error 2")}, + Expressions: []eval.Expression{expr2, expr3}, + } + + verr := &eval.ValidationErrors{} + verr.AddError(expr1, nestedVerr) // expr1 should be ignored since we're adding a ValidationErrors + + // The nested ValidationErrors should be flattened + require.Equal(t, 2, len(verr.Errors)) + require.Equal(t, 2, len(verr.Expressions)) + assert.Equal(t, "nested error 1", verr.Errors[0].Error()) + assert.Equal(t, "nested error 2", verr.Errors[1].Error()) + assert.Equal(t, expr2, verr.Expressions[0]) + assert.Equal(t, expr3, verr.Expressions[1]) + }) +} diff --git a/expr/http.go b/expr/http.go index 557adb64e4..c59b0330b1 100644 --- a/expr/http.go +++ b/expr/http.go @@ -29,6 +29,9 @@ type ( Services []*HTTPServiceExpr // Errors lists the error HTTP responses. Errors []*HTTPErrorExpr + // SSE contains the Server-Sent Events configuration for all + // streaming endpoints in the API. + SSE *HTTPSSEExpr } ) diff --git a/expr/http_endpoint.go b/expr/http_endpoint.go index bc6c92a679..6dc4f46941 100644 --- a/expr/http_endpoint.go +++ b/expr/http_endpoint.go @@ -66,6 +66,10 @@ type ( MultipartRequest bool // Redirect defines a redirect for the endpoint. Redirect *HTTPRedirectExpr + // SSE defines the Server-Sent Events configuration for this endpoint if it's + // a streaming endpoint. If nil, the endpoint uses WebSockets by default or + // inherits the service-level SSE configuration if defined. + SSE *HTTPSSEExpr // Meta is a set of key/value pairs with semantic that is // specific to each generator, see dsl.Meta. Meta MetaExpr @@ -274,6 +278,15 @@ func (e *HTTPEndpointExpr) Prepare() { e.Responses = []*HTTPResponseExpr{{StatusCode: status}} } + // Inherit SSE configuration from service or API level for streaming endpoints + if e.MethodExpr.Stream == ServerStreamKind && e.SSE == nil { + if e.Service.SSE != nil { + e.SSE = e.Service.SSE + } else if Root.API.HTTP.SSE != nil { + e.SSE = Root.API.HTTP.SSE + } + } + // Error -> ResponseError methodErrors := map[string]struct{}{} for _, v := range e.HTTPErrors { @@ -377,6 +390,23 @@ func (e *HTTPEndpointExpr) Validate() error { } } + // Validate streaming endpoints for SSE compatibility + if e.MethodExpr.Stream == ServerStreamKind { + // Prepare already handles inheriting SSE from service or API level + if e.SSE != nil { + verr.Merge(e.SSE.Validate(e).(*eval.ValidationErrors)) + } + } else if e.SSE != nil { + // Error if SSE is defined but endpoint is not server streaming + if e.MethodExpr.Stream == BidirectionalStreamKind { + verr.Add(e, "Server-Sent Events cannot be used with bidirectional streaming endpoints") + } else if e.MethodExpr.Stream == ClientStreamKind { + verr.Add(e, "Server-Sent Events cannot be used with client-to-server streaming endpoints") + } else { + verr.Add(e, "Server-Sent Events can only be used with endpoints that have a streaming result") + } + } + // Redirect is not compatible with Response. if e.Redirect != nil { found := false diff --git a/expr/http_service.go b/expr/http_service.go index 1944ed929f..1e1a1799f3 100644 --- a/expr/http_service.go +++ b/expr/http_service.go @@ -40,6 +40,9 @@ type ( HTTPErrors []*HTTPErrorExpr // FileServers is the list of static asset serving endpoints FileServers []*HTTPFileServerExpr + // SSE defines the Server-Sent Events configuration for all streaming endpoints + // in this service. If nil, streaming endpoints use WebSockets by default. + SSE *HTTPSSEExpr // Meta is a set of key/value pairs with semantic that is // specific to each generator. Meta MetaExpr diff --git a/expr/http_sse.go b/expr/http_sse.go new file mode 100644 index 0000000000..3b7ddfcdff --- /dev/null +++ b/expr/http_sse.go @@ -0,0 +1,103 @@ +package expr + +import ( + "fmt" + "strings" + + "slices" + + "goa.design/goa/v3/eval" +) + +type ( + // HTTPSSEExpr describes a Server-Sent Events configuration for a HTTP endpoint. + // It defines how a streaming endpoint should use the Server-Sent Events protocol + // instead of WebSockets. + HTTPSSEExpr struct { + // RequestIDField is the name of the attribute in the Payload type + // that provides the Last-Event-ID request header value. + // If empty, no Last-Event-ID request header is included in the request. + RequestIDField string + // DataField is the name of the attribute in the StreamingResult type + // that provides the data field for a Server-Sent Event. + // If empty, the entire StreamingResult is used as the data field. + DataField string + // IDField is the name of the attribute in the StreamingResult type + // that provides the id field for a Server-Sent Event. + // If empty, no id field is included in the event. + IDField string + // EventField is the name of the attribute in the StreamingResult type + // that provides the event field (event type) for a Server-Sent Event. + // If empty, no event field is included in the event. + EventField string + // RetryField is the name of the attribute in the StreamingResult type + // that provides the retry field for a Server-Sent Event. + // If empty, no retry field is included in the event. + RetryField string + } +) + +// EvalName returns the generic expression name used in error messages. +func (e *HTTPSSEExpr) EvalName() string { + return "Server-Sent Events" +} + +// Validate validates the Server-Sent Events expression against a specific result type. +func (e *HTTPSSEExpr) Validate(endpoint *HTTPEndpointExpr) error { + if endpoint == nil || endpoint.MethodExpr == nil || endpoint.MethodExpr.Result == nil { + return nil + } + + verr := new(eval.ValidationErrors) + if err := validateSSEField(endpoint.MethodExpr.Payload, e.RequestIDField, "request ID", []DataType{String}); err != nil { + verr.Add(endpoint, err.Error()) + } + if err := validateSSEField(endpoint.MethodExpr.Result, e.DataField, "event data", nil); err != nil { + verr.Add(endpoint, err.Error()) + } + if err := validateSSEField(endpoint.MethodExpr.Result, e.IDField, "event id", []DataType{String}); err != nil { + verr.Add(endpoint, err.Error()) + } + if err := validateSSEField(endpoint.MethodExpr.Result, e.EventField, "event type", []DataType{String}); err != nil { + verr.Add(endpoint, err.Error()) + } + if err := validateSSEField(endpoint.MethodExpr.Result, e.RetryField, "event retry", []DataType{Int, Int32, Int64, UInt, UInt32, UInt64}); err != nil { + verr.Add(endpoint, err.Error()) + } + + if len(verr.Errors) == 0 { + return nil + } + return verr +} + +// validateSSEField validates that the given field exists in the result type and has the expected type. +func validateSSEField(rt *AttributeExpr, field, desc string, expectedTypes []DataType) error { + if field == "" { + return nil + } + + if rt == nil { + return fmt.Errorf("cannot use %q for SSE %s field: result type is nil", field, desc) + } + + obj := AsObject(rt.Type) + if obj == nil { + return fmt.Errorf("cannot use %q for SSE %s field: result type is not an object", field, desc) + } + + att := obj.Attribute(field) + if att == nil { + return fmt.Errorf("cannot use %q for SSE %s field: attribute not found in result type", field, desc) + } + + if len(expectedTypes) > 0 && !slices.Contains(expectedTypes, att.Type) { + typeNames := make([]string, len(expectedTypes)) + for i, t := range expectedTypes { + typeNames[i] = t.Name() + } + return fmt.Errorf("cannot use %q for SSE %s field: attribute type must be one of %s", field, desc, strings.Join(typeNames, ", ")) + } + + return nil +} diff --git a/expr/http_sse_test.go b/expr/http_sse_test.go new file mode 100644 index 0000000000..da731f2f9d --- /dev/null +++ b/expr/http_sse_test.go @@ -0,0 +1,175 @@ +package expr_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "goa.design/goa/v3/expr" +) + +func TestHTTPSSEExprValidation(t *testing.T) { + cases := map[string]struct { + SSE *expr.HTTPSSEExpr + Payload *expr.AttributeExpr + Result *expr.AttributeExpr + ExpectedErrs []string + }{ + "valid-empty": { + SSE: &expr.HTTPSSEExpr{}, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + Result: &expr.AttributeExpr{Type: &expr.Object{}}, + }, + "valid-with-fields": { + SSE: &expr.HTTPSSEExpr{ + RequestIDField: "request_id", + DataField: "data", + IDField: "id", + EventField: "event", + RetryField: "retry", + }, + Payload: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "request_id", Attribute: &expr.AttributeExpr{Type: expr.String}}, + }, + }, + Result: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "data", Attribute: &expr.AttributeExpr{Type: expr.String}}, + &expr.NamedAttributeExpr{Name: "id", Attribute: &expr.AttributeExpr{Type: expr.String}}, + &expr.NamedAttributeExpr{Name: "event", Attribute: &expr.AttributeExpr{Type: expr.String}}, + &expr.NamedAttributeExpr{Name: "retry", Attribute: &expr.AttributeExpr{Type: expr.Int}}, + }, + }, + }, + "invalid-id-field-type": { + SSE: &expr.HTTPSSEExpr{ + IDField: "id", + }, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + Result: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "id", Attribute: &expr.AttributeExpr{Type: expr.Int}}, + }, + }, + ExpectedErrs: []string{"cannot use \"id\" for SSE event id field: attribute type must be one of string"}, + }, + "invalid-request-id-field-type": { + SSE: &expr.HTTPSSEExpr{ + RequestIDField: "request_id", + }, + Payload: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "request_id", Attribute: &expr.AttributeExpr{Type: expr.Int}}, + }, + }, + Result: &expr.AttributeExpr{Type: &expr.Object{}}, + ExpectedErrs: []string{"cannot use \"request_id\" for SSE request ID field: attribute type must be one of string"}, + }, + "invalid-event-field-type": { + SSE: &expr.HTTPSSEExpr{ + EventField: "event", + }, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + Result: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "event", Attribute: &expr.AttributeExpr{Type: expr.Int}}, + }, + }, + ExpectedErrs: []string{"cannot use \"event\" for SSE event type field: attribute type must be one of string"}, + }, + "invalid-retry-field-type": { + SSE: &expr.HTTPSSEExpr{ + RetryField: "retry", + }, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + Result: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "retry", Attribute: &expr.AttributeExpr{Type: expr.Boolean}}, + }, + }, + ExpectedErrs: []string{"cannot use \"retry\" for SSE event retry field: attribute type must be one of int, int32, int64, uint, uint32, uint64"}, + }, + "missing-field": { + SSE: &expr.HTTPSSEExpr{ + DataField: "missing", + }, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + Result: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "data", Attribute: &expr.AttributeExpr{Type: expr.String}}, + }, + }, + ExpectedErrs: []string{"cannot use \"missing\" for SSE event data field: attribute not found in result type"}, + }, + "missing-request-id-field": { + SSE: &expr.HTTPSSEExpr{ + RequestIDField: "missing", + }, + Payload: &expr.AttributeExpr{ + Type: &expr.Object{ + &expr.NamedAttributeExpr{Name: "request_id", Attribute: &expr.AttributeExpr{Type: expr.String}}, + }, + }, + Result: &expr.AttributeExpr{Type: &expr.Object{}}, + ExpectedErrs: []string{"cannot use \"missing\" for SSE request ID field: attribute not found in result type"}, + }, + "nil-result-type": { + SSE: &expr.HTTPSSEExpr{ + DataField: "data", + }, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + }, + "nil-payload-type": { + SSE: &expr.HTTPSSEExpr{ + RequestIDField: "request_id", + }, + }, + "empty-result-type": { + SSE: &expr.HTTPSSEExpr{ + DataField: "data", + }, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + Result: &expr.AttributeExpr{Type: expr.Int}, + ExpectedErrs: []string{"cannot use \"data\" for SSE event data field: result type is not an object"}, + }, + "non-object-result-type": { + SSE: &expr.HTTPSSEExpr{ + DataField: "data", + }, + Payload: &expr.AttributeExpr{Type: &expr.Object{}}, + Result: &expr.AttributeExpr{Type: expr.String}, + ExpectedErrs: []string{"cannot use \"data\" for SSE event data field: result type is not an object"}, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + // Create a mock HTTP endpoint expression for validation + methodExpr := &expr.MethodExpr{ + Name: "TestMethod", + Payload: tc.Payload, + Result: tc.Result, + Stream: expr.ServerStreamKind, // Must be a streaming method for SSE + } + endpoint := &expr.HTTPEndpointExpr{ + MethodExpr: methodExpr, + SSE: tc.SSE, + } + + // Run validation + err := tc.SSE.Validate(endpoint) + + // Check results + if len(tc.ExpectedErrs) == 0 { + require.NoError(t, err, "expected no error") + } else { + require.Error(t, err, "expected error, got none") + for _, expected := range tc.ExpectedErrs { + assert.Contains(t, err.Error(), expected, "error should contain expected message") + } + } + }) + } +} From afb7a838f81e80920a600f557028fc56bb9d6909 Mon Sep 17 00:00:00 2001 From: Raphael Simon Date: Sat, 12 Apr 2025 15:14:22 -0700 Subject: [PATCH 2/4] Add server side streaming This commit adds server side rendering of SSE responses. TBD: * Handling the incoming Last-Request-Id * Client side handling --- dsl/sse.go | 64 ++++-- expr/http_endpoint.go | 24 +- http/codegen/example_cli_test.go | 2 +- http/codegen/example_server_test.go | 2 +- http/codegen/server.go | 10 +- http/codegen/service_data.go | 4 + http/codegen/sse.go | 210 ++++++++++++++++++ http/codegen/sse_server_test.go | 41 ++++ .../templates/partial/sse_format.go.tpl | 21 ++ .../templates/server_handler_init.go.tpl | 36 ++- http/codegen/templates/server_sse.go.tpl | 104 +++++++++ .../{ => golden}/client-no-server.golden | 0 ...nt-server-hosting-multiple-services.golden | 0 ...lient-server-hosting-service-subset.golden | 0 .../client-streaming-multiple-services.golden | 0 .../{ => golden}/client-streaming.golden | 0 .../{ => golden}/server-no-server.golden | 0 ...er-server-hosting-multiple-services.golden | 0 ...erver-server-hosting-service-subset.golden | 0 ...er-hosting-service-with-file-server.golden | 0 .../{ => golden}/server-streaming.golden | 0 .../testdata/golden/sse-all-fields.golden | 77 +++++++ http/codegen/testdata/golden/sse-bool.golden | 58 +++++ .../testdata/golden/sse-data-field.golden | 62 ++++++ .../testdata/golden/sse-data-id-field.golden | 67 ++++++ http/codegen/testdata/golden/sse-int.golden | 54 +++++ .../codegen/testdata/golden/sse-object.golden | 60 +++++ .../testdata/golden/sse-request-id.golden | 55 +++++ .../codegen/testdata/golden/sse-string.golden | 55 +++++ http/codegen/testdata/sse_dsls.go | 142 ++++++++++++ http/codegen/websocket.go | 3 + 31 files changed, 1116 insertions(+), 35 deletions(-) create mode 100644 http/codegen/sse.go create mode 100644 http/codegen/sse_server_test.go create mode 100644 http/codegen/templates/partial/sse_format.go.tpl create mode 100644 http/codegen/templates/server_sse.go.tpl rename http/codegen/testdata/{ => golden}/client-no-server.golden (100%) rename http/codegen/testdata/{ => golden}/client-server-hosting-multiple-services.golden (100%) rename http/codegen/testdata/{ => golden}/client-server-hosting-service-subset.golden (100%) rename http/codegen/testdata/{ => golden}/client-streaming-multiple-services.golden (100%) rename http/codegen/testdata/{ => golden}/client-streaming.golden (100%) rename http/codegen/testdata/{ => golden}/server-no-server.golden (100%) rename http/codegen/testdata/{ => golden}/server-server-hosting-multiple-services.golden (100%) rename http/codegen/testdata/{ => golden}/server-server-hosting-service-subset.golden (100%) rename http/codegen/testdata/{ => golden}/server-server-hosting-service-with-file-server.golden (100%) rename http/codegen/testdata/{ => golden}/server-streaming.golden (100%) create mode 100644 http/codegen/testdata/golden/sse-all-fields.golden create mode 100644 http/codegen/testdata/golden/sse-bool.golden create mode 100644 http/codegen/testdata/golden/sse-data-field.golden create mode 100644 http/codegen/testdata/golden/sse-data-id-field.golden create mode 100644 http/codegen/testdata/golden/sse-int.golden create mode 100644 http/codegen/testdata/golden/sse-object.golden create mode 100644 http/codegen/testdata/golden/sse-request-id.golden create mode 100644 http/codegen/testdata/golden/sse-string.golden create mode 100644 http/codegen/testdata/sse_dsls.go diff --git a/dsl/sse.go b/dsl/sse.go index 66f752e70e..07482eca6c 100644 --- a/dsl/sse.go +++ b/dsl/sse.go @@ -7,11 +7,16 @@ import ( // ServerSentEvents specifies that a streaming endpoint should use the // Server-Sent Events protocol for streaming instead of WebSockets. It can be -// used in three ways: +// used in four ways: // -// 1. ServerSentEvents() - StreamingResult type is used directly as the "data" field -// 2. ServerSentEvents("attributeName") - The specified attribute is used as the "data" field -// 3. ServerSentEvents(func() { ... }) - Custom mapping of attributes to SSE fields +// 1. ServerSentEvents(): StreamingResult type is used directly as the event +// "data" field (serialized into JSON if not a primitive type) +// 2. ServerSentEvents("attributeName"): The specified attribute is used as the +// event "data" field (serialized into JSON if not a primitive type) +// 3. ServerSentEvents(func() { ... }): Custom mapping of attributes to event +// fields +// 4. ServerSentEvents("attributeName", func() { ... }): Define attribute name +// used as the "data" field and custom mapping for others. // // ServerSentEvents can appear in an API HTTP expression (to specify SSE for all streaming // methods in the API), in a Service HTTP expression (to specify SSE for all streaming @@ -19,13 +24,14 @@ import ( // API or service level, any method with a StreamingPayload will fall back to using WebSockets // as SSE only supports server-to-client streaming. // -// See SSEData, SSEID, SSEType, SSERetry for more details on mapping result attributes -// to SSE fields. +// See SSEEventData, SSEEventID, SSEEventType, SSEEventRetry for more details on +// mapping result attributes to event fields. See SSERequestID for more details on +// mapping payload attributes to the Last-Event-ID request header. // // Example: // // var Notification = Type("Notification", func() { -// Attribute("message", String)message +// Attribute("message", String) // Attribute("timestamp", String) // Required("message", "timestamp") // }) @@ -69,20 +75,40 @@ import ( // }) // }) // }) -func ServerSentEvents(val any) { +func ServerSentEvents(args ...any) { + if len(args) > 2 { + eval.TooManyArgError() + return + } + if len(args) == 2 { + if _, ok := args[1].(func()); !ok { + eval.InvalidArgError("function", args[1]) + return + } + } + var fn func() var dataField string - - switch actual := val.(type) { - case func(): - fn = actual - case string: - dataField = actual - case nil: - // Use the entire result as data field - default: - eval.InvalidArgError("function or string", val) - return + if len(args) > 0 { + switch actual := args[0].(type) { + case func(): + fn = actual + case string: + dataField = actual + case nil: + // Use the entire result as data field + default: + eval.InvalidArgError("function or string", args[0]) + return + } + if len(args) == 2 { + var ok bool + fn, ok = args[1].(func()) + if !ok { + eval.InvalidArgError("function", args[1]) + return + } + } } sse := &expr.HTTPSSEExpr{ diff --git a/expr/http_endpoint.go b/expr/http_endpoint.go index 6dc4f46941..4ce5fdba2b 100644 --- a/expr/http_endpoint.go +++ b/expr/http_endpoint.go @@ -1,6 +1,7 @@ package expr import ( + "errors" "fmt" "path" "strings" @@ -394,15 +395,21 @@ func (e *HTTPEndpointExpr) Validate() error { if e.MethodExpr.Stream == ServerStreamKind { // Prepare already handles inheriting SSE from service or API level if e.SSE != nil { - verr.Merge(e.SSE.Validate(e).(*eval.ValidationErrors)) + if err := e.SSE.Validate(e); err != nil { + var valErr *eval.ValidationErrors + if errors.As(err, &valErr) { + verr.Merge(valErr) + } + } } } else if e.SSE != nil { // Error if SSE is defined but endpoint is not server streaming - if e.MethodExpr.Stream == BidirectionalStreamKind { + switch e.MethodExpr.Stream { + case BidirectionalStreamKind: verr.Add(e, "Server-Sent Events cannot be used with bidirectional streaming endpoints") - } else if e.MethodExpr.Stream == ClientStreamKind { + case ClientStreamKind: verr.Add(e, "Server-Sent Events cannot be used with client-to-server streaming endpoints") - } else { + default: verr.Add(e, "Server-Sent Events can only be used with endpoints that have a streaming result") } } @@ -673,10 +680,13 @@ func (e *HTTPEndpointExpr) Validate() error { if e.SkipRequestBodyEncodeDecode && body.Type != Empty { verr.Add(e, "HTTP endpoint request body must be empty when using SkipRequestBodyEncodeDecode but not all method payload attributes are mapped to headers and params. Make sure to define Headers and Params as needed.") } + // For streaming endpoints, check if request body is allowed if e.MethodExpr.IsStreaming() && body.Type != Empty { - // Refer Websocket protocol - https://tools.ietf.org/html/rfc6455 - // Protocol does not allow HTTP request body to be passed. - verr.Add(e, "HTTP endpoint request body must be empty when the endpoint uses streaming. Payload attributes must be mapped to headers and/or params.") + // SSE endpoints can have request bodies, but WebSocket endpoints cannot + // Refer WebSocket protocol - https://tools.ietf.org/html/rfc6455 + if e.SSE == nil { // Only apply this validation to non-SSE streaming endpoints + verr.Add(e, "HTTP endpoint request body must be empty when the endpoint uses streaming. Payload attributes must be mapped to headers and/or params.") + } } return verr diff --git a/http/codegen/example_cli_test.go b/http/codegen/example_cli_test.go index c6de2bae7f..6b12fea1f1 100644 --- a/http/codegen/example_cli_test.go +++ b/http/codegen/example_cli_test.go @@ -38,7 +38,7 @@ func TestExampleCLIFiles(t *testing.T) { require.NoError(t, s.Write(&buf)) } code := codegen.FormatTestCode(t, "package foo\n"+buf.String()) - golden := filepath.Join("testdata", "client-"+c.Name+".golden") + golden := filepath.Join("testdata", "golden", "client-"+c.Name+".golden") compareOrUpdateGolden(t, code, golden) }) } diff --git a/http/codegen/example_server_test.go b/http/codegen/example_server_test.go index ffd6e1e7be..e0f61014d3 100644 --- a/http/codegen/example_server_test.go +++ b/http/codegen/example_server_test.go @@ -106,7 +106,7 @@ func TestExampleServerFiles(t *testing.T) { require.NoError(t, s.Write(&buf)) } code := codegen.FormatTestCode(t, "package foo\n"+buf.String()) - golden := filepath.Join("testdata", "server-"+c.Name+".golden") + golden := filepath.Join("testdata", "golden", "server-"+c.Name+".golden") compareOrUpdateGolden(t, code, golden) }) } diff --git a/http/codegen/server.go b/http/codegen/server.go index c514baa223..93d8428d13 100644 --- a/http/codegen/server.go +++ b/http/codegen/server.go @@ -20,6 +20,9 @@ func ServerFiles(genpkg string, root *expr.RootExpr) []*codegen.File { if f := websocketServerFile(genpkg, svc); f != nil { files = append(files, f) } + if f := sseServerFile(genpkg, svc); f != nil { + files = append(files, f) + } } for _, svc := range root.API.HTTP.Services { if f := serverEncodeDecodeFile(genpkg, svc); f != nil { @@ -39,6 +42,7 @@ func serverFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { "join": strings.Join, "hasWebSocket": hasWebSocket, "isWebSocketEndpoint": isWebSocketEndpoint, + "isSSEEndpoint": isSSEEndpoint, "viewedServerBody": viewedServerBody, "mustDecodeRequest": mustDecodeRequest, "addLeadingSlash": addLeadingSlash, @@ -266,10 +270,10 @@ func viewedServerBody(sbd []*TypeData, view string) *TypeData { } func addLeadingSlash(s string) string { - if strings.HasPrefix(s, "/") { - return s + if s == "" || s[0] != '/' { + return "/" + s } - return "/" + s + return s } func mapQueryDecodeData(dt expr.DataType, varName string, inc int) map[string]any { diff --git a/http/codegen/service_data.go b/http/codegen/service_data.go index 43ef37e932..cf7764ddf6 100644 --- a/http/codegen/service_data.go +++ b/http/codegen/service_data.go @@ -152,6 +152,9 @@ type ( // ServerWebSocket holds the data to render the server struct which // implements the server stream interface. ServerWebSocket *WebSocketData + // SSE holds the data to render the server struct which implements the + // server stream interface for SSE. + SSE *SSEData // Redirect defines a redirect for the endpoint. Redirect *RedirectData @@ -838,6 +841,7 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { } if httpEndpoint.MethodExpr.IsStreaming() { initWebSocketData(ed, httpEndpoint, sd) + initSSEData(ed, httpEndpoint, sd) } if httpEndpoint.MultipartRequest { diff --git a/http/codegen/sse.go b/http/codegen/sse.go new file mode 100644 index 0000000000..32f54ea6cd --- /dev/null +++ b/http/codegen/sse.go @@ -0,0 +1,210 @@ +package codegen + +import ( + "fmt" + "path/filepath" + + "goa.design/goa/v3/codegen" + "goa.design/goa/v3/expr" +) + +type ( + // SSEData contains the data needed to render struct type that + // implements the server stream interface for SSE. + SSEData struct { + // VarName is the name of the struct. + VarName string + // Interface is the fully qualified name of the interface that + // the struct implements. + Interface string + // Endpoint is endpoint data that defines streaming result. + Endpoint *EndpointData + // Response is the successful response data for the streaming + // endpoint. + Response *ResponseData + // SendName is the name of the send function. + SendName string + // SendDesc is the description for the send function. + SendDesc string + // SendWithContextName is the name of the send function with context. + SendWithContextName string + // SendWithContextDesc is the description for the send function with context. + SendWithContextDesc string + // SendTypeName is the fully qualified type name sent through + // the stream. + SendTypeName string + // SendTypeRef is the fully qualified type ref sent through the + // stream. + SendTypeRef string + + // PkgName is the service package name. + PkgName string + // SSEConfig contains the SSE configuration for this endpoint. + SSEConfig *expr.HTTPSSEExpr + // WriteHeaderName is the name of the WriteHeader function. + WriteHeaderName string + // WriteHeaderDesc is the description for the WriteHeader function. + WriteHeaderDesc string + + // DataFieldType is the type of the data field if SSEConfig.DataField is set. + // It's computed during initialization to avoid complex template logic. + DataFieldType expr.DataType + // ResultType is the type of the result. + ResultType expr.DataType + } +) + +// initSSEData initializes the SSE related data in ed. +func initSSEData(ed *EndpointData, e *expr.HTTPEndpointExpr, sd *ServiceData) { + if e.SSE == nil { + return + } + + md := ed.Method + svc := sd.Service + svrSendTypeName := ed.Result.Name + svrSendTypeRef := ed.Result.Ref + svrSendDesc := fmt.Sprintf("%s streams instances of %q to the %q endpoint SSE connection.", md.ServerStream.SendName, svrSendTypeName, md.Name) + svrSendWithContextDesc := fmt.Sprintf("%s streams instances of %q to the %q endpoint SSE connection with context.", md.ServerStream.SendWithContextName, svrSendTypeName, md.Name) + writeHeaderDesc := fmt.Sprintf("%s writes the given header to the HTTP response.", "WriteHeader") + + // Set the result type for use in the template + var resultType expr.DataType + if e.MethodExpr != nil && e.MethodExpr.Result != nil { + resultType = e.MethodExpr.Result.Type + } + + // Compute the data field type if a data field is specified + var dataFieldType expr.DataType + if e.SSE.DataField != "" && resultType != nil { + // If the result type is an object and has the data field, extract its type + if obj, ok := resultType.(*expr.Object); ok { + for _, nat := range *obj { + if nat.Name == e.SSE.DataField { + dataFieldType = nat.Attribute.Type + break + } + } + } + } + + // Create SSE data for server + ed.SSE = &SSEData{ + VarName: md.ServerStream.VarName, + Interface: fmt.Sprintf("%s.%s", svc.PkgName, md.ServerStream.Interface), + Endpoint: ed, + Response: ed.Result.Responses[0], + PkgName: svc.PkgName, + SendName: md.ServerStream.SendName, + SendDesc: svrSendDesc, + SendWithContextName: md.ServerStream.SendWithContextName, + SendWithContextDesc: svrSendWithContextDesc, + SendTypeName: svrSendTypeName, + SendTypeRef: svrSendTypeRef, + SSEConfig: e.SSE, + WriteHeaderName: "WriteHeader", + WriteHeaderDesc: writeHeaderDesc, + DataFieldType: dataFieldType, + ResultType: resultType, + } +} + +// We don't need the getPrimitiveFormatString function anymore +// since we're using a partial template for formatting + +// sseServerFile returns the file implementing the SSE server +// streaming implementation if any. +func sseServerFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { + data := HTTPServices.Get(svc.Name()) + if data == nil { + return nil + } + + // Check if any endpoint has SSE + hasSSE := false + for _, ed := range data.Endpoints { + if ed.SSE != nil { + hasSSE = true + break + } + } + if !hasSSE { + return nil + } + + path := filepath.Join(codegen.Gendir, "http", codegen.SnakeCase(svc.Name()), "server", "sse.go") + sections := []*codegen.SectionTemplate{ + codegen.Header( + "sse", + "server", + []*codegen.ImportSpec{ + {Path: "context"}, + {Path: "io"}, + {Path: "net/http"}, + {Path: "sync"}, + {Path: "time"}, + {Path: "encoding/json"}, + {Path: "fmt"}, + {Path: "goa.design/goa/v3/http"}, + {Path: genpkg + "/" + codegen.SnakeCase(svc.Name())}, + {Path: genpkg + "/" + codegen.SnakeCase(svc.Name()) + "/views"}, + }, + ), + } + sections = append(sections, sseTemplateSections(data)...) + return &codegen.File{Path: path, SectionTemplates: sections} +} + +// sseTemplateSections returns section templates for SSE endpoints. +func sseTemplateSections(data *ServiceData) []*codegen.SectionTemplate { + sections := make([]*codegen.SectionTemplate, 0) + for _, ed := range data.Endpoints { + if ed.SSE == nil { + continue + } + // Create a map of template functions needed for the SSE template + funcs := map[string]interface{}{ + "add": func(a, b int) int { return a + b }, + "dict": func(values ...interface{}) (map[string]interface{}, error) { + if len(values)%2 != 0 { + return nil, fmt.Errorf("odd number of arguments") + } + dict := make(map[string]interface{}, len(values)/2) + for i := 0; i < len(values); i += 2 { + key, ok := values[i].(string) + if !ok { + return nil, fmt.Errorf("dict keys must be strings") + } + dict[key] = values[i+1] + } + return dict, nil + }, + "AsObject": func(dt expr.DataType) map[string]interface{} { + if obj, ok := dt.(*expr.Object); ok { + result := make(map[string]interface{}) + for _, nat := range *obj { + result[nat.Name] = map[string]interface{}{ + "Attribute": nat.Attribute, + "Name": nat.Name, + } + } + return result + } + return nil + }, + } + sections = append(sections, &codegen.SectionTemplate{ + Name: "server-sse", + Source: readTemplate("server_sse", "sse_format"), + Data: ed.SSE, + FuncMap: funcs, + }) + } + return sections +} + +// isSSEEndpoint returns true if the endpoint defines a streaming result +// with SSE. +func isSSEEndpoint(ed *EndpointData) bool { + return ed.SSE != nil +} diff --git a/http/codegen/sse_server_test.go b/http/codegen/sse_server_test.go new file mode 100644 index 0000000000..118a3b07df --- /dev/null +++ b/http/codegen/sse_server_test.go @@ -0,0 +1,41 @@ +package codegen + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "goa.design/goa/v3/codegen" + "goa.design/goa/v3/expr" + "goa.design/goa/v3/http/codegen/testdata" +) + +func TestSSE(t *testing.T) { + cases := []struct { + Name string + DSL func() + }{ + {"string", testdata.SSEStringDSL}, + {"int", testdata.SSEIntDSL}, + {"bool", testdata.SSEBoolDSL}, + {"object", testdata.SSEObjectDSL}, + {"data-field", testdata.SSEDataFieldDSL}, + {"data-id-field", testdata.SSEDataIDFieldDSL}, + {"request-id", testdata.SSERequestIDDSL}, + {"all-fields", testdata.SSEAllFieldsDSL}, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + RunHTTPDSL(t, c.DSL) + fs := ServerFiles("", expr.Root) + require.Len(t, fs, 3) + sections := fs[1].SectionTemplates + require.Greater(t, len(sections), 1) + code := codegen.SectionCode(t, sections[1]) + golden := filepath.Join("testdata", "golden", "sse-"+c.Name+".golden") + compareOrUpdateGolden(t, code, golden) + }) + } +} diff --git a/http/codegen/templates/partial/sse_format.go.tpl b/http/codegen/templates/partial/sse_format.go.tpl new file mode 100644 index 0000000000..37d33acdd7 --- /dev/null +++ b/http/codegen/templates/partial/sse_format.go.tpl @@ -0,0 +1,21 @@ +{{- if eq .Type.Name "string" }} + data = {{ .VarName }} +{{- else if eq .Type.Name "boolean" }} + if {{ .VarName }} { + data = "true" + } else { + data = "false" + } +{{- else if eq .Type.Name "bytes" }} + data = string({{ .VarName }}) +{{- else if or (eq .Type.Name "int") (eq .Type.Name "int32") (eq .Type.Name "int64") (eq .Type.Name "uint") (eq .Type.Name "uint32") (eq .Type.Name "uint64") }} + data = fmt.Sprintf("%d", {{ .VarName }}) +{{- else if or (eq .Type.Name "float32") (eq .Type.Name "float64") }} + data = fmt.Sprintf("%g", {{ .VarName }}) +{{- else }} + byts, err := json.Marshal({{ .VarName }}) + if err != nil { + return err + } + data = string(byts) +{{- end }} \ No newline at end of file diff --git a/http/codegen/templates/server_handler_init.go.tpl b/http/codegen/templates/server_handler_init.go.tpl index fdd3a452b1..36c39fc0b1 100644 --- a/http/codegen/templates/server_handler_init.go.tpl +++ b/http/codegen/templates/server_handler_init.go.tpl @@ -11,19 +11,19 @@ func {{ .HandlerInit }}( configurer goahttp.ConnConfigureFunc, {{- end }} ) http.Handler { - {{- if (or (mustDecodeRequest .) (not (or .Redirect (isWebSocketEndpoint .))) (not .Redirect) .Method.SkipResponseBodyEncodeDecode) }} + {{- if (or (mustDecodeRequest .) (not (or .Redirect (isWebSocketEndpoint .) (isSSEEndpoint .))) (not .Redirect) .Method.SkipResponseBodyEncodeDecode) }} var ( {{- end }} {{- if mustDecodeRequest . }} decodeRequest = {{ .RequestDecoder }}(mux, decoder) {{- end }} - {{- if not (or .Redirect (isWebSocketEndpoint .)) }} + {{- if not (or .Redirect (isWebSocketEndpoint .) (isSSEEndpoint .)) }} encodeResponse = {{ .ResponseEncoder }}(encoder) {{- end }} {{- if (or (mustDecodeRequest .) (not .Redirect) .Method.SkipResponseBodyEncodeDecode) }} encodeError = {{ if .Errors }}{{ .ErrorEncoder }}{{ else }}goahttp.ErrorEncoder{{ end }}(encoder, formatter) {{- end }} - {{- if (or (mustDecodeRequest .) (not (or .Redirect (isWebSocketEndpoint .))) (not .Redirect) .Method.SkipResponseBodyEncodeDecode) }} + {{- if (or (mustDecodeRequest .) (not (or .Redirect (isWebSocketEndpoint .) (isSSEEndpoint .))) (not .Redirect) .Method.SkipResponseBodyEncodeDecode) }} ) {{- end }} return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -58,6 +58,30 @@ func {{ .HandlerInit }}( {{- end }} } _, err = endpoint(ctx, v) + {{- else if isSSEEndpoint . }} + {{- if .SSE.SSEConfig.RequestIDField }} + // Set Last-Event-ID header if present + if lastEventID := r.Header.Get("Last-Event-ID"); lastEventID != "" { + ctx = context.WithValue(ctx, "last-event-id", lastEventID) + {{- if .Payload.Ref }} + {{- if eq .Method.Payload.Type.Name "Object" }} + p := payload.({{ .Payload.Ref }}) + p.{{ .SSE.SSEConfig.RequestIDField }} = lastEventID + payload = p + {{- end }} + {{- end }} + } + {{- end }} + v := &{{ .ServicePkgName }}.{{ .Method.ServerStream.EndpointStruct }}{ + Stream: &{{ .SSE.VarName }}{ + w: w, + r: r, + }, + {{- if .Payload.Ref }} + Payload: payload.({{ .Payload.Ref }}), + {{- end }} + } + _, err = endpoint(ctx, v) {{- else if .Method.SkipRequestBodyEncodeDecode }} data := &{{ .ServicePkgName }}.{{ .Method.RequestStruct }}{ {{ if .Payload.Ref }}Payload: payload.({{ .Payload.Ref }}), {{ end }}Body: r.Body } res, err := endpoint(ctx, data) @@ -75,6 +99,10 @@ func {{ .HandlerInit }}( return } {{- end }} + {{- if isSSEEndpoint . }} + // For SSE, we need to set appropriate error headers + w.Header().Set("Content-Type", "application/json") + {{- end }} if err := encodeError(ctx, w, err); err != nil { errhandler(ctx, w, err) } @@ -115,7 +143,7 @@ func {{ .HandlerInit }}( return } {{- end }} - {{- if not (or .Redirect (isWebSocketEndpoint .)) }} + {{- if not (or .Redirect (isWebSocketEndpoint .) (isSSEEndpoint .)) }} if err := encodeResponse(ctx, w, {{ if and .Method.SkipResponseBodyEncodeDecode .Result.Ref }}o.Result{{ else }}res{{ end }}); err != nil { errhandler(ctx, w, err) {{- if .Method.SkipResponseBodyEncodeDecode }} diff --git a/http/codegen/templates/server_sse.go.tpl b/http/codegen/templates/server_sse.go.tpl new file mode 100644 index 0000000000..31c3c1c2c7 --- /dev/null +++ b/http/codegen/templates/server_sse.go.tpl @@ -0,0 +1,104 @@ +{{ printf "%s implements the %s interface using Server-Sent Events." .VarName .Interface | comment }} +type {{ .VarName }} struct { + {{ printf "once ensures the headers are written once." | comment }} + once sync.Once + {{ printf "w is the HTTP response writer used to send the SSE events." | comment }} + w http.ResponseWriter + {{ printf "r is the HTTP request." | comment }} + r *http.Request +} + +{{ printf "%s %s" .SendName .SendDesc | comment }} +func (s *{{ .VarName }}) {{ .SendName }}(v {{ .SendTypeRef }}) error { + return s.{{ .SendWithContextName }}(context.Background(), v) +} + +{{ printf "%s %s" .SendWithContextName .SendWithContextDesc | comment }} +func (s *{{ .VarName }}) {{ .SendWithContextName }}(ctx context.Context, v {{ .SendTypeRef }}) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + + {{- if .Endpoint.Method.ViewedResult }} + {{- if .Endpoint.Method.ViewedResult.ViewName }} + res := {{ .PkgName }}.{{ .Endpoint.Method.ViewedResult.Init.Name }}(v, {{ printf "%q" .Endpoint.Method.ViewedResult.ViewName }}) + {{- else }} + res := {{ .PkgName }}.{{ .Endpoint.Method.ViewedResult.Init.Name }}(v, "default") + {{- end }} + {{- else }} + res := v + {{- end }} + + {{ if .SSEConfig.IDField }} + id := res.{{ .SSEConfig.IDField }} + if id != "" { + fmt.Fprintf(s.w, "id: %s\n", id) + } + {{- end }} + + {{ if .SSEConfig.EventField }} + eventType := res.{{ .SSEConfig.EventField }} + if eventType != "" { + fmt.Fprintf(s.w, "event: %s\n", eventType) + } + {{- end }} + + {{ if .SSEConfig.RetryField }} + retry := res.{{ .SSEConfig.RetryField }} + if retry > 0 { + fmt.Fprintf(s.w, "retry: %d\n", retry) + } + {{- end }} + + {{ if .SSEConfig.DataField }} + var data string + dataField := res.{{ .SSEConfig.DataField }} + {{- if .DataFieldType }} + {{- template "partial_sse_format" dict "Type" .DataFieldType "VarName" "dataField" }} + {{- else }} + byts, err := json.Marshal(dataField) + if err != nil { + return err + } + data = string(byts) + {{- end }} + fmt.Fprintf(s.w, "data: %s\n\n", data) + {{- else }} + var data string + {{- if .ResultType }} + {{- template "partial_sse_format" dict "Type" .ResultType "VarName" "res" }} + {{- else }} + byts, err := json.Marshal(res) + if err != nil { + return err + } + data = string(byts) + {{- end }} + fmt.Fprintf(s.w, "data: %s\n\n", data) + {{- end }} + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +{{ printf "WriteHeader writes the given header to the HTTP response." | comment }} +func (s *{{ .VarName }}) {{ .WriteHeaderName }}(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/client-no-server.golden b/http/codegen/testdata/golden/client-no-server.golden similarity index 100% rename from http/codegen/testdata/client-no-server.golden rename to http/codegen/testdata/golden/client-no-server.golden diff --git a/http/codegen/testdata/client-server-hosting-multiple-services.golden b/http/codegen/testdata/golden/client-server-hosting-multiple-services.golden similarity index 100% rename from http/codegen/testdata/client-server-hosting-multiple-services.golden rename to http/codegen/testdata/golden/client-server-hosting-multiple-services.golden diff --git a/http/codegen/testdata/client-server-hosting-service-subset.golden b/http/codegen/testdata/golden/client-server-hosting-service-subset.golden similarity index 100% rename from http/codegen/testdata/client-server-hosting-service-subset.golden rename to http/codegen/testdata/golden/client-server-hosting-service-subset.golden diff --git a/http/codegen/testdata/client-streaming-multiple-services.golden b/http/codegen/testdata/golden/client-streaming-multiple-services.golden similarity index 100% rename from http/codegen/testdata/client-streaming-multiple-services.golden rename to http/codegen/testdata/golden/client-streaming-multiple-services.golden diff --git a/http/codegen/testdata/client-streaming.golden b/http/codegen/testdata/golden/client-streaming.golden similarity index 100% rename from http/codegen/testdata/client-streaming.golden rename to http/codegen/testdata/golden/client-streaming.golden diff --git a/http/codegen/testdata/server-no-server.golden b/http/codegen/testdata/golden/server-no-server.golden similarity index 100% rename from http/codegen/testdata/server-no-server.golden rename to http/codegen/testdata/golden/server-no-server.golden diff --git a/http/codegen/testdata/server-server-hosting-multiple-services.golden b/http/codegen/testdata/golden/server-server-hosting-multiple-services.golden similarity index 100% rename from http/codegen/testdata/server-server-hosting-multiple-services.golden rename to http/codegen/testdata/golden/server-server-hosting-multiple-services.golden diff --git a/http/codegen/testdata/server-server-hosting-service-subset.golden b/http/codegen/testdata/golden/server-server-hosting-service-subset.golden similarity index 100% rename from http/codegen/testdata/server-server-hosting-service-subset.golden rename to http/codegen/testdata/golden/server-server-hosting-service-subset.golden diff --git a/http/codegen/testdata/server-server-hosting-service-with-file-server.golden b/http/codegen/testdata/golden/server-server-hosting-service-with-file-server.golden similarity index 100% rename from http/codegen/testdata/server-server-hosting-service-with-file-server.golden rename to http/codegen/testdata/golden/server-server-hosting-service-with-file-server.golden diff --git a/http/codegen/testdata/server-streaming.golden b/http/codegen/testdata/golden/server-streaming.golden similarity index 100% rename from http/codegen/testdata/server-streaming.golden rename to http/codegen/testdata/golden/server-streaming.golden diff --git a/http/codegen/testdata/golden/sse-all-fields.golden b/http/codegen/testdata/golden/sse-all-fields.golden new file mode 100644 index 0000000000..4a50f7c0b9 --- /dev/null +++ b/http/codegen/testdata/golden/sse-all-fields.golden @@ -0,0 +1,77 @@ +// SSEAllFieldsMethodServerStream implements the +// sseallfieldsservice.SSEAllFieldsMethodServerStream interface using +// Server-Sent Events. +type SSEAllFieldsMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of +// "sseallfieldsservice.SSEAllFieldsMethodResult" to the "SSEAllFieldsMethod" +// endpoint SSE connection. +func (s *SSEAllFieldsMethodServerStream) Send(v *sseallfieldsservice.SSEAllFieldsMethodResult) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of +// "sseallfieldsservice.SSEAllFieldsMethodResult" to the "SSEAllFieldsMethod" +// endpoint SSE connection with context. +func (s *SSEAllFieldsMethodServerStream) SendWithContext(ctx context.Context, v *sseallfieldsservice.SSEAllFieldsMethodResult) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + id := res.id + if id != "" { + fmt.Fprintf(s.w, "id: %s\n", id) + } + + eventType := res.event + if eventType != "" { + fmt.Fprintf(s.w, "event: %s\n", eventType) + } + + retry := res.retry + if retry > 0 { + fmt.Fprintf(s.w, "retry: %d\n", retry) + } + + var data string + dataField := res.data + byts, err := json.Marshal(dataField) + if err != nil { + return err + } + data = string(byts) + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSEAllFieldsMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/golden/sse-bool.golden b/http/codegen/testdata/golden/sse-bool.golden new file mode 100644 index 0000000000..8580e30421 --- /dev/null +++ b/http/codegen/testdata/golden/sse-bool.golden @@ -0,0 +1,58 @@ +// SSEBoolMethodServerStream implements the +// sseboolservice.SSEBoolMethodServerStream interface using Server-Sent Events. +type SSEBoolMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of "bool" to the "SSEBoolMethod" endpoint SSE +// connection. +func (s *SSEBoolMethodServerStream) Send(v bool) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of "bool" to the +// "SSEBoolMethod" endpoint SSE connection with context. +func (s *SSEBoolMethodServerStream) SendWithContext(ctx context.Context, v bool) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + var data string + if res { + data = "true" + } else { + data = "false" + } + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSEBoolMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/golden/sse-data-field.golden b/http/codegen/testdata/golden/sse-data-field.golden new file mode 100644 index 0000000000..50d5dafd9e --- /dev/null +++ b/http/codegen/testdata/golden/sse-data-field.golden @@ -0,0 +1,62 @@ +// SSEDataFieldMethodServerStream implements the +// ssedatafieldservice.SSEDataFieldMethodServerStream interface using +// Server-Sent Events. +type SSEDataFieldMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of +// "ssedatafieldservice.SSEDataFieldMethodResult" to the "SSEDataFieldMethod" +// endpoint SSE connection. +func (s *SSEDataFieldMethodServerStream) Send(v *ssedatafieldservice.SSEDataFieldMethodResult) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of +// "ssedatafieldservice.SSEDataFieldMethodResult" to the "SSEDataFieldMethod" +// endpoint SSE connection with context. +func (s *SSEDataFieldMethodServerStream) SendWithContext(ctx context.Context, v *ssedatafieldservice.SSEDataFieldMethodResult) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + var data string + dataField := res.data + byts, err := json.Marshal(dataField) + if err != nil { + return err + } + data = string(byts) + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSEDataFieldMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/golden/sse-data-id-field.golden b/http/codegen/testdata/golden/sse-data-id-field.golden new file mode 100644 index 0000000000..5afef3eda4 --- /dev/null +++ b/http/codegen/testdata/golden/sse-data-id-field.golden @@ -0,0 +1,67 @@ +// SSEDataIDFieldMethodServerStream implements the +// ssedataidfieldservice.SSEDataIDFieldMethodServerStream interface using +// Server-Sent Events. +type SSEDataIDFieldMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of +// "ssedataidfieldservice.SSEDataIDFieldMethodResult" to the +// "SSEDataIDFieldMethod" endpoint SSE connection. +func (s *SSEDataIDFieldMethodServerStream) Send(v *ssedataidfieldservice.SSEDataIDFieldMethodResult) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of +// "ssedataidfieldservice.SSEDataIDFieldMethodResult" to the +// "SSEDataIDFieldMethod" endpoint SSE connection with context. +func (s *SSEDataIDFieldMethodServerStream) SendWithContext(ctx context.Context, v *ssedataidfieldservice.SSEDataIDFieldMethodResult) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + id := res.id + if id != "" { + fmt.Fprintf(s.w, "id: %s\n", id) + } + + var data string + dataField := res.data + byts, err := json.Marshal(dataField) + if err != nil { + return err + } + data = string(byts) + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSEDataIDFieldMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/golden/sse-int.golden b/http/codegen/testdata/golden/sse-int.golden new file mode 100644 index 0000000000..b09e8073b9 --- /dev/null +++ b/http/codegen/testdata/golden/sse-int.golden @@ -0,0 +1,54 @@ +// SSEIntMethodServerStream implements the +// sseintservice.SSEIntMethodServerStream interface using Server-Sent Events. +type SSEIntMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of "int" to the "SSEIntMethod" endpoint SSE +// connection. +func (s *SSEIntMethodServerStream) Send(v int) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of "int" to the +// "SSEIntMethod" endpoint SSE connection with context. +func (s *SSEIntMethodServerStream) SendWithContext(ctx context.Context, v int) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + var data string + data = fmt.Sprintf("%d", res) + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSEIntMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/golden/sse-object.golden b/http/codegen/testdata/golden/sse-object.golden new file mode 100644 index 0000000000..d65fb33900 --- /dev/null +++ b/http/codegen/testdata/golden/sse-object.golden @@ -0,0 +1,60 @@ +// SSEObjectMethodServerStream implements the +// sseobjectservice.SSEObjectMethodServerStream interface using Server-Sent +// Events. +type SSEObjectMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of "sseobjectservice.SSEObjectMethodResult" to +// the "SSEObjectMethod" endpoint SSE connection. +func (s *SSEObjectMethodServerStream) Send(v *sseobjectservice.SSEObjectMethodResult) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of +// "sseobjectservice.SSEObjectMethodResult" to the "SSEObjectMethod" endpoint +// SSE connection with context. +func (s *SSEObjectMethodServerStream) SendWithContext(ctx context.Context, v *sseobjectservice.SSEObjectMethodResult) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + var data string + byts, err := json.Marshal(res) + if err != nil { + return err + } + data = string(byts) + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSEObjectMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/golden/sse-request-id.golden b/http/codegen/testdata/golden/sse-request-id.golden new file mode 100644 index 0000000000..0b05ceba34 --- /dev/null +++ b/http/codegen/testdata/golden/sse-request-id.golden @@ -0,0 +1,55 @@ +// SSERequestIDMethodServerStream implements the +// sserequestidservice.SSERequestIDMethodServerStream interface using +// Server-Sent Events. +type SSERequestIDMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of "string" to the "SSERequestIDMethod" endpoint +// SSE connection. +func (s *SSERequestIDMethodServerStream) Send(v string) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of "string" to the +// "SSERequestIDMethod" endpoint SSE connection with context. +func (s *SSERequestIDMethodServerStream) SendWithContext(ctx context.Context, v string) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + var data string + data = res + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSERequestIDMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/golden/sse-string.golden b/http/codegen/testdata/golden/sse-string.golden new file mode 100644 index 0000000000..a63572061c --- /dev/null +++ b/http/codegen/testdata/golden/sse-string.golden @@ -0,0 +1,55 @@ +// SSEStringMethodServerStream implements the +// ssestringservice.SSEStringMethodServerStream interface using Server-Sent +// Events. +type SSEStringMethodServerStream struct { + // once ensures the headers are written once. + once sync.Once + // w is the HTTP response writer used to send the SSE events. + w http.ResponseWriter + // r is the HTTP request. + r *http.Request +} + +// Send Send streams instances of "string" to the "SSEStringMethod" endpoint +// SSE connection. +func (s *SSEStringMethodServerStream) Send(v string) error { + return s.SendWithContext(context.Background(), v) +} + +// SendWithContext SendWithContext streams instances of "string" to the +// "SSEStringMethod" endpoint SSE connection with context. +func (s *SSEStringMethodServerStream) SendWithContext(ctx context.Context, v string) error { + s.once.Do(func() { + // Set default SSE headers if not already set + header := s.w.Header() + if header.Get("Content-Type") == "" { + header.Set("Content-Type", "text/event-stream") + } + if header.Get("Cache-Control") == "" { + header.Set("Cache-Control", "no-cache") + } + if header.Get("Connection") == "" { + header.Set("Connection", "keep-alive") + } + s.w.WriteHeader(http.StatusOK) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + }) + res := v + + var data string + data = res + fmt.Fprintf(s.w, "data: %s\n\n", data) + + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + + return nil +} + +// WriteHeader writes the given header to the HTTP response. +func (s *SSEStringMethodServerStream) WriteHeader(key, value string) { + s.w.Header().Set(key, value) +} diff --git a/http/codegen/testdata/sse_dsls.go b/http/codegen/testdata/sse_dsls.go new file mode 100644 index 0000000000..d7de8bbf2b --- /dev/null +++ b/http/codegen/testdata/sse_dsls.go @@ -0,0 +1,142 @@ +package testdata + +import ( + . "goa.design/goa/v3/dsl" +) + +var SSEStringDSL = func() { + Service("SSEStringService", func() { + Method("SSEStringMethod", func() { + StreamingResult(String) + HTTP(func() { + GET("/string") + ServerSentEvents() + }) + }) + }) +} + +var SSEIntDSL = func() { + Service("SSEIntService", func() { + Method("SSEIntMethod", func() { + StreamingResult(Int) + HTTP(func() { + GET("/int") + ServerSentEvents() + }) + }) + }) +} + +var SSEBoolDSL = func() { + Service("SSEBoolService", func() { + Method("SSEBoolMethod", func() { + StreamingResult(Boolean) + HTTP(func() { + GET("/bool") + ServerSentEvents() + }) + }) + }) +} + +var SSEObjectDSL = func() { + Service("SSEObjectService", func() { + Method("SSEObjectMethod", func() { + StreamingResult(func() { + Attribute("id", String) + Attribute("value", Int) + Attribute("flag", Boolean) + }) + HTTP(func() { + GET("/object") + ServerSentEvents() + }) + }) + }) +} + +var SSEDataFieldDSL = func() { + Service("SSEDataFieldService", func() { + Method("SSEDataFieldMethod", func() { + StreamingResult(func() { + Attribute("data", String) + Attribute("flag", Boolean) + }) + HTTP(func() { + GET("/data-field") + ServerSentEvents("data") + }) + }) + }) +} + +var SSEDataIDFieldDSL = func() { + Service("SSEDataIDFieldService", func() { + Method("SSEDataIDFieldMethod", func() { + StreamingResult(func() { + Attribute("data", String) + Attribute("id", String) + }) + HTTP(func() { + GET("/data-id-field") + ServerSentEvents("data", func() { + SSEEventID("id") + }) + }) + }) + }) +} + +var SSERequestIDDSL = func() { + Service("SSERequestIDService", func() { + Method("SSERequestIDMethod", func() { + Payload(func() { + Attribute("id", String) + }) + StreamingResult(String) + HTTP(func() { + GET("/request-id") + ServerSentEvents(func() { + SSERequestID("id") + }) + }) + }) + }) +} + +var SSEAllFieldsDSL = func() { + Service("SSEAllFieldsService", func() { + Method("SSEAllFieldsMethod", func() { + Payload(func() { + Attribute("id", String) + }) + StreamingResult(func() { + Attribute("id", String, func() { + Example("123") + }) + Attribute("event", String, func() { + Example("update") + }) + Attribute("retry", Int, func() { + Example(3000) + }) + Attribute("data", func() { + Attribute("message", String, func() { + Example("Hello, world!") + }) + }) + }) + HTTP(func() { + GET("/all-fields") + ServerSentEvents(func() { + SSERequestID("id") + SSEEventID("id") + SSEEventType("event") + SSEEventRetry("retry") + SSEEventData("data") + }) + }) + }) + }) +} diff --git a/http/codegen/websocket.go b/http/codegen/websocket.go index b0df17604a..6be5e5c673 100644 --- a/http/codegen/websocket.go +++ b/http/codegen/websocket.go @@ -74,6 +74,9 @@ type ( // initWebSocketData initializes the WebSocket related data in ed. func initWebSocketData(ed *EndpointData, e *expr.HTTPEndpointExpr, sd *ServiceData) { + if e.SSE != nil { + return + } var ( svrRecvTypeName string svrRecvTypeRef string From 92e111a3b61723114ea6724d0fad327a4a301043 Mon Sep 17 00:00:00 2001 From: Raphael Simon Date: Sun, 20 Apr 2025 22:59:13 -0700 Subject: [PATCH 3/4] Initial complete implementation of SSE --- .gitignore | 4 +- http/codegen/client.go | 26 +- http/codegen/client_cli.go | 8 +- http/codegen/example_cli.go | 4 +- http/codegen/example_server.go | 2 +- http/codegen/service_data.go | 17 +- http/codegen/sse.go | 143 +++++----- http/codegen/sse_client.go | 85 ++++++ http/codegen/templates/cli_end.go.tpl | 2 +- http/codegen/templates/cli_streaming.go.tpl | 2 +- http/codegen/templates/client_sse.go.tpl | 245 ++++++++++++++++++ http/codegen/templates/endpoint_init.go.tpl | 21 ++ http/codegen/templates/parse_endpoint.go.tpl | 4 +- .../templates/partial/sse_format.go.tpl | 10 +- .../templates/partial/sse_parse.go.tpl | 53 ++++ http/codegen/templates/request_init.go.tpl | 4 +- .../codegen/templates/server_configure.go.tpl | 2 +- .../templates/server_handler_init.go.tpl | 10 +- http/codegen/templates/server_sse.go.tpl | 82 ++---- .../testdata/golden/sse-all-fields.golden | 25 +- http/codegen/testdata/golden/sse-bool.golden | 20 +- .../testdata/golden/sse-data-field.golden | 12 +- .../testdata/golden/sse-data-id-field.golden | 15 +- http/codegen/testdata/golden/sse-int.golden | 12 +- .../codegen/testdata/golden/sse-object.golden | 12 +- .../testdata/golden/sse-request-id.golden | 12 +- .../codegen/testdata/golden/sse-string.golden | 12 +- http/codegen/websocket.go | 8 +- 28 files changed, 587 insertions(+), 265 deletions(-) create mode 100644 http/codegen/sse_client.go create mode 100644 http/codegen/templates/client_sse.go.tpl create mode 100644 http/codegen/templates/partial/sse_parse.go.tpl diff --git a/.gitignore b/.gitignore index 2119f65eba..5d07e62b61 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,6 @@ cmd/goa/goa # DeepSource cruft cover.out -bin/ + +# MacOS cruft +**/.DS_Store diff --git a/http/codegen/client.go b/http/codegen/client.go index 5d70b2ed8d..14e68654d2 100644 --- a/http/codegen/client.go +++ b/http/codegen/client.go @@ -17,6 +17,9 @@ func ClientFiles(genpkg string, root *expr.RootExpr) []*codegen.File { if f := websocketClientFile(genpkg, svc); f != nil { files = append(files, f) } + if f := sseClientFile(genpkg, svc); f != nil { + files = append(files, f) + } } for _, svc := range root.API.HTTP.Services { if f := clientEncodeDecodeFile(genpkg, svc); f != nil { @@ -50,10 +53,13 @@ func clientFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { }), } sections = append(sections, &codegen.SectionTemplate{ - Name: "client-struct", - Source: readTemplate("client_struct"), - Data: data, - FuncMap: map[string]any{"hasWebSocket": hasWebSocket}, + Name: "client-struct", + Source: readTemplate("client_struct"), + Data: data, + FuncMap: map[string]any{ + "hasWebSocket": hasWebSocket, + "hasSSE": hasSSE, + }, }) for _, e := range data.Endpoints { @@ -67,10 +73,13 @@ func clientFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { } sections = append(sections, &codegen.SectionTemplate{ - Name: "http-client-init", - Source: readTemplate("client_init"), - Data: data, - FuncMap: map[string]any{"hasWebSocket": hasWebSocket}, + Name: "http-client-init", + Source: readTemplate("client_init"), + Data: data, + FuncMap: map[string]any{ + "hasWebSocket": hasWebSocket, + "hasSSE": hasSSE, + }, }) for _, e := range data.Endpoints { @@ -80,6 +89,7 @@ func clientFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { Data: e, FuncMap: map[string]any{ "isWebSocketEndpoint": isWebSocketEndpoint, + "isSSEEndpoint": isSSEEndpoint, "responseStructPkg": responseStructPkg, }, }) diff --git a/http/codegen/client_cli.go b/http/codegen/client_cli.go index da4fa238a8..ad85e35dbf 100644 --- a/http/codegen/client_cli.go +++ b/http/codegen/client_cli.go @@ -14,8 +14,8 @@ type commandData struct { *cli.CommandData // Subcommands is the list of endpoint commands. Subcommands []*subcommandData - // NeedStream if true initializes the websocket dialer. - NeedStream bool + // NeedDialer if true initializes the websocket dialer. + NeedDialer bool } // commandData wraps the common SubcommandData and adds HTTP-specific fields. @@ -50,7 +50,7 @@ func ClientCLIFiles(genpkg string, root *expr.RootExpr) []*codegen.File { if len(sd.Endpoints) > 0 { command := &commandData{ CommandData: cli.BuildCommandData(sd.Service), - NeedStream: hasWebSocket(sd), + NeedDialer: hasWebSocket(sd), } for _, e := range sd.Endpoints { @@ -292,7 +292,7 @@ func streamFlag(svcn, en string) *cli.FlagData { // uses stream for sending payload/result. func streamingCmdExists(data []*commandData) bool { for _, c := range data { - if c.NeedStream { + if c.NeedDialer { return true } } diff --git a/http/codegen/example_cli.go b/http/codegen/example_cli.go index c6d57c8340..937f6067d7 100644 --- a/http/codegen/example_cli.go +++ b/http/codegen/example_cli.go @@ -85,7 +85,7 @@ func exampleCLI(genpkg string, root *expr.RootExpr, svr *expr.ServerExpr) *codeg "Services": svcData, }, FuncMap: map[string]any{ - "needStream": needStream, + "needDialer": needDialer, }, }, { @@ -96,7 +96,7 @@ func exampleCLI(genpkg string, root *expr.RootExpr, svr *expr.ServerExpr) *codeg "APIPkg": apiPkg, }, FuncMap: map[string]any{ - "needStream": needStream, + "needDialer": needDialer, "hasWebSocket": hasWebSocket, }, }, diff --git a/http/codegen/example_server.go b/http/codegen/example_server.go index ffa9835af7..829781bda4 100644 --- a/http/codegen/example_server.go +++ b/http/codegen/example_server.go @@ -105,7 +105,7 @@ func exampleServer(genpkg string, root *expr.RootExpr, svr *expr.ServerExpr) *co "Services": svcdata, "APIPkg": apiPkg, }, - FuncMap: map[string]any{"needStream": needStream, "hasWebSocket": hasWebSocket}, + FuncMap: map[string]any{"needDialer": needDialer, "hasWebSocket": hasWebSocket}, }, { Name: "server-http-middleware", diff --git a/http/codegen/service_data.go b/http/codegen/service_data.go index cf7764ddf6..2d8bf0f091 100644 --- a/http/codegen/service_data.go +++ b/http/codegen/service_data.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "net/http" + "slices" "sort" "strconv" "strings" @@ -36,6 +37,7 @@ var ( _, ok := dt.(expr.UserType) return ok }, + "isWebSocketEndpoint": isWebSocketEndpoint, }). Parse(readTemplate("request_init")), ) @@ -797,7 +799,7 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { "Args": args, "PathInit": routes[0].PathInit, "Verb": routes[0].Verb, - "IsStreaming": httpEndpoint.MethodExpr.IsStreaming(), + "IsWebSocket": httpEndpoint.MethodExpr.IsStreaming() && httpEndpoint.SSE == nil, } if httpEndpoint.SkipRequestBodyEncodeDecode { data["RequestStruct"] = pkg + "." + method.RequestStruct @@ -2774,13 +2776,8 @@ func upgradeParams(e *EndpointData, fn string) map[string]any { } } -// needStream returns true if at least one method in the defined services -// uses stream for sending payload/result. -func needStream(data []*ServiceData) bool { - for _, svc := range data { - if hasWebSocket(svc) { - return true - } - } - return false +// needDialer returns true if at least one method in the defined services +// uses WebSocket for sending payload or result. +func needDialer(data []*ServiceData) bool { + return slices.ContainsFunc(data, hasWebSocket) } diff --git a/http/codegen/sse.go b/http/codegen/sse.go index 32f54ea6cd..9494f1a98f 100644 --- a/http/codegen/sse.go +++ b/http/codegen/sse.go @@ -4,24 +4,22 @@ import ( "fmt" "path/filepath" + "slices" + "goa.design/goa/v3/codegen" "goa.design/goa/v3/expr" ) type ( // SSEData contains the data needed to render struct type that - // implements the server stream interface for SSE. + // implements the server and client stream interface for SSE. SSEData struct { - // VarName is the name of the struct. - VarName string + // StructName is the name of the generated struct which encapsulates the + // server implementation. + StructName string // Interface is the fully qualified name of the interface that // the struct implements. Interface string - // Endpoint is endpoint data that defines streaming result. - Endpoint *EndpointData - // Response is the successful response data for the streaming - // endpoint. - Response *ResponseData // SendName is the name of the send function. SendName string // SendDesc is the description for the send function. @@ -30,27 +28,33 @@ type ( SendWithContextName string // SendWithContextDesc is the description for the send function with context. SendWithContextDesc string - // SendTypeName is the fully qualified type name sent through - // the stream. - SendTypeName string - // SendTypeRef is the fully qualified type ref sent through the - // stream. - SendTypeRef string - - // PkgName is the service package name. - PkgName string - // SSEConfig contains the SSE configuration for this endpoint. - SSEConfig *expr.HTTPSSEExpr - // WriteHeaderName is the name of the WriteHeader function. - WriteHeaderName string - // WriteHeaderDesc is the description for the WriteHeader function. - WriteHeaderDesc string - - // DataFieldType is the type of the data field if SSEConfig.DataField is set. - // It's computed during initialization to avoid complex template logic. - DataFieldType expr.DataType - // ResultType is the type of the result. - ResultType expr.DataType + // RecvName is the name of the client method to connect to the SSE endpoint. + RecvName string + // RecvDesc is the description for the client method. + RecvDesc string + // EventTypeRef is the fully qualified type ref for the event type. + EventTypeRef string + // EventTypeName is the name of the event type without package qualifier. + EventTypeName string + // EventIsStruct indicates whether the SSE method return type is a struct. + EventIsStruct bool + // DataFieldTypeRef is the fully qualified type ref for the data field if any. + DataFieldTypeRef string + // DataField is the name of the result type event data attribute if any. + // If empty, the entire result type is used as the data field. + DataField string + // IDField is the name of the result type event ID attribute if any. + // If empty, no id field is included in the event. + IDField string + // EventField is the name of the result type event field if any. + // If empty, no event field is included in the event. + EventField string + // RetryField is the name of the result type event retry field if any. + // If empty, no retry field is included in the event. + RetryField string + // RequestIDField is the name of the payload field that maps to the Last-Event-ID header if any. + // If empty, no last event id is included in the request. + RequestIDField string } ) @@ -62,56 +66,43 @@ func initSSEData(ed *EndpointData, e *expr.HTTPEndpointExpr, sd *ServiceData) { md := ed.Method svc := sd.Service - svrSendTypeName := ed.Result.Name - svrSendTypeRef := ed.Result.Ref - svrSendDesc := fmt.Sprintf("%s streams instances of %q to the %q endpoint SSE connection.", md.ServerStream.SendName, svrSendTypeName, md.Name) - svrSendWithContextDesc := fmt.Sprintf("%s streams instances of %q to the %q endpoint SSE connection with context.", md.ServerStream.SendWithContextName, svrSendTypeName, md.Name) - writeHeaderDesc := fmt.Sprintf("%s writes the given header to the HTTP response.", "WriteHeader") + sendDesc := fmt.Sprintf("%s streams instances of %q to the %q endpoint SSE connection.", md.ServerStream.SendName, ed.Result.Name, md.Name) + sendWithContextDesc := fmt.Sprintf("%s streams instances of %q to the %q endpoint SSE connection with context.", md.ServerStream.SendWithContextName, ed.Result.Name, md.Name) + recvDesc := fmt.Sprintf("%s connects to the %q SSE endpoint and streams events.", md.ServerStream.RecvName, md.Name) - // Set the result type for use in the template - var resultType expr.DataType - if e.MethodExpr != nil && e.MethodExpr.Result != nil { - resultType = e.MethodExpr.Result.Type - } - - // Compute the data field type if a data field is specified - var dataFieldType expr.DataType - if e.SSE.DataField != "" && resultType != nil { - // If the result type is an object and has the data field, extract its type - if obj, ok := resultType.(*expr.Object); ok { + var dataFieldTypeRef string + if e.SSE.DataField != "" { + if obj, ok := e.MethodExpr.Result.Type.(*expr.Object); ok { for _, nat := range *obj { if nat.Name == e.SSE.DataField { - dataFieldType = nat.Attribute.Type + dataFieldTypeRef = sd.Service.Scope.GoFullTypeRef(nat.Attribute, svc.PkgName) break } } } } - // Create SSE data for server ed.SSE = &SSEData{ - VarName: md.ServerStream.VarName, + StructName: md.ServerStream.VarName, Interface: fmt.Sprintf("%s.%s", svc.PkgName, md.ServerStream.Interface), - Endpoint: ed, - Response: ed.Result.Responses[0], - PkgName: svc.PkgName, SendName: md.ServerStream.SendName, - SendDesc: svrSendDesc, + SendDesc: sendDesc, SendWithContextName: md.ServerStream.SendWithContextName, - SendWithContextDesc: svrSendWithContextDesc, - SendTypeName: svrSendTypeName, - SendTypeRef: svrSendTypeRef, - SSEConfig: e.SSE, - WriteHeaderName: "WriteHeader", - WriteHeaderDesc: writeHeaderDesc, - DataFieldType: dataFieldType, - ResultType: resultType, + SendWithContextDesc: sendWithContextDesc, + RecvName: md.ClientStream.RecvName, + RecvDesc: recvDesc, + EventTypeRef: ed.Result.Ref, + EventTypeName: ed.Result.Name, + EventIsStruct: ed.Result.IsStruct, + DataFieldTypeRef: dataFieldTypeRef, + DataField: e.SSE.DataField, + IDField: e.SSE.IDField, + EventField: e.SSE.EventField, + RetryField: e.SSE.RetryField, + RequestIDField: e.SSE.RequestIDField, } } -// We don't need the getPrimitiveFormatString function anymore -// since we're using a partial template for formatting - // sseServerFile returns the file implementing the SSE server // streaming implementation if any. func sseServerFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { @@ -145,7 +136,6 @@ func sseServerFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { {Path: "time"}, {Path: "encoding/json"}, {Path: "fmt"}, - {Path: "goa.design/goa/v3/http"}, {Path: genpkg + "/" + codegen.SnakeCase(svc.Name())}, {Path: genpkg + "/" + codegen.SnakeCase(svc.Name()) + "/views"}, }, @@ -164,12 +154,11 @@ func sseTemplateSections(data *ServiceData) []*codegen.SectionTemplate { } // Create a map of template functions needed for the SSE template funcs := map[string]interface{}{ - "add": func(a, b int) int { return a + b }, - "dict": func(values ...interface{}) (map[string]interface{}, error) { + "dict": func(values ...any) (map[string]any, error) { if len(values)%2 != 0 { return nil, fmt.Errorf("odd number of arguments") } - dict := make(map[string]interface{}, len(values)/2) + dict := make(map[string]any, len(values)/2) for i := 0; i < len(values); i += 2 { key, ok := values[i].(string) if !ok { @@ -179,24 +168,11 @@ func sseTemplateSections(data *ServiceData) []*codegen.SectionTemplate { } return dict, nil }, - "AsObject": func(dt expr.DataType) map[string]interface{} { - if obj, ok := dt.(*expr.Object); ok { - result := make(map[string]interface{}) - for _, nat := range *obj { - result[nat.Name] = map[string]interface{}{ - "Attribute": nat.Attribute, - "Name": nat.Name, - } - } - return result - } - return nil - }, } sections = append(sections, &codegen.SectionTemplate{ Name: "server-sse", Source: readTemplate("server_sse", "sse_format"), - Data: ed.SSE, + Data: ed, FuncMap: funcs, }) } @@ -208,3 +184,8 @@ func sseTemplateSections(data *ServiceData) []*codegen.SectionTemplate { func isSSEEndpoint(ed *EndpointData) bool { return ed.SSE != nil } + +// hasSSE returns true if at least one endpoint in the service uses SSE. +func hasSSE(data *ServiceData) bool { + return slices.ContainsFunc(data.Endpoints, isSSEEndpoint) +} diff --git a/http/codegen/sse_client.go b/http/codegen/sse_client.go new file mode 100644 index 0000000000..5852e93220 --- /dev/null +++ b/http/codegen/sse_client.go @@ -0,0 +1,85 @@ +package codegen + +import ( + "fmt" + "path/filepath" + + "goa.design/goa/v3/codegen" + "goa.design/goa/v3/expr" +) + +// sseClientFile returns the file implementing the SSE client code for SSE endpoints if any. +// Relies on SSEData (ed.SSE) for all codegen needs. +func sseClientFile(genpkg string, svc *expr.HTTPServiceExpr) *codegen.File { + data := HTTPServices.Get(svc.Name()) + if data == nil { + return nil + } + // Check if any endpoint has SSE + hasSSE := false + for _, ed := range data.Endpoints { + if ed.SSE != nil { + hasSSE = true + break + } + } + if !hasSSE { + return nil + } + path := filepath.Join(codegen.Gendir, "http", codegen.SnakeCase(svc.Name()), "client", "sse.go") + sections := []*codegen.SectionTemplate{ + codegen.Header( + "sse-client", + "client", + []*codegen.ImportSpec{ + {Path: "bytes"}, + {Path: "context"}, + {Path: "encoding/json"}, + {Path: "io"}, + {Path: "net/http"}, + {Path: "fmt"}, + {Path: "strings"}, + {Path: "strconv"}, + {Path: "sync"}, + {Path: genpkg + "/" + codegen.SnakeCase(svc.Name())}, + {Path: genpkg + "/" + codegen.SnakeCase(svc.Name()) + "/views"}, + }, + ), + } + sections = append(sections, sseClientTemplateSections(data)...) // add SSE client methods + return &codegen.File{Path: path, SectionTemplates: sections} +} + +// sseClientTemplateSections returns section templates for SSE client endpoints. +func sseClientTemplateSections(data *ServiceData) []*codegen.SectionTemplate { + sections := make([]*codegen.SectionTemplate, 0) + for _, ed := range data.Endpoints { + if ed.SSE == nil { + continue + } + // Create a map of template functions needed for the SSE template + funcs := map[string]interface{}{ + "dict": func(values ...any) (map[string]any, error) { + if len(values)%2 != 0 { + return nil, fmt.Errorf("odd number of arguments") + } + dict := make(map[string]any, len(values)/2) + for i := 0; i < len(values); i += 2 { + key, ok := values[i].(string) + if !ok { + return nil, fmt.Errorf("dict keys must be strings") + } + dict[key] = values[i+1] + } + return dict, nil + }, + } + sections = append(sections, &codegen.SectionTemplate{ + Name: "client-sse", + Source: readTemplate("client_sse", "sse_parse"), + Data: ed, + FuncMap: funcs, + }) + } + return sections +} diff --git a/http/codegen/templates/cli_end.go.tpl b/http/codegen/templates/cli_end.go.tpl index 1a006c4c3f..bb31b97dc4 100644 --- a/http/codegen/templates/cli_end.go.tpl +++ b/http/codegen/templates/cli_end.go.tpl @@ -5,7 +5,7 @@ return cli.ParseEndpoint( goahttp.RequestEncoder, goahttp.ResponseDecoder, debug, -{{- if needStream .Services }} +{{- if needDialer .Services }} dialer, {{- range $svc := .Services }} {{- if hasWebSocket $svc }} diff --git a/http/codegen/templates/cli_streaming.go.tpl b/http/codegen/templates/cli_streaming.go.tpl index a1cb16b441..076e22c2d7 100644 --- a/http/codegen/templates/cli_streaming.go.tpl +++ b/http/codegen/templates/cli_streaming.go.tpl @@ -1,4 +1,4 @@ -{{- if needStream .Services }} +{{- if needDialer .Services }} var ( dialer *websocket.Dialer ) diff --git a/http/codegen/templates/client_sse.go.tpl b/http/codegen/templates/client_sse.go.tpl new file mode 100644 index 0000000000..e767c433ee --- /dev/null +++ b/http/codegen/templates/client_sse.go.tpl @@ -0,0 +1,245 @@ +type ( + // {{ .Method.VarName }}StreamImpl implements the {{ .ServiceName }}.{{ .Method.VarName }}ClientStream interface. + {{ .Method.VarName }}StreamImpl struct { + resp *http.Response + buffer []byte // Buffer for unprocessed data + lock sync.Mutex + closed bool + } +) + +// {{ .Method.VarName }}StreamImpl implements the {{ .ServiceName }}.{{ .Method.VarName }}ClientStream interface. +var _ {{ .ServiceName }}.{{ .Method.VarName }}ClientStream = (*{{ .Method.VarName }}StreamImpl)(nil) + +// New{{ .Method.VarName }}Stream creates a new {{ .ServiceName }}.{{ .Method.VarName }}ClientStream. +func New{{ .Method.VarName }}Stream(resp *http.Response) {{ .ServiceName }}.{{ .Method.VarName }}ClientStream { + return &{{ .Method.VarName }}StreamImpl{ + resp: resp, + buffer: make([]byte, 0, 4096), // Pre-allocate buffer + } +} + +// Recv reads and returns the next event from the SSE stream. +func (s *{{ .Method.VarName }}StreamImpl) Recv() (event {{ .SSE.EventTypeRef }}, err error) { + return s.RecvWithContext(context.Background()) +} + +// RecvWithContext reads and returns the next event from the SSE stream, respecting context cancellation. +func (s *{{ .Method.VarName }}StreamImpl) RecvWithContext(ctx context.Context) (event {{ .SSE.EventTypeRef }}, err error) { + var byts []byte + byts, err = s.readEvent(ctx) + if err != nil { + if err == io.EOF || err == context.Canceled || err == context.DeadlineExceeded { + // Clean up on EOF or context cancellation + s.Close() + if err == io.EOF { + err = nil + } + } + return + } + return s.processEvent(byts) +} + +// readEvent reads a single SSE event from the stream, respecting context +// cancellation. It first checks the internal buffer for a complete event +// (delimited by double newlines). If no complete event is found, it reads from +// the HTTP response body until it either finds an event boundary, reaches EOF, +// or encounters an error. Any data after the event boundary is saved in the +// buffer for the next call. +func (s *{{ .Method.VarName }}StreamImpl) readEvent(ctx context.Context) ([]byte, error) { + const bufSize = 4096 // 4KB buffer size + + // Check for event in existing buffer + event, ok := s.checkBuffer() + if ok { + return event, nil + } + + // Initialize with any data from buffer + eventData := event + wasNewline := len(eventData) > 0 && eventData[len(eventData)-1] == '\n' + buf := make([]byte, bufSize) + + // Read data in chunks until we find an event or hit EOF + for { + // Check if context is done + select { + case <-ctx.Done(): + if len(eventData) > 0 { + return eventData, nil + } + return nil, ctx.Err() + default: + // Continue processing + } + + // Check if stream is closed + s.lock.Lock() + if s.closed { + s.lock.Unlock() + if len(eventData) > 0 { + return eventData, nil + } + return nil, io.EOF + } + + // Read next chunk + n, err := s.resp.Body.Read(buf) + s.lock.Unlock() + + // Handle read errors + if err != nil && err != io.EOF { + return nil, err + } + + // Process data if we got any + if n > 0 { + // Look for event boundary in this chunk + for i := 0; i < n; i++ { + b := buf[i] + eventData = append(eventData, b) + + // Check for double newlines (event boundary) + if b == '\n' && wasNewline { + // Save any remaining data for next read + if i+1 < n { + s.lock.Lock() + s.buffer = append(s.buffer[:0], buf[i+1:n]...) + s.lock.Unlock() + } + return eventData, nil + } + + // Update newline tracking + wasNewline = (b == '\n') + } + } + + // Return partial data at EOF + if err == io.EOF { + if len(eventData) > 0 { + return eventData, nil + } + return nil, io.EOF + } + } +} + +// checkBuffer examines the internal buffer for a complete SSE event (delimited +// by double newlines). It returns two values: the event data (or all buffer +// contents if no complete event is found), and a boolean indicating whether a +// complete event was found. If a complete event is found, any remaining data +// after the event is kept in the buffer for the next call. +func (s *{{ .Method.VarName }}StreamImpl) checkBuffer() ([]byte, bool) { + s.lock.Lock() + defer s.lock.Unlock() + + // Quick return if buffer is empty + if len(s.buffer) == 0 { + return nil, false + } + + // Look for double newline in buffer + for i := 0; i < len(s.buffer)-1; i++ { + if s.buffer[i] == '\n' && s.buffer[i+1] == '\n' { + // Found complete event + eventEnd := i + 2 // Include both newlines + eventData := s.buffer[:eventEnd] + + // Save remaining data for next time + if eventEnd < len(s.buffer) { + s.buffer = append(s.buffer[:0], s.buffer[eventEnd:]...) + } else { + s.buffer = s.buffer[:0] + } + + return eventData, true + } + } + + // No complete event found, return buffer contents + eventData := s.buffer + s.buffer = s.buffer[:0] // Clear buffer but keep capacity + return eventData, false +} + +// Close closes the SSE stream and releases any associated resources. +func (s *{{ .Method.VarName }}StreamImpl) Close() error { + s.lock.Lock() + defer s.lock.Unlock() + if s.closed { + return nil + } + s.closed = true + return s.resp.Body.Close() +} + +// processEvent processes a raw SSE event into the expected type +func (s *{{ .Method.VarName }}StreamImpl) processEvent(eventData []byte) (event {{ .SSE.EventTypeRef }}, err error) { + {{- if .SSE.EventIsStruct }} + event = new({{ .SSE.EventTypeName }}) + {{- end }} + {{- if .SSE.IDField }} + var id string + {{- end }} + {{- if .SSE.EventField }} + var eventType string + {{- end }} + {{- if .SSE.RetryField }} + var retry int + {{- end }} + var dataLines []string + for _, line := range bytes.Split(eventData, []byte("\n")) { + if len(line) == 0 { + continue + } + if bytes.HasPrefix(line, []byte("data:")) { + dataLines = append(dataLines, s.trimHeader(len("data:"), line)) + continue + } + + {{- if .SSE.IDField }} + if bytes.HasPrefix(line, []byte("id:")) { + event.{{ .SSE.IDField }} = s.trimHeader(len("id:"), line) + continue + } + {{- end }} + + {{- if .SSE.EventField }} + if bytes.HasPrefix(line, []byte("event:")) { + event.{{ .SSE.EventField }} = s.trimHeader(len("event:"), line) + continue + } + {{- end }} + + {{- if .SSE.RetryField }} + if bytes.HasPrefix(line, []byte("retry:")) { + event.{{ .SSE.RetryField }} = s.trimHeader(len("retry:"), line) + continue + } + {{- end }} + } + if len(dataLines) > 0 { + dataContent := strings.Join(dataLines, "\n") + {{- if .SSE.DataField }} + {{ template "partial_sse_parse" dict "Target" (printf "event.%s" .SSE.DataField) "TypeRef" .SSE.DataFieldTypeRef }} + {{- else }} + {{ template "partial_sse_parse" dict "Target" "event" "TypeRef" .SSE.EventTypeRef }} + {{- end }} + } + + return +} + +// trimHeader removes the header prefix and optional leading space +func (s *{{ .Method.VarName }}StreamImpl) trimHeader(size int, data []byte) string { + if len(data) < size { + return string(data) + } + data = data[size:] + if len(data) > 0 && data[0] == ' ' { + data = data[1:] + } + return string(data) +} diff --git a/http/codegen/templates/endpoint_init.go.tpl b/http/codegen/templates/endpoint_init.go.tpl index b0f117db06..e2c5ad0a96 100644 --- a/http/codegen/templates/endpoint_init.go.tpl +++ b/http/codegen/templates/endpoint_init.go.tpl @@ -8,7 +8,9 @@ func (c *{{ .ClientStruct }}) {{ .EndpointInit }}({{ if .MultipartRequestEncoder encodeRequest = {{ .RequestEncoder }}({{ if .MultipartRequestEncoder }}{{ .MultipartRequestEncoder.InitName }}({{ .MultipartRequestEncoder.VarName }}){{ else }}c.encoder{{ end }}) {{- end }} {{- end }} + {{- if not (isSSEEndpoint .) }} decodeResponse = {{ .ResponseDecoder }}(c.decoder, c.RestoreResponseBody) + {{- end }} ) return func(ctx context.Context, v any) (any, error) { req, err := c.{{ .RequestInit.Name }}(ctx, {{ range .RequestInit.ClientArgs }}{{ .Ref }}, {{ end }}) @@ -58,6 +60,25 @@ func (c *{{ .ClientStruct }}) {{ .EndpointInit }}({{ if .MultipartRequestEncoder {{- end }} {{- end }} return stream, nil + {{- else if isSSEEndpoint . }} + // For SSE endpoints, connect and return a stream + resp, err := c.{{ .Method.VarName }}Doer.Do(req) + if err != nil { + return nil, goahttp.ErrRequestError("{{ .ServiceName }}", "{{ .Method.Name }}", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("unexpected status from SSE endpoint: %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if contentType != "" && !strings.HasPrefix(contentType, "text/event-stream") { + resp.Body.Close() + return nil, fmt.Errorf("unexpected content type: %s (expected text/event-stream)", contentType) + } + + return New{{ .Method.VarName }}Stream(resp), nil {{- else }} resp, err := c.{{ .Method.VarName }}Doer.Do(req) if err != nil { diff --git a/http/codegen/templates/parse_endpoint.go.tpl b/http/codegen/templates/parse_endpoint.go.tpl index 55ef173574..16f1ab1626 100644 --- a/http/codegen/templates/parse_endpoint.go.tpl +++ b/http/codegen/templates/parse_endpoint.go.tpl @@ -9,7 +9,7 @@ func ParseEndpoint( {{- if streamingCmdExists .Commands }} dialer goahttp.Dialer, {{- range .Commands }} - {{- if .NeedStream }} + {{- if .NeedDialer }} {{ .VarName }}Configurer *{{ .PkgName }}.ConnConfigurer, {{- end }} {{- end }} @@ -35,7 +35,7 @@ func ParseEndpoint( switch svcn { {{- range .Commands }} case "{{ .Name }}": - c := {{ .PkgName }}.NewClient(scheme, host, doer, enc, dec, restore{{ if .NeedStream }}, dialer, {{ .VarName }}Configurer{{ end }}) + c := {{ .PkgName }}.NewClient(scheme, host, doer, enc, dec, restore{{ if .NeedDialer }}, dialer, {{ .VarName }}Configurer{{ end }}) switch epn { {{- $pkgName := .PkgName }} {{- range .Subcommands }} diff --git a/http/codegen/templates/partial/sse_format.go.tpl b/http/codegen/templates/partial/sse_format.go.tpl index 37d33acdd7..1745c21900 100644 --- a/http/codegen/templates/partial/sse_format.go.tpl +++ b/http/codegen/templates/partial/sse_format.go.tpl @@ -1,16 +1,16 @@ -{{- if eq .Type.Name "string" }} +{{- if eq .TypeRef "string" }} data = {{ .VarName }} -{{- else if eq .Type.Name "boolean" }} +{{- else if eq .TypeRef "boolean" }} if {{ .VarName }} { data = "true" } else { data = "false" } -{{- else if eq .Type.Name "bytes" }} +{{- else if eq .TypeRef "bytes" }} data = string({{ .VarName }}) -{{- else if or (eq .Type.Name "int") (eq .Type.Name "int32") (eq .Type.Name "int64") (eq .Type.Name "uint") (eq .Type.Name "uint32") (eq .Type.Name "uint64") }} +{{- else if or (eq .TypeRef "int") (eq .TypeRef "int32") (eq .TypeRef "int64") (eq .TypeRef "uint") (eq .TypeRef "uint32") (eq .TypeRef "uint64") }} data = fmt.Sprintf("%d", {{ .VarName }}) -{{- else if or (eq .Type.Name "float32") (eq .Type.Name "float64") }} +{{- else if or (eq .TypeRef "float32") (eq .TypeRef "float64") }} data = fmt.Sprintf("%g", {{ .VarName }}) {{- else }} byts, err := json.Marshal({{ .VarName }}) diff --git a/http/codegen/templates/partial/sse_parse.go.tpl b/http/codegen/templates/partial/sse_parse.go.tpl new file mode 100644 index 0000000000..b9eb524580 --- /dev/null +++ b/http/codegen/templates/partial/sse_parse.go.tpl @@ -0,0 +1,53 @@ +{{- if eq .TypeRef "string" }} + {{ .Target }} = dataContent +{{- else if eq .TypeRef "boolean" }} + var val bool + val, err = strconv.ParseBool(dataContent) + if err != nil { + return + } + {{ .Target }} = val +{{- else if eq .TypeRef "bytes" }} + {{ .Target }} = []byte(dataContent) +{{- else if or (eq .TypeRef "int") (eq .TypeRef "int32") }} + var val int64 + val, err = strconv.ParseInt(dataContent, 10, 0) + if err != nil { + return + } + {{ .Target }} = {{ .TypeRef }}(val) +{{- else if eq .TypeRef "int64" }} + {{ .Target }}, err = strconv.ParseInt(dataContent, 10, 64) + if err != nil { + return + } +{{- else if or (eq .TypeRef "uint") (eq .TypeRef "uint32") }} + var val uint64 + val, err = strconv.ParseUint(dataContent, 10, 0) + if err != nil { + return + } + {{ .Target }} = {{ .TypeRef }}(val) +{{- else if eq .TypeRef "uint64" }} + {{ .Target }}, err = strconv.ParseUint(dataContent, 10, 64) + if err != nil { + return + } +{{- else if eq .TypeRef "float32" }} + var val float64 + val, err = strconv.ParseFloat(dataContent, 32) + if err != nil { + return + } + {{ .Target }} = float32(val) +{{- else if eq .TypeRef "float64" }} + {{ .Target }}, err = strconv.ParseFloat(dataContent, 64) + if err != nil { + return + } +{{- else }} + err = json.Unmarshal([]byte(dataContent), &{{ .Target }}) + if err != nil { + return + } +{{- end }} \ No newline at end of file diff --git a/http/codegen/templates/request_init.go.tpl b/http/codegen/templates/request_init.go.tpl index 09e19138bb..c6306197e0 100644 --- a/http/codegen/templates/request_init.go.tpl +++ b/http/codegen/templates/request_init.go.tpl @@ -45,7 +45,7 @@ } body = rd.Body {{- end }} - {{- if .IsStreaming }} + {{- if .IsWebSocket }} scheme := c.scheme switch c.scheme { case "http": @@ -54,7 +54,7 @@ scheme = "wss" } {{- end }} - u := &url.URL{Scheme: {{ if .IsStreaming }}scheme{{ else }}c.scheme{{ end }}, Host: c.host, Path: {{ .PathInit.Name }}({{ range .Args }}{{ .Ref }}, {{ end }})} + u := &url.URL{Scheme: {{ if .IsWebSocket }}scheme{{ else }}c.scheme{{ end }}, Host: c.host, Path: {{ .PathInit.Name }}({{ range .Args }}{{ .Ref }}, {{ end }})} req, err := http.NewRequest("{{ .Verb }}", u.String(), {{ if .RequestStruct }}body{{ else }}nil{{ end }}) if err != nil { return nil, goahttp.ErrInvalidURL("{{ .ServiceName }}", "{{ .EndpointName }}", u.String(), err) diff --git a/http/codegen/templates/server_configure.go.tpl b/http/codegen/templates/server_configure.go.tpl index 1f186a9a85..c337e9e6f4 100644 --- a/http/codegen/templates/server_configure.go.tpl +++ b/http/codegen/templates/server_configure.go.tpl @@ -10,7 +10,7 @@ ) { eh := errorHandler(ctx) - {{- if needStream .Services }} + {{- if needDialer .Services }} upgrader := &websocket.Upgrader{} {{- end }} {{- range $svc := .Services }} diff --git a/http/codegen/templates/server_handler_init.go.tpl b/http/codegen/templates/server_handler_init.go.tpl index 36c39fc0b1..cd2b4512ec 100644 --- a/http/codegen/templates/server_handler_init.go.tpl +++ b/http/codegen/templates/server_handler_init.go.tpl @@ -59,21 +59,21 @@ func {{ .HandlerInit }}( } _, err = endpoint(ctx, v) {{- else if isSSEEndpoint . }} - {{- if .SSE.SSEConfig.RequestIDField }} + {{- if .SSE.RequestIDField }} // Set Last-Event-ID header if present if lastEventID := r.Header.Get("Last-Event-ID"); lastEventID != "" { ctx = context.WithValue(ctx, "last-event-id", lastEventID) {{- if .Payload.Ref }} {{- if eq .Method.Payload.Type.Name "Object" }} p := payload.({{ .Payload.Ref }}) - p.{{ .SSE.SSEConfig.RequestIDField }} = lastEventID + p.{{ .SSE.RequestIDField }} = lastEventID payload = p {{- end }} {{- end }} } {{- end }} v := &{{ .ServicePkgName }}.{{ .Method.ServerStream.EndpointStruct }}{ - Stream: &{{ .SSE.VarName }}{ + Stream: &{{ .SSE.StructName }}{ w: w, r: r, }, @@ -99,10 +99,6 @@ func {{ .HandlerInit }}( return } {{- end }} - {{- if isSSEEndpoint . }} - // For SSE, we need to set appropriate error headers - w.Header().Set("Content-Type", "application/json") - {{- end }} if err := encodeError(ctx, w, err); err != nil { errhandler(ctx, w, err) } diff --git a/http/codegen/templates/server_sse.go.tpl b/http/codegen/templates/server_sse.go.tpl index 31c3c1c2c7..2135c23585 100644 --- a/http/codegen/templates/server_sse.go.tpl +++ b/http/codegen/templates/server_sse.go.tpl @@ -1,22 +1,21 @@ -{{ printf "%s implements the %s interface using Server-Sent Events." .VarName .Interface | comment }} -type {{ .VarName }} struct { - {{ printf "once ensures the headers are written once." | comment }} +{{ printf "%s implements the %s interface using Server-Sent Events." .SSE.StructName .SSE.Interface | comment }} +type {{ .SSE.StructName }} struct { + {{ comment "once ensures the headers are written once." }} once sync.Once - {{ printf "w is the HTTP response writer used to send the SSE events." | comment }} + {{ comment "w is the HTTP response writer used to send the SSE events." }} w http.ResponseWriter - {{ printf "r is the HTTP request." | comment }} + {{ comment "r is the HTTP request." }} r *http.Request } -{{ printf "%s %s" .SendName .SendDesc | comment }} -func (s *{{ .VarName }}) {{ .SendName }}(v {{ .SendTypeRef }}) error { - return s.{{ .SendWithContextName }}(context.Background(), v) +{{ printf "%s %s" .SSE.SendName .SSE.SendDesc | comment }} +func (s *{{ .SSE.StructName }}) {{ .SSE.SendName }}(v {{ .SSE.EventTypeRef }}) error { + return s.{{ .SSE.SendWithContextName }}(context.Background(), v) } -{{ printf "%s %s" .SendWithContextName .SendWithContextDesc | comment }} -func (s *{{ .VarName }}) {{ .SendWithContextName }}(ctx context.Context, v {{ .SendTypeRef }}) error { +{{ printf "%s %s" .SSE.SendWithContextName .SSE.SendWithContextDesc | comment }} +func (s *{{ .SSE.StructName }}) {{ .SSE.SendWithContextName }}(ctx context.Context, v {{ .SSE.EventTypeRef }}) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -28,77 +27,52 @@ func (s *{{ .VarName }}) {{ .SendWithContextName }}(ctx context.Context, v {{ .S header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) - {{- if .Endpoint.Method.ViewedResult }} - {{- if .Endpoint.Method.ViewedResult.ViewName }} - res := {{ .PkgName }}.{{ .Endpoint.Method.ViewedResult.Init.Name }}(v, {{ printf "%q" .Endpoint.Method.ViewedResult.ViewName }}) + {{- if .Method.ViewedResult }} + {{- if .Method.ViewedResult.ViewName }} + res := {{ .Service.PkgName }}.{{ .Method.ViewedResult.Init.Name }}(v, {{ printf "%q" .Method.ViewedResult.ViewName }}) {{- else }} - res := {{ .PkgName }}.{{ .Endpoint.Method.ViewedResult.Init.Name }}(v, "default") + res := {{ .Service.PkgName }}.{{ .Method.ViewedResult.Init.Name }}(v, "default") {{- end }} {{- else }} res := v {{- end }} - {{ if .SSEConfig.IDField }} - id := res.{{ .SSEConfig.IDField }} - if id != "" { + {{ if .SSE.IDField }} + if id := res.{{ .SSE.IDField }}; id != "" { fmt.Fprintf(s.w, "id: %s\n", id) } {{- end }} - {{ if .SSEConfig.EventField }} - eventType := res.{{ .SSEConfig.EventField }} - if eventType != "" { - fmt.Fprintf(s.w, "event: %s\n", eventType) + {{- if .SSE.EventField }} + if event := res.{{ .SSE.EventField }}; event != "" { + fmt.Fprintf(s.w, "event: %s\n", event) } {{- end }} - {{ if .SSEConfig.RetryField }} - retry := res.{{ .SSEConfig.RetryField }} - if retry > 0 { + {{- if .SSE.RetryField }} + if retry := res.{{ .SSE.RetryField }}; retry > 0 { fmt.Fprintf(s.w, "retry: %d\n", retry) } {{- end }} - {{ if .SSEConfig.DataField }} var data string - dataField := res.{{ .SSEConfig.DataField }} - {{- if .DataFieldType }} - {{- template "partial_sse_format" dict "Type" .DataFieldType "VarName" "dataField" }} - {{- else }} - byts, err := json.Marshal(dataField) - if err != nil { - return err - } - data = string(byts) - {{- end }} - fmt.Fprintf(s.w, "data: %s\n\n", data) + {{- if .SSE.DataField }} + dataField := res.{{ .SSE.DataField }} + {{- template "partial_sse_format" dict "TypeRef" .SSE.DataFieldTypeRef "VarName" "dataField" }} {{- else }} - var data string - {{- if .ResultType }} - {{- template "partial_sse_format" dict "Type" .ResultType "VarName" "res" }} - {{- else }} - byts, err := json.Marshal(res) - if err != nil { - return err - } - data = string(byts) + {{- template "partial_sse_format" dict "TypeRef" .SSE.EventTypeRef "VarName" "res" }} {{- end }} fmt.Fprintf(s.w, "data: %s\n\n", data) - {{- end }} if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -{{ printf "WriteHeader writes the given header to the HTTP response." | comment }} -func (s *{{ .VarName }}) {{ .WriteHeaderName }}(key, value string) { - s.w.Header().Set(key, value) +{{ comment "Close is a no-op for SSE. We keep the method for compatibility with other stream types." }} +func (s *{{ .SSE.StructName }}) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-all-fields.golden b/http/codegen/testdata/golden/sse-all-fields.golden index 4a50f7c0b9..6d0df711e6 100644 --- a/http/codegen/testdata/golden/sse-all-fields.golden +++ b/http/codegen/testdata/golden/sse-all-fields.golden @@ -22,7 +22,6 @@ func (s *SSEAllFieldsMethodServerStream) Send(v *sseallfieldsservice.SSEAllField // endpoint SSE connection with context. func (s *SSEAllFieldsMethodServerStream) SendWithContext(ctx context.Context, v *sseallfieldsservice.SSEAllFieldsMethodResult) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -34,24 +33,16 @@ func (s *SSEAllFieldsMethodServerStream) SendWithContext(ctx context.Context, v header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v - id := res.id - if id != "" { + if id := res.id; id != "" { fmt.Fprintf(s.w, "id: %s\n", id) } - - eventType := res.event - if eventType != "" { - fmt.Fprintf(s.w, "event: %s\n", eventType) + if event := res.event; event != "" { + fmt.Fprintf(s.w, "event: %s\n", event) } - - retry := res.retry - if retry > 0 { + if retry := res.retry; retry > 0 { fmt.Fprintf(s.w, "retry: %d\n", retry) } @@ -67,11 +58,11 @@ func (s *SSEAllFieldsMethodServerStream) SendWithContext(ctx context.Context, v if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSEAllFieldsMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSEAllFieldsMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-bool.golden b/http/codegen/testdata/golden/sse-bool.golden index 8580e30421..b58d798caf 100644 --- a/http/codegen/testdata/golden/sse-bool.golden +++ b/http/codegen/testdata/golden/sse-bool.golden @@ -19,7 +19,6 @@ func (s *SSEBoolMethodServerStream) Send(v bool) error { // "SSEBoolMethod" endpoint SSE connection with context. func (s *SSEBoolMethodServerStream) SendWithContext(ctx context.Context, v bool) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -31,28 +30,25 @@ func (s *SSEBoolMethodServerStream) SendWithContext(ctx context.Context, v bool) header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v var data string - if res { - data = "true" - } else { - data = "false" + byts, err := json.Marshal(res) + if err != nil { + return err } + data = string(byts) fmt.Fprintf(s.w, "data: %s\n\n", data) if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSEBoolMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSEBoolMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-data-field.golden b/http/codegen/testdata/golden/sse-data-field.golden index 50d5dafd9e..a09ace1c5b 100644 --- a/http/codegen/testdata/golden/sse-data-field.golden +++ b/http/codegen/testdata/golden/sse-data-field.golden @@ -22,7 +22,6 @@ func (s *SSEDataFieldMethodServerStream) Send(v *ssedatafieldservice.SSEDataFiel // endpoint SSE connection with context. func (s *SSEDataFieldMethodServerStream) SendWithContext(ctx context.Context, v *ssedatafieldservice.SSEDataFieldMethodResult) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -34,9 +33,6 @@ func (s *SSEDataFieldMethodServerStream) SendWithContext(ctx context.Context, v header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v @@ -52,11 +48,11 @@ func (s *SSEDataFieldMethodServerStream) SendWithContext(ctx context.Context, v if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSEDataFieldMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSEDataFieldMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-data-id-field.golden b/http/codegen/testdata/golden/sse-data-id-field.golden index 5afef3eda4..a1ee61b4a8 100644 --- a/http/codegen/testdata/golden/sse-data-id-field.golden +++ b/http/codegen/testdata/golden/sse-data-id-field.golden @@ -22,7 +22,6 @@ func (s *SSEDataIDFieldMethodServerStream) Send(v *ssedataidfieldservice.SSEData // "SSEDataIDFieldMethod" endpoint SSE connection with context. func (s *SSEDataIDFieldMethodServerStream) SendWithContext(ctx context.Context, v *ssedataidfieldservice.SSEDataIDFieldMethodResult) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -34,14 +33,10 @@ func (s *SSEDataIDFieldMethodServerStream) SendWithContext(ctx context.Context, header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v - id := res.id - if id != "" { + if id := res.id; id != "" { fmt.Fprintf(s.w, "id: %s\n", id) } @@ -57,11 +52,11 @@ func (s *SSEDataIDFieldMethodServerStream) SendWithContext(ctx context.Context, if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSEDataIDFieldMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSEDataIDFieldMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-int.golden b/http/codegen/testdata/golden/sse-int.golden index b09e8073b9..d789c17921 100644 --- a/http/codegen/testdata/golden/sse-int.golden +++ b/http/codegen/testdata/golden/sse-int.golden @@ -19,7 +19,6 @@ func (s *SSEIntMethodServerStream) Send(v int) error { // "SSEIntMethod" endpoint SSE connection with context. func (s *SSEIntMethodServerStream) SendWithContext(ctx context.Context, v int) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -31,9 +30,6 @@ func (s *SSEIntMethodServerStream) SendWithContext(ctx context.Context, v int) e header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v @@ -44,11 +40,11 @@ func (s *SSEIntMethodServerStream) SendWithContext(ctx context.Context, v int) e if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSEIntMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSEIntMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-object.golden b/http/codegen/testdata/golden/sse-object.golden index d65fb33900..a5f03b2cbd 100644 --- a/http/codegen/testdata/golden/sse-object.golden +++ b/http/codegen/testdata/golden/sse-object.golden @@ -21,7 +21,6 @@ func (s *SSEObjectMethodServerStream) Send(v *sseobjectservice.SSEObjectMethodRe // SSE connection with context. func (s *SSEObjectMethodServerStream) SendWithContext(ctx context.Context, v *sseobjectservice.SSEObjectMethodResult) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -33,9 +32,6 @@ func (s *SSEObjectMethodServerStream) SendWithContext(ctx context.Context, v *ss header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v @@ -50,11 +46,11 @@ func (s *SSEObjectMethodServerStream) SendWithContext(ctx context.Context, v *ss if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSEObjectMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSEObjectMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-request-id.golden b/http/codegen/testdata/golden/sse-request-id.golden index 0b05ceba34..96890d35ce 100644 --- a/http/codegen/testdata/golden/sse-request-id.golden +++ b/http/codegen/testdata/golden/sse-request-id.golden @@ -20,7 +20,6 @@ func (s *SSERequestIDMethodServerStream) Send(v string) error { // "SSERequestIDMethod" endpoint SSE connection with context. func (s *SSERequestIDMethodServerStream) SendWithContext(ctx context.Context, v string) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -32,9 +31,6 @@ func (s *SSERequestIDMethodServerStream) SendWithContext(ctx context.Context, v header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v @@ -45,11 +41,11 @@ func (s *SSERequestIDMethodServerStream) SendWithContext(ctx context.Context, v if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSERequestIDMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSERequestIDMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/testdata/golden/sse-string.golden b/http/codegen/testdata/golden/sse-string.golden index a63572061c..e013da59c4 100644 --- a/http/codegen/testdata/golden/sse-string.golden +++ b/http/codegen/testdata/golden/sse-string.golden @@ -20,7 +20,6 @@ func (s *SSEStringMethodServerStream) Send(v string) error { // "SSEStringMethod" endpoint SSE connection with context. func (s *SSEStringMethodServerStream) SendWithContext(ctx context.Context, v string) error { s.once.Do(func() { - // Set default SSE headers if not already set header := s.w.Header() if header.Get("Content-Type") == "" { header.Set("Content-Type", "text/event-stream") @@ -32,9 +31,6 @@ func (s *SSEStringMethodServerStream) SendWithContext(ctx context.Context, v str header.Set("Connection", "keep-alive") } s.w.WriteHeader(http.StatusOK) - if f, ok := s.w.(http.Flusher); ok { - f.Flush() - } }) res := v @@ -45,11 +41,11 @@ func (s *SSEStringMethodServerStream) SendWithContext(ctx context.Context, v str if f, ok := s.w.(http.Flusher); ok { f.Flush() } - return nil } -// WriteHeader writes the given header to the HTTP response. -func (s *SSEStringMethodServerStream) WriteHeader(key, value string) { - s.w.Header().Set(key, value) +// Close is a no-op for SSE. We keep the method for compatibility with other +// stream types. +func (s *SSEStringMethodServerStream) Close() error { + return nil } diff --git a/http/codegen/websocket.go b/http/codegen/websocket.go index 6be5e5c673..d4304cc3b7 100644 --- a/http/codegen/websocket.go +++ b/http/codegen/websocket.go @@ -3,6 +3,7 @@ package codegen import ( "fmt" "path/filepath" + "slices" "strings" "goa.design/goa/v3/codegen" @@ -454,12 +455,7 @@ func clientWSSections(data *ServiceData) []*codegen.SectionTemplate { // hasWebSocket returns true if at least one of the endpoints in the service // defines a streaming payload or result. func hasWebSocket(sd *ServiceData) bool { - for _, e := range sd.Endpoints { - if isWebSocketEndpoint(e) { - return true - } - } - return false + return slices.ContainsFunc(sd.Endpoints, isWebSocketEndpoint) } // isWebSocketEndpoint returns true if the endpoint defines a streaming payload From fe107128e48ab534ceedc99a41ea2ee30f97ae8f Mon Sep 17 00:00:00 2001 From: Raphael Simon Date: Sat, 3 May 2025 15:42:17 -0700 Subject: [PATCH 4/4] Fix dup check --- dsl/sse.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/dsl/sse.go b/dsl/sse.go index 07482eca6c..7d47684dca 100644 --- a/dsl/sse.go +++ b/dsl/sse.go @@ -80,12 +80,6 @@ func ServerSentEvents(args ...any) { eval.TooManyArgError() return } - if len(args) == 2 { - if _, ok := args[1].(func()); !ok { - eval.InvalidArgError("function", args[1]) - return - } - } var fn func() var dataField string @@ -102,6 +96,10 @@ func ServerSentEvents(args ...any) { return } if len(args) == 2 { + if fn != nil { + eval.TooManyArgError() + return + } var ok bool fn, ok = args[1].(func()) if !ok {