From a18c7c4405ece63bdaf07d237d6dcc9dba57bb17 Mon Sep 17 00:00:00 2001 From: Andrew Kaster Date: Wed, 17 Apr 2024 16:40:57 -0600 Subject: [PATCH] LibCore: Let LocalSocket send and receive messages with SCM_RIGHTS These new methods combine send/receive with send_fd/receive_fd. This is the 'correct' way to use SCM_RIGHTS, rather than trying to emulate the Serenity behavior on other Unixes. --- Userland/Libraries/LibCore/Socket.cpp | 69 +++++++++++++++++++++++++++ Userland/Libraries/LibCore/Socket.h | 4 ++ 2 files changed, 73 insertions(+) diff --git a/Userland/Libraries/LibCore/Socket.cpp b/Userland/Libraries/LibCore/Socket.cpp index c193ab1689c..bc2c6ed63c8 100644 --- a/Userland/Libraries/LibCore/Socket.cpp +++ b/Userland/Libraries/LibCore/Socket.cpp @@ -10,6 +10,8 @@ namespace Core { +static constexpr size_t MAX_LOCAL_SOCKET_TRANSFER_FDS = 64; + ErrorOr Socket::create_fd(SocketDomain domain, SocketType type) { int socket_domain; @@ -362,6 +364,73 @@ ErrorOr LocalSocket::send_fd(int fd) #endif } +ErrorOr LocalSocket::send_message(ReadonlyBytes data, int flags, Vector fds) +{ + size_t const num_fds = fds.size(); + if (num_fds == 0) + return m_helper.write(data, flags | default_flags()); + if (num_fds > MAX_LOCAL_SOCKET_TRANSFER_FDS) + return Error::from_string_literal("Too many file descriptors to send"); + + auto const fd_payload_size = num_fds * sizeof(int); + + alignas(struct cmsghdr) char control_buf[CMSG_SPACE(sizeof(int) * MAX_LOCAL_SOCKET_TRANSFER_FDS)] {}; + auto* header = new (control_buf) cmsghdr { + .cmsg_len = static_cast(CMSG_LEN(fd_payload_size)), + .cmsg_level = SOL_SOCKET, + .cmsg_type = SCM_RIGHTS, + }; + memcpy(CMSG_DATA(header), fds.data(), fd_payload_size); + struct iovec iov { + .iov_base = const_cast(data.data()), + .iov_len = data.size(), + }; + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = header; + msg.msg_controllen = CMSG_LEN(fd_payload_size); + + return TRY(Core::System::sendmsg(m_helper.fd(), &msg, default_flags() | flags)); +} + +ErrorOr LocalSocket::receive_message(AK::Bytes buffer, int flags, Vector& fds) +{ + struct iovec iov { + .iov_base = buffer.data(), + .iov_len = buffer.size(), + }; + + alignas(struct cmsghdr) char control_buf[CMSG_SPACE(sizeof(int) * MAX_LOCAL_SOCKET_TRANSFER_FDS)] {}; + + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = control_buf; + msg.msg_controllen = sizeof(control_buf); + + auto nread = TRY(Core::System::recvmsg(m_helper.fd(), &msg, default_flags() | flags)); + if (nread == 0) { + m_helper.did_reach_eof_on_read(); + return buffer.trim(nread); + } + + fds.clear(); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + while (cmsg != nullptr) { + if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { + size_t num_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); + auto* fd_data = reinterpret_cast(CMSG_DATA(cmsg)); + for (size_t i = 0; i < num_fds; ++i) { + fds.append(fd_data[i]); + } + } + AK_IGNORE_DIAGNOSTIC("-Wsign-compare", cmsg = CMSG_NXTHDR(&msg, cmsg)); + } + return buffer.trim(nread); +} + ErrorOr LocalSocket::peer_pid() const { #if defined(AK_OS_MACOS) || defined(AK_OS_IOS) diff --git a/Userland/Libraries/LibCore/Socket.h b/Userland/Libraries/LibCore/Socket.h index 3a6892a9280..133610b8413 100644 --- a/Userland/Libraries/LibCore/Socket.h +++ b/Userland/Libraries/LibCore/Socket.h @@ -329,6 +329,10 @@ public: ErrorOr receive_fd(int flags); ErrorOr send_fd(int fd); + + ErrorOr receive_message(Bytes buffer, int flags, Vector& fds); + ErrorOr send_message(ReadonlyBytes msg, int flags, Vector fds = {}); + ErrorOr peer_pid() const; ErrorOr read_without_waiting(Bytes buffer);