Merge branch 'tls_wip'

This commit is contained in:
loonycyborg 2021-05-18 11:31:51 +03:00
commit e5ac78527f
No known key found for this signature in database
GPG key ID: 6E8233FAB8F26D61
18 changed files with 618 additions and 294 deletions

View file

@ -54,13 +54,7 @@ option(ENABLE_TESTS "Build unit tests")
option(ENABLE_NLS "Enable building of translations" ${ENABLE_GAME})
option(ENABLE_HISTORY "Enable using GNU history for history in lua console" ON)
# boost::asio::post is new with 1.66
# Ubuntu 18.04 also only has 1.65
if(ENABLE_MYSQL)
set(BOOST_VERSION "1.66")
else()
set(BOOST_VERSION "1.65")
endif(ENABLE_MYSQL)
set(BOOST_VERSION "1.66")
# set what std version to use
if(NOT CXX_STD)
@ -77,6 +71,7 @@ if(NOT APPLE)
else()
set(OPENSSL_CRYPTO_LIBRARY "-framework Security")
endif()
find_package(OpenSSL REQUIRED)
find_package(Boost ${BOOST_VERSION} REQUIRED COMPONENTS iostreams program_options regex system thread random coroutine locale filesystem)
find_package(ICU REQUIRED COMPONENTS data i18n uc)

View file

@ -183,13 +183,7 @@ if env['distcc']:
if env['ccache']: env.Tool('ccache')
# boost::asio::post is new with 1.66
# Ubuntu 18.04 also only has 1.65
if env["forum_user_handler"]:
boost_version = "1.66"
else:
boost_version = "1.65"
boost_version = "1.66"
def SortHelpText(a, b):
return (a > b) - (a < b)
@ -386,7 +380,7 @@ if env["prereqs"]:
if(env["PLATFORM"] != 'darwin'):
# Otherwise, use Security.framework
have_server_prereqs = have_server_prereqs & conf.CheckLib("libcrypto")
have_server_prereqs = have_server_prereqs & conf.CheckLib("libcrypto") & conf.CheckLib("ssl")
env = conf.Finish()

View file

@ -181,6 +181,7 @@ if(ENABLE_GAME)
wesnoth-common
${game-external-libs}
OpenSSL::Crypto
OpenSSL::SSL
Boost::iostreams
Boost::program_options
Boost::regex
@ -219,6 +220,7 @@ if(ENABLE_TESTS)
wesnoth-common
${game-external-libs}
OpenSSL::Crypto
OpenSSL::SSL
Boost::iostreams
Boost::program_options
Boost::regex
@ -254,6 +256,7 @@ if(ENABLE_SERVER)
wesnoth-common
${server-external-libs}
OpenSSL::Crypto
OpenSSL::SSL
Boost::iostreams
Boost::program_options
Boost::regex
@ -294,6 +297,7 @@ if(ENABLE_CAMPAIGN_SERVER)
wesnoth-common
${server-external-libs}
OpenSSL::Crypto
OpenSSL::SSL
Boost::iostreams
Boost::program_options
Boost::regex

View file

@ -43,11 +43,7 @@ std::deque<boost::asio::const_buffer> split_buffer(boost::asio::streambuf::const
std::deque<boost::asio::const_buffer> buffers;
unsigned int remaining_size = boost::asio::buffer_size(source_buffer);
#if BOOST_VERSION >= 106600
const uint8_t* data = static_cast<const uint8_t*>(source_buffer.data());
#else
const uint8_t* data = boost::asio::buffer_cast<const uint8_t*>(source_buffer);
#endif
while(remaining_size > 0u) {
unsigned int size = std::min(remaining_size, chunk_size);
@ -66,8 +62,11 @@ using boost::system::system_error;
connection::connection(const std::string& host, const std::string& service)
: io_context_()
, host_(host)
, service_(service)
, resolver_(io_context_)
, socket_(io_context_)
, use_tls_(true)
, socket_(raw_socket(new raw_socket::element_type{io_context_}))
, done_(false)
, write_buf_()
, read_buf_()
@ -78,23 +77,38 @@ connection::connection(const std::string& host, const std::string& service)
, bytes_to_read_(0)
, bytes_read_(0)
{
#if BOOST_VERSION >= 106600
resolver_.async_resolve(host, service,
#else
resolver_.async_resolve(boost::asio::ip::tcp::resolver::query(host, service),
#endif
std::bind(&connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2));
boost::system::error_code ec;
auto result = resolver_.resolve(host, service, boost::asio::ip::resolver_query_base::numeric_host, ec);
if(!ec) { // if numeric resolve succeeds then we got raw ip address so TLS host name validation would never pass
use_tls_ = false;
boost::asio::post(io_context_, [this, ec, result](){ handle_resolve(ec, { result } ); } );
} else {
resolver_.async_resolve(host, service,
std::bind(&connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2));
}
LOG_NW << "Resolving hostname: " << host << '\n';
}
connection::~connection()
{
if(auto socket = utils::get_if<tls_socket>(&socket_)) {
boost::system::error_code ec;
// this sends close_notify for secure connection shutdown
(*socket)->async_shutdown([](const boost::system::error_code&) {} );
const char buffer[] = "";
// this write is needed to trigger immediate close instead of waiting for other side's close_notify
boost::asio::write(**socket, boost::asio::buffer(buffer, 0), ec);
}
}
void connection::handle_resolve(const boost::system::error_code& ec, results_type results)
{
if(ec) {
throw system_error(ec);
}
boost::asio::async_connect(socket_, results,
boost::asio::async_connect(*utils::get<raw_socket>(socket_), results,
std::bind(&connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2));
}
@ -104,11 +118,11 @@ void connection::handle_connect(const boost::system::error_code& ec, endpoint en
ERR_NW << "Tried all IPs. Giving up" << std::endl;
throw system_error(ec);
} else {
#if BOOST_VERSION >= 106600
LOG_NW << "Connected to " << endpoint.address() << '\n';
#else
LOG_NW << "Connected to " << endpoint->endpoint().address() << '\n';
#endif
if(endpoint.address().is_loopback()) {
use_tls_ = false;
}
handshake();
}
}
@ -116,30 +130,87 @@ void connection::handle_connect(const boost::system::error_code& ec, endpoint en
void connection::handshake()
{
static const uint32_t handshake = 0;
static const uint32_t tls_handshake = htonl(uint32_t(1));
boost::asio::async_write(socket_, boost::asio::buffer(reinterpret_cast<const char*>(&handshake), 4),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
boost::asio::async_write(
*utils::get<raw_socket>(socket_),
boost::asio::buffer(use_tls_ ? reinterpret_cast<const char*>(&tls_handshake) : reinterpret_cast<const char*>(&handshake), 4),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2)
);
boost::asio::async_read(socket_, boost::asio::buffer(&handshake_response_.binary, 4),
boost::asio::async_read(*utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
std::bind(&connection::handle_handshake, this, std::placeholders::_1));
}
void connection::handle_handshake(const boost::system::error_code& ec)
{
if(ec) {
if(ec == boost::asio::error::eof && use_tls_) {
// immediate disconnect likely means old server not supporting TLS handshake code
fallback_to_unencrypted();
return;
}
throw system_error(ec);
}
done_ = true;
if(use_tls_) {
if(handshake_response_.num == 0xFFFFFFFFU) {
use_tls_ = false;
handle_handshake(ec);
return;
}
if(handshake_response_.num == 0x00000000) {
tls_context_.set_default_verify_paths();
raw_socket s { std::move(utils::get<raw_socket>(socket_)) };
tls_socket ts { new tls_socket::element_type { std::move(*s), tls_context_ } };
socket_ = std::move(ts);
auto& socket { *utils::get<tls_socket>(socket_) };
socket.set_verify_mode(
boost::asio::ssl::verify_peer |
boost::asio::ssl::verify_fail_if_no_peer_cert
);
#if BOOST_VERSION >= 107300
socket.set_verify_callback(boost::asio::ssl::host_name_verification(host_));
#else
socket.set_verify_callback(boost::asio::ssl::rfc2818_verification(host_));
#endif
socket.async_handshake(boost::asio::ssl::stream_base::client, [this](const boost::system::error_code& ec) {
if(ec) {
throw system_error(ec);
}
done_ = true;
});
return;
}
fallback_to_unencrypted();
} else {
done_ = true;
}
}
void connection::fallback_to_unencrypted()
{
assert(use_tls_ == true);
use_tls_ = false;
boost::asio::ip::tcp::endpoint endpoint { utils::get<raw_socket>(socket_)->remote_endpoint() };
utils::get<raw_socket>(socket_)->close();
utils::get<raw_socket>(socket_)->async_connect(endpoint,
std::bind(&connection::handle_connect, this, std::placeholders::_1, endpoint));
}
void connection::transfer(const config& request, config& response)
{
#if BOOST_VERSION >= 106600
io_context_.restart();
#else
io_context_.reset();
#endif
done_ = false;
write_buf_.reset(new boost::asio::streambuf);
@ -154,19 +225,22 @@ void connection::transfer(const config& request, config& response)
auto bufs = split_buffer(write_buf_->data());
bufs.push_front(boost::asio::buffer(reinterpret_cast<const char*>(&payload_size_), 4));
boost::asio::async_write(socket_, bufs,
std::bind(&connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
utils::visit([this, &bufs, &response](auto&& socket) {
boost::asio::async_write(*socket, bufs,
std::bind(&connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
boost::asio::async_read(socket_, *read_buf_,
std::bind(&connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_read, this, std::placeholders::_1, std::placeholders::_2, std::ref(response)));
boost::asio::async_read(*socket, *read_buf_,
std::bind(&connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&connection::handle_read, this, std::placeholders::_1, std::placeholders::_2, std::ref(response)));
}, socket_);
}
void connection::cancel()
{
if(socket_.is_open()) {
boost::system::error_code ec;
utils::visit([](auto&& socket) {
if(socket->lowest_layer().is_open()) {
boost::system::error_code ec;
#ifdef _MSC_VER
// Silence warning about boost::asio::basic_socket<Protocol>::cancel always
@ -174,15 +248,16 @@ void connection::cancel()
#pragma warning(push)
#pragma warning(disable:4996)
#endif
socket_.cancel(ec);
socket->lowest_layer().cancel(ec);
#ifdef _MSC_VER
#pragma warning(pop)
#endif
if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
}
}
}
}, socket_);
bytes_to_write_ = 0;
bytes_written_ = 0;
bytes_to_read_ = 0;

View file

@ -31,14 +31,12 @@
#endif
#include "exceptions.hpp"
#include "utils/variant.hpp"
#if BOOST_VERSION >= 106600
#include <boost/asio/io_context.hpp>
#else
#include <boost/asio/io_service.hpp>
#endif
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/streambuf.hpp>
#include <boost/asio/ssl.hpp>
class config;
@ -68,6 +66,7 @@ public:
* @param service Service identifier such as "80" or "http"
*/
connection(const std::string& host, const std::string& service);
~connection();
void transfer(const config& request, config& response);
@ -103,6 +102,14 @@ public:
return done_;
}
/** True if connection is currently using TLS and thus is allowed to send cleartext passwords or auth tokens */
bool using_tls() const
{
// Calling this function before connection is ready may return wrong result
assert(done_);
return utils::holds_alternative<tls_socket>(socket_);
}
std::size_t bytes_to_write() const
{
return bytes_to_write_;
@ -124,30 +131,28 @@ public:
}
private:
#if BOOST_VERSION >= 106600
boost::asio::io_context io_context_;
#else
boost::asio::io_service io_context_;
#endif
std::string host_;
const std::string service_;
typedef boost::asio::ip::tcp::resolver resolver;
resolver resolver_;
typedef boost::asio::ip::tcp::socket socket;
socket socket_;
boost::asio::ssl::context tls_context_ { boost::asio::ssl::context::sslv23 };
typedef std::unique_ptr<boost::asio::ip::tcp::socket> raw_socket;
typedef std::unique_ptr<boost::asio::ssl::stream<raw_socket::element_type>> tls_socket;
typedef utils::variant<raw_socket, tls_socket> any_socket;
bool use_tls_;
any_socket socket_;
bool done_;
std::unique_ptr<boost::asio::streambuf> write_buf_;
std::unique_ptr<boost::asio::streambuf> read_buf_;
#if BOOST_VERSION >= 106600
using results_type = resolver::results_type;
using endpoint = const boost::asio::ip::tcp::endpoint&;
#else
using results_type = resolver::iterator;
using endpoint = resolver::iterator;
#endif
void handle_resolve(const boost::system::error_code& ec, results_type results);
void handle_connect(const boost::system::error_code& ec, endpoint endpoint);
@ -157,6 +162,8 @@ private:
data_union handshake_response_;
void fallback_to_unencrypted();
std::size_t is_write_complete(const boost::system::error_code& error, std::size_t bytes_transferred);
void handle_write(const boost::system::error_code& ec, std::size_t bytes_transferred);

View file

@ -466,45 +466,60 @@ void server::load_config()
user_handler_.reset(new fuh(user_handler));
}
#endif
load_tls_config(cfg_);
}
std::ostream& operator<<(std::ostream& o, const server::request& r)
{
o << '[' << r.addr << ' ' << r.cmd << "] ";
o << '[' << (utils::holds_alternative<tls_socket_ptr>(r.sock) ? "+" : "") << r.addr << ' ' << r.cmd << "] ";
return o;
}
void server::handle_new_client(tls_socket_ptr socket)
{
boost::asio::spawn(io_service_, [this, socket](boost::asio::yield_context yield) {
serve_requests(socket, yield);
});
}
void server::handle_new_client(socket_ptr socket)
{
boost::asio::spawn(io_service_, [this, socket](boost::asio::yield_context yield) {
while(true) {
boost::system::error_code ec;
auto doc { coro_receive_doc(socket, yield[ec]) };
if(check_error(ec, socket) || !doc) return;
serve_requests(socket, yield);
});
}
config data;
read(data, doc->output());
template<class Socket>
void server::serve_requests(Socket socket, boost::asio::yield_context yield)
{
while(true) {
boost::system::error_code ec;
auto doc { coro_receive_doc(socket, yield[ec]) };
if(check_error(ec, socket) || !doc) return;
config::all_children_iterator i = data.ordered_begin();
config data;
read(data, doc->output());
if(i != data.ordered_end()) {
// We only handle the first child.
const config::any_child& c = *i;
config::all_children_iterator i = data.ordered_begin();
request_handlers_table::const_iterator j
= handlers_.find(c.key);
if(i != data.ordered_end()) {
// We only handle the first child.
const config::any_child& c = *i;
if(j != handlers_.end()) {
// Call the handler.
request req{c.key, c.cfg, socket, yield};
auto st = service_timer(req);
j->second(this, req);
} else {
send_error("Unrecognized [" + c.key + "] request.",socket);
}
request_handlers_table::const_iterator j
= handlers_.find(c.key);
if(j != handlers_.end()) {
// Call the handler.
request req{c.key, c.cfg, socket, yield};
auto st = service_timer(req);
j->second(this, req);
} else {
send_error("Unrecognized [" + c.key + "] request.",socket);
}
}
});
}
}
#ifndef _WIN32
@ -803,24 +818,28 @@ bool server::ignore_address_stats(const std::string& addr) const
return false;
}
void server::send_message(const std::string& msg, socket_ptr sock)
void server::send_message(const std::string& msg, const any_socket_ptr& sock)
{
const auto& escaped_msg = simple_wml_escape(msg);
simple_wml::document doc;
doc.root().add_child("message").set_attr_dup("message", escaped_msg.c_str());
async_send_doc_queued(sock, doc);
utils::visit([this, &doc](auto&& sock) { async_send_doc_queued(sock, doc); }, sock);
}
void server::send_error(const std::string& msg, socket_ptr sock)
inline std::string client_address(const any_socket_ptr& sock) {
return utils::visit([](auto&& sock) { return client_address(sock); }, sock);
}
void server::send_error(const std::string& msg, const any_socket_ptr& sock)
{
ERR_CS << "[" << client_address(sock) << "] " << msg << '\n';
const auto& escaped_msg = simple_wml_escape(msg);
simple_wml::document doc;
doc.root().add_child("error").set_attr_dup("message", escaped_msg.c_str());
async_send_doc_queued(sock, doc);
utils::visit([this, &doc](auto&& sock) { async_send_doc_queued(sock, doc); }, sock);
}
void server::send_error(const std::string& msg, const std::string& extra_data, unsigned int status_code, socket_ptr sock)
void server::send_error(const std::string& msg, const std::string& extra_data, unsigned int status_code, const any_socket_ptr& sock)
{
const std::string& status_hex = formatter()
<< "0x" << std::setfill('0') << std::setw(2*sizeof(unsigned int)) << std::hex
@ -838,7 +857,7 @@ void server::send_error(const std::string& msg, const std::string& extra_data, u
err_cfg.set_attr_dup("extra_data", escaped_extra_data.c_str());
err_cfg.set_attr_dup("status_code", escaped_status_str.c_str());
async_send_doc_queued(sock, doc);
utils::visit([this, &doc](auto&& sock) { async_send_doc_queued(sock, doc); }, sock);
}
config& server::get_addon(const std::string& id)
@ -912,7 +931,7 @@ void server::handle_server_id(const server::request& req)
simple_wml::document doc(wml.c_str(), simple_wml::INIT_STATIC);
doc.compress();
async_send_doc_queued(req.sock, doc);
utils::visit([this, &doc](auto&& sock) { async_send_doc_queued(sock, doc); }, req.sock);
}
void server::handle_request_campaign_list(const server::request& req)
@ -1013,7 +1032,7 @@ void server::handle_request_campaign_list(const server::request& req)
simple_wml::document doc(wml.c_str(), simple_wml::INIT_STATIC);
doc.compress();
async_send_doc_queued(req.sock, doc);
utils::visit([this, &doc](auto&& sock) { async_send_doc_queued(sock, doc); }, req.sock);
}
void server::handle_request_campaign(const server::request& req)
@ -1125,9 +1144,11 @@ void server::handle_request_campaign(const server::request& req)
LOG_CS << req << "Sending add-on '" << name << "' version: " << from << " -> " << to << " (delta)\n";
boost::system::error_code ec;
coro_send_doc(req.sock, doc, req.yield[ec]);
if(check_error(ec, req.sock)) return;
if(utils::visit([this, &req, &doc](auto && sock) {
boost::system::error_code ec;
coro_send_doc(sock, doc, req.yield[ec]);
return check_error(ec, sock);
}, req.sock)) return;
full_pack_path.clear();
}
@ -1143,9 +1164,11 @@ void server::handle_request_campaign(const server::request& req)
}
LOG_CS << req << "Sending add-on '" << name << "' version: " << to << " size: " << full_pack_size / 1024 << " KiB\n";
boost::system::error_code ec;
coro_send_file(req.sock, full_pack_path, req.yield[ec]);
if(check_error(ec, req.sock)) return;
if(utils::visit([this, &req, &full_pack_path](auto&& socket) {
boost::system::error_code ec;
coro_send_file(socket, full_pack_path, req.yield[ec]);
return check_error(ec, socket);
}, req.sock)) return;
}
// Clients doing upgrades or some other specific thing shouldn't bump
@ -1197,9 +1220,11 @@ void server::handle_request_campaign_hash(const server::request& req)
}
LOG_CS << req << "Sending add-on hash index for '" << req.cfg["name"] << "' size: " << file_size / 1024 << " KiB\n";
boost::system::error_code ec;
coro_send_file(req.sock, path, req.yield[ec]);
if(check_error(ec, req.sock)) return;
if(utils::visit([this, &path, &req](auto&& socket) {
boost::system::error_code ec;
coro_send_file(socket, path, req.yield[ec]);
return check_error(ec, socket);
}, req.sock)) return;
}
}

View file

@ -57,7 +57,7 @@ public:
const std::string& cmd;
const config& cfg;
const socket_ptr sock;
const any_socket_ptr sock;
const std::string addr;
/**
@ -80,14 +80,15 @@ public:
* TO A CONST OBJECT, since some code may modify it directly for
* performance reasons.
*/
template<class Socket>
request(const std::string& reqcmd,
config& reqcfg,
socket_ptr reqsock,
Socket reqsock,
boost::asio::yield_context yield)
: cmd(reqcmd)
, cfg(reqcfg)
, sock(reqsock)
, addr(client_address(sock))
, addr(client_address(reqsock))
, yield(yield)
{}
};
@ -138,6 +139,10 @@ private:
boost::asio::basic_waitable_timer<std::chrono::steady_clock> flush_timer_;
void handle_new_client(socket_ptr socket);
void handle_new_client(tls_socket_ptr socket);
template<class Socket>
void serve_requests(Socket socket, boost::asio::yield_context yield);
#ifndef _WIN32
void handle_read_from_fifo(const boost::system::error_code& error, std::size_t bytes_transferred);
@ -240,7 +245,7 @@ private:
* The WML sent consists of a document containing a single @p [message]
* child with a @a message attribute holding the value of @a msg.
*/
void send_message(const std::string& msg, socket_ptr sock);
void send_message(const std::string& msg, const any_socket_ptr& sock);
/**
* Send a client an error message.
@ -250,7 +255,7 @@ private:
* sending the error to the client, a line with the client IP and message
* is recorded to the server log.
*/
void send_error(const std::string& msg, socket_ptr sock);
void send_error(const std::string& msg, const any_socket_ptr& sock);
/**
* Send a client an error message.
@ -262,7 +267,7 @@ private:
* addition to sending the error to the client, a line with the client IP
* and message is recorded to the server log.
*/
void send_error(const std::string& msg, const std::string& extra_data, unsigned int status_code, socket_ptr sock);
void send_error(const std::string& msg, const std::string& extra_data, unsigned int status_code, const any_socket_ptr& sock);
};
} // end namespace campaignd

View file

@ -15,6 +15,7 @@
#ifdef HAVE_MYSQLPP
#include "server/common/forum_user_handler.hpp"
#include "server/wesnothd/server.hpp"
#include "hash.hpp"
#include "log.hpp"
#include "config.hpp"
@ -207,10 +208,10 @@ std::string fuh::get_tournaments(){
return conn_.get_tournaments();
}
void fuh::async_get_and_send_game_history(boost::asio::io_service& io_service, server_base& s_base, socket_ptr player_socket, int player_id, int offset) {
boost::asio::post([this, &s_base, player_socket, player_id, offset, &io_service] {
boost::asio::post(io_service, [player_socket, &s_base, doc = conn_.get_game_history(player_id, offset)]{
s_base.async_send_doc_queued(player_socket, *doc);
void fuh::async_get_and_send_game_history(boost::asio::io_service& io_service, wesnothd::server& s, wesnothd::player_iterator player, int player_id, int offset) {
boost::asio::post([this, &s, player, player_id, offset, &io_service] {
boost::asio::post(io_service, [player, &s, doc = conn_.get_game_history(player_id, offset)]{
s.send_to_player(player, *doc);
});
});
}

View file

@ -125,12 +125,12 @@ public:
* The result is then posted back to the main boost::asio thread to be sent to the requesting player.
*
* @param io_service The boost io_service to use to post the query results back to the main boost::asio thread.
* @param s_base The server instance the player is connected to.
* @param player_socket The socket use to communicate with the player's client.
* @param s The server instance the player is connected to.
* @param player The player iterator used to communicate with the player's client.
* @param player_id The forum ID of the player to get the game history for.
* @param offset Where to start returning rows to the client from the query results.
*/
void async_get_and_send_game_history(boost::asio::io_service& io_service, server_base& s_base, socket_ptr player_socket, int player_id, int offset);
void async_get_and_send_game_history(boost::asio::io_service& io_service, wesnothd::server& s, wesnothd::player_iterator player, int player_id, int offset);
/**
* Inserts game related information.

View file

@ -15,6 +15,7 @@
#include "server/common/server_base.hpp"
#include "log.hpp"
#include "serialization/parser.hpp"
#include "filesystem.hpp"
#ifdef HAVE_CONFIG_H
@ -128,19 +129,50 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp::
DBG_SERVER << client_address(socket) << "\tnew connection tentatively accepted\n";
boost::shared_array<char> handshake(new char[4]);
async_read(*socket, boost::asio::buffer(handshake.get(), 4), yield[error]);
uint32_t protocol_version;
uint32_t handshake_response;
any_socket_ptr final_socket;
async_read(*socket, boost::asio::buffer(reinterpret_cast<std::byte*>(&protocol_version), 4), yield[error]);
if(check_error(error, socket))
return;
if(memcmp(handshake.get(), "\0\0\0\0", 4) != 0) {
ERR_SERVER << client_address(socket) << "\tincorrect handshake\n";
return;
switch(ntohl(protocol_version)) {
case 0:
async_write(*socket, boost::asio::buffer(handshake_response_.buf, 4), yield[error]);
if(check_error(error, socket)) return;
final_socket = socket;
break;
case 1:
if(!tls_enabled_) {
ERR_SERVER << client_address(socket) << "\tTLS requested by client but not enabled on server\n";
handshake_response = 0xFFFFFFFFU;
} else {
handshake_response = 0x00000000;
}
async_write(*socket, boost::asio::buffer(reinterpret_cast<const std::byte*>(&handshake_response), 4), yield[error]);
if(check_error(error, socket)) return;
if(!tls_enabled_) { // continue with unencrypted connection if TLS disabled
final_socket = socket;
break;
}
final_socket = tls_socket_ptr { new tls_socket_ptr::element_type(std::move(*socket), tls_context_) };
utils::get<tls_socket_ptr>(final_socket)->async_handshake(boost::asio::ssl::stream_base::server, yield[error]);
if(error) {
ERR_SERVER << "TLS handshake failed: " << error.message() << "\n";
return;
}
break;
default:
ERR_SERVER << client_address(socket) << "\tincorrect handshake\n";
return;
}
async_write(*socket, boost::asio::buffer(handshake_response_.buf, 4), yield[error]);
if(!check_error(error, socket)) {
utils::visit([this](auto&& socket) {
const std::string ip = client_address(socket);
const std::string reason = is_ip_banned(ip);
@ -153,10 +185,14 @@ void server_base::serve(boost::asio::yield_context yield, boost::asio::ip::tcp::
async_send_error(socket, "Too many connections from your IP.");
return;
} else {
DBG_SERVER << ip << "\tnew connection fully accepted\n";
if constexpr (utils::decayed_is_same<tls_socket_ptr, decltype(socket)>) {
DBG_SERVER << ip << "\tnew encrypted connection fully accepted\n";
} else {
DBG_SERVER << ip << "\tnew connection fully accepted\n";
}
this->handle_new_client(socket);
}
}
}, final_socket);
}
#ifndef _WIN32
@ -190,27 +226,28 @@ void server_base::run() {
}
}
std::string client_address(const socket_ptr socket)
template<class SocketPtr> std::string client_address(SocketPtr socket)
{
boost::system::error_code error;
std::string result = socket->remote_endpoint(error).address().to_string();
std::string result = socket->lowest_layer().remote_endpoint(error).address().to_string();
if(error)
return "<unknown address>";
else
return result;
}
bool check_error(const boost::system::error_code& error, socket_ptr socket)
template<class SocketPtr> bool check_error(const boost::system::error_code& error, SocketPtr socket)
{
if(error) {
if(error == boost::asio::error::eof)
LOG_SERVER << client_address(socket) << "\tconnection closed\n";
LOG_SERVER << log_address(socket) << "\tconnection closed\n";
else
ERR_SERVER << client_address(socket) << "\t" << error.message() << "\n";
ERR_SERVER << log_address(socket) << "\t" << error.message() << "\n";
return true;
}
return false;
}
template bool check_error<tls_socket_ptr>(const boost::system::error_code& error, tls_socket_ptr socket);
namespace {
@ -234,10 +271,10 @@ void info_table_into_simple_wml(simple_wml::document& doc, const std::string& pa
* @param doc
* @param yield The function will suspend on write operation using this yield context
*/
void server_base::coro_send_doc(socket_ptr socket, simple_wml::document& doc, boost::asio::yield_context yield)
template<class SocketPtr> void server_base::coro_send_doc(SocketPtr socket, simple_wml::document& doc, boost::asio::yield_context yield)
{
if(dump_wml) {
std::cout << "Sending WML to " << client_address(socket) << ": \n" << doc.output() << std::endl;
std::cout << "Sending WML to " << log_address(socket) << ": \n" << doc.output() << std::endl;
}
try {
@ -261,9 +298,39 @@ void server_base::coro_send_doc(socket_ptr socket, simple_wml::document& doc, bo
throw;
}
}
template void server_base::coro_send_doc<socket_ptr>(socket_ptr socket, simple_wml::document& doc, boost::asio::yield_context yield);
template void server_base::coro_send_doc<tls_socket_ptr>(tls_socket_ptr socket, simple_wml::document& doc, boost::asio::yield_context yield);
template<class SocketPtr> void coro_send_file_userspace(SocketPtr socket, const std::string& filename, boost::asio::yield_context yield)
{
std::size_t filesize { std::size_t(filesystem::file_size(filename)) };
union DataSize
{
uint32_t size;
char buf[4];
} data_size {};
data_size.size = htonl(filesize);
async_write(*socket, boost::asio::buffer(data_size.buf), yield);
auto ifs { filesystem::istream_file(filename) };
ifs->seekg(0);
while(ifs->good()) {
char buf[16384];
ifs->read(buf, sizeof(buf));
async_write(*socket, boost::asio::buffer(buf, ifs->gcount()), yield);
}
}
#ifdef HAVE_SENDFILE
void server_base::coro_send_file(tls_socket_ptr socket, const std::string& filename, boost::asio::yield_context yield)
{
// We fallback to userspace if using TLS socket because sendfile is not aware of TLS state
// TODO: keep in mind possibility of using KTLS instead. This seem to be available only in openssl3 branch for now
coro_send_file_userspace(socket, filename, yield);
}
void server_base::coro_send_file(socket_ptr socket, const std::string& filename, boost::asio::yield_context yield)
{
std::size_t filesize { std::size_t(filesystem::file_size(filename)) };
@ -321,9 +388,13 @@ void server_base::coro_send_file(socket_ptr socket, const std::string& filename,
#elif defined(_WIN32)
void server_base::coro_send_file(tls_socket_ptr socket, const std::string& filename, boost::asio::yield_context yield)
{
coro_send_file_userspace(socket, filename, yield);
}
void server_base::coro_send_file(socket_ptr socket, const std::string& filename, boost::asio::yield_context yield)
{
OVERLAPPED overlap;
std::vector<boost::asio::const_buffer> buffers;
@ -382,15 +453,19 @@ void server_base::coro_send_file(socket_ptr socket, const std::string& filename,
#else
void server_base::coro_send_file(tls_socket_ptr socket, const std::string& filename, boost::asio::yield_context yield)
{
coro_send_file_userspace(socket, filename, yield);
}
void server_base::coro_send_file(socket_ptr socket, const std::string& filename, boost::asio::yield_context yield)
{
// TODO: Implement this for systems without sendfile()
assert(false && "Not implemented yet");
coro_send_file_userspace(socket, filename, yield);
}
#endif
std::unique_ptr<simple_wml::document> server_base::coro_receive_doc(socket_ptr socket, boost::asio::yield_context yield)
template<class SocketPtr> std::unique_ptr<simple_wml::document> server_base::coro_receive_doc(SocketPtr socket, boost::asio::yield_context yield)
{
union DataSize
{
@ -403,13 +478,13 @@ std::unique_ptr<simple_wml::document> server_base::coro_receive_doc(socket_ptr s
if(size == 0) {
ERR_SERVER <<
client_address(socket) <<
log_address(socket) <<
"\treceived invalid packet with payload size 0" << std::endl;
return {};
}
if(size > simple_wml::document::document_size_limit) {
ERR_SERVER <<
client_address(socket) <<
log_address(socket) <<
"\treceived packet with payload size over size limit" << std::endl;
return {};
}
@ -422,18 +497,20 @@ std::unique_ptr<simple_wml::document> server_base::coro_receive_doc(socket_ptr s
return std::make_unique<simple_wml::document>(compressed_buf);
} catch (simple_wml::error& e) {
ERR_SERVER <<
client_address(socket) <<
log_address(socket) <<
"\tsimple_wml error in received data: " << e.message << std::endl;
async_send_error(socket, "Invalid WML received: " + e.message);
return {};
}
}
template std::unique_ptr<simple_wml::document> server_base::coro_receive_doc<socket_ptr>(socket_ptr socket, boost::asio::yield_context yield);
template std::unique_ptr<simple_wml::document> server_base::coro_receive_doc<tls_socket_ptr>(tls_socket_ptr socket, boost::asio::yield_context yield);
void server_base::async_send_doc_queued(socket_ptr socket, simple_wml::document& doc)
template<class SocketPtr> void server_base::async_send_doc_queued(SocketPtr socket, simple_wml::document& doc)
{
boost::asio::spawn(
io_service_, [this, doc_ptr = doc.clone(), socket](boost::asio::yield_context yield) mutable {
static std::map<socket_ptr, std::queue<std::unique_ptr<simple_wml::document>>> queues;
static std::map<SocketPtr, std::queue<std::unique_ptr<simple_wml::document>>> queues;
queues[socket].push(std::move(doc_ptr));
if(queues[socket].size() > 1) {
@ -449,7 +526,7 @@ void server_base::async_send_doc_queued(socket_ptr socket, simple_wml::document&
);
}
void server_base::async_send_error(socket_ptr socket, const std::string& msg, const char* error_code, const info_table& info)
template<class SocketPtr> void server_base::async_send_error(SocketPtr socket, const std::string& msg, const char* error_code, const info_table& info)
{
simple_wml::document doc;
doc.root().add_child("error").set_attr_dup("message", msg.c_str());
@ -460,8 +537,10 @@ void server_base::async_send_error(socket_ptr socket, const std::string& msg, co
async_send_doc_queued(socket, doc);
}
template void server_base::async_send_error<socket_ptr>(socket_ptr socket, const std::string& msg, const char* error_code, const info_table& info);
template void server_base::async_send_error<tls_socket_ptr>(tls_socket_ptr socket, const std::string& msg, const char* error_code, const info_table& info);
void server_base::async_send_warning(socket_ptr socket, const std::string& msg, const char* warning_code, const info_table& info)
template<class SocketPtr> void server_base::async_send_warning(SocketPtr socket, const std::string& msg, const char* warning_code, const info_table& info)
{
simple_wml::document doc;
doc.root().add_child("warning").set_attr_dup("message", msg.c_str());
@ -472,6 +551,25 @@ void server_base::async_send_warning(socket_ptr socket, const std::string& msg,
async_send_doc_queued(socket, doc);
}
template void server_base::async_send_warning<socket_ptr>(socket_ptr socket, const std::string& msg, const char* warning_code, const info_table& info);
template void server_base::async_send_warning<tls_socket_ptr>(tls_socket_ptr socket, const std::string& msg, const char* warning_code, const info_table& info);
void server_base::load_tls_config(const config& cfg)
{
tls_enabled_ = cfg["tls_enabled"].to_bool(false);
if(!tls_enabled_) return;
tls_context_.set_options(
boost::asio::ssl::context::default_workarounds
| boost::asio::ssl::context::no_sslv2
| boost::asio::ssl::context::no_sslv3
| boost::asio::ssl::context::single_dh_use
);
tls_context_.use_certificate_chain_file(cfg["tls_fullchain"].str());
tls_context_.use_private_key_file(cfg["tls_private_key"].str(), boost::asio::ssl::context::pem);
if(!cfg["tls_dh"].str().empty()) tls_context_.use_tmp_dh_file(cfg["tls_dh"].str());
}
// This is just here to get it to build without the deprecation_message function
#include "game_version.hpp"

View file

@ -22,6 +22,9 @@
#include "exceptions.hpp"
#include "server/common/simple_wml.hpp"
#include "utils/variant.hpp"
#include "utils/general.hpp"
#ifdef _WIN32
#include "serialization/unicode_cast.hpp"
#endif
@ -33,6 +36,7 @@
#endif
#include <boost/asio/signal_set.hpp>
#include <boost/asio/streambuf.hpp>
#include <boost/asio/ssl.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/shared_array.hpp>
@ -40,7 +44,11 @@
extern bool dump_wml;
class config;
typedef std::shared_ptr<boost::asio::ip::tcp::socket> socket_ptr;
typedef std::shared_ptr<boost::asio::ssl::stream<socket_ptr::element_type>> tls_socket_ptr;
typedef utils::variant<socket_ptr, tls_socket_ptr> any_socket_ptr;
struct server_shutdown : public game::error
{
@ -60,7 +68,7 @@ public:
* @param doc
* @param yield The function will suspend on write operation using this yield context
*/
void coro_send_doc(socket_ptr socket, simple_wml::document& doc, boost::asio::yield_context yield);
template<class SocketPtr> void coro_send_doc(SocketPtr socket, simple_wml::document& doc, boost::asio::yield_context yield);
/**
* Send contents of entire file directly to socket from within a coroutine
* @param socket
@ -68,12 +76,13 @@ public:
* @param yield The function will suspend on write operations using this yield context
*/
void coro_send_file(socket_ptr socket, const std::string& filename, boost::asio::yield_context yield);
void coro_send_file(tls_socket_ptr socket, const std::string& filename, boost::asio::yield_context yield);
/**
* Receive WML document from a coroutine
* @param socket
* @param yield The function will suspend on read operation using this yield context
*/
std::unique_ptr<simple_wml::document> coro_receive_doc(socket_ptr socket, boost::asio::yield_context yield);
template<class SocketPtr> std::unique_ptr<simple_wml::document> coro_receive_doc(SocketPtr socket, boost::asio::yield_context yield);
/**
* High level wrapper for sending a WML document
@ -83,18 +92,23 @@ public:
* @param socket
* @param doc Document to send. A copy of it will be made so there is no need to keep the reference live after the function returns.
*/
void async_send_doc_queued(socket_ptr socket, simple_wml::document& doc);
template<class SocketPtr> void async_send_doc_queued(SocketPtr socket, simple_wml::document& doc);
typedef std::map<std::string, std::string> info_table;
void async_send_error(socket_ptr socket, const std::string& msg, const char* error_code = "", const info_table& info = {});
void async_send_warning(socket_ptr socket, const std::string& msg, const char* warning_code = "", const info_table& info = {});
template<class SocketPtr> void async_send_error(SocketPtr socket, const std::string& msg, const char* error_code = "", const info_table& info = {});
template<class SocketPtr> void async_send_warning(SocketPtr socket, const std::string& msg, const char* warning_code = "", const info_table& info = {});
protected:
unsigned short port_;
bool keep_alive_;
boost::asio::io_service io_service_;
boost::asio::ssl::context tls_context_ { boost::asio::ssl::context::sslv23 };
bool tls_enabled_ { false };
boost::asio::ip::tcp::acceptor acceptor_v6_;
boost::asio::ip::tcp::acceptor acceptor_v4_;
void load_tls_config(const config& cfg);
void start_server();
void serve(boost::asio::yield_context yield, boost::asio::ip::tcp::acceptor& acceptor, boost::asio::ip::tcp::endpoint endpoint);
@ -104,6 +118,7 @@ protected:
} handshake_response_;
virtual void handle_new_client(socket_ptr socket) = 0;
virtual void handle_new_client(tls_socket_ptr socket) = 0;
virtual bool accepting_connections() const { return true; }
virtual std::string is_ip_banned(const std::string&) { return std::string(); }
@ -123,5 +138,6 @@ protected:
void handle_termination(const boost::system::error_code& error, int signal_number);
};
std::string client_address(socket_ptr socket);
bool check_error(const boost::system::error_code& error, socket_ptr socket);
template<class SocketPtr> std::string client_address(SocketPtr socket);
template<class SocketPtr> std::string log_address(SocketPtr socket) { return (utils::decayed_is_same<tls_socket_ptr, decltype(socket)> ? "+" : "") + client_address(socket); }
template<class SocketPtr> bool check_error(const boost::system::error_code& error, SocketPtr socket);

View file

@ -17,13 +17,19 @@
class config;
#include "exceptions.hpp"
#include "server/common/server_base.hpp"
#include <ctime>
#include <string>
#include <boost/asio/io_service.hpp>
#include "server/wesnothd/player_connection.hpp"
namespace wesnothd
{
class server;
}
/**
* An interface class to handle nick registration
* To activate it put a [user_handler] section into the
@ -135,7 +141,7 @@ public:
virtual std::string get_uuid() = 0;
virtual std::string get_tournaments() = 0;
virtual void async_get_and_send_game_history(boost::asio::io_service& io_service, server_base& s_base, socket_ptr player_socket, int player_id, int offset) =0;
virtual void async_get_and_send_game_history(boost::asio::io_service& io_service, wesnothd::server& s, wesnothd::player_iterator player, int player_id, int offset) =0;
virtual void db_insert_game_info(const std::string& uuid, int game_id, const std::string& version, const std::string& name, int reload, int observers, int is_public, int has_password) = 0;
virtual void db_update_game_end(const std::string& uuid, int game_id, const std::string& replay_location) = 0;
virtual void db_insert_game_player_info(const std::string& uuid, int game_id, const std::string& username, int side_number, int is_host, const std::string& faction, const std::string& version, const std::string& source, const std::string& current_user) = 0;

View file

@ -350,7 +350,7 @@ bool game::send_taken_side(simple_wml::document& cfg, const simple_wml::node* si
cfg.root().set_attr_dup("side", (*side)["side"]);
// Tell the host which side the new player should take.
server.async_send_doc_queued(owner_->socket(), cfg);
server.send_to_player(owner_, cfg);
return true;
}
@ -602,7 +602,7 @@ void game::change_controller(
// side_drop already.)
if(!player_left) {
response->root().child("change_controller")->set_attr("is_local", "yes");
server.async_send_doc_queued(player->socket(), *response.get());
server.send_to_player(player, *response.get());
}
}
@ -630,7 +630,7 @@ void game::notify_new_host()
cfg.root().add_child("host_transfer");
std::string message = owner_name + " has been chosen as the new host.";
server.async_send_doc_queued(owner_->socket(), cfg);
server.send_to_player(owner_, cfg);
send_and_record_server_message(message);
}
@ -772,7 +772,7 @@ void game::unmute_observer(const simple_wml::node& unmute, player_iterator unmut
void game::send_leave_game(player_iterator user) const
{
static simple_wml::document leave_game("[leave_game]\n[/leave_game]\n", simple_wml::INIT_COMPRESSED);
server.async_send_doc_queued(user->socket(), leave_game);
server.send_to_player(user, leave_game);
}
std::optional<player_iterator> game::kick_member(const simple_wml::node& kick, player_iterator kicker)
@ -986,7 +986,7 @@ bool game::process_turn(simple_wml::document& data, player_iterator user)
msg << "Removing illegal command '" << (*command).first_child().to_string() << "' from: " << username(user)
<< ". Current player is: " << username(*current_player()) << " (" << current_side_index_ + 1 << "/" << nsides_
<< ").";
LOG_GAME << msg.str() << " (socket: " << (*current_player())->socket() << ") (game id: " << id_ << ", " << db_id_ << ")\n";
LOG_GAME << msg.str() << " (game id: " << id_ << ", " << db_id_ << ")\n";
send_and_record_server_message(msg.str());
marked.push_back(index - marked.size());
@ -1187,7 +1187,7 @@ void game::handle_controller_choice(const simple_wml::node& req)
command.set_attr("dependent", "yes");
if(sides_[side_index]) {
server.async_send_doc_queued((*sides_[side_index])->socket(), *mdata);
server.send_to_player((*sides_[side_index]), *mdata);
}
change_controller_wml.set_attr("is_local", "no");
@ -1226,7 +1226,7 @@ void game::handle_choice(const simple_wml::node& data, player_iterator user)
}
DBG_GAME << "answering seed request " << request_id << " by player "
<< user->info().name() << "(" << user->socket() << ")" << std::endl;
<< user->info().name() << std::endl;
last_choice_request_id_ = request_id;
if(const simple_wml::node* rand = data.child("random_seed")) {
@ -1351,7 +1351,7 @@ void game::update_turn_data()
bool game::add_player(player_iterator player, bool observer)
{
if(is_member(player)) {
ERR_GAME << "ERROR: Player is already in this game. (socket: " << player->socket() << ")\n";
ERR_GAME << "ERROR: Player is already in this game.\n";
return false;
}
@ -1393,20 +1393,19 @@ bool game::add_player(player_iterator player, bool observer)
LOG_GAME
<< player->client_ip() << "\t" << user->info().name() << "\tjoined game:\t\""
<< name_ << "\" (" << id_ << ", " << db_id_ << ")" << (observer ? " as an observer" : "") << ". (socket: " << player->socket()
<< ")\n";
<< name_ << "\" (" << id_ << ", " << db_id_ << ")" << (observer ? " as an observer" : "") << ".\n";
user->info().mark_available(id_, name_);
user->info().set_status((observer) ? player::OBSERVING : player::PLAYING);
DBG_GAME << debug_player_info();
// Send the user the game data.
server.async_send_doc_queued(player->socket(), level_);
server.send_to_player(player, level_);
if(started_) {
// Tell this player that the game has started
static simple_wml::document start_game_doc("[start_game]\n[/start_game]\n", simple_wml::INIT_COMPRESSED);
server.async_send_doc_queued(player->socket(), start_game_doc);
server.send_to_player(player, start_game_doc);
// Send observer join of all the observers in the game to the new player
// only once the game started. The client forgets about it anyway otherwise.
@ -1435,7 +1434,7 @@ bool game::add_player(player_iterator player, bool observer)
bool game::remove_player(player_iterator player, const bool disconnect, const bool destruct)
{
if(!is_member(player)) {
ERR_GAME << "ERROR: User is not in this game. (socket: " << player->socket() << ")\n";
ERR_GAME << "ERROR: User is not in this game.\n";
return false;
}
@ -1462,8 +1461,7 @@ bool game::remove_player(player_iterator player, const bool disconnect, const bo
? " at turn: " + lexical_cast_default<std::string, std::size_t>(current_turn())
+ " with reason: '" + termination_reason() + "'"
: "")
<< (observer ? " as an observer" : "") << (disconnect ? " and disconnected" : "") << ". (socket: " << user->socket()
<< ")\n";
<< (observer ? " as an observer" : "") << (disconnect ? " and disconnected" : "") << ".\n";
if(game_ended && started_ && !(observer && destruct)) {
send_server_message_to_all(user->info().name() + " ended the game.", player);
@ -1529,7 +1527,7 @@ bool game::remove_player(player_iterator player, const bool disconnect, const bo
DBG_GAME << "*** sending side drop: \n" << drop.output() << std::endl;
server.async_send_doc_queued(owner_->socket(), drop);
server.send_to_player(owner_, drop);
}
if(ai_transfer) {
@ -1617,8 +1615,8 @@ void game::load_next_scenario(player_iterator user)
cfg_controller.set_attr("is_local", side_user == user ? "yes" : "no");
}
server.async_send_doc_queued(user->socket(), cfg_scenario);
server.async_send_doc_queued(user->socket(), doc_controllers);
server.send_to_player(user, cfg_scenario);
server.send_to_player(user, doc_controllers);
players_not_advanced_.erase(&*user);
@ -1634,7 +1632,7 @@ void game::send_to_players(simple_wml::document& data, const Container& players,
{
for(const auto& player : players) {
if(player != exclude) {
server.async_send_doc_queued(player->socket(), data);
server.send_to_player(player, data);
}
}
}
@ -1703,7 +1701,7 @@ void game::send_observerjoins(std::optional<player_iterator> player)
send_data(cfg, ob);
} else {
// Send to the (new) user.
server.async_send_doc_queued((*player)->socket(), cfg);
server.send_to_player(*player, cfg);
}
}
}
@ -1738,7 +1736,7 @@ void game::send_history(player_iterator player) const
auto doc = std::make_unique<simple_wml::document>(buf.c_str(), simple_wml::INIT_STATIC);
doc->compress();
server.async_send_doc_queued(player->socket(), *doc);
server.send_to_player(player, *doc);
history_.clear();
history_.push_back(std::move(doc));
@ -1911,8 +1909,7 @@ std::string game::debug_sides_info() const
<< "side " << (*s)["side"].to_int()
<< " :\t" << (*s)["controller"].to_string()
<< "\t, " << side_controllers_[(*s)["side"].to_int() - 1].to_cstring()
<< "\t( " << (*sides_[(*s)["side"].to_int() - 1])->socket()
<< ",\t" << (*s)["current_player"].to_string() << " )\n";
<< "\t( " << (*s)["current_player"].to_string() << " )\n";
}
return result.str();
@ -1971,7 +1968,7 @@ void game::send_server_message(const char* message, std::optional<player_iterato
}
if(player) {
server.async_send_doc_queued((*player)->socket(), doc);
server.send_to_player(*player, doc);
}
}

View file

@ -31,7 +31,8 @@ class game;
class player_record
{
public:
player_record(const socket_ptr socket, const player& player)
template<class SocketPtr>
player_record(const SocketPtr socket, const player& player)
: socket_(socket)
, player_(player)
, game_()
@ -39,7 +40,7 @@ public:
{
}
const socket_ptr socket() const
const any_socket_ptr socket() const
{
return socket_;
}
@ -70,7 +71,7 @@ public:
void enter_lobby();
private:
const socket_ptr socket_;
const any_socket_ptr socket_;
mutable player player_;
std::shared_ptr<game> game_;
std::string ip_address;
@ -84,7 +85,7 @@ namespace bmi = boost::multi_index;
using player_connections = bmi::multi_index_container<player_record, bmi::indexed_by<
bmi::ordered_unique<bmi::tag<socket_t>,
bmi::const_mem_fun<player_record, const socket_ptr, &player_record::socket>>,
bmi::const_mem_fun<player_record, const any_socket_ptr, &player_record::socket>>,
bmi::hashed_unique<bmi::tag<name_t>,
bmi::const_mem_fun<player_record, const std::string&, &player_record::name>>,
bmi::ordered_non_unique<bmi::tag<game_t>,

View file

@ -29,6 +29,7 @@
#include "serialization/string_utils.hpp"
#include "serialization/unicode.hpp"
#include <functional>
#include "utils/general.hpp"
#include "utils/iterable_pair.hpp"
#include "game_version.hpp"
@ -536,6 +537,8 @@ void server::load_config()
tournaments_ = user_handler_->get_tournaments();
}
#endif
load_tls_config(cfg_);
}
bool server::ip_exceeds_connection_limit(const std::string& ip) const
@ -546,7 +549,7 @@ bool server::ip_exceeds_connection_limit(const std::string& ip) const
std::size_t connections = 0;
for(const auto& player : player_connections_) {
if(client_address(player.socket()) == ip) {
if(player.client_ip() == ip) {
++connections;
}
}
@ -614,7 +617,13 @@ void server::handle_new_client(socket_ptr socket)
boost::asio::spawn(io_service_, [socket, this](boost::asio::yield_context yield) { login_client(yield, socket); });
}
void server::login_client(boost::asio::yield_context yield, socket_ptr socket)
void server::handle_new_client(tls_socket_ptr socket)
{
boost::asio::spawn(io_service_, [socket, this](boost::asio::yield_context yield) { login_client(yield, socket); });
}
template<class SocketPtr>
void server::login_client(boost::asio::yield_context yield, SocketPtr socket)
{
boost::system::error_code ec;
@ -637,7 +646,7 @@ void server::login_client(boost::asio::yield_context yield, socket_ptr socket)
std::bind(&utils::wildcard_string_match, client_version, std::placeholders::_1));
if(accepted_it != accepted_versions_.end()) {
LOG_SERVER << client_address(socket) << "\tplayer joined using accepted version " << client_version
LOG_SERVER << log_address(socket) << "\tplayer joined using accepted version " << client_version
<< ":\ttelling them to log in.\n";
coro_send_doc(socket, login_response_, yield[ec]);
if(check_error(ec, socket)) return;
@ -647,7 +656,7 @@ void server::login_client(boost::asio::yield_context yield, socket_ptr socket)
// Check if it is a redirected version
for(const auto& redirect_version : redirected_versions_) {
if(utils::wildcard_string_match(client_version, redirect_version.first)) {
LOG_SERVER << client_address(socket) << "\tplayer joined using version " << client_version
LOG_SERVER << log_address(socket) << "\tplayer joined using version " << client_version
<< ":\tredirecting them to " << redirect_version.second["host"] << ":"
<< redirect_version.second["port"] << "\n";
@ -661,7 +670,7 @@ void server::login_client(boost::asio::yield_context yield, socket_ptr socket)
}
}
LOG_SERVER << client_address(socket) << "\tplayer joined using unknown version " << client_version
LOG_SERVER << log_address(socket) << "\tplayer joined using unknown version " << client_version
<< ":\trejecting them\n";
// For compatibility with older clients
@ -673,7 +682,7 @@ void server::login_client(boost::asio::yield_context yield, socket_ptr socket)
return;
}
} else {
LOG_SERVER << client_address(socket) << "\tclient didn't send its version: rejecting\n";
LOG_SERVER << log_address(socket) << "\tclient didn't send its version: rejecting\n";
return;
}
@ -718,7 +727,7 @@ void server::login_client(boost::asio::yield_context yield, socket_ptr socket)
[this, socket, new_player](boost::asio::yield_context yield) { handle_player(yield, socket, new_player); }
);
LOG_SERVER << client_address(socket) << "\t" << username << "\thas logged on"
LOG_SERVER << log_address(socket) << "\t" << username << "\thas logged on"
<< (registered ? " to a registered account" : "") << "\n";
std::shared_ptr<game> last_sent;
@ -746,7 +755,7 @@ void server::login_client(boost::asio::yield_context yield, socket_ptr socket)
}
}
bool server::is_login_allowed(socket_ptr socket, const simple_wml::node* const login, const std::string& username, bool& registered, bool& is_moderator)
template<class SocketPtr> bool server::is_login_allowed(SocketPtr socket, const simple_wml::node* const login, const std::string& username, bool& registered, bool& is_moderator)
{
// Check if the username is valid (all alpha-numeric plus underscore and hyphen)
if(!utils::isvalid_username(username)) {
@ -833,7 +842,7 @@ bool server::is_login_allowed(socket_ptr socket, const simple_wml::node* const l
ban_reason += " (" + ban_duration + ")";
if(!is_moderator) {
LOG_SERVER << client_address(socket) << "\t" << username << "\tis banned by user_handler (" << ban_type_desc
LOG_SERVER << log_address(socket) << "\t" << username << "\tis banned by user_handler (" << ban_type_desc
<< ")\n";
if(auth_ban.duration) {
// Temporary ban
@ -844,7 +853,7 @@ bool server::is_login_allowed(socket_ptr socket, const simple_wml::node* const l
}
return false;
} else {
LOG_SERVER << client_address(socket) << "\t" << username << "\tis banned by user_handler (" << ban_type_desc
LOG_SERVER << log_address(socket) << "\t" << username << "\tis banned by user_handler (" << ban_type_desc
<< "), " << "ignoring due to moderator flag\n";
}
}
@ -866,8 +875,8 @@ bool server::is_login_allowed(socket_ptr socket, const simple_wml::node* const l
return true;
}
bool server::authenticate(
socket_ptr socket, const std::string& username, const std::string& password, bool name_taken, bool& registered)
template<class SocketPtr> bool server::authenticate(
SocketPtr socket, const std::string& username, const std::string& password, bool name_taken, bool& registered)
{
// Current login procedure for registered nicks is:
// - Client asks to log in with a particular nick
@ -956,7 +965,7 @@ bool server::authenticate(
}
// Log the failure
LOG_SERVER << client_address(socket) << "\t"
LOG_SERVER << log_address(socket) << "\t"
<< "Login attempt with incorrect password for nickname '" << username << "'.\n";
return false;
}
@ -973,7 +982,7 @@ bool server::authenticate(
return true;
}
void server::send_password_request(socket_ptr socket,
template<class SocketPtr> void server::send_password_request(SocketPtr socket,
const std::string& msg,
const std::string& user,
const char* error_code,
@ -1014,7 +1023,7 @@ void server::send_password_request(socket_ptr socket,
async_send_doc_queued(socket, doc);
}
void server::handle_player(boost::asio::yield_context yield, socket_ptr socket, const player& player_data)
template<class SocketPtr> void server::handle_player(boost::asio::yield_context yield, SocketPtr socket, const player& player_data)
{
if(lan_server_)
abort_lan_server_timer();
@ -1106,7 +1115,7 @@ void server::handle_whisper(player_iterator player, simple_wml::node& whisper)
simple_wml::INIT_COMPRESSED
);
async_send_doc_queued(player->socket(), data);
send_to_player(player, data);
return;
}
@ -1132,7 +1141,7 @@ void server::handle_whisper(player_iterator player, simple_wml::node& whisper)
const simple_wml::string_span& msg = trunc_whisper["message"];
chat_message::truncate_message(msg, trunc_whisper);
async_send_doc_queued(receiver_iter->socket(), cwhisper);
send_to_player(player_connections_.project<0>(receiver_iter), cwhisper);
}
void server::handle_query(player_iterator iter, simple_wml::node& query)
@ -1271,13 +1280,13 @@ void server::handle_create_game(player_iterator player, simple_wml::node& create
{
if(graceful_restart) {
static simple_wml::document leave_game_doc("[leave_game]\n[/leave_game]\n", simple_wml::INIT_COMPRESSED);
async_send_doc_queued(player->socket(), leave_game_doc);
send_to_player(player, leave_game_doc);
send_server_message(player,
"This server is shutting down. You aren't allowed to make new games. Please "
"reconnect to the new server.", "error");
async_send_doc_queued(player->socket(), games_and_users_list_);
send_to_player(player, games_and_users_list_);
return;
}
@ -1360,16 +1369,16 @@ void server::handle_join_game(player_iterator player, simple_wml::node& join)
if(!g) {
WRN_SERVER << player->client_ip() << "\t" << player->info().name()
<< "\tattempted to join unknown game:\t" << game_id << ".\n";
async_send_doc_queued(player->socket(), leave_game_doc);
send_to_player(player, leave_game_doc);
send_server_message(player, "Attempt to join unknown game.", "error");
async_send_doc_queued(player->socket(), games_and_users_list_);
send_to_player(player, games_and_users_list_);
return;
} else if(!g->level_init()) {
WRN_SERVER << player->client_ip() << "\t" << player->info().name()
<< "\tattempted to join uninitialized game:\t\"" << g->name() << "\" (" << game_id << ").\n";
async_send_doc_queued(player->socket(), leave_game_doc);
send_to_player(player, leave_game_doc);
send_server_message(player, "Attempt to join an uninitialized game.", "error");
async_send_doc_queued(player->socket(), games_and_users_list_);
send_to_player(player, games_and_users_list_);
return;
} else if(player->info().is_moderator()) {
// Admins are always allowed to join.
@ -1377,16 +1386,16 @@ void server::handle_join_game(player_iterator player, simple_wml::node& join)
DBG_SERVER << player->client_ip()
<< "\tReject banned player: " << player->info().name()
<< "\tfrom game:\t\"" << g->name() << "\" (" << game_id << ").\n";
async_send_doc_queued(player->socket(), leave_game_doc);
send_to_player(player, leave_game_doc);
send_server_message(player, "You are banned from this game.", "error");
async_send_doc_queued(player->socket(), games_and_users_list_);
send_to_player(player, games_and_users_list_);
return;
} else if(!g->password_matches(password)) {
WRN_SERVER << player->client_ip() << "\t" << player->info().name()
<< "\tattempted to join game:\t\"" << g->name() << "\" (" << game_id << ") with bad password\n";
async_send_doc_queued(player->socket(), leave_game_doc);
send_to_player(player, leave_game_doc);
send_server_message(player, "Incorrect password.", "error");
async_send_doc_queued(player->socket(), games_and_users_list_);
send_to_player(player, games_and_users_list_);
return;
}
@ -1395,13 +1404,13 @@ void server::handle_join_game(player_iterator player, simple_wml::node& join)
WRN_SERVER << player->client_ip() << "\t" << player->info().name()
<< "\tattempted to observe game:\t\"" << g->name() << "\" (" << game_id
<< ") which doesn't allow observers.\n";
async_send_doc_queued(player->socket(), leave_game_doc);
send_to_player(player, leave_game_doc);
send_server_message(player,
"Attempt to observe a game that doesn't allow observers. (You probably joined the "
"game shortly after it filled up.)", "error");
async_send_doc_queued(player->socket(), games_and_users_list_);
send_to_player(player, games_and_users_list_);
return;
}
@ -1533,7 +1542,7 @@ void server::handle_player_in_game(player_iterator p, simple_wml::document& data
// Everything below should only be processed if the game is already initialized.
} else if(!g.level_init()) {
WRN_SERVER << p->client_ip() << "\tReceived unknown data from: " << player.name()
<< " (socket:" << p->socket() << ") while the scenario wasn't yet initialized.\n"
<< " while the scenario wasn't yet initialized.\n"
<< data.output();
return;
// If the host is sending the next scenario data.
@ -1688,7 +1697,7 @@ void server::handle_player_in_game(player_iterator p, simple_wml::document& data
}
// Send the player who has quit the gamelist.
async_send_doc_queued(p->socket(), games_and_users_list_);
send_to_player(p, games_and_users_list_);
}
return;
@ -1763,7 +1772,7 @@ void server::handle_player_in_game(player_iterator p, simple_wml::document& data
send_to_lobby(gamelist_diff, p);
// Send the removed user the lobby game list.
async_send_doc_queued((*user)->socket(), games_and_users_list_);
send_to_player(*user, games_and_users_list_);
}
return;
@ -1834,7 +1843,7 @@ void server::handle_player_in_game(player_iterator p, simple_wml::document& data
if(player_id != 0) {
LOG_SERVER << "Querying game history requested by player `" << player.name() << "` for player id `" << player_id << "`." << std::endl;
user_handler_->async_get_and_send_game_history(io_service_, *this, p->socket(), player_id, offset);
user_handler_->async_get_and_send_game_history(io_service_, *this, p, player_id, offset);
}
}
return;
@ -1848,12 +1857,12 @@ void server::handle_player_in_game(player_iterator p, simple_wml::document& data
return;
}
WRN_SERVER << p->client_ip() << "\tReceived unknown data from: " << player.name() << " (socket:" << p->socket()
<< ") in game: \"" << g.name() << "\" (" << g.id() << ", " << g.db_id() << ")\n"
WRN_SERVER << p->client_ip() << "\tReceived unknown data from: " << player.name()
<< " in game: \"" << g.name() << "\" (" << g.id() << ", " << g.db_id() << ")\n"
<< data.output();
}
void server::send_server_message(socket_ptr socket, const std::string& message, const std::string& type)
template<class SocketPtr> void server::send_server_message(SocketPtr socket, const std::string& message, const std::string& type)
{
simple_wml::document server_message;
simple_wml::node& msg = server_message.root().add_child("message");
@ -1866,7 +1875,13 @@ void server::send_server_message(socket_ptr socket, const std::string& message,
void server::disconnect_player(player_iterator player)
{
player->socket()->shutdown(boost::asio::ip::tcp::socket::shutdown_receive);
utils::visit([](auto&& socket) {
if constexpr (utils::decayed_is_same<tls_socket_ptr, decltype(socket)>) {
socket->shutdown();
} else {
socket->lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_receive);
}
}, player->socket());
}
void server::remove_player(player_iterator iter)
@ -1918,7 +1933,7 @@ void server::send_to_lobby(simple_wml::document& data, std::optional<player_iter
for(const auto& p : player_connections_.get<game_t>().equal_range(0)) {
auto player { player_connections_.iterator_to(p) };
if(player != exclude) {
async_send_doc_queued(player->socket(), data);
send_to_player(player, data);
}
}
}
@ -2224,7 +2239,7 @@ void server::adminmsg_handler(
for(const auto& player : player_connections_) {
if(player.info().is_moderator()) {
++n;
async_send_doc_queued(player.socket(), data);
send_to_player(player_connections_.iterator_to(player), data);
}
}
@ -2279,7 +2294,7 @@ void server::pm_handler(
continue;
}
async_send_doc_queued(player.socket(), data);
send_to_player(player_connections_.iterator_to(player), data);
*out << "Message to " << receiver << " successfully sent.";
return;
}
@ -2587,7 +2602,7 @@ void server::kickban_handler(
for(auto user : users_to_kick) {
*out << "\nKicked " << user->info().name() << " (" << user->client_ip() << ").";
async_send_error(user->socket(), "You have been banned. Reason: " + reason);
utils::visit([this,reason](auto&& socket) { async_send_error(socket, "You have been banned. Reason: " + reason); }, user->socket());
disconnect_player(user);
}
}
@ -2729,7 +2744,7 @@ void server::kick_handler(const std::string& /*issuer_name*/,
*out << "Kicked " << player->name() << " (" << player->client_ip() << "). '"
<< kick_message << "'";
async_send_error(player->socket(), kick_message);
utils::visit([this, &kick_message](auto&& socket) { async_send_error(socket, kick_message); }, player->socket());
disconnect_player(player);
}
@ -2872,7 +2887,7 @@ void server::delete_game(int gameid, const std::string& reason)
if(make_change_diff(games_and_users_list_.root(), nullptr, "user", it->info().config_address(), udiff)) {
send_to_lobby(udiff);
} else {
ERR_SERVER << "ERROR: delete_game(): Could not find user in players_. (socket: " << it->socket() << ")\n";
ERR_SERVER << "ERROR: delete_game(): Could not find user in players_.\n";
}
}
@ -2887,14 +2902,15 @@ void server::delete_game(int gameid, const std::string& reason)
static simple_wml::document leave_game_doc("[leave_game]\n[/leave_game]\n", simple_wml::INIT_COMPRESSED);
for(const auto& it : range_vctor) {
player_iterator p { player_connections_.project<0>(it) };
if(reason != "") {
simple_wml::document leave_game_doc_reason("[leave_game]\n[/leave_game]\n", simple_wml::INIT_STATIC);
leave_game_doc_reason.child("leave_game")->set_attr_dup("reason", reason.c_str());
async_send_doc_queued(it->socket(), leave_game_doc_reason);
send_to_player(p, leave_game_doc_reason);
} else {
async_send_doc_queued(it->socket(), leave_game_doc);
send_to_player(p, leave_game_doc);
}
async_send_doc_queued(it->socket(), games_and_users_list_);
send_to_player(p, games_and_users_list_);
}
}

View file

@ -38,15 +38,16 @@ public:
private:
void handle_new_client(socket_ptr socket);
void handle_new_client(tls_socket_ptr socket);
void login_client(boost::asio::yield_context yield, socket_ptr socket);
bool is_login_allowed(socket_ptr socket, const simple_wml::node* const login, const std::string& username, bool& registered, bool& is_moderator);
bool authenticate(socket_ptr socket, const std::string& username, const std::string& password, bool name_taken, bool& registered);
void send_password_request(socket_ptr socket, const std::string& msg,
template<class SocketPtr> void login_client(boost::asio::yield_context yield, SocketPtr socket);
template<class SocketPtr> bool is_login_allowed(SocketPtr socket, const simple_wml::node* const login, const std::string& username, bool& registered, bool& is_moderator);
template<class SocketPtr> bool authenticate(SocketPtr socket, const std::string& username, const std::string& password, bool name_taken, bool& registered);
template<class SocketPtr> void send_password_request(SocketPtr socket, const std::string& msg,
const std::string& user, const char* error_code = "", bool force_confirmation = false);
bool accepting_connections() const { return !graceful_restart; }
void handle_player(boost::asio::yield_context yield, socket_ptr socket, const player& player);
template<class SocketPtr> void handle_player(boost::asio::yield_context yield, SocketPtr socket, const player& player);
void handle_player_in_lobby(player_iterator player, simple_wml::document& doc);
void handle_player_in_game(player_iterator player, simple_wml::document& doc);
void handle_whisper(player_iterator player, simple_wml::node& whisper);
@ -59,11 +60,21 @@ private:
void disconnect_player(player_iterator player);
void remove_player(player_iterator player);
void send_server_message(socket_ptr socket, const std::string& message, const std::string& type);
public:
template<class SocketPtr> void send_server_message(SocketPtr socket, const std::string& message, const std::string& type);
void send_server_message(player_iterator player, const std::string& message, const std::string& type) {
send_server_message(player->socket(), message, type);
utils::visit(
[this, &message, &type](auto&& socket) { send_server_message(socket, message, type); },
player->socket()
);
}
void send_to_lobby(simple_wml::document& data, std::optional<player_iterator> exclude = {});
void send_to_player(player_iterator player, simple_wml::document& data) {
utils::visit(
[this, &data](auto&& socket) { async_send_doc_queued(socket, data); },
player->socket()
);
}
void send_server_message_to_lobby(const std::string& message, std::optional<player_iterator> exclude = {});
void send_server_message_to_all(const std::string& message, std::optional<player_iterator> exclude = {});
@ -71,6 +82,7 @@ private:
return player->get_game() != nullptr;
}
private:
wesnothd::ban_manager ban_manager_;
struct connection_log
@ -103,7 +115,7 @@ private:
std::deque<login_log> failed_logins_;
std::unique_ptr<user_handler> user_handler_;
std::map<socket_ptr::element_type*, std::string> seeds_;
std::map<void*, std::string> seeds_;
std::mt19937 die_;

View file

@ -60,7 +60,11 @@ wesnothd_connection::wesnothd_connection(const std::string& host, const std::str
: worker_thread_()
, io_context_()
, resolver_(io_context_)
, socket_(io_context_)
, tls_context_(boost::asio::ssl::context::sslv23)
, host_(host)
, service_(service)
, use_tls_(true)
, socket_(raw_socket{ new raw_socket::element_type{io_context_} })
, last_error_()
, last_error_mutex_()
, handshake_finished_()
@ -76,12 +80,16 @@ wesnothd_connection::wesnothd_connection(const std::string& host, const std::str
, bytes_read_(0)
{
MPTEST_LOG;
#if BOOST_VERSION >= 106600
resolver_.async_resolve(host, service,
#else
resolver_.async_resolve(boost::asio::ip::tcp::resolver::query(host, service),
#endif
std::bind(&wesnothd_connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2));
error_code ec;
auto result = resolver_.resolve(host, service, boost::asio::ip::resolver_query_base::numeric_host, ec);
if(!ec) { // if numeric resolve succeeds then we got raw ip address so TLS host name validation would never pass
use_tls_ = false;
boost::asio::post(io_context_, [this, ec, result](){ handle_resolve(ec, { result } ); } );
} else {
resolver_.async_resolve(host, service,
std::bind(&wesnothd_connection::handle_resolve, this, std::placeholders::_1, std::placeholders::_2));
}
// Starts the worker thread. Do this *after* the above async_resolve call or it will just exit immediately!
worker_thread_ = std::thread([this]() {
@ -107,6 +115,14 @@ wesnothd_connection::~wesnothd_connection()
{
MPTEST_LOG;
if(auto socket = utils::get_if<tls_socket>(&socket_)) {
error_code ec;
// this sends close_notify for secure connection shutdown
(*socket)->async_shutdown([](const error_code&) {} );
const char buffer[] = "";
// this write is needed to trigger immediate close instead of waiting for other side's close_notify
boost::asio::write(**socket, boost::asio::buffer(buffer, 0), ec);
}
// Stop the io_service and wait for the worker thread to terminate.
stop();
worker_thread_.join();
@ -121,7 +137,7 @@ void wesnothd_connection::handle_resolve(const error_code& ec, results_type resu
throw system_error(ec);
}
boost::asio::async_connect(socket_, results,
boost::asio::async_connect(*utils::get<raw_socket>(socket_), results,
std::bind(&wesnothd_connection::handle_connect, this, std::placeholders::_1, std::placeholders::_2));
}
@ -133,11 +149,11 @@ void wesnothd_connection::handle_connect(const boost::system::error_code& ec, en
ERR_NW << "Tried all IPs. Giving up" << std::endl;
throw system_error(ec);
} else {
#if BOOST_VERSION >= 106600
LOG_NW << "Connected to " << endpoint.address() << '\n';
#else
LOG_NW << "Connected to " << endpoint->endpoint().address() << '\n';
#endif
if(endpoint.address().is_loopback()) {
use_tls_ = false;
}
handshake();
}
}
@ -147,11 +163,11 @@ void wesnothd_connection::handshake()
{
MPTEST_LOG;
static const uint32_t handshake = 0;
static const uint32_t tls_handshake = htonl(uint32_t(1));
boost::asio::async_write(socket_, boost::asio::buffer(reinterpret_cast<const char*>(&handshake), 4),
boost::asio::async_write(*utils::get<raw_socket>(socket_), boost::asio::buffer(use_tls_ ? reinterpret_cast<const char*>(&tls_handshake) : reinterpret_cast<const char*>(&handshake), 4),
[](const error_code& ec, std::size_t) { if(ec) { throw system_error(ec); } });
boost::asio::async_read(socket_, boost::asio::buffer(&handshake_response_.binary, 4),
boost::asio::async_read(*utils::get<raw_socket>(socket_), boost::asio::buffer(&handshake_response_.binary, 4),
std::bind(&wesnothd_connection::handle_handshake, this, std::placeholders::_1));
}
@ -160,12 +176,71 @@ void wesnothd_connection::handle_handshake(const error_code& ec)
{
MPTEST_LOG;
if(ec) {
if(ec == boost::asio::error::eof && use_tls_) {
// immediate disconnect likely means old server not supporting TLS handshake code
fallback_to_unencrypted();
return;
}
LOG_NW << __func__ << " Throwing: " << ec << "\n";
throw system_error(ec);
}
handshake_finished_.set_value();
recv();
if(use_tls_) {
if(handshake_response_.num == 0xFFFFFFFFU) {
use_tls_ = false;
handle_handshake(ec);
return;
}
if(handshake_response_.num == 0x00000000) {
tls_context_.set_default_verify_paths();
raw_socket s { std::move(utils::get<raw_socket>(socket_)) };
tls_socket ts { new tls_socket::element_type{std::move(*s), tls_context_} };
socket_ = std::move(ts);
auto& socket { *utils::get<tls_socket>(socket_) };
socket.set_verify_mode(
boost::asio::ssl::verify_peer |
boost::asio::ssl::verify_fail_if_no_peer_cert
);
#if BOOST_VERSION >= 107300
socket.set_verify_callback(boost::asio::ssl::host_name_verification(host_));
#else
socket.set_verify_callback(boost::asio::ssl::rfc2818_verification(host_));
#endif
socket.async_handshake(boost::asio::ssl::stream_base::client, [this](const error_code& ec) {
if(ec) {
LOG_NW << __func__ << " Throwing: " << ec << "\n";
throw system_error(ec);
}
handshake_finished_.set_value();
recv();
});
return;
}
fallback_to_unencrypted();
} else {
handshake_finished_.set_value();
recv();
}
}
// worker thread
void wesnothd_connection::fallback_to_unencrypted()
{
assert(use_tls_ == true);
use_tls_ = false;
boost::asio::ip::tcp::endpoint endpoint { utils::get<raw_socket>(socket_)->remote_endpoint() };
utils::get<raw_socket>(socket_)->close();
utils::get<raw_socket>(socket_)->async_connect(endpoint,
std::bind(&wesnothd_connection::handle_connect, this, std::placeholders::_1, endpoint));
}
// main thread
@ -195,21 +270,13 @@ void wesnothd_connection::send_data(const configr_of& request)
{
MPTEST_LOG;
#if BOOST_VERSION >= 106600
auto buf_ptr = std::make_unique<boost::asio::streambuf>();
#else
auto buf_ptr = std::make_shared<boost::asio::streambuf>();
#endif
std::ostream os(buf_ptr.get());
write_gz(os, request);
// No idea why io_context::post doesn't like this lambda while asio::post does.
#if BOOST_VERSION >= 106600
boost::asio::post(io_context_, [this, buf_ptr = std::move(buf_ptr)]() mutable {
#else
io_context_.post([this, buf_ptr]() {
#endif
DBG_NW << "In wesnothd_connection::send_data::lambda\n";
send_queue_.push(std::move(buf_ptr));
@ -223,8 +290,9 @@ void wesnothd_connection::send_data(const configr_of& request)
void wesnothd_connection::cancel()
{
MPTEST_LOG;
if(socket_.is_open()) {
boost::system::error_code ec;
utils::visit([](auto&& socket) {
if(socket->lowest_layer().is_open()) {
boost::system::error_code ec;
#ifdef _MSC_VER
// Silence warning about boost::asio::basic_socket<Protocol>::cancel always
@ -232,15 +300,16 @@ void wesnothd_connection::cancel()
#pragma warning(push)
#pragma warning(disable:4996)
#endif
socket_.cancel(ec);
socket->lowest_layer().cancel(ec);
#ifdef _MSC_VER
#pragma warning(pop)
#endif
if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
if(ec) {
WRN_NW << "Failed to cancel network operations: " << ec.message() << std::endl;
}
}
}
}, socket_);
}
// main thread
@ -384,9 +453,11 @@ void wesnothd_connection::send()
buf.data()
};
boost::asio::async_write(socket_, bufs,
std::bind(&wesnothd_connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
utils::visit([this, &bufs](auto&& socket) {
boost::asio::async_write(*socket, bufs,
std::bind(&wesnothd_connection::is_write_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_write, this, std::placeholders::_1, std::placeholders::_2));
}, socket_);
}
// worker thread
@ -394,9 +465,11 @@ void wesnothd_connection::recv()
{
MPTEST_LOG;
boost::asio::async_read(socket_, read_buf_,
std::bind(&wesnothd_connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_read, this, std::placeholders::_1, std::placeholders::_2));
utils::visit([this](auto&& socket) {
boost::asio::async_read(*socket, read_buf_,
std::bind(&wesnothd_connection::is_read_complete, this, std::placeholders::_1, std::placeholders::_2),
std::bind(&wesnothd_connection::handle_read, this, std::placeholders::_1, std::placeholders::_2));
}, socket_);
}
// main thread

View file

@ -33,13 +33,10 @@
#include "configr_assign.hpp"
#include "wesnothd_connection_error.hpp"
#if BOOST_VERSION >= 106600
#include <boost/asio/io_context.hpp>
#else
#include <boost/asio/io_service.hpp>
#endif
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/streambuf.hpp>
#include <boost/asio/ssl.hpp>
#include <condition_variable>
#include <deque>
@ -101,6 +98,12 @@ public:
/** Waits until the server handshake is complete. */
void wait_for_handshake();
/** True if connection is currently using TLS and thus is allowed to send cleartext passwords or auth tokens */
bool using_tls() const
{
return utils::holds_alternative<tls_socket>(socket_);
}
void cancel();
void stop();
@ -138,17 +141,20 @@ public:
private:
std::thread worker_thread_;
#if BOOST_VERSION >= 106600
boost::asio::io_context io_context_;
#else
boost::asio::io_service io_context_;
#endif
typedef boost::asio::ip::tcp::resolver resolver;
resolver resolver_;
typedef boost::asio::ip::tcp::socket socket;
socket socket_;
boost::asio::ssl::context tls_context_;
std::string host_;
std::string service_;
typedef std::unique_ptr<boost::asio::ip::tcp::socket> raw_socket;
typedef std::unique_ptr<boost::asio::ssl::stream<raw_socket::element_type>> tls_socket;
typedef utils::variant<raw_socket, tls_socket> any_socket;
bool use_tls_;
any_socket socket_;
boost::system::error_code last_error_;
@ -158,13 +164,8 @@ private:
boost::asio::streambuf read_buf_;
#if BOOST_VERSION >= 106600
using results_type = resolver::results_type;
using endpoint = const boost::asio::ip::tcp::endpoint&;
#else
using results_type = resolver::iterator;
using endpoint = resolver::iterator;
#endif
void handle_resolve(const boost::system::error_code& ec, results_type results);
void handle_connect(const boost::system::error_code& ec, endpoint endpoint);
@ -174,6 +175,8 @@ private:
data_union handshake_response_;
void fallback_to_unencrypted();
std::size_t is_write_complete(const boost::system::error_code& error, std::size_t bytes_transferred);
void handle_write(const boost::system::error_code& ec, std::size_t bytes_transferred);
@ -186,11 +189,7 @@ private:
template<typename T>
using data_queue = std::queue<T, std::list<T>>;
#if BOOST_VERSION >= 106600
data_queue<std::unique_ptr<boost::asio::streambuf>> send_queue_;
#else
data_queue<std::shared_ptr<boost::asio::streambuf>> send_queue_;
#endif
data_queue<config> recv_queue_;
std::mutex recv_queue_mutex_;