@@ -485,23 +485,23 @@ absl::Status UnixSocket::SendFds(const std::vector<FileDescriptor> &fds,
485485#pragma GCC diagnostic push
486486#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
487487#endif
488- struct msghdr msg = {.msg_iov = &iov,
489- .msg_iovlen = 1 ,
490- .msg_control = control_buf.data (),
491- .msg_controllen =
492- static_cast <socklen_t >(CMSG_SPACE (fds_size))};
488+ struct msghdr msg = {.msg_iov = &iov, .msg_iovlen = 1 };
493489#if defined(__clang__)
494490#pragma clang diagnostic pop
495491#elif defined(__GNUC__)
496492#pragma GCC diagnostic pop
497493#endif
498- struct cmsghdr *cmsg = CMSG_FIRSTHDR (&msg);
499- cmsg->cmsg_level = SOL_SOCKET ;
500- cmsg->cmsg_type = SCM_RIGHTS ;
501- cmsg->cmsg_len = CMSG_LEN (fds_size);
502- int *fdptr = reinterpret_cast <int *>(CMSG_DATA (cmsg));
503- for (size_t i = first_fd; i < first_fd + fds_to_send; i++) {
504- *fdptr++ = fds[i].Fd ();
494+ if (fds_to_send > 0 ) {
495+ msg.msg_control = control_buf.data ();
496+ msg.msg_controllen = static_cast <socklen_t >(CMSG_SPACE (fds_size));
497+ struct cmsghdr *cmsg = CMSG_FIRSTHDR (&msg);
498+ cmsg->cmsg_level = SOL_SOCKET ;
499+ cmsg->cmsg_type = SCM_RIGHTS ;
500+ cmsg->cmsg_len = CMSG_LEN (fds_size);
501+ int *fdptr = reinterpret_cast <int *>(CMSG_DATA (cmsg));
502+ for (size_t i = first_fd; i < first_fd + fds_to_send; i++) {
503+ *fdptr++ = fds[i].Fd ();
504+ }
505505 }
506506
507507 if (c != nullptr ) {
@@ -531,14 +531,19 @@ absl::Status UnixSocket::ReceiveFds(std::vector<FileDescriptor> &fds,
531531
532532 int32_t num_fds_received = 0 ;
533533 for (;;) {
534- std::fill (control_buf.begin (), control_buf.end (), 0 );
535-
536534 // The total number of fds we need to see. This is
537535 // sent in each message, but each message contains only portion
538536 // of the total (there's a limit per message).
539- int32_t total_fds;
540- struct iovec iov = {.iov_base = reinterpret_cast <void *>(&total_fds),
541- .iov_len = sizeof (int32_t )};
537+ int32_t total_fds = 0 ;
538+ size_t total_fds_bytes = 0 ;
539+ bool saw_rights = false ;
540+ int num_fds = 0 ;
541+
542+ while (total_fds_bytes < sizeof (total_fds)) {
543+ std::fill (control_buf.begin (), control_buf.end (), 0 );
544+ struct iovec iov = {
545+ .iov_base = reinterpret_cast <char *>(&total_fds) + total_fds_bytes,
546+ .iov_len = sizeof (total_fds) - total_fds_bytes};
542547
543548#if defined(__clang__)
544549#pragma clang diagnostic push
@@ -547,62 +552,67 @@ absl::Status UnixSocket::ReceiveFds(std::vector<FileDescriptor> &fds,
547552#pragma GCC diagnostic push
548553#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
549554#endif
550- struct msghdr msg = {.msg_iov = &iov,
551- .msg_iovlen = 1 ,
552- .msg_control = control_buf.data (),
553- .msg_controllen =
554- static_cast <socklen_t >(control_buf.size ())};
555+ struct msghdr msg = {.msg_iov = &iov,
556+ .msg_iovlen = 1 ,
557+ .msg_control = control_buf.data (),
558+ .msg_controllen =
559+ static_cast <socklen_t >(control_buf.size ())};
555560#if defined(__clang__)
556561#pragma clang diagnostic pop
557562#elif defined(__GNUC__)
558563#pragma GCC diagnostic pop
559564#endif
560- if (c != nullptr ) {
561- int fd = c->Wait (fd_.Fd (), POLLIN );
562- if (fd != fd_.Fd ()) {
563- return absl::InternalError (" Interrupted" );
564- }
565- }
566- ssize_t n = ::recvmsg (fd_.Fd (), &msg, 0 );
567- if (n == -1 ) {
568- return absl::InternalError (absl::StrFormat (
569- " Failed to read fds to unix socket: %s" , strerror (errno)));
570- }
571- if (n == 0 ) {
572- return absl::InternalError (
573- absl::StrFormat (" EOF from socket while reading fds\n " ));
574- }
575-
576- if ((msg.msg_flags & MSG_CTRUNC ) != 0 ) {
577- return absl::InternalError (
578- " Control data was truncated while reading fds from unix socket" );
579- }
580-
581- bool saw_rights = false ;
582- int num_fds = 0 ;
583- for (struct cmsghdr *cmsg = CMSG_FIRSTHDR (&msg); cmsg != nullptr ;
584- cmsg = CMSG_NXTHDR (&msg, cmsg)) {
585- if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS ) {
586- continue ;
565+ if (c != nullptr ) {
566+ int fd = c->Wait (fd_.Fd (), POLLIN );
567+ if (fd != fd_.Fd ()) {
568+ return absl::InternalError (" Interrupted" );
569+ }
587570 }
588- saw_rights = true ;
589- if (cmsg-> cmsg_len < CMSG_LEN ( 0 ) ) {
571+ ssize_t n = :: recvmsg (fd_. Fd (), &msg, 0 ) ;
572+ if (n == - 1 ) {
590573 return absl::InternalError (absl::StrFormat (
591- " Invalid SCM_RIGHTS control length %zu while reading fds" ,
592- static_cast <size_t >(cmsg->cmsg_len )));
574+ " Failed to read fds to unix socket: %s" , strerror (errno)));
593575 }
594- size_t data_len = cmsg->cmsg_len - CMSG_LEN (0 );
595- if (data_len % sizeof (int ) != 0 ) {
596- return absl::InternalError (absl::StrFormat (
597- " Misaligned SCM_RIGHTS control length %zu while reading fds" ,
598- static_cast <size_t >(cmsg->cmsg_len )));
576+ if (n == 0 ) {
577+ return absl::InternalError (
578+ absl::StrFormat (" EOF from socket while reading fds\n " ));
599579 }
600- int *fdptr = reinterpret_cast <int *>(CMSG_DATA (cmsg));
601- int fds_in_message = static_cast <int >(data_len / sizeof (int ));
602- for (int i = 0 ; i < fds_in_message; i++) {
603- fds.emplace_back (fdptr[i]);
580+
581+ total_fds_bytes += static_cast <size_t >(n);
582+
583+ if ((msg.msg_flags & MSG_CTRUNC ) != 0 ) {
584+ return absl::InternalError (
585+ " Control data was truncated while reading fds from unix socket" );
586+ }
587+
588+ for (struct cmsghdr *cmsg = CMSG_FIRSTHDR (&msg); cmsg != nullptr ;
589+ cmsg = CMSG_NXTHDR (&msg, cmsg)) {
590+ if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS ) {
591+ continue ;
592+ }
593+ saw_rights = true ;
594+ if (cmsg->cmsg_len < CMSG_LEN (0 )) {
595+ return absl::InternalError (absl::StrFormat (
596+ " Invalid SCM_RIGHTS control length %zu while reading fds" ,
597+ static_cast <size_t >(cmsg->cmsg_len )));
598+ }
599+ size_t data_len = cmsg->cmsg_len - CMSG_LEN (0 );
600+ if (data_len % sizeof (int ) != 0 ) {
601+ return absl::InternalError (absl::StrFormat (
602+ " Misaligned SCM_RIGHTS control length %zu while reading fds" ,
603+ static_cast <size_t >(cmsg->cmsg_len )));
604+ }
605+ int *fdptr = reinterpret_cast <int *>(CMSG_DATA (cmsg));
606+ int fds_in_message = static_cast <int >(data_len / sizeof (int ));
607+ for (int i = 0 ; i < fds_in_message; i++) {
608+ fds.emplace_back (fdptr[i]);
609+ }
610+ num_fds += fds_in_message;
604611 }
605- num_fds += fds_in_message;
612+ }
613+
614+ if (total_fds == 0 ) {
615+ break ;
606616 }
607617 if (!saw_rights && total_fds > 0 ) {
608618 return absl::InternalError (absl::StrFormat (
0 commit comments