Skip to content

Commit 59258ac

Browse files
authored
Fix cors serving (#156)
1 parent 010110c commit 59258ac

3 files changed

Lines changed: 46 additions & 0 deletions

File tree

pkg/cmd/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ func (s *RunServer) Run(ctx server.Cmd) error {
114114
if err != nil {
115115
return fmt.Errorf("router: %w", err)
116116
}
117+
srv.SetHandler(router)
117118

118119
// Register routes
119120
for _, fn := range s.register {

pkg/httpserver/httpserver.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ func (server *server) Router() *http.ServeMux {
100100
return server.mux
101101
}
102102

103+
// SetHandler replaces the server's active HTTP handler. This is used when a
104+
// higher-level router wraps the underlying ServeMux with middleware such as
105+
// CORS, CSRF protection, tracing, or authentication.
106+
func (server *server) SetHandler(handler http.Handler) {
107+
if handler == nil {
108+
server.http.Handler = server.mux
109+
return
110+
}
111+
server.http.Handler = handler
112+
}
113+
103114
// Addr returns the listen address. After [Listen] has been called this
104115
// returns the actual bound address (which may differ from the configured
105116
// address when an ephemeral port is used).

pkg/httpserver/httpserver_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,40 @@ func Test_Options_004(t *testing.T) {
203203
assert.NotNil(s)
204204
}
205205

206+
func Test_Server_UsesCustomHandler(t *testing.T) {
207+
assert := assert.New(t)
208+
209+
s, err := httpserver.New(":0", nil)
210+
if !assert.NoError(err) {
211+
return
212+
}
213+
214+
s.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
215+
w.WriteHeader(http.StatusAccepted)
216+
}))
217+
218+
ctx, cancel := context.WithCancel(context.Background())
219+
defer cancel()
220+
221+
if !assert.NoError(s.Listen()) {
222+
return
223+
}
224+
225+
done := make(chan error, 1)
226+
go func() {
227+
done <- s.Run(ctx)
228+
}()
229+
230+
resp, err := http.Get(s.URL().String())
231+
if assert.NoError(err) {
232+
defer resp.Body.Close()
233+
assert.Equal(http.StatusAccepted, resp.StatusCode)
234+
}
235+
236+
cancel()
237+
assert.NoError(<-done)
238+
}
239+
206240
///////////////////////////////////////////////////////////////////////////////
207241
// TESTS - TLS CONFIG
208242

0 commit comments

Comments
 (0)