Skip to content

Commit 7fd0db9

Browse files
authored
Fix WrapM pattern restore (#106)
* fix(context): restore req.Pattern on the original request in WrapM * feat(context): expose http.Flusher to std middleware via WrapM * style(response_writer): rename flusherWriter receiver from s to w * test(context): drop redundant pre-call req.Pattern assertion in TestWrapM_RestoresRequestPattern
1 parent 70ae18c commit 7fd0db9

3 files changed

Lines changed: 70 additions & 5 deletions

File tree

context.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,23 +456,30 @@ func (mw wrapM) handle(c *Context) {
456456
defer func() { req.Pattern = p }()
457457

458458
req.Pattern = c.Pattern()
459+
r := req
459460
if route := c.Route(); route != nil && route.ParamsLen() > 0 {
460461
params := slices.AppendSeq(make(Params, 0, route.ParamsLen()), c.Params())
461-
ctx := context.WithValue(c.Request().Context(), paramsKey, params)
462-
req = req.WithContext(ctx)
462+
ctx := context.WithValue(req.Context(), paramsKey, params)
463+
r = req.WithContext(ctx)
463464
}
464465

465466
mw.m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
466467
// Avoid allocation if w has not been wrapped by m.
467-
rec, ok := w.(*recorder)
468-
if !ok {
468+
var rec *recorder
469+
switch v := w.(type) {
470+
case flusherWriter:
471+
rec, _ = v.ResponseWriter.(*recorder)
472+
case *recorder:
473+
rec = v
474+
}
475+
if rec == nil {
469476
rec = new(recorder)
470477
rec.reset(w)
471478
}
472479
cc := c.CloneWith(rec, r)
473480
defer cc.Close()
474481
mw.next(cc)
475-
})).ServeHTTP(c.Writer(), req)
482+
})).ServeHTTP(flusherWriter{c.Writer()}, r)
476483
}
477484

478485
func sumLen(s []string) int {

context_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,56 @@ func TestWrapM(t *testing.T) {
617617
assert.Equal(t, "OK", w.Body.String())
618618
}
619619

620+
func TestWrapM_RestoresRequestPattern(t *testing.T) {
621+
mw := func(h http.Handler) http.Handler {
622+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
623+
h.ServeHTTP(w, r)
624+
})
625+
}
626+
627+
f := MustRouter(WithMiddleware(WrapM(mw)))
628+
f.MustAdd(MethodGet, "/foo/{bar}", func(c *Context) {
629+
assert.Equal(t, "/foo/{bar}", c.Request().Pattern)
630+
require.NoError(t, c.String(http.StatusOK, "OK"))
631+
})
632+
633+
req := httptest.NewRequest(http.MethodGet, "/foo/bar", nil)
634+
w := httptest.NewRecorder()
635+
636+
f.ServeHTTP(w, req)
637+
638+
assert.Empty(t, req.Pattern)
639+
}
640+
641+
func TestWrapM_FlusherShim(t *testing.T) {
642+
var sawFlusher bool
643+
var flushed bool
644+
645+
mw := func(h http.Handler) http.Handler {
646+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
647+
flusher, ok := w.(http.Flusher)
648+
sawFlusher = ok
649+
h.ServeHTTP(w, r)
650+
if ok {
651+
flusher.Flush()
652+
flushed = true
653+
}
654+
})
655+
}
656+
657+
f := MustRouter(WithMiddleware(WrapM(mw)))
658+
f.MustAdd(MethodGet, "/", func(c *Context) {
659+
require.NoError(t, c.String(http.StatusOK, "ok"))
660+
})
661+
662+
w := httptest.NewRecorder()
663+
f.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
664+
665+
assert.True(t, sawFlusher)
666+
assert.True(t, flushed)
667+
assert.True(t, w.Flushed)
668+
}
669+
620670
func BenchmarkWrapH(b *testing.B) {
621671
req := httptest.NewRequest(http.MethodGet, "https://example.com/a/b/c", nil)
622672
w := httptest.NewRecorder()

response_writer.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,14 @@ type onlyWrite struct {
296296
io.Writer
297297
}
298298

299+
type flusherWriter struct {
300+
ResponseWriter
301+
}
302+
303+
func (w flusherWriter) Flush() { _ = w.FlushError() }
304+
305+
func (w flusherWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter }
306+
299307
type noopWriter struct {
300308
h http.Header
301309
}

0 commit comments

Comments
 (0)