Skip to content

Commit b7cda47

Browse files
committed
feat(middleware): add Connect interceptors and HTTP middleware
New middleware package providing ConnectRPC interceptors and net/http middleware for common cross-cutting concerns: - middleware/recovery: panic recovery with logging - middleware/requestid: X-Request-ID propagation/generation - middleware/requestlog: request logging with duration and status - middleware/errorz: error sanitization for client-facing responses - middleware/cors: CORS with Connect-specific header defaults Chain builders: Default(logger) returns the standard interceptor chain, DefaultHTTP(logger) returns the standard HTTP middleware chain, ChainHTTP() composes middleware in order.
1 parent c483409 commit b7cda47

9 files changed

Lines changed: 642 additions & 4 deletions

File tree

go.mod

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module github.com/raystack/salt
22

3-
go 1.22
3+
go 1.24.0
44

55
require (
6+
connectrpc.com/connect v1.19.1
67
github.com/AlecAivazis/survey/v2 v2.3.6
78
github.com/MakeNowJust/heredoc v1.0.0
89
github.com/NYTimes/gziphandler v1.1.1
@@ -56,7 +57,7 @@ require (
5657
github.com/tidwall/match v1.1.1 // indirect
5758
github.com/tidwall/pretty v1.2.1 // indirect
5859
github.com/tidwall/sjson v1.2.5 // indirect
59-
google.golang.org/protobuf v1.35.1 // indirect
60+
google.golang.org/protobuf v1.36.9 // indirect
6061
)
6162

6263
require (

go.sum

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
22
cloud.google.com/go/compute/metadata v0.5.0 h1:Zr0eK8JbFv6+Wi4ilXAR8FJ3wyNdpxHKJNPos6LTZOY=
33
cloud.google.com/go/compute/metadata v0.5.0/go.mod h1:aHnloV2TPI38yx4s9+wAZhHykWvVCfu7hQbF+9CWoiY=
4+
connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14=
5+
connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w=
46
github.com/AlecAivazis/survey/v2 v2.3.6 h1:NvTuVHISgTHEHeBFqt6BHOe4Ny/NwGZr7w+F8S9ziyw=
57
github.com/AlecAivazis/survey/v2 v2.3.6/go.mod h1:4AuI9b7RjAR+G7v9+C4YSlX/YL3K3cWNXgWXOhllqvI=
68
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
@@ -548,8 +550,8 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD
548550
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
549551
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
550552
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
551-
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
552-
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
553+
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
554+
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
553555
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
554556
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
555557
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

middleware/cors/cors.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package cors
2+
3+
import (
4+
"net/http"
5+
"strconv"
6+
"strings"
7+
)
8+
9+
// Option configures the CORS middleware.
10+
type Option func(*config)
11+
12+
type config struct {
13+
allowedOrigins []string
14+
allowedMethods []string
15+
allowedHeaders []string
16+
maxAge int
17+
}
18+
19+
// WithAllowedOrigins sets the allowed origins. Use "*" to allow all.
20+
func WithAllowedOrigins(origins ...string) Option {
21+
return func(c *config) { c.allowedOrigins = origins }
22+
}
23+
24+
// WithAllowedMethods sets the allowed HTTP methods.
25+
func WithAllowedMethods(methods ...string) Option {
26+
return func(c *config) { c.allowedMethods = methods }
27+
}
28+
29+
// WithAllowedHeaders sets the allowed request headers.
30+
func WithAllowedHeaders(headers ...string) Option {
31+
return func(c *config) { c.allowedHeaders = headers }
32+
}
33+
34+
// WithMaxAge sets the max age (in seconds) for preflight cache.
35+
func WithMaxAge(seconds int) Option {
36+
return func(c *config) { c.maxAge = seconds }
37+
}
38+
39+
// Defaults returns sensible CORS defaults for ConnectRPC services.
40+
// Includes Connect-specific headers.
41+
func Defaults() []Option {
42+
return []Option{
43+
WithAllowedOrigins("*"),
44+
WithAllowedMethods("GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"),
45+
WithAllowedHeaders(
46+
"Content-Type",
47+
"Connect-Protocol-Version",
48+
"Connect-Timeout-Ms",
49+
"Grpc-Timeout",
50+
"X-Grpc-Web",
51+
"X-User-Agent",
52+
"X-Request-ID",
53+
"Authorization",
54+
),
55+
WithMaxAge(7200),
56+
}
57+
}
58+
59+
func newConfig(opts []Option) *config {
60+
c := &config{}
61+
// Apply defaults first, then user overrides.
62+
for _, opt := range Defaults() {
63+
opt(c)
64+
}
65+
for _, opt := range opts {
66+
opt(c)
67+
}
68+
return c
69+
}
70+
71+
// Middleware returns net/http CORS middleware.
72+
func Middleware(opts ...Option) func(http.Handler) http.Handler {
73+
cfg := newConfig(opts)
74+
return func(next http.Handler) http.Handler {
75+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76+
origin := r.Header.Get("Origin")
77+
if origin == "" {
78+
next.ServeHTTP(w, r)
79+
return
80+
}
81+
82+
if isOriginAllowed(cfg.allowedOrigins, origin) {
83+
w.Header().Set("Access-Control-Allow-Origin", origin)
84+
}
85+
86+
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.allowedMethods, ", "))
87+
w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.allowedHeaders, ", "))
88+
89+
if cfg.maxAge > 0 {
90+
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(cfg.maxAge))
91+
}
92+
93+
w.Header().Set("Vary", "Origin")
94+
95+
if r.Method == http.MethodOptions {
96+
w.WriteHeader(http.StatusNoContent)
97+
return
98+
}
99+
100+
next.ServeHTTP(w, r)
101+
})
102+
}
103+
}
104+
105+
func isOriginAllowed(allowed []string, origin string) bool {
106+
for _, a := range allowed {
107+
if a == "*" || a == origin {
108+
return true
109+
}
110+
}
111+
return false
112+
}

middleware/errorz/errorz.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package errorz
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"time"
8+
9+
"connectrpc.com/connect"
10+
"github.com/raystack/salt/logger"
11+
)
12+
13+
// Option configures the error sanitization middleware.
14+
type Option func(*config)
15+
16+
type config struct {
17+
verbose bool
18+
logger logger.Logger
19+
}
20+
21+
// WithVerbose enables full error messages in responses.
22+
// Useful for development/staging environments.
23+
func WithVerbose(v bool) Option {
24+
return func(c *config) { c.verbose = v }
25+
}
26+
27+
// WithLogger sets the logger for recording original errors before sanitization.
28+
func WithLogger(l logger.Logger) Option {
29+
return func(c *config) { c.logger = l }
30+
}
31+
32+
func newConfig(opts []Option) *config {
33+
c := &config{logger: &logger.Noop{}}
34+
for _, opt := range opts {
35+
opt(c)
36+
}
37+
return c
38+
}
39+
40+
// NewInterceptor returns a Connect interceptor that sanitizes internal errors.
41+
// Non-Connect errors are mapped to CodeInternal with a timestamp reference.
42+
// Connect errors with known codes are passed through.
43+
func NewInterceptor(opts ...Option) connect.UnaryInterceptorFunc {
44+
cfg := newConfig(opts)
45+
return func(next connect.UnaryFunc) connect.UnaryFunc {
46+
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
47+
resp, err := next(ctx, req)
48+
if err == nil {
49+
return resp, nil
50+
}
51+
52+
// If it's already a Connect error, preserve the code.
53+
var connectErr *connect.Error
54+
if errors.As(err, &connectErr) {
55+
if cfg.verbose {
56+
return resp, err
57+
}
58+
// Preserve code but sanitize message for client-facing codes.
59+
code := connectErr.Code()
60+
if code == connect.CodeInternal || code == connect.CodeUnknown {
61+
ref := time.Now().Unix()
62+
cfg.logger.Error("internal error",
63+
"error", err.Error(),
64+
"ref", ref,
65+
)
66+
return resp, connect.NewError(code, fmt.Errorf("internal error (ref: %d)", ref))
67+
}
68+
return resp, err
69+
}
70+
71+
// Non-Connect error: sanitize completely.
72+
ref := time.Now().Unix()
73+
cfg.logger.Error("internal error",
74+
"error", err.Error(),
75+
"ref", ref,
76+
)
77+
if cfg.verbose {
78+
return resp, connect.NewError(connect.CodeInternal, err)
79+
}
80+
return resp, connect.NewError(connect.CodeInternal, fmt.Errorf("internal error (ref: %d)", ref))
81+
}
82+
}
83+
}

middleware/middleware.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
6+
"connectrpc.com/connect"
7+
"github.com/raystack/salt/logger"
8+
"github.com/raystack/salt/middleware/cors"
9+
"github.com/raystack/salt/middleware/errorz"
10+
"github.com/raystack/salt/middleware/recovery"
11+
"github.com/raystack/salt/middleware/requestid"
12+
"github.com/raystack/salt/middleware/requestlog"
13+
)
14+
15+
// Default returns the standard raystack Connect interceptor chain:
16+
// recovery → requestid → requestlog → errorz
17+
func Default(l logger.Logger) []connect.Interceptor {
18+
return []connect.Interceptor{
19+
recovery.NewInterceptor(recovery.WithLogger(l)),
20+
requestid.NewInterceptor(),
21+
requestlog.NewInterceptor(requestlog.WithLogger(l)),
22+
errorz.NewInterceptor(errorz.WithLogger(l)),
23+
}
24+
}
25+
26+
// DefaultHTTP returns the standard raystack HTTP middleware chain:
27+
// recovery → requestid → requestlog → cors
28+
func DefaultHTTP(l logger.Logger, corsOpts ...cors.Option) func(http.Handler) http.Handler {
29+
return ChainHTTP(
30+
recovery.HTTPMiddleware(recovery.WithLogger(l)),
31+
requestid.HTTPMiddleware(),
32+
requestlog.HTTPMiddleware(requestlog.WithLogger(l)),
33+
cors.Middleware(corsOpts...),
34+
)
35+
}
36+
37+
// ChainHTTP chains net/http middleware in order.
38+
// The first middleware wraps outermost (processes request first).
39+
func ChainHTTP(mws ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
40+
return func(final http.Handler) http.Handler {
41+
for i := len(mws) - 1; i >= 0; i-- {
42+
final = mws[i](final)
43+
}
44+
return final
45+
}
46+
}

0 commit comments

Comments
 (0)