RequestClient.cpp 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. /*
  2. * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
  3. *
  4. * SPDX-License-Identifier: BSD-2-Clause
  5. */
  6. #include <AK/FileStream.h>
  7. #include <LibProtocol/Request.h>
  8. #include <LibProtocol/RequestClient.h>
  9. namespace Protocol {
  10. RequestClient::RequestClient()
  11. : IPC::ServerConnection<RequestClientEndpoint, RequestServerEndpoint>(*this, "/tmp/portal/request")
  12. {
  13. handshake();
  14. }
  15. void RequestClient::handshake()
  16. {
  17. send_sync<Messages::RequestServer::Greet>();
  18. }
  19. bool RequestClient::is_supported_protocol(const String& protocol)
  20. {
  21. return send_sync<Messages::RequestServer::IsSupportedProtocol>(protocol)->supported();
  22. }
  23. template<typename RequestHashMapTraits>
  24. RefPtr<Request> RequestClient::start_request(const String& method, const String& url, const HashMap<String, String, RequestHashMapTraits>& request_headers, ReadonlyBytes request_body)
  25. {
  26. IPC::Dictionary header_dictionary;
  27. for (auto& it : request_headers)
  28. header_dictionary.add(it.key, it.value);
  29. auto response = send_sync<Messages::RequestServer::StartRequest>(method, url, header_dictionary, ByteBuffer::copy(request_body));
  30. auto request_id = response->request_id();
  31. if (request_id < 0 || !response->response_fd().has_value())
  32. return nullptr;
  33. auto response_fd = response->response_fd().value().take_fd();
  34. auto request = Request::create_from_id({}, *this, request_id);
  35. request->set_request_fd({}, response_fd);
  36. m_requests.set(request_id, request);
  37. return request;
  38. }
  39. bool RequestClient::stop_request(Badge<Request>, Request& request)
  40. {
  41. if (!m_requests.contains(request.id()))
  42. return false;
  43. return send_sync<Messages::RequestServer::StopRequest>(request.id())->success();
  44. }
  45. bool RequestClient::set_certificate(Badge<Request>, Request& request, String certificate, String key)
  46. {
  47. if (!m_requests.contains(request.id()))
  48. return false;
  49. return send_sync<Messages::RequestServer::SetCertificate>(request.id(), move(certificate), move(key))->success();
  50. }
  51. void RequestClient::handle(const Messages::RequestClient::RequestFinished& message)
  52. {
  53. RefPtr<Request> request;
  54. if ((request = m_requests.get(message.request_id()).value_or(nullptr))) {
  55. request->did_finish({}, message.success(), message.total_size());
  56. }
  57. m_requests.remove(message.request_id());
  58. }
  59. void RequestClient::handle(const Messages::RequestClient::RequestProgress& message)
  60. {
  61. if (auto request = const_cast<Request*>(m_requests.get(message.request_id()).value_or(nullptr))) {
  62. request->did_progress({}, message.total_size(), message.downloaded_size());
  63. }
  64. }
  65. void RequestClient::handle(const Messages::RequestClient::HeadersBecameAvailable& message)
  66. {
  67. if (auto request = const_cast<Request*>(m_requests.get(message.request_id()).value_or(nullptr))) {
  68. HashMap<String, String, CaseInsensitiveStringTraits> headers;
  69. message.response_headers().for_each_entry([&](auto& name, auto& value) { headers.set(name, value); });
  70. request->did_receive_headers({}, headers, message.status_code());
  71. }
  72. }
  73. Messages::RequestClient::CertificateRequestedResponse RequestClient::handle(const Messages::RequestClient::CertificateRequested& message)
  74. {
  75. if (auto request = const_cast<Request*>(m_requests.get(message.request_id()).value_or(nullptr))) {
  76. request->did_request_certificates({});
  77. }
  78. return {};
  79. }
  80. }
  81. template RefPtr<Protocol::Request> Protocol::RequestClient::start_request(const String& method, const String& url, const HashMap<String, String>& request_headers, ReadonlyBytes request_body);
  82. template RefPtr<Protocol::Request> Protocol::RequestClient::start_request(const String& method, const String& url, const HashMap<String, String, CaseInsensitiveStringTraits>& request_headers, ReadonlyBytes request_body);