Skip to content

Commit 0a25d94

Browse files
committed
fix(middleware): clear request metadata after decompressing gzip body
1 parent 6a390cb commit 0a25d94

4 files changed

Lines changed: 156 additions & 8 deletions

File tree

middleware/body_limit_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,31 @@ func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
6868
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
6969
}
7070

71+
func TestBodyLimitAfterDecompressUsesDecompressedSize(t *testing.T) {
72+
e := echo.New()
73+
body := "ok"
74+
gz, err := gzipString(body)
75+
assert.NoError(t, err)
76+
assert.Greater(t, len(gz), len(body))
77+
78+
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
79+
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
80+
rec := httptest.NewRecorder()
81+
c := e.NewContext(req, rec)
82+
83+
err = Decompress()(BodyLimit(int64(len(body)))(func(c *echo.Context) error {
84+
body, readErr := io.ReadAll(c.Request().Body)
85+
if readErr != nil {
86+
return readErr
87+
}
88+
return c.String(http.StatusOK, string(body))
89+
}))(c)
90+
91+
assert.NoError(t, err)
92+
assert.Equal(t, http.StatusOK, rec.Code)
93+
assert.Equal(t, body, rec.Body.String())
94+
}
95+
7196
func TestBodyLimitReader(t *testing.T) {
7297
hw := []byte("Hello, World!")
7398

middleware/decompress.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"compress/gzip"
88
"io"
99
"net/http"
10+
"strings"
1011
"sync"
1112

1213
"github.com/labstack/echo/v5"
@@ -82,7 +83,9 @@ func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
8283
return next(c)
8384
}
8485

85-
if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding {
86+
req := c.Request()
87+
contentEncoding := req.Header.Values(echo.HeaderContentEncoding)
88+
if len(contentEncoding) != 1 || strings.TrimSpace(contentEncoding[0]) != GZIPEncoding {
8689
return next(c)
8790
}
8891

@@ -96,7 +99,7 @@ func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
9699
}
97100
defer pool.Put(gr)
98101

99-
b := c.Request().Body
102+
b := req.Body
100103
defer b.Close()
101104

102105
if err := gr.Reset(b); err != nil {
@@ -111,15 +114,19 @@ func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
111114

112115
// Apply decompression size limit to prevent zip bombs
113116
if config.MaxDecompressedSize > 0 {
114-
c.Request().Body = &limitedGzipReader{
117+
req.Body = &limitedGzipReader{
115118
Reader: gr,
116119
remaining: config.MaxDecompressedSize,
117120
limit: config.MaxDecompressedSize,
118121
}
119122
} else {
120123
// -1 means explicitly unlimited (not recommended)
121-
c.Request().Body = gr
124+
req.Body = gr
122125
}
126+
req.Header.Del(echo.HeaderContentEncoding)
127+
req.Header.Del(echo.HeaderContentLength)
128+
req.ContentLength = -1
129+
req.GetBody = nil
123130

124131
return next(c)
125132
}

middleware/decompress_test.go

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"io"
1111
"net/http"
1212
"net/http/httptest"
13+
"strconv"
1314
"strings"
1415
"sync"
1516
"testing"
@@ -31,18 +32,80 @@ func TestDecompress(t *testing.T) {
3132
gz, _ := gzipString(body)
3233
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
3334
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
35+
req.Header.Set(echo.HeaderContentLength, strconv.Itoa(len(gz)))
36+
req.GetBody = func() (io.ReadCloser, error) {
37+
return io.NopCloser(bytes.NewReader(gz)), nil
38+
}
3439
rec := httptest.NewRecorder()
3540
c := e.NewContext(req, rec)
3641

3742
err := h(c)
3843
assert.NoError(t, err)
3944

40-
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
45+
assert.Empty(t, req.Header.Get(echo.HeaderContentEncoding))
46+
assert.Empty(t, req.Header.Get(echo.HeaderContentLength))
47+
assert.Equal(t, int64(-1), req.ContentLength)
48+
assert.Nil(t, req.GetBody)
4149
b, err := io.ReadAll(req.Body)
4250
assert.NoError(t, err)
4351
assert.Equal(t, body, string(b))
4452
}
4553

54+
func TestDecompress_SkipsRepeatedContentEncodingValues(t *testing.T) {
55+
e := echo.New()
56+
57+
body := `{"name": "echo"}`
58+
gz, _ := gzipString(body)
59+
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
60+
req.Header.Add(echo.HeaderContentEncoding, GZIPEncoding)
61+
req.Header.Add(echo.HeaderContentEncoding, "br")
62+
req.Header.Set(echo.HeaderContentLength, strconv.Itoa(len(gz)))
63+
req.GetBody = func() (io.ReadCloser, error) {
64+
return io.NopCloser(bytes.NewReader(gz)), nil
65+
}
66+
rec := httptest.NewRecorder()
67+
c := e.NewContext(req, rec)
68+
69+
var got []byte
70+
err := Decompress()(func(c *echo.Context) error {
71+
var readErr error
72+
got, readErr = io.ReadAll(c.Request().Body)
73+
return readErr
74+
})(c)
75+
76+
assert.NoError(t, err)
77+
assert.Equal(t, gz, got)
78+
assert.Equal(t, []string{GZIPEncoding, "br"}, req.Header.Values(echo.HeaderContentEncoding))
79+
assert.Equal(t, strconv.Itoa(len(gz)), req.Header.Get(echo.HeaderContentLength))
80+
assert.Equal(t, int64(len(gz)), req.ContentLength)
81+
assert.NotNil(t, req.GetBody)
82+
}
83+
84+
func TestDecompress_SkipsCommaSeparatedContentEncodingValues(t *testing.T) {
85+
e := echo.New()
86+
87+
body := `{"name": "echo"}`
88+
gz, _ := gzipString(body)
89+
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
90+
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding+", br")
91+
req.Header.Set(echo.HeaderContentLength, strconv.Itoa(len(gz)))
92+
rec := httptest.NewRecorder()
93+
c := e.NewContext(req, rec)
94+
95+
var got []byte
96+
err := Decompress()(func(c *echo.Context) error {
97+
var readErr error
98+
got, readErr = io.ReadAll(c.Request().Body)
99+
return readErr
100+
})(c)
101+
102+
assert.NoError(t, err)
103+
assert.Equal(t, gz, got)
104+
assert.Equal(t, GZIPEncoding+", br", req.Header.Get(echo.HeaderContentEncoding))
105+
assert.Equal(t, strconv.Itoa(len(gz)), req.Header.Get(echo.HeaderContentLength))
106+
assert.Equal(t, int64(len(gz)), req.ContentLength)
107+
}
108+
46109
func TestDecompress_skippedIfNoHeader(t *testing.T) {
47110
e := echo.New()
48111
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
@@ -99,7 +162,7 @@ func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
99162
err := h(c)
100163
assert.NoError(t, err)
101164

102-
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
165+
assert.Empty(t, req.Header.Get(echo.HeaderContentEncoding))
103166
b, err := io.ReadAll(req.Body)
104167
assert.NoError(t, err)
105168
assert.Equal(t, body, string(b))
@@ -215,8 +278,6 @@ func BenchmarkDecompress(b *testing.B) {
215278
e := echo.New()
216279
body := `{"name": "echo"}`
217280
gz, _ := gzipString(body)
218-
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
219-
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
220281

221282
h := Decompress()(func(c *echo.Context) error {
222283
c.Response().Write([]byte(body)) // For Content-Type sniffing
@@ -228,6 +289,8 @@ func BenchmarkDecompress(b *testing.B) {
228289

229290
for i := 0; i < b.N; i++ {
230291
// Decompress
292+
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
293+
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
231294
rec := httptest.NewRecorder()
232295
c := e.NewContext(req, rec)
233296
h(c)

middleware/proxy_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,59 @@ func TestProxy(t *testing.T) {
128128
e.ServeHTTP(rec, req)
129129
}
130130

131+
func TestProxyAfterDecompressForwardsDecodedBody(t *testing.T) {
132+
body := "proxied body"
133+
gz, err := gzipString(body)
134+
assert.NoError(t, err)
135+
136+
type upstreamObservation struct {
137+
body string
138+
contentEncoding string
139+
contentLength int64
140+
readErr error
141+
}
142+
observations := make(chan upstreamObservation, 1)
143+
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
144+
b, readErr := io.ReadAll(r.Body)
145+
observations <- upstreamObservation{
146+
body: string(b),
147+
contentEncoding: r.Header.Get(echo.HeaderContentEncoding),
148+
contentLength: r.ContentLength,
149+
readErr: readErr,
150+
}
151+
_, _ = w.Write(b)
152+
}))
153+
defer upstream.Close()
154+
155+
targetURL, err := url.Parse(upstream.URL)
156+
assert.NoError(t, err)
157+
158+
e := echo.New()
159+
e.Use(Decompress())
160+
e.Use(ProxyWithConfig(ProxyConfig{
161+
Balancer: NewRoundRobinBalancer([]*ProxyTarget{{URL: targetURL}}),
162+
}))
163+
164+
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
165+
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
166+
rec := httptest.NewRecorder()
167+
168+
e.ServeHTTP(rec, req)
169+
170+
var observed upstreamObservation
171+
select {
172+
case observed = <-observations:
173+
default:
174+
t.Fatal("upstream was not called")
175+
}
176+
assert.Equal(t, http.StatusOK, rec.Code)
177+
assert.Equal(t, body, rec.Body.String())
178+
assert.NoError(t, observed.readErr)
179+
assert.Equal(t, body, observed.body)
180+
assert.Empty(t, observed.contentEncoding)
181+
assert.Equal(t, int64(-1), observed.contentLength)
182+
}
183+
131184
func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
132185
assert.Panics(t, func() {
133186
ProxyWithConfig(ProxyConfig{Balancer: nil})

0 commit comments

Comments
 (0)