IPv4Socket.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. #include <Kernel/IPv4Socket.h>
  2. #include <Kernel/UnixTypes.h>
  3. #include <Kernel/Process.h>
  4. #include <Kernel/NetworkAdapter.h>
  5. #include <Kernel/IPv4.h>
  6. #include <Kernel/ICMP.h>
  7. #include <Kernel/TCP.h>
  8. #include <Kernel/UDP.h>
  9. #include <Kernel/ARP.h>
  10. #include <LibC/errno_numbers.h>
  11. #define IPV4_SOCKET_DEBUG
  12. Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_udp_port()
  13. {
  14. static Lockable<HashMap<word, IPv4Socket*>>* s_map;
  15. if (!s_map)
  16. s_map = new Lockable<HashMap<word, IPv4Socket*>>;
  17. return *s_map;
  18. }
  19. Lockable<HashMap<word, IPv4Socket*>>& IPv4Socket::sockets_by_tcp_port()
  20. {
  21. static Lockable<HashMap<word, IPv4Socket*>>* s_map;
  22. if (!s_map)
  23. s_map = new Lockable<HashMap<word, IPv4Socket*>>;
  24. return *s_map;
  25. }
  26. Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
  27. {
  28. static Lockable<HashTable<IPv4Socket*>>* s_table;
  29. if (!s_table)
  30. s_table = new Lockable<HashTable<IPv4Socket*>>;
  31. return *s_table;
  32. }
  33. Retained<IPv4Socket> IPv4Socket::create(int type, int protocol)
  34. {
  35. return adopt(*new IPv4Socket(type, protocol));
  36. }
  37. IPv4Socket::IPv4Socket(int type, int protocol)
  38. : Socket(AF_INET, type, protocol)
  39. , m_lock("IPv4Socket")
  40. {
  41. kprintf("%s(%u) IPv4Socket{%p} created with type=%u, protocol=%d\n", current->name().characters(), current->pid(), this, type, protocol);
  42. LOCKER(all_sockets().lock());
  43. all_sockets().resource().set(this);
  44. }
  45. IPv4Socket::~IPv4Socket()
  46. {
  47. {
  48. LOCKER(all_sockets().lock());
  49. all_sockets().resource().remove(this);
  50. }
  51. if (type() == SOCK_DGRAM) {
  52. LOCKER(sockets_by_udp_port().lock());
  53. sockets_by_udp_port().resource().remove(m_source_port);
  54. }
  55. if (type() == SOCK_STREAM) {
  56. LOCKER(sockets_by_tcp_port().lock());
  57. sockets_by_tcp_port().resource().remove(m_source_port);
  58. }
  59. }
  60. bool IPv4Socket::get_address(sockaddr* address, socklen_t* address_size)
  61. {
  62. // FIXME: Look into what fallback behavior we should have here.
  63. if (*address_size != sizeof(sockaddr_in))
  64. return false;
  65. memcpy(address, &m_destination_address, sizeof(sockaddr_in));
  66. *address_size = sizeof(sockaddr_in);
  67. return true;
  68. }
  69. KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size)
  70. {
  71. ASSERT(!is_connected());
  72. if (address_size != sizeof(sockaddr_in))
  73. return KResult(-EINVAL);
  74. if (address->sa_family != AF_INET)
  75. return KResult(-EINVAL);
  76. ASSERT_NOT_REACHED();
  77. }
  78. KResult IPv4Socket::connect(const sockaddr* address, socklen_t address_size)
  79. {
  80. ASSERT(!m_bound);
  81. if (address_size != sizeof(sockaddr_in))
  82. return KResult(-EINVAL);
  83. if (address->sa_family != AF_INET)
  84. return KResult(-EINVAL);
  85. auto& ia = *(const sockaddr_in*)address;
  86. m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
  87. m_destination_port = ntohs(ia.sin_port);
  88. if (type() == SOCK_STREAM) {
  89. // FIXME: Figure out the adapter somehow differently.
  90. auto* adapter = NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2));
  91. if (!adapter)
  92. ASSERT_NOT_REACHED();
  93. send_tcp_packet(*adapter, TCPFlags::SYN);
  94. m_tcp_state = TCPState::Connecting1;
  95. return KSuccess;
  96. }
  97. return KSuccess;
  98. }
  99. void IPv4Socket::attach_fd(SocketRole)
  100. {
  101. ++m_attached_fds;
  102. }
  103. void IPv4Socket::detach_fd(SocketRole)
  104. {
  105. --m_attached_fds;
  106. }
  107. bool IPv4Socket::can_read(SocketRole) const
  108. {
  109. return m_can_read;
  110. }
  111. ssize_t IPv4Socket::read(SocketRole, byte*, ssize_t)
  112. {
  113. ASSERT_NOT_REACHED();
  114. }
  115. ssize_t IPv4Socket::write(SocketRole, const byte*, ssize_t)
  116. {
  117. ASSERT_NOT_REACHED();
  118. }
  119. bool IPv4Socket::can_write(SocketRole) const
  120. {
  121. ASSERT_NOT_REACHED();
  122. }
  123. void IPv4Socket::allocate_source_port_if_needed()
  124. {
  125. if (m_source_port)
  126. return;
  127. if (type() == SOCK_DGRAM) {
  128. // This is not a very efficient allocation algorithm.
  129. // FIXME: Replace it with a bitmap or some other fast-paced looker-upper.
  130. LOCKER(sockets_by_udp_port().lock());
  131. for (word port = 2000; port < 60000; ++port) {
  132. auto it = sockets_by_udp_port().resource().find(port);
  133. if (it == sockets_by_udp_port().resource().end()) {
  134. m_source_port = port;
  135. sockets_by_udp_port().resource().set(port, this);
  136. return;
  137. }
  138. }
  139. ASSERT_NOT_REACHED();
  140. }
  141. if (type() == SOCK_STREAM) {
  142. // This is not a very efficient allocation algorithm.
  143. // FIXME: Replace it with a bitmap or some other fast-paced looker-upper.
  144. LOCKER(sockets_by_tcp_port().lock());
  145. for (word port = 2000; port < 60000; ++port) {
  146. auto it = sockets_by_tcp_port().resource().find(port);
  147. if (it == sockets_by_tcp_port().resource().end()) {
  148. m_source_port = port;
  149. sockets_by_tcp_port().resource().set(port, this);
  150. return;
  151. }
  152. }
  153. ASSERT_NOT_REACHED();
  154. }
  155. }
  156. struct [[gnu::packed]] TCPPseudoHeader {
  157. IPv4Address source;
  158. IPv4Address destination;
  159. byte zero;
  160. byte protocol;
  161. NetworkOrdered<word> payload_size;
  162. };
  163. NetworkOrdered<word> IPv4Socket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, word payload_size)
  164. {
  165. TCPPseudoHeader pseudo_header { source, destination, 0, (byte)IPv4Protocol::TCP, sizeof(TCPPacket) + payload_size };
  166. dword checksum = 0;
  167. auto* w = (const NetworkOrdered<word>*)&pseudo_header;
  168. for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(word); ++i) {
  169. checksum += w[i];
  170. if (checksum > 0xffff)
  171. checksum = (checksum >> 16) + (checksum & 0xffff);
  172. }
  173. w = (const NetworkOrdered<word>*)&packet;
  174. for (size_t i = 0; i < sizeof(packet) / sizeof(word); ++i) {
  175. checksum += w[i];
  176. if (checksum > 0xffff)
  177. checksum = (checksum >> 16) + (checksum & 0xffff);
  178. }
  179. ASSERT(packet.data_offset() * 4 == sizeof(TCPPacket));
  180. w = (const NetworkOrdered<word>*)packet.payload();
  181. for (size_t i = 0; i < payload_size / sizeof(word); ++i) {
  182. checksum += w[i];
  183. if (checksum > 0xffff)
  184. checksum = (checksum >> 16) + (checksum & 0xffff);
  185. }
  186. if (payload_size & 1)
  187. ASSERT_NOT_REACHED();
  188. return ~(checksum & 0xffff);
  189. }
  190. void IPv4Socket::send_tcp_packet(NetworkAdapter& adapter, word flags, const void* payload, size_t payload_size)
  191. {
  192. auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size);
  193. auto& tcp_packet = *(TCPPacket*)(buffer.pointer());
  194. tcp_packet.set_source_port(m_source_port);
  195. tcp_packet.set_destination_port(m_destination_port);
  196. tcp_packet.set_window_size(1024);
  197. tcp_packet.set_sequence_number(m_tcp_sequence_number);
  198. tcp_packet.set_data_offset(5);
  199. tcp_packet.set_flags(flags);
  200. if (flags & TCPFlags::ACK)
  201. tcp_packet.set_ack_number(m_tcp_ack_number);
  202. if (flags == TCPFlags::SYN) {
  203. ++m_tcp_sequence_number;
  204. } else {
  205. m_tcp_sequence_number += payload_size;
  206. }
  207. memcpy(tcp_packet.payload(), payload, payload_size);
  208. tcp_packet.set_checksum(compute_tcp_checksum(adapter.ipv4_address(), m_destination_address, tcp_packet, payload_size));
  209. kprintf("sending tcp packet from %s:%u to %s:%u!\n",
  210. adapter.ipv4_address().to_string().characters(),
  211. source_port(),
  212. m_destination_address.to_string().characters(),
  213. m_destination_port);
  214. adapter.send_ipv4(MACAddress(), m_destination_address, IPv4Protocol::TCP, move(buffer));
  215. }
  216. ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length)
  217. {
  218. (void)flags;
  219. if (addr && addr_length != sizeof(sockaddr_in))
  220. return -EINVAL;
  221. // FIXME: Find the adapter some better way!
  222. auto* adapter = NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2));
  223. if (!adapter) {
  224. // FIXME: Figure out which error code to return.
  225. ASSERT_NOT_REACHED();
  226. }
  227. if (addr) {
  228. if (addr->sa_family != AF_INET) {
  229. kprintf("sendto: Bad address family: %u is not AF_INET!\n", addr->sa_family);
  230. return -EAFNOSUPPORT;
  231. }
  232. auto& ia = *(const sockaddr_in*)addr;
  233. m_destination_address = IPv4Address((const byte*)&ia.sin_addr.s_addr);
  234. m_destination_port = ntohs(ia.sin_port);
  235. }
  236. allocate_source_port_if_needed();
  237. kprintf("sendto: destination=%s:%u\n", m_destination_address.to_string().characters(), m_destination_port);
  238. if (type() == SOCK_RAW) {
  239. adapter->send_ipv4(MACAddress(), m_destination_address, (IPv4Protocol)protocol(), ByteBuffer::copy((const byte*)data, data_length));
  240. return data_length;
  241. }
  242. if (type() == SOCK_DGRAM) {
  243. auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length);
  244. auto& udp_packet = *(UDPPacket*)(buffer.pointer());
  245. udp_packet.set_source_port(m_source_port);
  246. udp_packet.set_destination_port(m_destination_port);
  247. udp_packet.set_length(sizeof(UDPPacket) + data_length);
  248. memcpy(udp_packet.payload(), data, data_length);
  249. kprintf("sending as udp packet from %s:%u to %s:%u!\n",
  250. adapter->ipv4_address().to_string().characters(),
  251. source_port(),
  252. m_destination_address.to_string().characters(),
  253. m_destination_port);
  254. adapter->send_ipv4(MACAddress(), m_destination_address, IPv4Protocol::UDP, move(buffer));
  255. return data_length;
  256. }
  257. if (type() == SOCK_STREAM) {
  258. send_tcp_packet(*adapter, 0, data, data_length);
  259. return data_length;
  260. }
  261. ASSERT_NOT_REACHED();
  262. }
  263. ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sockaddr* addr, socklen_t* addr_length)
  264. {
  265. (void)flags;
  266. if (addr_length && *addr_length < sizeof(sockaddr_in))
  267. return -EINVAL;
  268. #ifdef IPV4_SOCKET_DEBUG
  269. kprintf("recvfrom: type=%d, source_port=%u\n", type(), source_port());
  270. #endif
  271. ByteBuffer packet_buffer;
  272. {
  273. LOCKER(m_lock);
  274. if (!m_receive_queue.is_empty()) {
  275. packet_buffer = m_receive_queue.take_first();
  276. m_can_read = !m_receive_queue.is_empty();
  277. }
  278. }
  279. if (packet_buffer.is_null()) {
  280. current->set_blocked_socket(this);
  281. load_receive_deadline();
  282. block(Process::BlockedReceive);
  283. Scheduler::yield();
  284. LOCKER(m_lock);
  285. if (!m_can_read) {
  286. // Unblocked due to timeout.
  287. return -EAGAIN;
  288. }
  289. ASSERT(m_can_read);
  290. ASSERT(!m_receive_queue.is_empty());
  291. packet_buffer = m_receive_queue.take_first();
  292. m_can_read = !m_receive_queue.is_empty();
  293. }
  294. ASSERT(!packet_buffer.is_null());
  295. auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.pointer());
  296. if (addr) {
  297. auto& ia = *(sockaddr_in*)addr;
  298. memcpy(&ia.sin_addr, &m_destination_address, sizeof(IPv4Address));
  299. ia.sin_family = AF_INET;
  300. ASSERT(addr_length);
  301. *addr_length = sizeof(sockaddr_in);
  302. }
  303. if (type() == SOCK_RAW) {
  304. ASSERT(buffer_length >= ipv4_packet.payload_size());
  305. memcpy(buffer, ipv4_packet.payload(), ipv4_packet.payload_size());
  306. return ipv4_packet.payload_size();
  307. }
  308. if (type() == SOCK_DGRAM) {
  309. auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
  310. ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier.
  311. ASSERT(buffer_length >= (udp_packet.length() - sizeof(UDPPacket)));
  312. if (addr) {
  313. auto& ia = *(sockaddr_in*)addr;
  314. ia.sin_port = htons(udp_packet.destination_port());
  315. }
  316. memcpy(buffer, udp_packet.payload(), udp_packet.length() - sizeof(UDPPacket));
  317. return udp_packet.length() - sizeof(UDPPacket);
  318. }
  319. if (type() == SOCK_STREAM) {
  320. auto& tcp_packet = *static_cast<const TCPPacket*>(ipv4_packet.payload());
  321. size_t payload_size = packet_buffer.size() - sizeof(TCPPacket);
  322. ASSERT(buffer_length >= payload_size);
  323. if (addr) {
  324. auto& ia = *(sockaddr_in*)addr;
  325. ia.sin_port = htons(tcp_packet.destination_port());
  326. }
  327. memcpy(buffer, tcp_packet.payload(), payload_size);
  328. return payload_size;
  329. }
  330. ASSERT_NOT_REACHED();
  331. }
  332. void IPv4Socket::did_receive(ByteBuffer&& packet)
  333. {
  334. LOCKER(m_lock);
  335. m_receive_queue.append(move(packet));
  336. m_can_read = true;
  337. #ifdef IPV4_SOCKET_DEBUG
  338. kprintf("IPv4Socket(%p): did_receive %d bytes, packets in queue: %d\n", this, packet.size(), m_receive_queue.size_slow());
  339. #endif
  340. }