Skip to content

Commit f293a76

Browse files
authored
fix: Limit webhook payload size in ValidatePayloadFromBody (#4125)
1 parent a276aa8 commit f293a76

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

github/messages.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ const (
4242
EventTypeHeader = "X-Github-Event"
4343
// DeliveryIDHeader is the GitHub header key used to pass the unique ID for the webhook event.
4444
DeliveryIDHeader = "X-Github-Delivery"
45+
46+
// maxPayloadSize is the maximum size of a GitHub webhook payload.
47+
// GitHub documents a 25 MB limit for webhook payloads.
48+
maxPayloadSize = 25 * 1024 * 1024
4549
)
4650

4751
var (
@@ -146,8 +150,19 @@ func checkMAC(message, messageMAC, key []byte, hashFunc func() hash.Hash) bool {
146150
return hmac.Equal(messageMAC, expectedMAC)
147151
}
148152

149-
// messageMAC returns the hex-decoded HMAC tag from the signature and its
150-
// corresponding hash function.
153+
// readPayloadBody reads the body from readable, enforcing maxPayloadSize.
154+
func readPayloadBody(readable io.Reader) ([]byte, error) {
155+
body, err := io.ReadAll(io.LimitReader(readable, maxPayloadSize+1))
156+
if err != nil {
157+
return nil, err
158+
}
159+
if len(body) > maxPayloadSize {
160+
return nil, errors.New("webhook payload exceeds maximum allowed size")
161+
}
162+
return body, nil
163+
}
164+
165+
// messageMAC returns the MAC method and the corresponding hash function.
151166
func messageMAC(signature string) ([]byte, func() hash.Hash, error) {
152167
if signature == "" {
153168
return nil, nil, errors.New("missing signature")
@@ -199,7 +214,7 @@ func ValidatePayloadFromBody(contentType string, readable io.Reader, signature s
199214
switch contentType {
200215
case "application/json":
201216
var err error
202-
if body, err = io.ReadAll(readable); err != nil {
217+
if body, err = readPayloadBody(readable); err != nil {
203218
return nil, err
204219
}
205220

@@ -213,7 +228,7 @@ func ValidatePayloadFromBody(contentType string, readable io.Reader, signature s
213228
const payloadFormParam = "payload"
214229

215230
var err error
216-
if body, err = io.ReadAll(readable); err != nil {
231+
if body, err = readPayloadBody(readable); err != nil {
217232
return nil, err
218233
}
219234

github/messages_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"encoding/json"
1111
"errors"
1212
"fmt"
13+
"io"
1314
"net/http"
1415
"net/url"
1516
"strings"
@@ -205,6 +206,10 @@ func (b *badReader) Read([]byte) (int, error) {
205206

206207
func (b *badReader) Close() error { return errors.New("bad reader") }
207208

209+
type readerFunc func([]byte) (int, error)
210+
211+
func (f readerFunc) Read(p []byte) (int, error) { return f(p) }
212+
208213
func TestValidatePayload_BadRequestBody(t *testing.T) {
209214
t.Parallel()
210215
tests := []struct {
@@ -228,6 +233,40 @@ func TestValidatePayload_BadRequestBody(t *testing.T) {
228233
}
229234
}
230235

236+
func TestValidatePayload_OversizedBody(t *testing.T) {
237+
t.Parallel()
238+
tests := []struct {
239+
contentType string
240+
}{
241+
{contentType: "application/json"},
242+
{contentType: "application/x-www-form-urlencoded"},
243+
}
244+
245+
for i, tt := range tests {
246+
t.Run(fmt.Sprintf("test #%v", i), func(t *testing.T) {
247+
t.Parallel()
248+
// Simulate a reader that reports more than maxPayloadSize bytes.
249+
oversized := io.LimitReader(readerFunc(func(p []byte) (int, error) {
250+
for i := range p {
251+
p[i] = 0
252+
}
253+
return len(p), nil
254+
}), maxPayloadSize+1)
255+
req := &http.Request{
256+
Header: http.Header{"Content-Type": []string{tt.contentType}},
257+
Body: io.NopCloser(oversized),
258+
}
259+
_, err := ValidatePayload(req, nil)
260+
if err == nil {
261+
t.Fatal("ValidatePayload returned nil; want error for oversized body")
262+
}
263+
if want := "webhook payload exceeds maximum allowed size"; err.Error() != want {
264+
t.Errorf("ValidatePayload error = %q, want %q", err.Error(), want)
265+
}
266+
})
267+
}
268+
}
269+
231270
func TestValidatePayload_InvalidContentTypeParams(t *testing.T) {
232271
t.Parallel()
233272
req, err := http.NewRequest("POST", "http://localhost/event", nil)

0 commit comments

Comments
 (0)