ClientConnection.cpp 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. /*
  2. * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <AK/Badge.h>
  7. #include <RequestServer/ClientConnection.h>
  8. #include <RequestServer/Protocol.h>
  9. #include <RequestServer/Request.h>
  10. #include <RequestServer/RequestClientEndpoint.h>
  11. #include <netdb.h>
  12. namespace RequestServer {
  13. static HashMap<int, RefPtr<ClientConnection>> s_connections;
  14. ClientConnection::ClientConnection(NonnullOwnPtr<Core::Stream::LocalSocket> socket)
  15. : IPC::ClientConnection<RequestClientEndpoint, RequestServerEndpoint>(*this, move(socket), 1)
  16. {
  17. s_connections.set(1, *this);
  18. }
  19. ClientConnection::~ClientConnection()
  20. {
  21. }
  22. void ClientConnection::die()
  23. {
  24. s_connections.remove(client_id());
  25. if (s_connections.is_empty())
  26. Core::EventLoop::current().quit(0);
  27. }
  28. Messages::RequestServer::IsSupportedProtocolResponse ClientConnection::is_supported_protocol(String const& protocol)
  29. {
  30. bool supported = Protocol::find_by_name(protocol.to_lowercase());
  31. return supported;
  32. }
  33. Messages::RequestServer::StartRequestResponse ClientConnection::start_request(String const& method, URL const& url, IPC::Dictionary const& request_headers, ByteBuffer const& request_body)
  34. {
  35. if (!url.is_valid()) {
  36. dbgln("StartRequest: Invalid URL requested: '{}'", url);
  37. return { -1, Optional<IPC::File> {} };
  38. }
  39. auto* protocol = Protocol::find_by_name(url.protocol());
  40. if (!protocol) {
  41. dbgln("StartRequest: No protocol handler for URL: '{}'", url);
  42. return { -1, Optional<IPC::File> {} };
  43. }
  44. auto request = protocol->start_request(*this, method, url, request_headers.entries(), request_body);
  45. if (!request) {
  46. dbgln("StartRequest: Protocol handler failed to start request: '{}'", url);
  47. return { -1, Optional<IPC::File> {} };
  48. }
  49. auto id = request->id();
  50. auto fd = request->request_fd();
  51. m_requests.set(id, move(request));
  52. return { id, IPC::File(fd, IPC::File::CloseAfterSending) };
  53. }
  54. Messages::RequestServer::StopRequestResponse ClientConnection::stop_request(i32 request_id)
  55. {
  56. auto* request = const_cast<Request*>(m_requests.get(request_id).value_or(nullptr));
  57. bool success = false;
  58. if (request) {
  59. request->stop();
  60. m_requests.remove(request_id);
  61. success = true;
  62. }
  63. return success;
  64. }
  65. void ClientConnection::did_receive_headers(Badge<Request>, Request& request)
  66. {
  67. IPC::Dictionary response_headers;
  68. for (auto& it : request.response_headers())
  69. response_headers.add(it.key, it.value);
  70. async_headers_became_available(request.id(), move(response_headers), request.status_code());
  71. }
  72. void ClientConnection::did_finish_request(Badge<Request>, Request& request, bool success)
  73. {
  74. VERIFY(request.total_size().has_value());
  75. async_request_finished(request.id(), success, request.total_size().value());
  76. m_requests.remove(request.id());
  77. }
  78. void ClientConnection::did_progress_request(Badge<Request>, Request& request)
  79. {
  80. async_request_progress(request.id(), request.total_size(), request.downloaded_size());
  81. }
  82. void ClientConnection::did_request_certificates(Badge<Request>, Request& request)
  83. {
  84. async_certificate_requested(request.id());
  85. }
  86. Messages::RequestServer::SetCertificateResponse ClientConnection::set_certificate(i32 request_id, String const& certificate, String const& key)
  87. {
  88. auto* request = const_cast<Request*>(m_requests.get(request_id).value_or(nullptr));
  89. bool success = false;
  90. if (request) {
  91. request->set_certificate(certificate, key);
  92. success = true;
  93. }
  94. return success;
  95. }
  96. void ClientConnection::ensure_connection(URL const& url, ::RequestServer::CacheLevel const& cache_level)
  97. {
  98. if (!url.is_valid()) {
  99. dbgln("EnsureConnection: Invalid URL requested: '{}'", url);
  100. return;
  101. }
  102. if (cache_level == CacheLevel::ResolveOnly) {
  103. return Core::deferred_invoke([host = url.host()] {
  104. dbgln("EnsureConnection: DNS-preload for {}", host);
  105. (void)gethostbyname(host.characters());
  106. });
  107. }
  108. struct {
  109. URL const& m_url;
  110. void start(NonnullRefPtr<Core::Socket> socket)
  111. {
  112. auto is_tls = is<TLS::TLSv12>(*socket);
  113. auto* tls_instance = is_tls ? static_cast<TLS::TLSv12*>(socket.ptr()) : nullptr;
  114. auto is_connected = false;
  115. if (is_tls && tls_instance->is_established())
  116. is_connected = true;
  117. if (!is_tls && socket->is_connected())
  118. is_connected = true;
  119. VERIFY(!is_connected);
  120. bool did_connect;
  121. if (is_tls) {
  122. tls_instance->set_root_certificates(DefaultRootCACertificates::the().certificates());
  123. tls_instance->on_tls_connected = [socket = socket.ptr(), url = m_url, tls_instance] {
  124. tls_instance->set_on_tls_ready_to_write([socket, url](auto&) {
  125. ConnectionCache::request_did_finish(url, socket);
  126. });
  127. };
  128. tls_instance->on_tls_error = [socket = socket.ptr(), url = m_url](auto) {
  129. ConnectionCache::request_did_finish(url, socket);
  130. };
  131. did_connect = tls_instance->connect(m_url.host(), m_url.port_or_default());
  132. } else {
  133. socket->on_connected = [socket = socket.ptr(), url = m_url]() mutable {
  134. ConnectionCache::request_did_finish(url, socket);
  135. };
  136. did_connect = socket->connect(m_url.host(), m_url.port_or_default());
  137. }
  138. if (!did_connect)
  139. ConnectionCache::request_did_finish(m_url, socket);
  140. }
  141. } job { url };
  142. dbgln("EnsureConnection: Pre-connect to {}", url);
  143. auto do_preconnect = [&](auto& cache) {
  144. auto it = cache.find({ url.host(), url.port_or_default() });
  145. if (it == cache.end() || it->value->is_empty())
  146. ConnectionCache::get_or_create_connection(cache, url, job);
  147. };
  148. if (url.scheme() == "http"sv)
  149. do_preconnect(ConnectionCache::g_tcp_connection_cache);
  150. else if (url.scheme() == "https"sv)
  151. do_preconnect(ConnectionCache::g_tls_connection_cache);
  152. else
  153. dbgln("EnsureConnection: Invalid URL scheme: '{}'", url.scheme());
  154. }
  155. }