Skip to content

Commit 92c01de

Browse files
authored
Implement FD batching per protocol spec Section 4.1 (#4)
The jsonrpc-fdpass spec calls for a mechanism to handle more FDs than can be sent in a single sendmsg() - add support for that. Assisted-by: OpenCode (Claude Opus 4) Signed-off-by: Colin Walters <walters@verbum.org>
1 parent 6df3f9d commit 92c01de

2 files changed

Lines changed: 334 additions & 40 deletions

File tree

transport.go

Lines changed: 119 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,16 @@ import (
1414
)
1515

1616
const (
17-
// MaxFDsPerMessage is the maximum number of file descriptors per message.
18-
MaxFDsPerMessage = 8
17+
// DefaultMaxFDsPerSendmsg is the default maximum number of file descriptors
18+
// per sendmsg() call. Platform limits for SCM_RIGHTS vary (e.g., ~253 on
19+
// Linux, ~512 on macOS). We start with an optimistic value; if sendmsg()
20+
// fails with EINVAL, the batch size is automatically reduced and retried.
21+
DefaultMaxFDsPerSendmsg = 500
22+
23+
// MaxFDsPerRecvmsg is the maximum number of FDs to expect in a single
24+
// recvmsg() call. Must be at least as large as the largest platform limit.
25+
MaxFDsPerRecvmsg = 512
26+
1927
// ReadBufferSize is the size of the read buffer.
2028
ReadBufferSize = 4096
2129
)
@@ -31,13 +39,27 @@ var (
3139

3240
// Sender sends JSON-RPC messages with file descriptors over a Unix socket.
3341
type Sender struct {
34-
conn *net.UnixConn
35-
mu sync.Mutex
42+
conn *net.UnixConn
43+
mu sync.Mutex
44+
maxFDsPerSendmsg int
3645
}
3746

3847
// NewSender creates a new Sender for the given Unix connection.
3948
func NewSender(conn *net.UnixConn) *Sender {
40-
return &Sender{conn: conn}
49+
return &Sender{
50+
conn: conn,
51+
maxFDsPerSendmsg: DefaultMaxFDsPerSendmsg,
52+
}
53+
}
54+
55+
// SetMaxFDsPerSendmsg sets the maximum number of file descriptors to send per
56+
// sendmsg() call. This is primarily useful for testing FD batching behavior.
57+
// The value must be at least 1.
58+
func (s *Sender) SetMaxFDsPerSendmsg(max int) {
59+
if max < 1 {
60+
max = 1
61+
}
62+
s.maxFDsPerSendmsg = max
4163
}
4264

4365
// Send sends a message with optional file descriptors.
@@ -79,37 +101,77 @@ func (s *Sender) Send(msg *MessageWithFds) error {
79101
}
80102

81103
func (s *Sender) sendWithFDs(sockfd int, data []byte, files []*os.File) error {
82-
bytesSent := 0
83-
fdsSent := false
104+
// Extract raw FD ints
105+
allFDs := make([]int, len(files))
106+
for i, f := range files {
107+
allFDs[i] = int(f.Fd())
108+
}
84109

85-
for bytesSent < len(data) {
86-
remaining := data[bytesSent:]
110+
bytesSent := 0
111+
fdsSent := 0
112+
currentMaxFDs := s.maxFDsPerSendmsg
113+
114+
// Send data and FDs in batches. Each sendmsg can only handle a limited
115+
// number of FDs. After all data bytes are sent, remaining FDs are sent
116+
// with whitespace padding bytes per the protocol spec (Section 4.1).
117+
for bytesSent < len(data) || fdsSent < len(allFDs) {
118+
remainingData := data[bytesSent:]
119+
remainingFDs := allFDs[fdsSent:]
120+
121+
// Determine how many FDs to send in this batch
122+
fdBatchSize := len(remainingFDs)
123+
if fdBatchSize > currentMaxFDs {
124+
fdBatchSize = currentMaxFDs
125+
}
126+
fdBatch := remainingFDs[:fdBatchSize]
87127

88128
var n int
89129
var err error
90130

91-
if !fdsSent && len(files) > 0 {
92-
// First chunk with FDs: use sendmsg with ancillary data
93-
fds := make([]int, len(files))
94-
for i, f := range files {
95-
fds[i] = int(f.Fd())
131+
if len(fdBatch) > 0 {
132+
// Send with FDs using sendmsg with ancillary data
133+
rights := unix.UnixRights(fdBatch...)
134+
135+
var payload []byte
136+
if len(remainingData) > 0 {
137+
payload = remainingData
138+
} else {
139+
// All data bytes already sent; send a whitespace padding byte.
140+
// The receiver's JSON parser ignores inter-message whitespace
141+
// per RFC 8259. This is required because some systems need
142+
// non-empty data for ancillary data delivery.
143+
payload = []byte{' '}
96144
}
97145

98-
rights := unix.UnixRights(fds...)
99-
n, err = unix.SendmsgN(sockfd, remaining, rights, nil, 0)
146+
n, err = unix.SendmsgN(sockfd, payload, rights, nil, 0)
100147
if err != nil {
148+
// EINVAL with multiple FDs likely means we exceeded the
149+
// kernel's SCM_MAX_FD limit. Halve the batch size and retry.
150+
if errors.Is(err, unix.EINVAL) && fdBatchSize > 1 {
151+
currentMaxFDs = fdBatchSize / 2
152+
continue
153+
}
101154
return fmt.Errorf("sendmsg failed: %w", err)
102155
}
103-
fdsSent = true
104-
} else {
105-
// No FDs or FDs already sent: use regular send
106-
n, err = unix.Write(sockfd, remaining)
156+
fdsSent += fdBatchSize
157+
158+
// Only count actual data bytes, not the padding byte
159+
if len(remainingData) > 0 {
160+
bytesSent += n
161+
}
162+
} else if len(remainingData) > 0 {
163+
// No FDs left, just send remaining data bytes
164+
n, err = unix.Write(sockfd, remainingData)
107165
if err != nil {
108166
return fmt.Errorf("write failed: %w", err)
109167
}
168+
bytesSent += n
110169
}
170+
}
111171

112-
bytesSent += n
172+
// If we discovered a lower limit, remember it for future sends
173+
if currentMaxFDs < s.maxFDsPerSendmsg {
174+
s.maxFDsPerSendmsg = currentMaxFDs
113175
}
114176

115177
return nil
@@ -139,24 +201,37 @@ func (r *Receiver) Receive() (*MessageWithFds, error) {
139201

140202
for {
141203
// Try to parse a complete message from the buffer
142-
msg, err := r.tryParseMessage()
204+
result, err := r.tryParseMessage()
143205
if err != nil {
144206
return nil, err
145207
}
146-
if msg != nil {
147-
return msg, nil
208+
if result.msg != nil {
209+
return result.msg, nil
148210
}
149211

150-
// Need more data
212+
// Need more data — either incomplete JSON or waiting for
213+
// batched FDs from continuation sendmsg() calls.
151214
if err := r.readMoreData(); err != nil {
215+
// If we had a parsed message waiting for FDs and the
216+
// connection closed, that's a mismatched count error.
217+
if result.needFDs && errors.Is(err, ErrConnectionClosed) {
218+
return nil, fmt.Errorf("%w: connection closed while waiting for batched FDs", ErrMismatchedCount)
219+
}
152220
return nil, err
153221
}
154222
}
155223
}
156224

157-
func (r *Receiver) tryParseMessage() (*MessageWithFds, error) {
225+
// tryParseResult is used internally to communicate between tryParseMessage
226+
// and Receive about whether more FDs are needed from batched sendmsg calls.
227+
type tryParseResult struct {
228+
msg *MessageWithFds
229+
needFDs bool // true when message is parsed but FD queue is short
230+
}
231+
232+
func (r *Receiver) tryParseMessage() (*tryParseResult, error) {
158233
if len(r.buffer) == 0 {
159-
return nil, nil
234+
return &tryParseResult{}, nil
160235
}
161236

162237
// Use streaming JSON decoder to find message boundaries
@@ -166,7 +241,7 @@ func (r *Receiver) tryParseMessage() (*MessageWithFds, error) {
166241
err := decoder.Decode(&value)
167242
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
168243
// Incomplete JSON - need more data
169-
return nil, nil
244+
return &tryParseResult{}, nil
170245
}
171246
if err != nil {
172247
// Actual parse error - framing error
@@ -178,20 +253,22 @@ func (r *Receiver) tryParseMessage() (*MessageWithFds, error) {
178253
bytesConsumed := decoder.InputOffset()
179254

180255
// Extract the consumed bytes for re-parsing
181-
consumedData := r.buffer[:bytesConsumed]
182-
183-
// Remove consumed bytes from buffer
184-
r.buffer = r.buffer[bytesConsumed:]
256+
consumedData := make([]byte, bytesConsumed)
257+
copy(consumedData, r.buffer[:bytesConsumed])
185258

186259
// Read the fds count from the message
187260
fdCount := GetFDCount(value)
188261

189-
// Check we have enough FDs
262+
// Check we have enough FDs. When FD batching is in use, the sender
263+
// sends continuation sendmsg() calls with whitespace padding and more
264+
// FDs. We need to read more data to collect them.
190265
if fdCount > len(r.fdQueue) {
191-
return nil, fmt.Errorf("%w: expected %d FDs, have %d in queue",
192-
ErrMismatchedCount, fdCount, len(r.fdQueue))
266+
return &tryParseResult{needFDs: true}, nil
193267
}
194268

269+
// Remove consumed bytes from buffer
270+
r.buffer = r.buffer[bytesConsumed:]
271+
195272
// Dequeue FDs
196273
fds := make([]*os.File, fdCount)
197274
copy(fds, r.fdQueue[:fdCount])
@@ -203,9 +280,11 @@ func (r *Receiver) tryParseMessage() (*MessageWithFds, error) {
203280
return nil, err
204281
}
205282

206-
return &MessageWithFds{
207-
Message: msg,
208-
FileDescriptors: fds,
283+
return &tryParseResult{
284+
msg: &MessageWithFds{
285+
Message: msg,
286+
FileDescriptors: fds,
287+
},
209288
}, nil
210289
}
211290

@@ -251,9 +330,9 @@ func (r *Receiver) readMoreData() error {
251330

252331
func (r *Receiver) recvWithFDs(sockfd int) (int, []*os.File, error) {
253332
buf := make([]byte, ReadBufferSize)
254-
// Allocate space for control message (for up to MaxFDsPerMessage FDs)
333+
// Allocate space for control message (for up to MaxFDsPerRecvmsg FDs)
255334
// Each FD is 4 bytes (int32), use CmsgSpace to get properly aligned size
256-
oob := make([]byte, unix.CmsgSpace(MaxFDsPerMessage*4))
335+
oob := make([]byte, unix.CmsgSpace(MaxFDsPerRecvmsg*4))
257336

258337
n, oobn, _, _, err := unix.Recvmsg(sockfd, buf, oob, unix.MSG_CMSG_CLOEXEC)
259338
if err != nil {

0 commit comments

Comments
 (0)