Browse Source

Kernel: Store LocalSocket address as a KString internally

Just because we deal with sockaddr_un at the userspace API layer doesn't
mean we have to store an awkward C type internally. :^)
Andreas Kling 3 years ago
parent
commit
70b2225b3d
2 changed files with 47 additions and 21 deletions
  1. 44 20
      Kernel/Net/LocalSocket.cpp
  2. 3 1
      Kernel/Net/LocalSocket.h

+ 44 - 20
Kernel/Net/LocalSocket.cpp

@@ -55,12 +55,13 @@ KResultOr<SocketPair> LocalSocket::try_create_connected_pair(int type)
         return socket_or_error.error();
         return socket_or_error.error();
 
 
     auto socket = socket_or_error.release_value();
     auto socket = socket_or_error.release_value();
+
     auto description1_result = FileDescription::try_create(*socket);
     auto description1_result = FileDescription::try_create(*socket);
     if (description1_result.is_error())
     if (description1_result.is_error())
         return description1_result.error();
         return description1_result.error();
 
 
-    socket->m_address.sun_family = AF_LOCAL;
-    memcpy(socket->m_address.sun_path, "[socketpair]", 13);
+    if (auto result = socket->try_set_path("[socketpair]"sv); result.is_error())
+        return result;
 
 
     socket->set_acceptor(Process::current());
     socket->set_acceptor(Process::current());
     socket->set_connected(true);
     socket->set_connected(true);
@@ -107,8 +108,13 @@ LocalSocket::~LocalSocket()
 
 
 void LocalSocket::get_local_address(sockaddr* address, socklen_t* address_size)
 void LocalSocket::get_local_address(sockaddr* address, socklen_t* address_size)
 {
 {
-    size_t bytes_to_copy = min(static_cast<size_t>(*address_size), sizeof(sockaddr_un));
-    memcpy(address, &m_address, bytes_to_copy);
+    if (!m_path || m_path->is_empty()) {
+        size_t bytes_to_copy = min(static_cast<size_t>(*address_size), sizeof(sockaddr_un));
+        memset(address, 0, bytes_to_copy);
+    } else {
+        size_t bytes_to_copy = min(m_path->length(), min(static_cast<size_t>(*address_size), sizeof(sockaddr_un)));
+        memcpy(address, m_path->characters(), bytes_to_copy);
+    }
     *address_size = sizeof(sockaddr_un);
     *address_size = sizeof(sockaddr_un);
 }
 }
 
 
@@ -123,20 +129,22 @@ KResult LocalSocket::bind(Userspace<const sockaddr*> user_address, socklen_t add
     if (address_size != sizeof(sockaddr_un))
     if (address_size != sizeof(sockaddr_un))
         return set_so_error(EINVAL);
         return set_so_error(EINVAL);
 
 
-    sockaddr_un address;
+    sockaddr_un address = {};
     if (!copy_from_user(&address, user_address, sizeof(sockaddr_un)))
     if (!copy_from_user(&address, user_address, sizeof(sockaddr_un)))
         return set_so_error(EFAULT);
         return set_so_error(EFAULT);
 
 
     if (address.sun_family != AF_LOCAL)
     if (address.sun_family != AF_LOCAL)
         return set_so_error(EINVAL);
         return set_so_error(EINVAL);
 
 
-    auto path = String(address.sun_path, strnlen(address.sun_path, sizeof(address.sun_path)));
+    auto path = KString::try_create(StringView { address.sun_path, strnlen(address.sun_path, sizeof(address.sun_path)) });
+    if (!path)
+        return set_so_error(ENOMEM);
 
 
     dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) bind({})", this, path);
     dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) bind({})", this, path);
 
 
     mode_t mode = S_IFSOCK | (m_prebind_mode & 0777);
     mode_t mode = S_IFSOCK | (m_prebind_mode & 0777);
     UidAndGid owner { m_prebind_uid, m_prebind_gid };
     UidAndGid owner { m_prebind_uid, m_prebind_gid };
-    auto result = VirtualFileSystem::the().open(path, O_CREAT | O_EXCL | O_NOFOLLOW_NOERROR, mode, Process::current().current_directory(), owner);
+    auto result = VirtualFileSystem::the().open(path->view(), O_CREAT | O_EXCL | O_NOFOLLOW_NOERROR, mode, Process::current().current_directory(), owner);
     if (result.is_error()) {
     if (result.is_error()) {
         if (result.error() == EEXIST)
         if (result.error() == EEXIST)
             return set_so_error(EADDRINUSE);
             return set_so_error(EADDRINUSE);
@@ -151,7 +159,7 @@ KResult LocalSocket::bind(Userspace<const sockaddr*> user_address, socklen_t add
 
 
     m_file = move(file);
     m_file = move(file);
 
 
-    m_address = address;
+    m_path = move(path);
     m_bound = true;
     m_bound = true;
     return KSuccess;
     return KSuccess;
 }
 }
@@ -170,15 +178,22 @@ KResult LocalSocket::connect(FileDescription& description, Userspace<const socka
     if (is_connected())
     if (is_connected())
         return set_so_error(EISCONN);
         return set_so_error(EISCONN);
 
 
-    const auto& local_address = *reinterpret_cast<const sockaddr_un*>(user_address);
-    char safe_address[sizeof(local_address.sun_path) + 1] = { 0 };
-    if (!copy_from_user(&safe_address[0], &local_address.sun_path[0], sizeof(safe_address) - 1))
-        return set_so_error(EFAULT);
-    safe_address[sizeof(safe_address) - 1] = '\0';
+    OwnPtr<KString> maybe_path;
+    {
+        auto const& local_address = *reinterpret_cast<sockaddr_un const*>(user_address);
+        char safe_address[sizeof(local_address.sun_path) + 1] = { 0 };
+        if (!copy_from_user(&safe_address[0], &local_address.sun_path[0], sizeof(safe_address) - 1))
+            return set_so_error(EFAULT);
+        safe_address[sizeof(safe_address) - 1] = '\0';
+        maybe_path = KString::try_create(safe_address);
+        if (!maybe_path)
+            return set_so_error(ENOMEM);
+    }
 
 
-    dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) connect({})", this, safe_address);
+    auto path = maybe_path.release_nonnull();
+    dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) connect({})", this, *path);
 
 
-    auto description_or_error = VirtualFileSystem::the().open(safe_address, O_RDWR, 0, Process::current().current_directory());
+    auto description_or_error = VirtualFileSystem::the().open(path->view(), O_RDWR, 0, Process::current().current_directory());
     if (description_or_error.is_error())
     if (description_or_error.is_error())
         return set_so_error(ECONNREFUSED);
         return set_so_error(ECONNREFUSED);
 
 
@@ -188,8 +203,7 @@ KResult LocalSocket::connect(FileDescription& description, Userspace<const socka
     if (!m_file->inode()->socket())
     if (!m_file->inode()->socket())
         return set_so_error(ECONNREFUSED);
         return set_so_error(ECONNREFUSED);
 
 
-    m_address.sun_family = sa_family_copy;
-    memcpy(m_address.sun_path, safe_address, sizeof(m_address.sun_path));
+    m_path = move(path);
 
 
     VERIFY(m_connect_side_fd == &description);
     VERIFY(m_connect_side_fd == &description);
     set_connect_side_role(Role::Connecting);
     set_connect_side_role(Role::Connecting);
@@ -212,7 +226,7 @@ KResult LocalSocket::connect(FileDescription& description, Userspace<const socka
         return set_so_error(EINTR);
         return set_so_error(EINTR);
     }
     }
 
 
-    dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) connect({}) status is {}", this, safe_address, to_string(setup_state()));
+    dbgln_if(LOCAL_SOCKET_DEBUG, "LocalSocket({}) connect({}) status is {}", this, *m_path, to_string(setup_state()));
 
 
     if (!has_flag(unblock_flags, Thread::FileDescriptionBlocker::BlockFlags::Connect)) {
     if (!has_flag(unblock_flags, Thread::FileDescriptionBlocker::BlockFlags::Connect)) {
         set_connect_side_role(Role::None);
         set_connect_side_role(Role::None);
@@ -356,8 +370,9 @@ KResultOr<size_t> LocalSocket::recvfrom(FileDescription& description, UserOrKern
 
 
 StringView LocalSocket::socket_path() const
 StringView LocalSocket::socket_path() const
 {
 {
-    size_t len = strnlen(m_address.sun_path, sizeof(m_address.sun_path));
-    return { m_address.sun_path, len };
+    if (!m_path)
+        return {};
+    return m_path->view();
 }
 }
 
 
 String LocalSocket::absolute_path(const FileDescription& description) const
 String LocalSocket::absolute_path(const FileDescription& description) const
@@ -517,4 +532,13 @@ KResultOr<NonnullRefPtr<FileDescription>> LocalSocket::recvfd(const FileDescript
     return queue.take_first();
     return queue.take_first();
 }
 }
 
 
+KResult LocalSocket::try_set_path(StringView path)
+{
+    auto kstring = KString::try_create(path);
+    if (!kstring)
+        return ENOMEM;
+    m_path = move(kstring);
+    return KSuccess;
+}
+
 }
 }

+ 3 - 1
Kernel/Net/LocalSocket.h

@@ -69,6 +69,8 @@ private:
             evaluate_block_conditions();
             evaluate_block_conditions();
     }
     }
 
 
+    KResult try_set_path(StringView);
+
     // An open socket file on the filesystem.
     // An open socket file on the filesystem.
     RefPtr<FileDescription> m_file;
     RefPtr<FileDescription> m_file;
 
 
@@ -92,7 +94,7 @@ private:
 
 
     bool m_bound { false };
     bool m_bound { false };
     bool m_accept_side_fd_open { false };
     bool m_accept_side_fd_open { false };
-    sockaddr_un m_address { 0, { 0 } };
+    OwnPtr<KString> m_path;
 
 
     NonnullOwnPtr<DoubleBuffer> m_for_client;
     NonnullOwnPtr<DoubleBuffer> m_for_client;
     NonnullOwnPtr<DoubleBuffer> m_for_server;
     NonnullOwnPtr<DoubleBuffer> m_for_server;