ClientConnection.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /*
  2. * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <AK/Badge.h>
  7. #include <RequestServer/ClientConnection.h>
  8. #include <RequestServer/Protocol.h>
  9. #include <RequestServer/Request.h>
  10. #include <RequestServer/RequestClientEndpoint.h>
  11. #include <netdb.h>
  12. namespace RequestServer {
  13. static HashMap<int, RefPtr<ClientConnection>> s_connections;
  14. ClientConnection::ClientConnection(NonnullOwnPtr<Core::Stream::LocalSocket> socket)
  15. : IPC::ClientConnection<RequestClientEndpoint, RequestServerEndpoint>(*this, move(socket), 1)
  16. {
  17. s_connections.set(1, *this);
  18. }
  19. ClientConnection::~ClientConnection()
  20. {
  21. }
  22. void ClientConnection::die()
  23. {
  24. s_connections.remove(client_id());
  25. if (s_connections.is_empty())
  26. Core::EventLoop::current().quit(0);
  27. }
  28. Messages::RequestServer::IsSupportedProtocolResponse ClientConnection::is_supported_protocol(String const& protocol)
  29. {
  30. bool supported = Protocol::find_by_name(protocol.to_lowercase());
  31. return supported;
  32. }
  33. Messages::RequestServer::StartRequestResponse ClientConnection::start_request(String const& method, URL const& url, IPC::Dictionary const& request_headers, ByteBuffer const& request_body)
  34. {
  35. if (!url.is_valid()) {
  36. dbgln("StartRequest: Invalid URL requested: '{}'", url);
  37. return { -1, Optional<IPC::File> {} };
  38. }
  39. auto* protocol = Protocol::find_by_name(url.protocol());
  40. if (!protocol) {
  41. dbgln("StartRequest: No protocol handler for URL: '{}'", url);
  42. return { -1, Optional<IPC::File> {} };
  43. }
  44. auto request = protocol->start_request(*this, method, url, request_headers.entries(), request_body);
  45. if (!request) {
  46. dbgln("StartRequest: Protocol handler failed to start request: '{}'", url);
  47. return { -1, Optional<IPC::File> {} };
  48. }
  49. auto id = request->id();
  50. auto fd = request->request_fd();
  51. m_requests.set(id, move(request));
  52. return { id, IPC::File(fd, IPC::File::CloseAfterSending) };
  53. }
  54. Messages::RequestServer::StopRequestResponse ClientConnection::stop_request(i32 request_id)
  55. {
  56. auto* request = const_cast<Request*>(m_requests.get(request_id).value_or(nullptr));
  57. bool success = false;
  58. if (request) {
  59. request->stop();
  60. m_requests.remove(request_id);
  61. success = true;
  62. }
  63. return success;
  64. }
  65. void ClientConnection::did_receive_headers(Badge<Request>, Request& request)
  66. {
  67. IPC::Dictionary response_headers;
  68. for (auto& it : request.response_headers())
  69. response_headers.add(it.key, it.value);
  70. async_headers_became_available(request.id(), move(response_headers), request.status_code());
  71. }
  72. void ClientConnection::did_finish_request(Badge<Request>, Request& request, bool success)
  73. {
  74. VERIFY(request.total_size().has_value());
  75. async_request_finished(request.id(), success, request.total_size().value());
  76. m_requests.remove(request.id());
  77. }
  78. void ClientConnection::did_progress_request(Badge<Request>, Request& request)
  79. {
  80. async_request_progress(request.id(), request.total_size(), request.downloaded_size());
  81. }
  82. void ClientConnection::did_request_certificates(Badge<Request>, Request& request)
  83. {
  84. async_certificate_requested(request.id());
  85. }
  86. Messages::RequestServer::SetCertificateResponse ClientConnection::set_certificate(i32 request_id, String const& certificate, String const& key)
  87. {
  88. auto* request = const_cast<Request*>(m_requests.get(request_id).value_or(nullptr));
  89. bool success = false;
  90. if (request) {
  91. request->set_certificate(certificate, key);
  92. success = true;
  93. }
  94. return success;
  95. }
  96. void ClientConnection::ensure_connection(URL const& url, ::RequestServer::CacheLevel const& cache_level)
  97. {
  98. if (!url.is_valid()) {
  99. dbgln("EnsureConnection: Invalid URL requested: '{}'", url);
  100. return;
  101. }
  102. if (cache_level == CacheLevel::ResolveOnly) {
  103. return Core::deferred_invoke([host = url.host()] {
  104. dbgln("EnsureConnection: DNS-preload for {}", host);
  105. (void)gethostbyname(host.characters());
  106. });
  107. }
  108. struct {
  109. URL m_url;
  110. void start(Core::Stream::Socket& socket)
  111. {
  112. auto is_connected = socket.is_open();
  113. VERIFY(is_connected);
  114. ConnectionCache::request_did_finish(m_url, &socket);
  115. }
  116. void fail(Core::NetworkJob::Error error)
  117. {
  118. dbgln("Pre-connect to {} failed: {}", m_url, Core::to_string(error));
  119. }
  120. } job { url };
  121. dbgln("EnsureConnection: Pre-connect to {}", url);
  122. auto do_preconnect = [&](auto& cache) {
  123. auto it = cache.find({ url.host(), url.port_or_default() });
  124. if (it == cache.end() || it->value->is_empty())
  125. ConnectionCache::get_or_create_connection(cache, url, job);
  126. };
  127. if (url.scheme() == "http"sv)
  128. do_preconnect(ConnectionCache::g_tcp_connection_cache);
  129. else if (url.scheme() == "https"sv)
  130. do_preconnect(ConnectionCache::g_tls_connection_cache);
  131. else
  132. dbgln("EnsureConnection: Invalid URL scheme: '{}'", url.scheme());
  133. }
  134. }