NetworkTask.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. #include <Kernel/Lock.h>
  2. #include <Kernel/Net/ARP.h>
  3. #include <Kernel/Net/EtherType.h>
  4. #include <Kernel/Net/EthernetFrameHeader.h>
  5. #include <Kernel/Net/ICMP.h>
  6. #include <Kernel/Net/IPv4.h>
  7. #include <Kernel/Net/IPv4Socket.h>
  8. #include <Kernel/Net/LoopbackAdapter.h>
  9. #include <Kernel/Net/Routing.h>
  10. #include <Kernel/Net/TCP.h>
  11. #include <Kernel/Net/TCPSocket.h>
  12. #include <Kernel/Net/UDP.h>
  13. #include <Kernel/Net/UDPSocket.h>
  14. #include <Kernel/Process.h>
  15. //#define NETWORK_TASK_DEBUG
  16. //#define ETHERNET_DEBUG
  17. //#define ETHERNET_VERY_DEBUG
  18. //#define ARP_DEBUG
  19. //#define IPV4_DEBUG
  20. //#define ICMP_DEBUG
  21. //#define UDP_DEBUG
  22. //#define TCP_DEBUG
  23. static void handle_arp(const EthernetFrameHeader&, size_t frame_size);
  24. static void handle_ipv4(const EthernetFrameHeader&, size_t frame_size);
  25. static void handle_icmp(const EthernetFrameHeader&, const IPv4Packet&);
  26. static void handle_udp(const IPv4Packet&);
  27. static void handle_tcp(const IPv4Packet&);
  28. void NetworkTask_main()
  29. {
  30. u8 octet = 15;
  31. int pending_packets = 0;
  32. NetworkAdapter::for_each([&octet, &pending_packets](auto& adapter) {
  33. if (String(adapter.class_name()) == "LoopbackAdapter") {
  34. adapter.set_ipv4_address({ 127, 0, 0, 1 });
  35. adapter.set_ipv4_netmask({ 255, 0, 0, 0 });
  36. adapter.set_ipv4_gateway({ 0, 0, 0, 0 });
  37. } else {
  38. adapter.set_ipv4_address({ 10, 0, 2, octet++ });
  39. adapter.set_ipv4_netmask({ 255, 255, 255, 0 });
  40. adapter.set_ipv4_gateway({ 10, 0, 2, 2 });
  41. }
  42. kprintf("NetworkTask: %s network adapter found: hw=%s address=%s netmask=%s gateway=%s\n",
  43. adapter.class_name(),
  44. adapter.mac_address().to_string().characters(),
  45. adapter.ipv4_address().to_string().characters(),
  46. adapter.ipv4_netmask().to_string().characters(),
  47. adapter.ipv4_gateway().to_string().characters());
  48. adapter.on_receive = [&pending_packets]() {
  49. pending_packets++;
  50. };
  51. });
  52. auto dequeue_packet = [&pending_packets]() -> Optional<KBuffer> {
  53. Optional<KBuffer> packet;
  54. NetworkAdapter::for_each([&packet, &pending_packets](auto& adapter) {
  55. if (packet.has_value() || !adapter.has_queued_packets())
  56. return;
  57. packet = adapter.dequeue_packet();
  58. pending_packets--;
  59. #ifdef NETWORK_TASK_DEBUG
  60. kprintf("NetworkTask: Dequeued packet from %s (%d bytes)\n", adapter.name().characters(), packet.value().size());
  61. #endif
  62. });
  63. return packet;
  64. };
  65. kprintf("NetworkTask: Enter main loop.\n");
  66. for (;;) {
  67. auto packet_maybe_null = dequeue_packet();
  68. if (!packet_maybe_null.has_value()) {
  69. (void)current->block_until("Networking", [&pending_packets] {
  70. return pending_packets > 0;
  71. });
  72. continue;
  73. }
  74. auto& packet = packet_maybe_null.value();
  75. if (packet.size() < sizeof(EthernetFrameHeader)) {
  76. kprintf("NetworkTask: Packet is too small to be an Ethernet packet! (%zu)\n", packet.size());
  77. continue;
  78. }
  79. auto& eth = *(const EthernetFrameHeader*)packet.data();
  80. #ifdef ETHERNET_DEBUG
  81. kprintf("NetworkTask: From %s to %s, ether_type=%w, packet_length=%u\n",
  82. eth.source().to_string().characters(),
  83. eth.destination().to_string().characters(),
  84. eth.ether_type(),
  85. packet.size());
  86. #endif
  87. #ifdef ETHERNET_VERY_DEBUG
  88. u8* data = packet.data();
  89. for (size_t i = 0; i < packet.size(); i++) {
  90. kprintf("%b", data[i]);
  91. switch (i % 16) {
  92. case 7:
  93. kprintf(" ");
  94. break;
  95. case 15:
  96. kprintf("\n");
  97. break;
  98. default:
  99. kprintf(" ");
  100. break;
  101. }
  102. }
  103. kprintf("\n");
  104. #endif
  105. switch (eth.ether_type()) {
  106. case EtherType::ARP:
  107. handle_arp(eth, packet.size());
  108. break;
  109. case EtherType::IPv4:
  110. handle_ipv4(eth, packet.size());
  111. break;
  112. case EtherType::IPv6:
  113. // ignore
  114. break;
  115. default:
  116. kprintf("NetworkTask: Unknown ethernet type %#04x\n", eth.ether_type());
  117. }
  118. }
  119. }
  120. void handle_arp(const EthernetFrameHeader& eth, size_t frame_size)
  121. {
  122. constexpr size_t minimum_arp_frame_size = sizeof(EthernetFrameHeader) + sizeof(ARPPacket);
  123. if (frame_size < minimum_arp_frame_size) {
  124. kprintf("handle_arp: Frame too small (%d, need %d)\n", frame_size, minimum_arp_frame_size);
  125. return;
  126. }
  127. auto& packet = *static_cast<const ARPPacket*>(eth.payload());
  128. if (packet.hardware_type() != 1 || packet.hardware_address_length() != sizeof(MACAddress)) {
  129. kprintf("handle_arp: Hardware type not ethernet (%w, len=%u)\n",
  130. packet.hardware_type(),
  131. packet.hardware_address_length());
  132. return;
  133. }
  134. if (packet.protocol_type() != EtherType::IPv4 || packet.protocol_address_length() != sizeof(IPv4Address)) {
  135. kprintf("handle_arp: Protocol type not IPv4 (%w, len=%u)\n",
  136. packet.hardware_type(),
  137. packet.protocol_address_length());
  138. return;
  139. }
  140. #ifdef ARP_DEBUG
  141. kprintf("handle_arp: operation=%w, sender=%s/%s, target=%s/%s\n",
  142. packet.operation(),
  143. packet.sender_hardware_address().to_string().characters(),
  144. packet.sender_protocol_address().to_string().characters(),
  145. packet.target_hardware_address().to_string().characters(),
  146. packet.target_protocol_address().to_string().characters());
  147. #endif
  148. if (!packet.sender_hardware_address().is_zero() && !packet.sender_protocol_address().is_zero()) {
  149. // Someone has this IPv4 address. I guess we can try to remember that.
  150. // FIXME: Protect against ARP spamming.
  151. // FIXME: Support static ARP table entries.
  152. LOCKER(arp_table().lock());
  153. arp_table().resource().set(packet.sender_protocol_address(), packet.sender_hardware_address());
  154. kprintf("ARP table (%d entries):\n", arp_table().resource().size());
  155. for (auto& it : arp_table().resource()) {
  156. kprintf("%s :: %s\n", it.value.to_string().characters(), it.key.to_string().characters());
  157. }
  158. }
  159. if (packet.operation() == ARPOperation::Request) {
  160. // Who has this IP address?
  161. if (auto adapter = NetworkAdapter::from_ipv4_address(packet.target_protocol_address())) {
  162. // We do!
  163. kprintf("handle_arp: Responding to ARP request for my IPv4 address (%s)\n",
  164. adapter->ipv4_address().to_string().characters());
  165. ARPPacket response;
  166. response.set_operation(ARPOperation::Response);
  167. response.set_target_hardware_address(packet.sender_hardware_address());
  168. response.set_target_protocol_address(packet.sender_protocol_address());
  169. response.set_sender_hardware_address(adapter->mac_address());
  170. response.set_sender_protocol_address(adapter->ipv4_address());
  171. adapter->send(packet.sender_hardware_address(), response);
  172. }
  173. return;
  174. }
  175. }
  176. void handle_ipv4(const EthernetFrameHeader& eth, size_t frame_size)
  177. {
  178. constexpr size_t minimum_ipv4_frame_size = sizeof(EthernetFrameHeader) + sizeof(IPv4Packet);
  179. if (frame_size < minimum_ipv4_frame_size) {
  180. kprintf("handle_ipv4: Frame too small (%d, need %d)\n", frame_size, minimum_ipv4_frame_size);
  181. return;
  182. }
  183. auto& packet = *static_cast<const IPv4Packet*>(eth.payload());
  184. if (packet.length() < sizeof(IPv4Packet)) {
  185. kprintf("handle_ipv4: IPv4 packet too short (%u, need %u)\n", packet.length(), sizeof(IPv4Packet));
  186. return;
  187. }
  188. size_t actual_ipv4_packet_length = frame_size - sizeof(EthernetFrameHeader);
  189. if (packet.length() > actual_ipv4_packet_length) {
  190. kprintf("handle_ipv4: IPv4 packet claims to be longer than it is (%u, actually %zu)\n", packet.length(), actual_ipv4_packet_length);
  191. return;
  192. }
  193. #ifdef IPV4_DEBUG
  194. kprintf("handle_ipv4: source=%s, target=%s\n",
  195. packet.source().to_string().characters(),
  196. packet.destination().to_string().characters());
  197. #endif
  198. switch ((IPv4Protocol)packet.protocol()) {
  199. case IPv4Protocol::ICMP:
  200. return handle_icmp(eth, packet);
  201. case IPv4Protocol::UDP:
  202. return handle_udp(packet);
  203. case IPv4Protocol::TCP:
  204. return handle_tcp(packet);
  205. default:
  206. kprintf("handle_ipv4: Unhandled protocol %u\n", packet.protocol());
  207. break;
  208. }
  209. }
  210. void handle_icmp(const EthernetFrameHeader& eth, const IPv4Packet& ipv4_packet)
  211. {
  212. auto& icmp_header = *static_cast<const ICMPHeader*>(ipv4_packet.payload());
  213. #ifdef ICMP_DEBUG
  214. kprintf("handle_icmp: source=%s, destination=%s, type=%b, code=%b\n",
  215. ipv4_packet.source().to_string().characters(),
  216. ipv4_packet.destination().to_string().characters(),
  217. icmp_header.type(),
  218. icmp_header.code());
  219. #endif
  220. {
  221. LOCKER(IPv4Socket::all_sockets().lock());
  222. for (RefPtr<IPv4Socket> socket : IPv4Socket::all_sockets().resource()) {
  223. LOCKER(socket->lock());
  224. if (socket->protocol() != (unsigned)IPv4Protocol::ICMP)
  225. continue;
  226. socket->did_receive(ipv4_packet.source(), 0, KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
  227. }
  228. }
  229. auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
  230. if (!adapter)
  231. return;
  232. if (icmp_header.type() == ICMPType::EchoRequest) {
  233. auto& request = reinterpret_cast<const ICMPEchoPacket&>(icmp_header);
  234. kprintf("handle_icmp: EchoRequest from %s: id=%u, seq=%u\n",
  235. ipv4_packet.source().to_string().characters(),
  236. (u16)request.identifier,
  237. (u16)request.sequence_number);
  238. size_t icmp_packet_size = ipv4_packet.payload_size();
  239. auto buffer = ByteBuffer::create_zeroed(icmp_packet_size);
  240. auto& response = *(ICMPEchoPacket*)buffer.data();
  241. response.header.set_type(ICMPType::EchoReply);
  242. response.header.set_code(0);
  243. response.identifier = request.identifier;
  244. response.sequence_number = request.sequence_number;
  245. if (size_t icmp_payload_size = icmp_packet_size - sizeof(ICMPEchoPacket))
  246. memcpy(response.payload(), request.payload(), icmp_payload_size);
  247. response.header.set_checksum(internet_checksum(&response, icmp_packet_size));
  248. // FIXME: What is the right TTL value here? Is 64 ok? Should we use the same TTL as the echo request?
  249. adapter->send_ipv4(eth.source(), ipv4_packet.source(), IPv4Protocol::ICMP, buffer.data(), buffer.size(), 64);
  250. }
  251. }
  252. void handle_udp(const IPv4Packet& ipv4_packet)
  253. {
  254. if (ipv4_packet.payload_size() < sizeof(UDPPacket)) {
  255. kprintf("handle_udp: Packet too small (%u, need %zu)\n", ipv4_packet.payload_size());
  256. return;
  257. }
  258. auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
  259. if (!adapter) {
  260. kprintf("handle_udp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters());
  261. return;
  262. }
  263. auto& udp_packet = *static_cast<const UDPPacket*>(ipv4_packet.payload());
  264. #ifdef UDP_DEBUG
  265. kprintf("handle_udp: source=%s:%u, destination=%s:%u length=%u\n",
  266. ipv4_packet.source().to_string().characters(),
  267. udp_packet.source_port(),
  268. ipv4_packet.destination().to_string().characters(),
  269. udp_packet.destination_port(),
  270. udp_packet.length());
  271. #endif
  272. auto socket = UDPSocket::from_port(udp_packet.destination_port());
  273. if (!socket) {
  274. kprintf("handle_udp: No UDP socket for port %u\n", udp_packet.destination_port());
  275. return;
  276. }
  277. ASSERT(socket->type() == SOCK_DGRAM);
  278. ASSERT(socket->local_port() == udp_packet.destination_port());
  279. socket->did_receive(ipv4_packet.source(), udp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
  280. }
  281. void handle_tcp(const IPv4Packet& ipv4_packet)
  282. {
  283. if (ipv4_packet.payload_size() < sizeof(TCPPacket)) {
  284. kprintf("handle_tcp: IPv4 payload is too small to be a TCP packet (%u, need %zu)\n", ipv4_packet.payload_size(), sizeof(TCPPacket));
  285. return;
  286. }
  287. auto& tcp_packet = *static_cast<const TCPPacket*>(ipv4_packet.payload());
  288. size_t minimum_tcp_header_size = 5 * sizeof(u32);
  289. size_t maximum_tcp_header_size = 15 * sizeof(u32);
  290. if (tcp_packet.header_size() < minimum_tcp_header_size || tcp_packet.header_size() > maximum_tcp_header_size) {
  291. kprintf("handle_tcp: TCP packet header has invalid size %zu\n", tcp_packet.header_size());
  292. }
  293. if (ipv4_packet.payload_size() < tcp_packet.header_size()) {
  294. kprintf("handle_tcp: IPv4 payload is smaller than TCP header claims (%u, supposedly %u)\n", ipv4_packet.payload_size(), tcp_packet.header_size());
  295. return;
  296. }
  297. size_t payload_size = ipv4_packet.payload_size() - tcp_packet.header_size();
  298. #ifdef TCP_DEBUG
  299. kprintf("handle_tcp: source=%s:%u, destination=%s:%u seq_no=%u, ack_no=%u, flags=%w (%s%s%s%s), window_size=%u, payload_size=%u\n",
  300. ipv4_packet.source().to_string().characters(),
  301. tcp_packet.source_port(),
  302. ipv4_packet.destination().to_string().characters(),
  303. tcp_packet.destination_port(),
  304. tcp_packet.sequence_number(),
  305. tcp_packet.ack_number(),
  306. tcp_packet.flags(),
  307. tcp_packet.has_syn() ? "SYN " : "",
  308. tcp_packet.has_ack() ? "ACK " : "",
  309. tcp_packet.has_fin() ? "FIN " : "",
  310. tcp_packet.has_rst() ? "RST " : "",
  311. tcp_packet.window_size(),
  312. payload_size);
  313. #endif
  314. auto adapter = NetworkAdapter::from_ipv4_address(ipv4_packet.destination());
  315. if (!adapter) {
  316. kprintf("handle_tcp: this packet is not for me, it's for %s\n", ipv4_packet.destination().to_string().characters());
  317. return;
  318. }
  319. IPv4SocketTuple tuple(ipv4_packet.destination(), tcp_packet.destination_port(), ipv4_packet.source(), tcp_packet.source_port());
  320. #ifdef TCP_DEBUG
  321. kprintf("handle_tcp: looking for socket; tuple=%s\n", tuple.to_string().characters());
  322. #endif
  323. auto socket = TCPSocket::from_tuple(tuple);
  324. if (!socket) {
  325. kprintf("handle_tcp: No TCP socket for tuple %s\n", tuple.to_string().characters());
  326. return;
  327. }
  328. ASSERT(socket->type() == SOCK_STREAM);
  329. ASSERT(socket->local_port() == tcp_packet.destination_port());
  330. #ifdef TCP_DEBUG
  331. kprintf("handle_tcp: got socket; state=%s\n", socket->tuple().to_string().characters(), TCPSocket::to_string(socket->state()));
  332. #endif
  333. socket->receive_tcp_packet(tcp_packet, ipv4_packet.payload_size());
  334. switch (socket->state()) {
  335. case TCPSocket::State::Closed:
  336. kprintf("handle_tcp: unexpected flags in Closed state\n");
  337. // TODO: we may want to send an RST here, maybe as a configurable option
  338. return;
  339. case TCPSocket::State::TimeWait:
  340. kprintf("handle_tcp: unexpected flags in TimeWait state\n");
  341. socket->send_tcp_packet(TCPFlags::RST);
  342. socket->set_state(TCPSocket::State::Closed);
  343. return;
  344. case TCPSocket::State::Listen:
  345. switch (tcp_packet.flags()) {
  346. case TCPFlags::SYN: {
  347. #ifdef TCP_DEBUG
  348. kprintf("handle_tcp: incoming connection\n");
  349. #endif
  350. auto& local_address = ipv4_packet.destination();
  351. auto& peer_address = ipv4_packet.source();
  352. auto client = socket->create_client(local_address, tcp_packet.destination_port(), peer_address, tcp_packet.source_port());
  353. if (!client) {
  354. kprintf("handle_tcp: couldn't create client socket\n");
  355. return;
  356. }
  357. #ifdef TCP_DEBUG
  358. kprintf("handle_tcp: created new client socket with tuple %s\n", client->tuple().to_string().characters());
  359. #endif
  360. client->set_sequence_number(1000);
  361. client->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
  362. client->send_tcp_packet(TCPFlags::SYN | TCPFlags::ACK);
  363. client->set_state(TCPSocket::State::SynReceived);
  364. return;
  365. }
  366. default:
  367. kprintf("handle_tcp: unexpected flags in Listen state\n");
  368. // socket->send_tcp_packet(TCPFlags::RST);
  369. return;
  370. }
  371. case TCPSocket::State::SynSent:
  372. switch (tcp_packet.flags()) {
  373. case TCPFlags::SYN:
  374. socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
  375. socket->send_tcp_packet(TCPFlags::ACK);
  376. socket->set_state(TCPSocket::State::SynReceived);
  377. return;
  378. case TCPFlags::ACK | TCPFlags::SYN:
  379. socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
  380. socket->send_tcp_packet(TCPFlags::ACK);
  381. socket->set_state(TCPSocket::State::Established);
  382. socket->set_setup_state(Socket::SetupState::Completed);
  383. socket->set_connected(true);
  384. return;
  385. case TCPFlags::ACK | TCPFlags::FIN:
  386. socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
  387. socket->send_tcp_packet(TCPFlags::ACK);
  388. socket->set_state(TCPSocket::State::Closed);
  389. socket->set_error(TCPSocket::Error::FINDuringConnect);
  390. socket->set_setup_state(Socket::SetupState::Completed);
  391. return;
  392. case TCPFlags::ACK | TCPFlags::RST:
  393. socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
  394. socket->send_tcp_packet(TCPFlags::ACK);
  395. socket->set_state(TCPSocket::State::Closed);
  396. socket->set_error(TCPSocket::Error::RSTDuringConnect);
  397. socket->set_setup_state(Socket::SetupState::Completed);
  398. return;
  399. default:
  400. kprintf("handle_tcp: unexpected flags in SynSent state\n");
  401. socket->send_tcp_packet(TCPFlags::RST);
  402. socket->set_state(TCPSocket::State::Closed);
  403. socket->set_error(TCPSocket::Error::UnexpectedFlagsDuringConnect);
  404. socket->set_setup_state(Socket::SetupState::Completed);
  405. return;
  406. }
  407. case TCPSocket::State::SynReceived:
  408. switch (tcp_packet.flags()) {
  409. case TCPFlags::ACK:
  410. socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
  411. switch (socket->direction()) {
  412. case TCPSocket::Direction::Incoming:
  413. if (!socket->has_originator()) {
  414. kprintf("handle_tcp: connection doesn't have an originating socket; maybe it went away?\n");
  415. socket->send_tcp_packet(TCPFlags::RST);
  416. socket->set_state(TCPSocket::State::Closed);
  417. return;
  418. }
  419. socket->set_state(TCPSocket::State::Established);
  420. socket->set_setup_state(Socket::SetupState::Completed);
  421. socket->release_to_originator();
  422. return;
  423. case TCPSocket::Direction::Outgoing:
  424. socket->set_state(TCPSocket::State::Established);
  425. socket->set_setup_state(Socket::SetupState::Completed);
  426. socket->set_connected(true);
  427. return;
  428. default:
  429. kprintf("handle_tcp: got ACK in SynReceived state but direction is invalid (%s)\n", TCPSocket::to_string(socket->direction()));
  430. socket->send_tcp_packet(TCPFlags::RST);
  431. socket->set_state(TCPSocket::State::Closed);
  432. return;
  433. }
  434. return;
  435. default:
  436. kprintf("handle_tcp: unexpected flags in SynReceived state\n");
  437. socket->send_tcp_packet(TCPFlags::RST);
  438. socket->set_state(TCPSocket::State::Closed);
  439. return;
  440. }
  441. case TCPSocket::State::CloseWait:
  442. switch (tcp_packet.flags()) {
  443. default:
  444. kprintf("handle_tcp: unexpected flags in CloseWait state\n");
  445. socket->send_tcp_packet(TCPFlags::RST);
  446. socket->set_state(TCPSocket::State::Closed);
  447. return;
  448. }
  449. case TCPSocket::State::LastAck:
  450. switch (tcp_packet.flags()) {
  451. case TCPFlags::ACK:
  452. socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
  453. socket->set_state(TCPSocket::State::Closed);
  454. return;
  455. default:
  456. kprintf("handle_tcp: unexpected flags in LastAck state\n");
  457. socket->send_tcp_packet(TCPFlags::RST);
  458. socket->set_state(TCPSocket::State::Closed);
  459. return;
  460. }
  461. case TCPSocket::State::FinWait1:
  462. switch (tcp_packet.flags()) {
  463. case TCPFlags::ACK:
  464. socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
  465. socket->set_state(TCPSocket::State::FinWait2);
  466. return;
  467. case TCPFlags::FIN:
  468. socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
  469. socket->set_state(TCPSocket::State::Closing);
  470. return;
  471. default:
  472. kprintf("handle_tcp: unexpected flags in FinWait1 state\n");
  473. socket->send_tcp_packet(TCPFlags::RST);
  474. socket->set_state(TCPSocket::State::Closed);
  475. return;
  476. }
  477. case TCPSocket::State::FinWait2:
  478. switch (tcp_packet.flags()) {
  479. case TCPFlags::FIN:
  480. socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
  481. socket->set_state(TCPSocket::State::TimeWait);
  482. return;
  483. default:
  484. kprintf("handle_tcp: unexpected flags in FinWait2 state\n");
  485. socket->send_tcp_packet(TCPFlags::RST);
  486. socket->set_state(TCPSocket::State::Closed);
  487. return;
  488. }
  489. case TCPSocket::State::Closing:
  490. switch (tcp_packet.flags()) {
  491. case TCPFlags::ACK:
  492. socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
  493. socket->set_state(TCPSocket::State::TimeWait);
  494. return;
  495. default:
  496. kprintf("handle_tcp: unexpected flags in Closing state\n");
  497. socket->send_tcp_packet(TCPFlags::RST);
  498. socket->set_state(TCPSocket::State::Closed);
  499. return;
  500. }
  501. case TCPSocket::State::Established:
  502. if (tcp_packet.has_fin()) {
  503. if (payload_size != 0)
  504. socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
  505. socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
  506. // TODO: We should only send a FIN packet out once we're shutting
  507. // down our side of the socket, so we should change this back to
  508. // just being an ACK and a transition to CloseWait once we have a
  509. // shutdown() implementation.
  510. socket->send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
  511. socket->set_state(TCPSocket::State::Closing);
  512. socket->set_connected(false);
  513. return;
  514. }
  515. socket->set_ack_number(tcp_packet.sequence_number() + payload_size);
  516. #ifdef TCP_DEBUG
  517. kprintf("Got packet with ack_no=%u, seq_no=%u, payload_size=%u, acking it with new ack_no=%u, seq_no=%u\n",
  518. tcp_packet.ack_number(),
  519. tcp_packet.sequence_number(),
  520. payload_size,
  521. socket->ack_number(),
  522. socket->sequence_number());
  523. #endif
  524. socket->send_tcp_packet(TCPFlags::ACK);
  525. if (payload_size != 0)
  526. socket->did_receive(ipv4_packet.source(), tcp_packet.source_port(), KBuffer::copy(&ipv4_packet, sizeof(IPv4Packet) + ipv4_packet.payload_size()));
  527. }
  528. }