Skip to content

Commit 04faab2

Browse files
committed
Don't rely on sendmsg boundaries for SCM_RIGHTS
1 parent cd72868 commit 04faab2

2 files changed

Lines changed: 172 additions & 63 deletions

File tree

toolbelt/sockets.cc

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -485,23 +485,23 @@ absl::Status UnixSocket::SendFds(const std::vector<FileDescriptor> &fds,
485485
#pragma GCC diagnostic push
486486
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
487487
#endif
488-
struct msghdr msg = {.msg_iov = &iov,
489-
.msg_iovlen = 1,
490-
.msg_control = control_buf.data(),
491-
.msg_controllen =
492-
static_cast<socklen_t>(CMSG_SPACE(fds_size))};
488+
struct msghdr msg = {.msg_iov = &iov, .msg_iovlen = 1};
493489
#if defined(__clang__)
494490
#pragma clang diagnostic pop
495491
#elif defined(__GNUC__)
496492
#pragma GCC diagnostic pop
497493
#endif
498-
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
499-
cmsg->cmsg_level = SOL_SOCKET;
500-
cmsg->cmsg_type = SCM_RIGHTS;
501-
cmsg->cmsg_len = CMSG_LEN(fds_size);
502-
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
503-
for (size_t i = first_fd; i < first_fd + fds_to_send; i++) {
504-
*fdptr++ = fds[i].Fd();
494+
if (fds_to_send > 0) {
495+
msg.msg_control = control_buf.data();
496+
msg.msg_controllen = static_cast<socklen_t>(CMSG_SPACE(fds_size));
497+
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
498+
cmsg->cmsg_level = SOL_SOCKET;
499+
cmsg->cmsg_type = SCM_RIGHTS;
500+
cmsg->cmsg_len = CMSG_LEN(fds_size);
501+
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
502+
for (size_t i = first_fd; i < first_fd + fds_to_send; i++) {
503+
*fdptr++ = fds[i].Fd();
504+
}
505505
}
506506

507507
if (c != nullptr) {
@@ -531,14 +531,19 @@ absl::Status UnixSocket::ReceiveFds(std::vector<FileDescriptor> &fds,
531531

532532
int32_t num_fds_received = 0;
533533
for (;;) {
534-
std::fill(control_buf.begin(), control_buf.end(), 0);
535-
536534
// The total number of fds we need to see. This is
537535
// sent in each message, but each message contains only portion
538536
// of the total (there's a limit per message).
539-
int32_t total_fds;
540-
struct iovec iov = {.iov_base = reinterpret_cast<void *>(&total_fds),
541-
.iov_len = sizeof(int32_t)};
537+
int32_t total_fds = 0;
538+
size_t total_fds_bytes = 0;
539+
bool saw_rights = false;
540+
int num_fds = 0;
541+
542+
while (total_fds_bytes < sizeof(total_fds)) {
543+
std::fill(control_buf.begin(), control_buf.end(), 0);
544+
struct iovec iov = {
545+
.iov_base = reinterpret_cast<char *>(&total_fds) + total_fds_bytes,
546+
.iov_len = sizeof(total_fds) - total_fds_bytes};
542547

543548
#if defined(__clang__)
544549
#pragma clang diagnostic push
@@ -547,62 +552,67 @@ absl::Status UnixSocket::ReceiveFds(std::vector<FileDescriptor> &fds,
547552
#pragma GCC diagnostic push
548553
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
549554
#endif
550-
struct msghdr msg = {.msg_iov = &iov,
551-
.msg_iovlen = 1,
552-
.msg_control = control_buf.data(),
553-
.msg_controllen =
554-
static_cast<socklen_t>(control_buf.size())};
555+
struct msghdr msg = {.msg_iov = &iov,
556+
.msg_iovlen = 1,
557+
.msg_control = control_buf.data(),
558+
.msg_controllen =
559+
static_cast<socklen_t>(control_buf.size())};
555560
#if defined(__clang__)
556561
#pragma clang diagnostic pop
557562
#elif defined(__GNUC__)
558563
#pragma GCC diagnostic pop
559564
#endif
560-
if (c != nullptr) {
561-
int fd = c->Wait(fd_.Fd(), POLLIN);
562-
if (fd != fd_.Fd()) {
563-
return absl::InternalError("Interrupted");
564-
}
565-
}
566-
ssize_t n = ::recvmsg(fd_.Fd(), &msg, 0);
567-
if (n == -1) {
568-
return absl::InternalError(absl::StrFormat(
569-
"Failed to read fds to unix socket: %s", strerror(errno)));
570-
}
571-
if (n == 0) {
572-
return absl::InternalError(
573-
absl::StrFormat("EOF from socket while reading fds\n"));
574-
}
575-
576-
if ((msg.msg_flags & MSG_CTRUNC) != 0) {
577-
return absl::InternalError(
578-
"Control data was truncated while reading fds from unix socket");
579-
}
580-
581-
bool saw_rights = false;
582-
int num_fds = 0;
583-
for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr;
584-
cmsg = CMSG_NXTHDR(&msg, cmsg)) {
585-
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) {
586-
continue;
565+
if (c != nullptr) {
566+
int fd = c->Wait(fd_.Fd(), POLLIN);
567+
if (fd != fd_.Fd()) {
568+
return absl::InternalError("Interrupted");
569+
}
587570
}
588-
saw_rights = true;
589-
if (cmsg->cmsg_len < CMSG_LEN(0)) {
571+
ssize_t n = ::recvmsg(fd_.Fd(), &msg, 0);
572+
if (n == -1) {
590573
return absl::InternalError(absl::StrFormat(
591-
"Invalid SCM_RIGHTS control length %zu while reading fds",
592-
static_cast<size_t>(cmsg->cmsg_len)));
574+
"Failed to read fds to unix socket: %s", strerror(errno)));
593575
}
594-
size_t data_len = cmsg->cmsg_len - CMSG_LEN(0);
595-
if (data_len % sizeof(int) != 0) {
596-
return absl::InternalError(absl::StrFormat(
597-
"Misaligned SCM_RIGHTS control length %zu while reading fds",
598-
static_cast<size_t>(cmsg->cmsg_len)));
576+
if (n == 0) {
577+
return absl::InternalError(
578+
absl::StrFormat("EOF from socket while reading fds\n"));
599579
}
600-
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
601-
int fds_in_message = static_cast<int>(data_len / sizeof(int));
602-
for (int i = 0; i < fds_in_message; i++) {
603-
fds.emplace_back(fdptr[i]);
580+
581+
total_fds_bytes += static_cast<size_t>(n);
582+
583+
if ((msg.msg_flags & MSG_CTRUNC) != 0) {
584+
return absl::InternalError(
585+
"Control data was truncated while reading fds from unix socket");
586+
}
587+
588+
for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr;
589+
cmsg = CMSG_NXTHDR(&msg, cmsg)) {
590+
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) {
591+
continue;
592+
}
593+
saw_rights = true;
594+
if (cmsg->cmsg_len < CMSG_LEN(0)) {
595+
return absl::InternalError(absl::StrFormat(
596+
"Invalid SCM_RIGHTS control length %zu while reading fds",
597+
static_cast<size_t>(cmsg->cmsg_len)));
598+
}
599+
size_t data_len = cmsg->cmsg_len - CMSG_LEN(0);
600+
if (data_len % sizeof(int) != 0) {
601+
return absl::InternalError(absl::StrFormat(
602+
"Misaligned SCM_RIGHTS control length %zu while reading fds",
603+
static_cast<size_t>(cmsg->cmsg_len)));
604+
}
605+
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
606+
int fds_in_message = static_cast<int>(data_len / sizeof(int));
607+
for (int i = 0; i < fds_in_message; i++) {
608+
fds.emplace_back(fdptr[i]);
609+
}
610+
num_fds += fds_in_message;
604611
}
605-
num_fds += fds_in_message;
612+
}
613+
614+
if (total_fds == 0) {
615+
break;
606616
}
607617
if (!saw_rights && total_fds > 0) {
608618
return absl::InternalError(absl::StrFormat(

toolbelt/sockets_test.cc

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,105 @@ TEST(SocketsTest, UnixSocket) {
137137
remove(socket_name.c_str());
138138
}
139139

140+
TEST(SocketsTest, UnixSocketZeroFds) {
141+
char tmp[] = "/tmp/socketsXXXXXX";
142+
int fd = mkstemp(tmp);
143+
ASSERT_NE(-1, fd);
144+
std::string socket_name = tmp;
145+
close(fd);
146+
147+
unlink(socket_name.c_str());
148+
co::CoroutineScheduler scheduler;
149+
150+
toolbelt::UnixSocket listener;
151+
absl::Status status = listener.Bind(socket_name, true);
152+
ASSERT_TRUE(status.ok());
153+
154+
co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) {
155+
absl::StatusOr<toolbelt::UnixSocket> s = listener.Accept(c);
156+
ASSERT_TRUE(s.ok());
157+
auto socket = s.value();
158+
159+
std::vector<toolbelt::FileDescriptor> fds;
160+
absl::Status s2 = socket.ReceiveFds(fds, c);
161+
ASSERT_TRUE(s2.ok());
162+
ASSERT_TRUE(fds.empty());
163+
});
164+
165+
co::Coroutine outgoing(scheduler, [&socket_name](co::Coroutine* c) {
166+
toolbelt::UnixSocket socket;
167+
absl::Status s = socket.Connect(socket_name);
168+
ASSERT_TRUE(s.ok());
169+
170+
std::vector<toolbelt::FileDescriptor> fds;
171+
absl::Status s2 = socket.SendFds(fds, c);
172+
ASSERT_TRUE(s2.ok());
173+
});
174+
175+
scheduler.Run();
176+
remove(socket_name.c_str());
177+
}
178+
179+
TEST(SocketsTest, UnixSocketShortFdCountRead) {
180+
char tmp[] = "/tmp/socketsXXXXXX";
181+
int fd = mkstemp(tmp);
182+
ASSERT_NE(-1, fd);
183+
std::string socket_name = tmp;
184+
close(fd);
185+
186+
unlink(socket_name.c_str());
187+
co::CoroutineScheduler scheduler;
188+
189+
toolbelt::UnixSocket listener;
190+
absl::Status status = listener.Bind(socket_name, true);
191+
ASSERT_TRUE(status.ok());
192+
193+
co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) {
194+
absl::StatusOr<toolbelt::UnixSocket> s = listener.Accept(c);
195+
ASSERT_TRUE(s.ok());
196+
auto socket = s.value();
197+
198+
std::vector<toolbelt::FileDescriptor> fds;
199+
absl::Status s2 = socket.ReceiveFds(fds, c);
200+
ASSERT_TRUE(s2.ok());
201+
ASSERT_EQ(1, fds.size());
202+
});
203+
204+
co::Coroutine outgoing(scheduler, [&socket_name](co::Coroutine* c) {
205+
toolbelt::UnixSocket socket;
206+
absl::Status s = socket.Connect(socket_name);
207+
ASSERT_TRUE(s.ok());
208+
209+
int32_t num_fds = 1;
210+
char* num_fds_bytes = reinterpret_cast<char*>(&num_fds);
211+
int fd_to_send = dup(0);
212+
ASSERT_NE(-1, fd_to_send);
213+
214+
char control_buf[CMSG_SPACE(sizeof(int))] = {};
215+
struct iovec iov = {.iov_base = num_fds_bytes, .iov_len = 1};
216+
struct msghdr msg = {.msg_iov = &iov,
217+
.msg_iovlen = 1,
218+
.msg_control = control_buf,
219+
.msg_controllen = sizeof(control_buf)};
220+
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
221+
cmsg->cmsg_level = SOL_SOCKET;
222+
cmsg->cmsg_type = SCM_RIGHTS;
223+
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
224+
*reinterpret_cast<int*>(CMSG_DATA(cmsg)) = fd_to_send;
225+
226+
ssize_t n = sendmsg(socket.GetFileDescriptor().Fd(), &msg, 0);
227+
ASSERT_EQ(1, n);
228+
close(fd_to_send);
229+
230+
n = send(socket.GetFileDescriptor().Fd(), num_fds_bytes + 1,
231+
sizeof(num_fds) - 1, 0);
232+
ASSERT_EQ(static_cast<ssize_t>(sizeof(num_fds) - 1), n);
233+
});
234+
235+
scheduler.Run();
236+
remove(socket_name.c_str());
237+
}
238+
140239
TEST(SocketsTest, UnixSocketErrors) {
141240
toolbelt::UnixSocket socket;
142241
// Socket is inValid, all will fail.

0 commit comments

Comments
 (0)