Skip to content

Commit b70eb70

Browse files
committed
fix: bound stdio JSON-RPC message size
1 parent 5045d86 commit b70eb70

3 files changed

Lines changed: 108 additions & 24 deletions

File tree

mcp/cmd.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ var defaultTerminateDuration = 5 * time.Second // mutable for testing
1919
// with it over stdin/stdout, using newline-delimited JSON.
2020
type CommandTransport struct {
2121
Command *exec.Cmd
22+
// MaxMessageBytes, if positive, rejects incoming JSON-RPC messages from the
23+
// subprocess larger than this many bytes.
24+
MaxMessageBytes int64
2225
// TerminateDuration controls how long Close waits after closing stdin
2326
// for the process to exit before sending SIGTERM.
2427
// If zero or negative, the default of 5s is used.
@@ -43,7 +46,7 @@ func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) {
4346
if td <= 0 {
4447
td = defaultTerminateDuration
4548
}
46-
return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil
49+
return newIOConnWithOptions(&pipeRWC{t.Command, stdout, stdin, td}, t.MaxMessageBytes), nil
4750
}
4851

4952
// A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over

mcp/transport.go

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
package mcp
66

77
import (
8+
"bufio"
9+
"bytes"
810
"context"
911
"encoding/json"
1012
"errors"
@@ -98,11 +100,15 @@ type serverConnection interface {
98100

99101
// A StdioTransport is a [Transport] that communicates over stdin/stdout using
100102
// newline-delimited JSON.
101-
type StdioTransport struct{}
103+
type StdioTransport struct {
104+
// MaxMessageBytes, if positive, rejects incoming JSON-RPC messages larger
105+
// than this many bytes.
106+
MaxMessageBytes int64
107+
}
102108

103109
// Connect implements the [Transport] interface.
104-
func (*StdioTransport) Connect(context.Context) (Connection, error) {
105-
return newIOConn(rwc{os.Stdin, nopCloserWriter{os.Stdout}}), nil
110+
func (t *StdioTransport) Connect(context.Context) (Connection, error) {
111+
return newIOConnWithOptions(rwc{os.Stdin, nopCloserWriter{os.Stdout}}, t.MaxMessageBytes), nil
106112
}
107113

108114
// nopCloserWriter is an io.WriteCloser with a trivial Close method.
@@ -115,13 +121,14 @@ func (nopCloserWriter) Close() error { return nil }
115121
// An IOTransport is a [Transport] that communicates over separate
116122
// io.ReadCloser and io.WriteCloser using newline-delimited JSON.
117123
type IOTransport struct {
118-
Reader io.ReadCloser
119-
Writer io.WriteCloser
124+
Reader io.ReadCloser
125+
Writer io.WriteCloser
126+
MaxMessageBytes int64
120127
}
121128

122129
// Connect implements the [Transport] interface.
123130
func (t *IOTransport) Connect(context.Context) (Connection, error) {
124-
return newIOConn(rwc{t.Reader, t.Writer}), nil
131+
return newIOConnWithOptions(rwc{t.Reader, t.Writer}, t.MaxMessageBytes), nil
125132
}
126133

127134
// An InMemoryTransport is a [Transport] that communicates over an in-memory
@@ -392,6 +399,10 @@ type msgOrErr struct {
392399
}
393400

394401
func newIOConn(rwc io.ReadWriteCloser) *ioConn {
402+
return newIOConnWithOptions(rwc, 0)
403+
}
404+
405+
func newIOConnWithOptions(rwc io.ReadWriteCloser, maxMessageBytes int64) *ioConn {
395406
var (
396407
incoming = make(chan msgOrErr)
397408
closed = make(chan struct{})
@@ -403,24 +414,9 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
403414
// but that is unavoidable since AFAIK there is no (easy and portable) way to
404415
// guarantee that reads of stdin are unblocked when closed.
405416
go func() {
406-
dec := json.NewDecoder(rwc)
417+
reader := bufio.NewReader(rwc)
407418
for {
408-
var raw json.RawMessage
409-
err := dec.Decode(&raw)
410-
// If decoding was successful, check for trailing data at the end of the stream.
411-
if err == nil {
412-
// Read the next byte to check if there is trailing data.
413-
var tr [1]byte
414-
if n, readErr := dec.Buffered().Read(tr[:]); n > 0 {
415-
// If read byte is not a newline, it is an error.
416-
// Support both Unix (\n) and Windows (\r\n) line endings.
417-
if tr[0] != '\n' && tr[0] != '\r' {
418-
err = fmt.Errorf("invalid trailing data at the end of stream")
419-
}
420-
} else if readErr != nil && readErr != io.EOF {
421-
err = readErr
422-
}
423-
}
419+
raw, err := readFrame(reader, maxMessageBytes)
424420
select {
425421
case incoming <- msgOrErr{msg: raw, err: err}:
426422
case <-closed:
@@ -438,6 +434,55 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
438434
}
439435
}
440436

437+
func readFrame(reader *bufio.Reader, maxMessageBytes int64) (json.RawMessage, error) {
438+
var frame []byte
439+
for {
440+
part, err := reader.ReadSlice('\n')
441+
if maxMessageBytes > 0 && int64(len(frame)+len(part)) > maxMessageBytes {
442+
return nil, fmt.Errorf("JSON-RPC message exceeds maximum size of %d bytes", maxMessageBytes)
443+
}
444+
frame = append(frame, part...)
445+
switch {
446+
case err == nil:
447+
if n := len(frame); n > 0 && frame[n-1] == '\n' {
448+
frame = frame[:n-1]
449+
}
450+
if n := len(frame); n > 0 && frame[n-1] == '\r' {
451+
frame = frame[:n-1]
452+
}
453+
if err := validateJSONFrame(frame); err != nil {
454+
return nil, err
455+
}
456+
return json.RawMessage(frame), nil
457+
case errors.Is(err, bufio.ErrBufferFull):
458+
continue
459+
case errors.Is(err, io.EOF):
460+
if len(frame) == 0 {
461+
return nil, io.EOF
462+
}
463+
if err := validateJSONFrame(frame); err != nil {
464+
return nil, err
465+
}
466+
return json.RawMessage(frame), nil
467+
default:
468+
return nil, err
469+
}
470+
}
471+
}
472+
473+
func validateJSONFrame(frame []byte) error {
474+
dec := json.NewDecoder(bytes.NewReader(frame))
475+
var raw json.RawMessage
476+
if err := dec.Decode(&raw); err != nil {
477+
return err
478+
}
479+
var extra json.RawMessage
480+
if err := dec.Decode(&extra); err != io.EOF {
481+
return fmt.Errorf("invalid trailing data at the end of stream")
482+
}
483+
return nil
484+
}
485+
441486
func (c *ioConn) SessionID() string { return "" }
442487

443488
func (c *ioConn) sessionUpdated(state ServerSessionState) {

mcp/transport_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,39 @@ func TestIOConnRead(t *testing.T) {
124124
})
125125
}
126126
}
127+
128+
func TestIOConnReadMaxMessageBytes(t *testing.T) {
129+
ctx := context.Background()
130+
input := `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`
131+
132+
t.Run("allows frame at limit", func(t *testing.T) {
133+
tr := newIOConnWithOptions(rwc{
134+
rc: io.NopCloser(strings.NewReader(input)),
135+
}, int64(len(input)))
136+
t.Cleanup(func() { tr.Close() })
137+
138+
msg, err := tr.Read(ctx)
139+
if err != nil {
140+
t.Fatalf("Read() returned error: %v", err)
141+
}
142+
if got := msg.(*jsonrpc.Request).Method; got != "test" {
143+
t.Fatalf("Read() method = %q, want test", got)
144+
}
145+
})
146+
147+
t.Run("rejects frame over limit", func(t *testing.T) {
148+
tr := newIOConnWithOptions(rwc{
149+
rc: io.NopCloser(strings.NewReader(input)),
150+
}, int64(len(input)-1))
151+
t.Cleanup(func() { tr.Close() })
152+
153+
_, err := tr.Read(ctx)
154+
if err == nil {
155+
t.Fatal("Read() returned nil error")
156+
}
157+
want := "JSON-RPC message exceeds maximum size"
158+
if !strings.Contains(err.Error(), want) {
159+
t.Fatalf("Read() error = %q, want substring %q", err, want)
160+
}
161+
})
162+
}

0 commit comments

Comments
 (0)