Browse Source

Kernel: Make Socket::bind() take credentials as input

Andreas Kling 2 years ago
parent
commit
51318d51a4

+ 3 - 3
Kernel/Net/IPv4Socket.cpp

@@ -94,7 +94,7 @@ void IPv4Socket::get_peer_address(sockaddr* address, socklen_t* address_size)
     *address_size = sizeof(sockaddr_in);
     *address_size = sizeof(sockaddr_in);
 }
 }
 
 
-ErrorOr<void> IPv4Socket::bind(Userspace<sockaddr const*> user_address, socklen_t address_size)
+ErrorOr<void> IPv4Socket::bind(Credentials const& credentials, Userspace<sockaddr const*> user_address, socklen_t address_size)
 {
 {
     VERIFY(setup_state() == SetupState::Unstarted);
     VERIFY(setup_state() == SetupState::Unstarted);
     if (address_size != sizeof(sockaddr_in))
     if (address_size != sizeof(sockaddr_in))
@@ -107,9 +107,9 @@ ErrorOr<void> IPv4Socket::bind(Userspace<sockaddr const*> user_address, socklen_
         return set_so_error(EINVAL);
         return set_so_error(EINVAL);
 
 
     auto requested_local_port = ntohs(address.sin_port);
     auto requested_local_port = ntohs(address.sin_port);
-    if (!Process::current().is_superuser()) {
+    if (!credentials.is_superuser()) {
         if (requested_local_port > 0 && requested_local_port < 1024) {
         if (requested_local_port > 0 && requested_local_port < 1024) {
-            dbgln("UID {} attempted to bind {} to port {}", Process::current().uid(), class_name(), requested_local_port);
+            dbgln("UID {} attempted to bind {} to port {}", credentials.uid(), class_name(), requested_local_port);
             return set_so_error(EACCES);
             return set_so_error(EACCES);
         }
         }
     }
     }

+ 1 - 1
Kernel/Net/IPv4Socket.h

@@ -32,7 +32,7 @@ public:
     virtual ~IPv4Socket() override;
     virtual ~IPv4Socket() override;
 
 
     virtual ErrorOr<void> close() override;
     virtual ErrorOr<void> close() override;
-    virtual ErrorOr<void> bind(Userspace<sockaddr const*>, socklen_t) override;
+    virtual ErrorOr<void> bind(Credentials const&, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<void> connect(OpenFileDescription&, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<void> connect(OpenFileDescription&, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<void> listen(size_t) override;
     virtual ErrorOr<void> listen(size_t) override;
     virtual void get_local_address(sockaddr*, socklen_t*) override;
     virtual void get_local_address(sockaddr*, socklen_t*) override;

+ 2 - 2
Kernel/Net/LocalSocket.cpp

@@ -122,7 +122,7 @@ void LocalSocket::get_peer_address(sockaddr* address, socklen_t* address_size)
     get_local_address(address, address_size);
     get_local_address(address, address_size);
 }
 }
 
 
-ErrorOr<void> LocalSocket::bind(Userspace<sockaddr const*> user_address, socklen_t address_size)
+ErrorOr<void> LocalSocket::bind(Credentials const& credentials, Userspace<sockaddr const*> user_address, socklen_t address_size)
 {
 {
     VERIFY(setup_state() == SetupState::Unstarted);
     VERIFY(setup_state() == SetupState::Unstarted);
     if (address_size > sizeof(sockaddr_un))
     if (address_size > sizeof(sockaddr_un))
@@ -139,7 +139,7 @@ ErrorOr<void> LocalSocket::bind(Userspace<sockaddr const*> user_address, socklen
 
 
     mode_t mode = S_IFSOCK | (m_prebind_mode & 0777);
     mode_t mode = S_IFSOCK | (m_prebind_mode & 0777);
     UidAndGid owner { m_prebind_uid, m_prebind_gid };
     UidAndGid owner { m_prebind_uid, m_prebind_gid };
-    auto result = VirtualFileSystem::the().open(Process::current().credentials(), path->view(), O_CREAT | O_EXCL | O_NOFOLLOW_NOERROR, mode, Process::current().current_directory(), owner);
+    auto result = VirtualFileSystem::the().open(credentials, path->view(), O_CREAT | O_EXCL | O_NOFOLLOW_NOERROR, mode, Process::current().current_directory(), owner);
     if (result.is_error()) {
     if (result.is_error()) {
         if (result.error().code() == EEXIST)
         if (result.error().code() == EEXIST)
             return set_so_error(EADDRINUSE);
             return set_so_error(EADDRINUSE);

+ 1 - 1
Kernel/Net/LocalSocket.h

@@ -36,7 +36,7 @@ public:
     ErrorOr<NonnullOwnPtr<KString>> pseudo_path(OpenFileDescription const& description) const override;
     ErrorOr<NonnullOwnPtr<KString>> pseudo_path(OpenFileDescription const& description) const override;
 
 
     // ^Socket
     // ^Socket
-    virtual ErrorOr<void> bind(Userspace<sockaddr const*>, socklen_t) override;
+    virtual ErrorOr<void> bind(Credentials const&, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<void> connect(OpenFileDescription&, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<void> connect(OpenFileDescription&, Userspace<sockaddr const*>, socklen_t) override;
     virtual ErrorOr<void> listen(size_t) override;
     virtual ErrorOr<void> listen(size_t) override;
     virtual void get_local_address(sockaddr*, socklen_t*) override;
     virtual void get_local_address(sockaddr*, socklen_t*) override;

+ 1 - 1
Kernel/Net/Socket.h

@@ -72,7 +72,7 @@ public:
 
 
     ErrorOr<void> shutdown(int how);
     ErrorOr<void> shutdown(int how);
 
 
-    virtual ErrorOr<void> bind(Userspace<sockaddr const*>, socklen_t) = 0;
+    virtual ErrorOr<void> bind(Credentials const&, Userspace<sockaddr const*>, socklen_t) = 0;
     virtual ErrorOr<void> connect(OpenFileDescription&, Userspace<sockaddr const*>, socklen_t) = 0;
     virtual ErrorOr<void> connect(OpenFileDescription&, Userspace<sockaddr const*>, socklen_t) = 0;
     virtual ErrorOr<void> listen(size_t) = 0;
     virtual ErrorOr<void> listen(size_t) = 0;
     virtual void get_local_address(sockaddr*, socklen_t*) = 0;
     virtual void get_local_address(sockaddr*, socklen_t*) = 0;

+ 1 - 1
Kernel/Syscalls/socket.cpp

@@ -56,7 +56,7 @@ ErrorOr<FlatPtr> Process::sys$bind(int sockfd, Userspace<sockaddr const*> addres
         return ENOTSOCK;
         return ENOTSOCK;
     auto& socket = *description->socket();
     auto& socket = *description->socket();
     REQUIRE_PROMISE_FOR_SOCKET_DOMAIN(socket.domain());
     REQUIRE_PROMISE_FOR_SOCKET_DOMAIN(socket.domain());
-    TRY(socket.bind(address, address_length));
+    TRY(socket.bind(credentials(), address, address_length));
     return 0;
     return 0;
 }
 }