Browse Source

Kernel: Add getpeername() syscall, and fix getsockname() behavior.

We were copying the raw IPv4 addresses into the wrong part of sockaddr_in,
and we didn't set sa_family or sa_port.
Andreas Kling 6 years ago
parent
commit
ae470ec955

+ 19 - 3
Kernel/Net/IPv4Socket.cpp

@@ -46,12 +46,28 @@ IPv4Socket::~IPv4Socket()
     all_sockets().resource().remove(this);
 }
 
-bool IPv4Socket::get_address(sockaddr* address, socklen_t* address_size)
+bool IPv4Socket::get_local_address(sockaddr* address, socklen_t* address_size)
 {
     // FIXME: Look into what fallback behavior we should have here.
-    if (*address_size != sizeof(sockaddr_in))
+    if (*address_size < sizeof(sockaddr_in))
         return false;
-    memcpy(address, &m_peer_address, sizeof(sockaddr_in));
+    auto& ia = (sockaddr_in&)*address;
+    ia.sin_family = AF_INET;
+    ia.sin_port = m_local_port;
+    memcpy(&ia.sin_addr, &m_local_address, sizeof(IPv4Address));
+    *address_size = sizeof(sockaddr_in);
+    return true;
+}
+
+bool IPv4Socket::get_peer_address(sockaddr* address, socklen_t* address_size)
+{
+    // FIXME: Look into what fallback behavior we should have here.
+    if (*address_size < sizeof(sockaddr_in))
+        return false;
+    auto& ia = (sockaddr_in&)*address;
+    ia.sin_family = AF_INET;
+    ia.sin_port = m_peer_port;
+    memcpy(&ia.sin_addr, &m_peer_address, sizeof(IPv4Address));
     *address_size = sizeof(sockaddr_in);
     return true;
 }

+ 2 - 1
Kernel/Net/IPv4Socket.h

@@ -22,7 +22,8 @@ public:
 
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult connect(FileDescriptor&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
-    virtual bool get_address(sockaddr*, socklen_t*) override;
+    virtual bool get_local_address(sockaddr*, socklen_t*) override;
+    virtual bool get_peer_address(sockaddr*, socklen_t*) override;
     virtual void attach(FileDescriptor&) override;
     virtual void detach(FileDescriptor&) override;
     virtual bool can_read(FileDescriptor&) const override;

+ 6 - 1
Kernel/Net/LocalSocket.cpp

@@ -24,7 +24,7 @@ LocalSocket::~LocalSocket()
 {
 }
 
-bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size)
+bool LocalSocket::get_local_address(sockaddr* address, socklen_t* address_size)
 {
     // FIXME: Look into what fallback behavior we should have here.
     if (*address_size != sizeof(sockaddr_un))
@@ -34,6 +34,11 @@ bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size)
     return true;
 }
 
+bool LocalSocket::get_peer_address(sockaddr* address, socklen_t* address_size)
+{
+    return get_local_address(address, address_size);
+}
+
 KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size)
 {
     ASSERT(!is_connected());

+ 2 - 1
Kernel/Net/LocalSocket.h

@@ -12,7 +12,8 @@ public:
 
     virtual KResult bind(const sockaddr*, socklen_t) override;
     virtual KResult connect(FileDescriptor&, const sockaddr*, socklen_t, ShouldBlock = ShouldBlock::Yes) override;
-    virtual bool get_address(sockaddr*, socklen_t*) override;
+    virtual bool get_local_address(sockaddr*, socklen_t*) override;
+    virtual bool get_peer_address(sockaddr*, socklen_t*) override;
     virtual void attach(FileDescriptor&) override;
     virtual void detach(FileDescriptor&) override;
     virtual bool can_read(FileDescriptor&) const override;

+ 2 - 1
Kernel/Net/Socket.h

@@ -30,7 +30,8 @@ public:
 
     virtual KResult bind(const sockaddr*, socklen_t) = 0;
     virtual KResult connect(FileDescriptor&, const sockaddr*, socklen_t, ShouldBlock) = 0;
-    virtual bool get_address(sockaddr*, socklen_t*) = 0;
+    virtual bool get_local_address(sockaddr*, socklen_t*) = 0;
+    virtual bool get_peer_address(sockaddr*, socklen_t*) = 0;
     virtual bool is_local() const { return false; }
     virtual bool is_ipv4() const { return false; }
     virtual void attach(FileDescriptor&) = 0;

+ 28 - 2
Kernel/Process.cpp

@@ -2120,7 +2120,7 @@ int Process::sys$accept(int accepting_socket_fd, sockaddr* address, socklen_t* a
     }
     auto accepted_socket = socket.accept();
     ASSERT(accepted_socket);
-    bool success = accepted_socket->get_address(address, address_size);
+    bool success = accepted_socket->get_local_address(address, address_size);
     ASSERT(success);
     auto accepted_socket_descriptor = FileDescriptor::create(move(accepted_socket), SocketRole::Accepted);
     // NOTE: The accepted socket inherits fd flags from the accepting socket.
@@ -2240,7 +2240,33 @@ int Process::sys$getsockname(int sockfd, sockaddr* addr, socklen_t* addrlen)
         return -ENOTSOCK;
 
     auto& socket = *descriptor->socket();
-    if (!socket.get_address(addr, addrlen))
+    if (!socket.get_local_address(addr, addrlen))
+        return -EINVAL; // FIXME: Should this be another error? I'm not sure.
+
+    return 0;
+}
+
+int Process::sys$getpeername(int sockfd, sockaddr* addr, socklen_t* addrlen)
+{
+    if (!validate_read_typed(addrlen))
+        return -EFAULT;
+
+    if (*addrlen <= 0)
+        return -EINVAL;
+
+    if (!validate_write(addr, *addrlen))
+        return -EFAULT;
+
+    auto* descriptor = file_descriptor(sockfd);
+    if (!descriptor)
+        return -EBADF;
+
+    if (!descriptor->is_socket())
+        return -ENOTSOCK;
+
+    auto& socket = *descriptor->socket();
+
+    if (!socket.get_peer_address(addr, addrlen))
         return -EINVAL; // FIXME: Should this be another error? I'm not sure.
 
     return 0;

+ 1 - 0
Kernel/Process.h

@@ -176,6 +176,7 @@ public:
     int sys$getsockopt(const Syscall::SC_getsockopt_params*);
     int sys$setsockopt(const Syscall::SC_setsockopt_params*);
     int sys$getsockname(int sockfd, sockaddr* addr, socklen_t* addrlen);
+    int sys$getpeername(int sockfd, sockaddr* addr, socklen_t* addrlen);
     int sys$restore_signal_mask(dword mask);
     int sys$create_thread(int(*)(void*), void*);
     void sys$exit_thread(int code);

+ 2 - 0
Kernel/Syscall.cpp

@@ -274,6 +274,8 @@ static dword handle(RegisterDump& regs, dword function, dword arg1, dword arg2,
         return current->process().sys$writev((int)arg1, (const struct iovec*)arg2, (int)arg3);
     case Syscall::SC_getsockname:
         return current->process().sys$getsockname((int)arg1, (sockaddr*)arg2, (socklen_t*)arg3);
+    case Syscall::SC_getpeername:
+        return current->process().sys$getpeername((int)arg1, (sockaddr*)arg2, (socklen_t*)arg3);
     default:
         kprintf("<%u> int0x82: Unknown function %u requested {%x, %x, %x}\n", current->process().pid(), function, arg1, arg2, arg3);
         break;

+ 1 - 0
Kernel/Syscall.h

@@ -105,6 +105,7 @@
     __ENUMERATE_SYSCALL(writev) \
     __ENUMERATE_SYSCALL(beep) \
     __ENUMERATE_SYSCALL(getsockname) \
+    __ENUMERATE_SYSCALL(getpeername) \
 
 
 namespace Syscall {

+ 8 - 0
LibC/sys/socket.cpp

@@ -1,6 +1,8 @@
 #include <sys/socket.h>
 #include <errno.h>
 #include <Kernel/Syscall.h>
+#include <AK/Assertions.h>
+#include <stdio.h>
 
 extern "C" {
 
@@ -78,4 +80,10 @@ int getsockname(int sockfd, struct sockaddr* addr, socklen_t* addrlen)
     __RETURN_WITH_ERRNO(rc, rc, -1);
 }
 
+int getpeername(int sockfd, struct sockaddr* addr, socklen_t* addrlen)
+{
+    int rc = syscall(SC_getpeername, sockfd, addr, addrlen);
+    __RETURN_WITH_ERRNO(rc, rc, -1);
+}
+
 }

+ 2 - 0
LibC/sys/socket.h

@@ -55,6 +55,7 @@ struct sockaddr_in {
 
 #define SO_RCVTIMEO 1
 #define SO_SNDTIMEO 2
+#define SO_KEEPALIVE 3
 
 int socket(int domain, int type, int protocol);
 int bind(int sockfd, const struct sockaddr* addr, socklen_t);
@@ -68,6 +69,7 @@ ssize_t recvfrom(int sockfd, void*, size_t, int flags, struct sockaddr*, socklen
 int getsockopt(int sockfd, int level, int option, void*, socklen_t*);
 int setsockopt(int sockfd, int level, int option, const void*, socklen_t);
 int getsockname(int sockfd, struct sockaddr*, socklen_t*);
+int getpeername(int sockfd, struct sockaddr*, socklen_t*);
 
 __END_DECLS