55package mcp
66
77import (
8+ "bufio"
9+ "bytes"
810 "context"
911 "encoding/json"
1012 "errors"
@@ -98,11 +100,15 @@ type serverConnection interface {
98100
99101// A StdioTransport is a [Transport] that communicates over stdin/stdout using
100102// newline-delimited JSON.
101- type StdioTransport struct {}
103+ type StdioTransport struct {
104+ // MaxMessageBytes, if positive, rejects incoming JSON-RPC messages larger
105+ // than this many bytes.
106+ MaxMessageBytes int64
107+ }
102108
103109// Connect implements the [Transport] interface.
104- func (* StdioTransport ) Connect (context.Context ) (Connection , error ) {
105- return newIOConn (rwc {os .Stdin , nopCloserWriter {os .Stdout }}), nil
110+ func (t * StdioTransport ) Connect (context.Context ) (Connection , error ) {
111+ return newIOConnWithOptions (rwc {os .Stdin , nopCloserWriter {os .Stdout }}, t . MaxMessageBytes ), nil
106112}
107113
108114// nopCloserWriter is an io.WriteCloser with a trivial Close method.
@@ -115,13 +121,14 @@ func (nopCloserWriter) Close() error { return nil }
115121// An IOTransport is a [Transport] that communicates over separate
116122// io.ReadCloser and io.WriteCloser using newline-delimited JSON.
117123type IOTransport struct {
118- Reader io.ReadCloser
119- Writer io.WriteCloser
124+ Reader io.ReadCloser
125+ Writer io.WriteCloser
126+ MaxMessageBytes int64
120127}
121128
122129// Connect implements the [Transport] interface.
123130func (t * IOTransport ) Connect (context.Context ) (Connection , error ) {
124- return newIOConn (rwc {t .Reader , t .Writer }), nil
131+ return newIOConnWithOptions (rwc {t .Reader , t .Writer }, t . MaxMessageBytes ), nil
125132}
126133
127134// An InMemoryTransport is a [Transport] that communicates over an in-memory
@@ -392,6 +399,10 @@ type msgOrErr struct {
392399}
393400
394401func newIOConn (rwc io.ReadWriteCloser ) * ioConn {
402+ return newIOConnWithOptions (rwc , 0 )
403+ }
404+
405+ func newIOConnWithOptions (rwc io.ReadWriteCloser , maxMessageBytes int64 ) * ioConn {
395406 var (
396407 incoming = make (chan msgOrErr )
397408 closed = make (chan struct {})
@@ -403,24 +414,9 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
403414 // but that is unavoidable since AFAIK there is no (easy and portable) way to
404415 // guarantee that reads of stdin are unblocked when closed.
405416 go func () {
406- dec := json . NewDecoder (rwc )
417+ reader := bufio . NewReader (rwc )
407418 for {
408- var raw json.RawMessage
409- err := dec .Decode (& raw )
410- // If decoding was successful, check for trailing data at the end of the stream.
411- if err == nil {
412- // Read the next byte to check if there is trailing data.
413- var tr [1 ]byte
414- if n , readErr := dec .Buffered ().Read (tr [:]); n > 0 {
415- // If read byte is not a newline, it is an error.
416- // Support both Unix (\n) and Windows (\r\n) line endings.
417- if tr [0 ] != '\n' && tr [0 ] != '\r' {
418- err = fmt .Errorf ("invalid trailing data at the end of stream" )
419- }
420- } else if readErr != nil && readErr != io .EOF {
421- err = readErr
422- }
423- }
419+ raw , err := readFrame (reader , maxMessageBytes )
424420 select {
425421 case incoming <- msgOrErr {msg : raw , err : err }:
426422 case <- closed :
@@ -438,6 +434,55 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
438434 }
439435}
440436
437+ func readFrame (reader * bufio.Reader , maxMessageBytes int64 ) (json.RawMessage , error ) {
438+ var frame []byte
439+ for {
440+ part , err := reader .ReadSlice ('\n' )
441+ if maxMessageBytes > 0 && int64 (len (frame )+ len (part )) > maxMessageBytes {
442+ return nil , fmt .Errorf ("JSON-RPC message exceeds maximum size of %d bytes" , maxMessageBytes )
443+ }
444+ frame = append (frame , part ... )
445+ switch {
446+ case err == nil :
447+ if n := len (frame ); n > 0 && frame [n - 1 ] == '\n' {
448+ frame = frame [:n - 1 ]
449+ }
450+ if n := len (frame ); n > 0 && frame [n - 1 ] == '\r' {
451+ frame = frame [:n - 1 ]
452+ }
453+ if err := validateJSONFrame (frame ); err != nil {
454+ return nil , err
455+ }
456+ return json .RawMessage (frame ), nil
457+ case errors .Is (err , bufio .ErrBufferFull ):
458+ continue
459+ case errors .Is (err , io .EOF ):
460+ if len (frame ) == 0 {
461+ return nil , io .EOF
462+ }
463+ if err := validateJSONFrame (frame ); err != nil {
464+ return nil , err
465+ }
466+ return json .RawMessage (frame ), nil
467+ default :
468+ return nil , err
469+ }
470+ }
471+ }
472+
473+ func validateJSONFrame (frame []byte ) error {
474+ dec := json .NewDecoder (bytes .NewReader (frame ))
475+ var raw json.RawMessage
476+ if err := dec .Decode (& raw ); err != nil {
477+ return err
478+ }
479+ var extra json.RawMessage
480+ if err := dec .Decode (& extra ); err != io .EOF {
481+ return fmt .Errorf ("invalid trailing data at the end of stream" )
482+ }
483+ return nil
484+ }
485+
441486func (c * ioConn ) SessionID () string { return "" }
442487
443488func (c * ioConn ) sessionUpdated (state ServerSessionState ) {
0 commit comments