Skip to content

Commit 3ec96dc

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix/medium-severity-audit-findings
# Conflicts: # framework/database/sql.go # framework/database/sql_test.go # framework/hooks_writer_test.go # framework/middleware/cors_test.go # framework/middleware/rate_limit.go # framework/middleware/rate_limit_test.go # framework/session/middleware.go # router/router_test.go
2 parents 0a1e5d2 + e0f5eab commit 3ec96dc

23 files changed

Lines changed: 1315 additions & 204 deletions

contract/request/body.go

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
package request
22

33
import (
4+
"bytes"
45
"encoding/json"
56
"encoding/xml"
7+
"errors"
68
"io"
79
"net/http"
810
)
911

12+
// ErrBodyTooLarge is returned when the request body exceeds
13+
// the maximum allowed size. Callers can check for this error
14+
// using [errors.Is].
15+
var ErrBodyTooLarge = errors.New("request body too large")
16+
1017
// DefaultMaxBodySize is the default maximum request body size
1118
// (10 MB) used by the size-limited body reading functions.
1219
// This prevents denial-of-service attacks via excessively
@@ -24,16 +31,26 @@ func Bytes(r *http.Request) ([]byte, error) {
2431
}
2532

2633
// LimitedBytes reads the request body up to maxSize bytes and
27-
// returns it as a byte slice. If the body exceeds maxSize, an
28-
// error is returned. This prevents denial-of-service attacks via
29-
// excessively large request bodies. Pass -1 to use
34+
// returns it as a byte slice. If the body exceeds maxSize,
35+
// [ErrBodyTooLarge] is returned. This prevents denial-of-service
36+
// attacks via excessively large request bodies. Pass -1 to use
3037
// [DefaultMaxBodySize].
3138
func LimitedBytes(r *http.Request, maxSize int64) ([]byte, error) {
3239
if maxSize < 0 {
3340
maxSize = DefaultMaxBodySize
3441
}
3542

36-
return io.ReadAll(io.LimitReader(r.Body, maxSize+1))
43+
data, err := io.ReadAll(io.LimitReader(r.Body, maxSize+1))
44+
45+
if err != nil {
46+
return nil, err
47+
}
48+
49+
if int64(len(data)) > maxSize {
50+
return nil, ErrBodyTooLarge
51+
}
52+
53+
return data, nil
3754
}
3855

3956
// String reads the request body and returns it as a string.
@@ -54,8 +71,8 @@ func String(r *http.Request) (string, error) {
5471
}
5572

5673
// LimitedString reads the request body up to maxSize bytes and
57-
// returns it as a string. If the body exceeds maxSize, the result
58-
// is truncated. Pass -1 to use [DefaultMaxBodySize].
74+
// returns it as a string. If the body exceeds maxSize,
75+
// [ErrBodyTooLarge] is returned. Pass -1 to use [DefaultMaxBodySize].
5976
func LimitedString(r *http.Request, maxSize int64) (string, error) {
6077
body, err := LimitedBytes(r, maxSize)
6178

@@ -106,17 +123,26 @@ func StrictJSON[T any](r *http.Request) (value T, err error) {
106123
}
107124

108125
// LimitedJSON decodes JSON data from the request body into a value
109-
// of type T, reading at most maxSize bytes. This prevents
126+
// of type T, reading at most maxSize bytes. If the body exceeds
127+
// maxSize, [ErrBodyTooLarge] is returned. This prevents
110128
// denial-of-service attacks via oversized JSON payloads. Pass -1
111129
// to use [DefaultMaxBodySize].
112130
func LimitedJSON[T any](r *http.Request, maxSize int64) (value T, err error) {
113131
if maxSize < 0 {
114132
maxSize = DefaultMaxBodySize
115133
}
116134

117-
limited := io.LimitReader(r.Body, maxSize+1)
135+
data, err := io.ReadAll(io.LimitReader(r.Body, maxSize+1))
118136

119-
if err := json.NewDecoder(limited).Decode(&value); err != nil {
137+
if err != nil {
138+
return value, err
139+
}
140+
141+
if int64(len(data)) > maxSize {
142+
return value, ErrBodyTooLarge
143+
}
144+
145+
if err := json.NewDecoder(bytes.NewReader(data)).Decode(&value); err != nil {
120146
return value, err
121147
}
122148

@@ -125,14 +151,24 @@ func LimitedJSON[T any](r *http.Request, maxSize int64) (value T, err error) {
125151

126152
// StrictLimitedJSON decodes JSON data from the request body into
127153
// a value of type T, reading at most maxSize bytes and rejecting
128-
// unknown fields. Pass -1 to use [DefaultMaxBodySize].
154+
// unknown fields. If the body exceeds maxSize, [ErrBodyTooLarge]
155+
// is returned. Pass -1 to use [DefaultMaxBodySize].
129156
func StrictLimitedJSON[T any](r *http.Request, maxSize int64) (value T, err error) {
130157
if maxSize < 0 {
131158
maxSize = DefaultMaxBodySize
132159
}
133160

134-
limited := io.LimitReader(r.Body, maxSize+1)
135-
decoder := json.NewDecoder(limited)
161+
data, err := io.ReadAll(io.LimitReader(r.Body, maxSize+1))
162+
163+
if err != nil {
164+
return value, err
165+
}
166+
167+
if int64(len(data)) > maxSize {
168+
return value, ErrBodyTooLarge
169+
}
170+
171+
decoder := json.NewDecoder(bytes.NewReader(data))
136172
decoder.DisallowUnknownFields()
137173

138174
if err := decoder.Decode(&value); err != nil {
@@ -149,6 +185,11 @@ func StrictLimitedJSON[T any](r *http.Request, maxSize int64) (value T, err erro
149185
// WARNING: This function decodes without any body size limit.
150186
// Prefer [LimitedXML] or apply [http.MaxBytesReader] in a
151187
// middleware to prevent memory exhaustion from oversized requests.
188+
//
189+
// WARNING: Go's [encoding/xml] does not protect against entity
190+
// expansion attacks (e.g. "Billion Laughs"). Callers should
191+
// validate or sanitize XML input before processing, or use
192+
// [LimitedXML] with a small size limit to bound expansion.
152193
func XML[T any](r *http.Request) (value T, err error) {
153194
if err := xml.NewDecoder(r.Body).Decode(&value); err != nil {
154195
return value, err
@@ -158,17 +199,32 @@ func XML[T any](r *http.Request) (value T, err error) {
158199
}
159200

160201
// LimitedXML decodes XML data from the request body into a value
161-
// of type T, reading at most maxSize bytes. This prevents
202+
// of type T, reading at most maxSize bytes. If the body exceeds
203+
// maxSize, [ErrBodyTooLarge] is returned. This prevents
162204
// denial-of-service attacks via oversized XML payloads. Pass -1
163205
// to use [DefaultMaxBodySize].
206+
//
207+
// WARNING: Go's [encoding/xml] does not protect against entity
208+
// expansion attacks (e.g. "Billion Laughs"). Even with a small
209+
// maxSize, a crafted XML document may expand to significantly
210+
// more memory than its wire size. Callers should validate or
211+
// sanitize XML input before processing.
164212
func LimitedXML[T any](r *http.Request, maxSize int64) (value T, err error) {
165213
if maxSize < 0 {
166214
maxSize = DefaultMaxBodySize
167215
}
168216

169-
limited := io.LimitReader(r.Body, maxSize+1)
217+
data, err := io.ReadAll(io.LimitReader(r.Body, maxSize+1))
218+
219+
if err != nil {
220+
return value, err
221+
}
222+
223+
if int64(len(data)) > maxSize {
224+
return value, ErrBodyTooLarge
225+
}
170226

171-
if err := xml.NewDecoder(limited).Decode(&value); err != nil {
227+
if err := xml.NewDecoder(bytes.NewReader(data)).Decode(&value); err != nil {
172228
return value, err
173229
}
174230

contract/request/body_test.go

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package request_test
22

33
import (
4+
"errors"
45
"io"
56
"net/http"
67
"net/http/httptest"
@@ -38,14 +39,15 @@ func TestBytesErrorOnFailedRead(t *testing.T) {
3839
require.Error(t, err)
3940
}
4041

41-
func TestLimitedBytesReadsUpToLimit(t *testing.T) {
42+
func TestLimitedBytesReturnsErrorWhenBodyExceedsLimit(t *testing.T) {
43+
t.Parallel()
44+
4245
body := "abcdefghij"
4346
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
4447

45-
result, err := request.LimitedBytes(r, 5)
48+
_, err := request.LimitedBytes(r, 5)
4649

47-
require.NoError(t, err)
48-
require.Len(t, result, 6)
50+
require.ErrorIs(t, err, request.ErrBodyTooLarge)
4951
}
5052

5153
func TestLimitedBytesReadsFullBodyUnderLimit(t *testing.T) {
@@ -95,14 +97,15 @@ func TestStringErrorOnFailedRead(t *testing.T) {
9597
require.Error(t, err)
9698
}
9799

98-
func TestLimitedStringReadsUpToLimit(t *testing.T) {
100+
func TestLimitedStringReturnsErrorWhenBodyExceedsLimit(t *testing.T) {
101+
t.Parallel()
102+
99103
body := "abcdefghij"
100104
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
101105

102-
result, err := request.LimitedString(r, 5)
106+
_, err := request.LimitedString(r, 5)
103107

104-
require.NoError(t, err)
105-
require.Len(t, result, 6)
108+
require.ErrorIs(t, err, request.ErrBodyTooLarge)
106109
}
107110

108111
func TestLimitedStringReadsFullBodyUnderLimit(t *testing.T) {
@@ -415,6 +418,98 @@ func TestLimitedXMLReturnsErrorOnInvalidPayload(t *testing.T) {
415418
require.Error(t, err)
416419
}
417420

421+
func TestLimitedBytesReturnsDataWhenBodyWithinLimit(t *testing.T) {
422+
t.Parallel()
423+
424+
body := "hello"
425+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
426+
427+
result, err := request.LimitedBytes(r, 10)
428+
429+
require.NoError(t, err)
430+
require.Equal(t, []byte("hello"), result)
431+
}
432+
433+
func TestLimitedBytesReturnsDataWhenBodyExactlyAtLimit(t *testing.T) {
434+
t.Parallel()
435+
436+
body := "12345"
437+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
438+
439+
result, err := request.LimitedBytes(r, 5)
440+
441+
require.NoError(t, err)
442+
require.Equal(t, []byte("12345"), result)
443+
}
444+
445+
func TestLimitedStringReturnsDataWhenBodyWithinLimit(t *testing.T) {
446+
t.Parallel()
447+
448+
body := "hello"
449+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
450+
451+
result, err := request.LimitedString(r, 10)
452+
453+
require.NoError(t, err)
454+
require.Equal(t, "hello", result)
455+
}
456+
457+
func TestLimitedJSONReturnsErrorWhenBodyExceedsLimit(t *testing.T) {
458+
t.Parallel()
459+
460+
type payload struct {
461+
Name string `json:"name"`
462+
}
463+
464+
body := `{"name":"this is a very long name that exceeds the limit"}`
465+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
466+
467+
_, err := request.LimitedJSON[payload](r, 10)
468+
469+
require.ErrorIs(t, err, request.ErrBodyTooLarge)
470+
}
471+
472+
func TestStrictLimitedJSONReturnsErrorWhenBodyExceedsLimit(t *testing.T) {
473+
t.Parallel()
474+
475+
type payload struct {
476+
Name string `json:"name"`
477+
}
478+
479+
body := `{"name":"this is a very long name that exceeds the limit"}`
480+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
481+
482+
_, err := request.StrictLimitedJSON[payload](r, 10)
483+
484+
require.ErrorIs(t, err, request.ErrBodyTooLarge)
485+
}
486+
487+
func TestLimitedXMLReturnsErrorWhenBodyExceedsLimit(t *testing.T) {
488+
t.Parallel()
489+
490+
type payload struct {
491+
Name string `xml:"name"`
492+
}
493+
494+
body := `<payload><name>this is a very long name that exceeds the limit</name></payload>`
495+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
496+
497+
_, err := request.LimitedXML[payload](r, 10)
498+
499+
require.ErrorIs(t, err, request.ErrBodyTooLarge)
500+
}
501+
502+
func TestErrBodyTooLargeIsCheckableWithErrorsIs(t *testing.T) {
503+
t.Parallel()
504+
505+
body := "abcdefghij"
506+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
507+
508+
_, err := request.LimitedBytes(r, 5)
509+
510+
require.True(t, errors.Is(err, request.ErrBodyTooLarge))
511+
}
512+
418513
// errReader is an io.Reader that always returns an error.
419514
type errReader struct{}
420515

framework/cache/redis.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package cache
22

33
import (
44
"context"
5-
"encoding/json"
65
"errors"
76
"time"
87

@@ -72,9 +71,9 @@ func (client *RedisClient) Has(ctx context.Context, key string) (bool, error) {
7271
}
7372

7473
// Pull atomically retrieves and deletes a key using Redis GETDEL.
75-
// The stored value is JSON-decoded into the return value.
76-
func (client *RedisClient) Pull(ctx context.Context, key string) (value any, err error) {
77-
encoded, err := (*redis.Client)(client).GetDel(ctx, key).Result()
74+
// Returns contract.ErrCacheKeyNotFound when the key does not exist.
75+
func (client *RedisClient) Pull(ctx context.Context, key string) (any, error) {
76+
value, err := (*redis.Client)(client).GetDel(ctx, key).Result()
7877

7978
if errors.Is(err, redis.Nil) {
8079
return nil, contract.ErrCacheKeyNotFound
@@ -84,10 +83,6 @@ func (client *RedisClient) Pull(ctx context.Context, key string) (value any, err
8483
return nil, err
8584
}
8685

87-
if err := json.Unmarshal([]byte(encoded), &value); err != nil {
88-
return nil, err
89-
}
90-
9186
return value, nil
9287
}
9388

framework/database/sql.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ func (database *SQL) FindNamed(ctx context.Context, query string, dest any, arg
178178
// original error and any rollback error are joined. If fn succeeds,
179179
// the transaction is committed. Nested transactions are not supported
180180
// and return contract.ErrDatabaseNestedTransaction.
181-
func (database *SQL) WithTransaction(ctx context.Context, fn func(tx contract.Database) error) error {
181+
//
182+
// If fn panics, the transaction is rolled back before the panic is
183+
// re-raised, preventing connection pool leaks.
184+
func (database *SQL) WithTransaction(ctx context.Context, fn func(tx contract.Database) error) (retErr error) {
182185
if _, ok := database.db.(*sqlx.Tx); ok {
183186
return contract.ErrDatabaseNestedTransaction
184187
}
@@ -189,10 +192,23 @@ func (database *SQL) WithTransaction(ctx context.Context, fn func(tx contract.Da
189192
return err
190193
}
191194

195+
defer func() {
196+
if p := recover(); p != nil {
197+
_ = tx.Rollback()
198+
panic(p)
199+
}
200+
201+
if retErr != nil {
202+
retErr = errors.Join(retErr, tx.Rollback())
203+
}
204+
}()
205+
192206
txWrapper := &sqlTx{SQL{db: tx, raw: database.raw}}
193207

194208
if err := fn(txWrapper); err != nil {
195-
return errors.Join(err, tx.Rollback())
209+
retErr = err
210+
211+
return
196212
}
197213

198214
return tx.Commit()

0 commit comments

Comments
 (0)