Kernel: Support sending filedescriptors with sendmsg(2) and SCM_RIGHTS
This is necessary to support the wayland protocol. I also moved the CMSG_* macros to the kernel API since they are used in both kernel and userspace. this does not break ntpquery/SCM_TIMESTAMP.
This commit is contained in:
parent
ae5d7f542c
commit
f20902deb3
Notes:
sideshowbarker
2024-07-17 00:04:48 +09:00
Author: https://github.com/petelliott Commit: https://github.com/SerenityOS/serenity/commit/f20902deb3 Pull-request: https://github.com/SerenityOS/serenity/pull/17471 Reviewed-by: https://github.com/ADKaster Reviewed-by: https://github.com/linusg
5 changed files with 97 additions and 37 deletions
|
@ -81,6 +81,32 @@ struct msghdr {
|
|||
int msg_flags;
|
||||
};
|
||||
|
||||
// These three are non-POSIX, but common:
|
||||
#define CMSG_ALIGN(x) (((x) + sizeof(void*) - 1) & ~(sizeof(void*) - 1))
|
||||
#define CMSG_SPACE(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + CMSG_ALIGN(x))
|
||||
#define CMSG_LEN(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + (x))
|
||||
|
||||
static inline struct cmsghdr* CMSG_FIRSTHDR(struct msghdr* msg)
|
||||
{
|
||||
if (msg->msg_controllen < sizeof(struct cmsghdr))
|
||||
return (struct cmsghdr*)0;
|
||||
return (struct cmsghdr*)msg->msg_control;
|
||||
}
|
||||
|
||||
static inline struct cmsghdr* CMSG_NXTHDR(struct msghdr* msg, struct cmsghdr* cmsg)
|
||||
{
|
||||
struct cmsghdr* next = (struct cmsghdr*)((char*)cmsg + CMSG_ALIGN(cmsg->cmsg_len));
|
||||
unsigned offset = (char*)next - (char*)msg->msg_control;
|
||||
if (msg->msg_controllen < offset + sizeof(struct cmsghdr))
|
||||
return (struct cmsghdr*)0;
|
||||
return next;
|
||||
}
|
||||
|
||||
static inline void* CMSG_DATA(struct cmsghdr* cmsg)
|
||||
{
|
||||
return (void*)(cmsg + 1);
|
||||
}
|
||||
|
||||
struct sockaddr {
|
||||
sa_family_t sa_family;
|
||||
char sa_data[14];
|
||||
|
|
|
@ -520,6 +520,26 @@ ErrorOr<NonnullLockRefPtr<OpenFileDescription>> LocalSocket::recvfd(OpenFileDesc
|
|||
return queue.take_first();
|
||||
}
|
||||
|
||||
ErrorOr<NonnullLockRefPtrVector<OpenFileDescription>> LocalSocket::recvfds(OpenFileDescription const& socket_description, int n)
|
||||
{
|
||||
MutexLocker locker(mutex());
|
||||
NonnullLockRefPtrVector<OpenFileDescription> fds;
|
||||
|
||||
auto role = this->role(socket_description);
|
||||
if (role != Role::Connected && role != Role::Accepted)
|
||||
return set_so_error(EINVAL);
|
||||
auto& queue = recvfd_queue_for(socket_description);
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (queue.is_empty())
|
||||
break;
|
||||
|
||||
fds.append(queue.take_first());
|
||||
}
|
||||
|
||||
return fds;
|
||||
}
|
||||
|
||||
ErrorOr<void> LocalSocket::try_set_path(StringView path)
|
||||
{
|
||||
m_path = TRY(KString::try_create(path));
|
||||
|
|
|
@ -28,6 +28,7 @@ public:
|
|||
|
||||
ErrorOr<void> sendfd(OpenFileDescription const& socket_description, NonnullLockRefPtr<OpenFileDescription> passing_description);
|
||||
ErrorOr<NonnullLockRefPtr<OpenFileDescription>> recvfd(OpenFileDescription const& socket_description);
|
||||
ErrorOr<NonnullLockRefPtrVector<OpenFileDescription>> recvfds(OpenFileDescription const& socket_description, int n);
|
||||
|
||||
static void for_each(Function<void(LocalSocket const&)>);
|
||||
static ErrorOr<void> try_for_each(Function<ErrorOr<void>(LocalSocket const&)>);
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
* SPDX-License-Identifier: BSD-2-Clause
|
||||
*/
|
||||
|
||||
#include <AK/ByteBuffer.h>
|
||||
#include <Kernel/FileSystem/OpenFileDescription.h>
|
||||
#include <Kernel/Net/LocalSocket.h>
|
||||
#include <Kernel/Process.h>
|
||||
|
@ -199,6 +200,24 @@ ErrorOr<FlatPtr> Process::sys$sendmsg(int sockfd, Userspace<const struct msghdr*
|
|||
Thread::current()->send_signal(SIGPIPE, &Process::current());
|
||||
return EPIPE;
|
||||
}
|
||||
|
||||
if (msg.msg_controllen > 0) {
|
||||
// Handle command messages.
|
||||
auto cmsg_buffer = TRY(ByteBuffer::create_uninitialized(msg.msg_controllen));
|
||||
TRY(copy_from_user(cmsg_buffer.data(), msg.msg_control, msg.msg_controllen));
|
||||
msg.msg_control = cmsg_buffer.data();
|
||||
for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
|
||||
if (socket.is_local() && cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
|
||||
auto& local_socket = static_cast<LocalSocket&>(socket);
|
||||
int* fds = (int*)CMSG_DATA(cmsg);
|
||||
size_t nfds = (cmsg->cmsg_len - CMSG_ALIGN(sizeof(struct cmsghdr))) / sizeof(int);
|
||||
for (size_t i = 0; i < nfds; ++i) {
|
||||
TRY(local_socket.sendfd(*description, TRY(open_file_description(fds[i]))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto data_buffer = TRY(UserOrKernelBuffer::for_user_buffer((u8*)iovs[0].iov_base, iovs[0].iov_len));
|
||||
|
||||
while (true) {
|
||||
|
@ -267,21 +286,41 @@ ErrorOr<FlatPtr> Process::sys$recvmsg(int sockfd, Userspace<struct msghdr*> user
|
|||
msg_flags |= MSG_TRUNC;
|
||||
}
|
||||
|
||||
if (socket.wants_timestamp()) {
|
||||
struct {
|
||||
cmsghdr cmsg;
|
||||
timeval timestamp;
|
||||
} cmsg_timestamp;
|
||||
socklen_t control_length = sizeof(cmsg_timestamp);
|
||||
if (msg.msg_controllen < control_length) {
|
||||
socklen_t current_cmsg_len = 0;
|
||||
auto try_add_cmsg = [&](int level, int type, void const* data, socklen_t len) -> ErrorOr<bool> {
|
||||
if (current_cmsg_len + len > msg.msg_controllen) {
|
||||
msg_flags |= MSG_CTRUNC;
|
||||
} else {
|
||||
cmsg_timestamp = { { control_length, SOL_SOCKET, SCM_TIMESTAMP }, timestamp.to_timeval() };
|
||||
TRY(copy_to_user(msg.msg_control, &cmsg_timestamp, control_length));
|
||||
return false;
|
||||
}
|
||||
TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_controllen, &control_length));
|
||||
|
||||
cmsghdr cmsg = { (socklen_t)CMSG_LEN(len), level, type };
|
||||
cmsghdr* target = (cmsghdr*)(((char*)msg.msg_control) + current_cmsg_len);
|
||||
TRY(copy_to_user(target, &cmsg));
|
||||
TRY(copy_to_user(CMSG_DATA(target), data, len));
|
||||
current_cmsg_len += CMSG_ALIGN(cmsg.cmsg_len);
|
||||
return true;
|
||||
};
|
||||
|
||||
if (socket.wants_timestamp()) {
|
||||
timeval time = timestamp.to_timeval();
|
||||
TRY(try_add_cmsg(SOL_SOCKET, SCM_TIMESTAMP, &time, sizeof(time)));
|
||||
}
|
||||
|
||||
int space_for_fds = (msg.msg_controllen - current_cmsg_len - sizeof(struct cmsghdr)) / sizeof(int);
|
||||
if (space_for_fds > 0 && socket.is_local()) {
|
||||
auto& local_socket = static_cast<LocalSocket&>(socket);
|
||||
auto descriptions = TRY(local_socket.recvfds(description, space_for_fds));
|
||||
Vector<int> fdnums;
|
||||
for (auto& description : descriptions) {
|
||||
auto fd_allocation = TRY(m_fds.with_exclusive([](auto& fds) { return fds.allocate(); }));
|
||||
m_fds.with_exclusive([&](auto& fds) { fds[fd_allocation.fd].set(description, 0); });
|
||||
fdnums.append(fd_allocation.fd);
|
||||
}
|
||||
TRY(try_add_cmsg(SOL_SOCKET, SCM_RIGHTS, fdnums.data(), fdnums.size() * sizeof(int)));
|
||||
}
|
||||
|
||||
TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_controllen, ¤t_cmsg_len));
|
||||
|
||||
TRY(copy_to_user(&user_msg.unsafe_userspace_ptr()->msg_flags, &msg_flags));
|
||||
return result.value();
|
||||
}
|
||||
|
|
|
@ -33,30 +33,4 @@ int socketpair(int domain, int type, int protocol, int sv[2]);
|
|||
int sendfd(int sockfd, int fd);
|
||||
int recvfd(int sockfd, int options);
|
||||
|
||||
// These three are non-POSIX, but common:
|
||||
#define CMSG_ALIGN(x) (((x) + sizeof(void*) - 1) & ~(sizeof(void*) - 1))
|
||||
#define CMSG_SPACE(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + CMSG_ALIGN(x))
|
||||
#define CMSG_LEN(x) (CMSG_ALIGN(sizeof(struct cmsghdr)) + (x))
|
||||
|
||||
static inline struct cmsghdr* CMSG_FIRSTHDR(struct msghdr* msg)
|
||||
{
|
||||
if (msg->msg_controllen < sizeof(struct cmsghdr))
|
||||
return 0;
|
||||
return (struct cmsghdr*)msg->msg_control;
|
||||
}
|
||||
|
||||
static inline struct cmsghdr* CMSG_NXTHDR(struct msghdr* msg, struct cmsghdr* cmsg)
|
||||
{
|
||||
struct cmsghdr* next = (struct cmsghdr*)((char*)cmsg + CMSG_ALIGN(cmsg->cmsg_len));
|
||||
unsigned offset = (char*)next - (char*)msg->msg_control;
|
||||
if (msg->msg_controllen < offset + sizeof(struct cmsghdr))
|
||||
return NULL;
|
||||
return next;
|
||||
}
|
||||
|
||||
static inline void* CMSG_DATA(struct cmsghdr* cmsg)
|
||||
{
|
||||
return (void*)(cmsg + 1);
|
||||
}
|
||||
|
||||
__END_DECLS
|
||||
|
|
Loading…
Add table
Reference in a new issue