diff --git a/include/bitcoin/network/net/socket.hpp b/include/bitcoin/network/net/socket.hpp index f84ebc86b..2426aef26 100644 --- a/include/bitcoin/network/net/socket.hpp +++ b/include/bitcoin/network/net/socket.hpp @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include #include @@ -291,6 +291,8 @@ class BCT_API socket const count_handler& handler) NOEXCEPT; protected: + using transport = std::variant; + socket(const logger& log, asio::io_context& service, size_t maximum_request, const config::address& address, const config::endpoint& endpoint, bool proxied, bool inbound) NOEXCEPT; @@ -304,10 +306,9 @@ class BCT_API socket std::atomic_bool stopped_{}; // These are protected by strand (see also handle_accept). - asio::socket socket_; config::address address_; config::endpoint endpoint_; - std::optional websocket_{}; + transport transport_; }; typedef std::function socket_handler; diff --git a/src/messages/http_body.cpp b/src/messages/http_body.cpp index d45e527c6..e77167d97 100644 --- a/src/messages/http_body.cpp +++ b/src/messages/http_body.cpp @@ -107,7 +107,7 @@ void body::reader::finish(boost_code& ec) NOEXCEPT void body::writer::init(boost_code& ec) NOEXCEPT { - return std::visit(overload + std::visit(overload { [&] (std::monostate&) NOEXCEPT { diff --git a/src/net/socket.cpp b/src/net/socket.cpp index 5ed11d697..2f19e2fda 100644 --- a/src/net/socket.cpp +++ b/src/net/socket.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -69,9 +70,9 @@ socket::socket(const logger& log, asio::io_context& service, maximum_(maximum_request), strand_(service.get_executor()), service_(service), - socket_(strand_), address_(address), endpoint_(endpoint), + transport_(std::in_place_type, strand_), reporter(log), tracker(log) { @@ -107,7 +108,8 @@ void socket::do_stop() NOEXCEPT BC_ASSERT(stranded()); // Release the callback closure before shutdown/close. - if (websocket()) websocket_->control_callback(); + if (std::holds_alternative(transport_)) + std::get(transport_).control_callback(); boost_code ignore{}; auto& socket = get_transport(); @@ -140,7 +142,7 @@ void socket::do_async_stop() NOEXCEPT { BC_ASSERT(stranded()); - if (!websocket()) + if (!std::holds_alternative(transport_)) { do_stop(); return; @@ -152,7 +154,8 @@ void socket::do_async_stop() NOEXCEPT // This will repost to the strand, but the iocontext is alive because this // is not initiated by session callback invoking stop(). Any subsequent // stop() call will terminate this listener by invoking socket.shutdown(). - websocket_->async_close(beast::websocket::close_code::normal, + std::get(transport_).async_close( + beast::websocket::close_code::normal, std::bind(&socket::do_stop, shared_from_this())); } @@ -221,7 +224,7 @@ void socket::do_cancel(const result_handler& handler) NOEXCEPT { // Causes connect, send, and receive calls to quit with // asio::error::operation_aborted passed to handlers. - socket_.cancel(); + get_transport().cancel(); } catch (const std::exception& LOG_ONLY(e)) { @@ -238,7 +241,8 @@ void socket::do_cancel(const result_handler& handler) NOEXCEPT void socket::accept(asio::acceptor& acceptor, result_handler&& handler) NOEXCEPT { - BC_ASSERT_MSG(!socket_.is_open(), "accept on open socket"); + BC_ASSERT_MSG(!std::get(transport_).is_open(), + "accept on open socket"); // Closure of the acceptor, not the socket, releases this handler. // The socket is not guarded during async_accept. This is required so the @@ -252,7 +256,7 @@ void socket::accept(asio::acceptor& acceptor, { // Dispatches on the acceptor's strand (which should be network). // Cannot move handler due to catch block invocation. - acceptor.async_accept(socket_, + acceptor.async_accept(std::get(transport_), std::bind(&socket::handle_accept, shared_from_this(), _1, handler)); } @@ -368,12 +372,13 @@ void socket::do_connect(const asio::endpoints& range, { BC_ASSERT(stranded()); BC_ASSERT_MSG(!websocket(), "socket is upgraded"); - BC_ASSERT_MSG(!socket_.is_open(), "connect on open socket"); + BC_ASSERT_MSG(!std::get(transport_).is_open(), + "connect on open socket"); try { // Establishes a socket connection by trying each endpoint in sequence. - boost::asio::async_connect(socket_, range, + boost::asio::async_connect(std::get(transport_), range, std::bind(&socket::handle_connect, shared_from_this(), _1, _2, handler)); } @@ -395,7 +400,7 @@ void socket::do_read(const asio::mutable_buffer& out, try { // This composed operation posts all intermediate handlers to strand. - boost::asio::async_read(socket_, out, + boost::asio::async_read(std::get(transport_), out, std::bind(&socket::handle_tcp, shared_from_this(), _1, _2, handler)); } @@ -414,7 +419,7 @@ void socket::do_write(const asio::const_buffer& in, try { // This composed operation posts all intermediate handlers to strand. - boost::asio::async_write(socket_, in, + boost::asio::async_write(std::get(transport_), in, std::bind(&socket::handle_tcp, shared_from_this(), _1, _2, handler)); } @@ -490,7 +495,9 @@ void socket::do_ws_read(std::reference_wrapper out, try { - websocket_->async_read(out.get(), + auto& socket = std::get(transport_); + + socket.async_read(out.get(), std::bind(&socket::handle_ws_read, shared_from_this(), _1, _2, handler)); } @@ -509,12 +516,14 @@ void socket::do_ws_write(const asio::const_buffer& in, bool binary, try { + auto& socket = std::get(transport_); + if (binary) - websocket_->binary(true); + socket.binary(true); else - websocket_->text(true); + socket.text(true); - websocket_->async_write(in, + socket.async_write(in, std::bind(&socket::handle_ws_write, shared_from_this(), _1, _2, handler)); } @@ -565,8 +574,10 @@ void socket::do_http_read(std::reference_wrapper buffer, // Causes http::error::header_limit on completion. parser->header_limit(limit(maximum_)); + auto& socket = std::get(transport_); + // This operation posts handler to the strand. - beast::http::async_read(socket_, buffer.get(), *parser, + beast::http::async_read(socket, buffer.get(), *parser, std::bind(&socket::handle_http_read, shared_from_this(), _1, _2, request, parser, handler)); } @@ -591,8 +602,10 @@ void socket::do_http_write( try { + auto& socket = std::get(transport_); + // This operation posts handler to the strand. - beast::http::async_write(socket_, response.get(), + beast::http::async_write(socket, response.get(), std::bind(&socket::handle_http_write, shared_from_this(), _1, _2, handler)); } @@ -613,7 +626,10 @@ void socket::handle_accept(boost_code ec, // This is running in the acceptor (not socket) execution context. // socket_ and endpoint_ are not guarded here, see comments on accept. if (!ec) - endpoint_ = { socket_.remote_endpoint(ec) }; + { + const auto& socket = std::get(transport_); + endpoint_ = { socket.remote_endpoint(ec) }; + } if (error::asio_is_canceled(ec)) { @@ -777,7 +793,8 @@ void socket::handle_ws_event(ws::frame_type kind, LOGX("WS pong [" << endpoint() << "] size: " << data.size()); break; case ws::frame_type::close: - LOGX("WS close [" << endpoint() << "] " << websocket_->reason()); + const auto& socket = std::get(transport_); + LOGX("WS close [" << endpoint() << "] " << socket.reason()); break; } } @@ -877,7 +894,7 @@ asio::io_context& socket::service() const NOEXCEPT bool socket::websocket() const NOEXCEPT { BC_ASSERT(stranded()); - return websocket_.has_value(); + return std::holds_alternative(transport_); } code socket::set_websocket(const http::request& request) NOEXCEPT @@ -887,11 +904,14 @@ code socket::set_websocket(const http::request& request) NOEXCEPT try { - websocket_.emplace(std::move(socket_)); + transport_.emplace( + std::move(std::get(transport_))); + + auto& socket = std::get(transport_); // Causes websocket::error::message_too_big on completion. - websocket_->read_message_max(maximum_); - websocket_->set_option(ws::decorator + socket.read_message_max(maximum_); + socket.set_option(ws::decorator { [](http::fields& header) NOEXCEPT { @@ -901,11 +921,11 @@ code socket::set_websocket(const http::request& request) NOEXCEPT }); // Handle ping, pong, close - must be cleared on stop. - websocket_->control_callback(std::bind(&socket::do_ws_event, + socket.control_callback(std::bind(&socket::do_ws_event, shared_from_this(), _1, _2)); - websocket_->binary(true); - websocket_->accept(request); + socket.binary(true); + socket.accept(request); return error::upgraded; } catch (const std::exception& LOG_ONLY(e)) @@ -922,7 +942,18 @@ asio::socket& socket::get_transport() NOEXCEPT { BC_ASSERT(stranded()); - return websocket() ? beast::get_lowest_layer(*websocket_) : socket_; + // Explicit returns required to prevent reference stripping. + return std::visit(overload + { + [&](asio::socket& arg) NOEXCEPT -> asio::socket& + { + return arg; + }, + [&](ws::websocket& arg) NOEXCEPT -> asio::socket& + { + return beast::get_lowest_layer(arg); + } + }, transport_); } void socket::logx(const std::string& context, diff --git a/test/net/socket.cpp b/test/net/socket.cpp index 034bd6967..95c73703f 100644 --- a/test/net/socket.cpp +++ b/test/net/socket.cpp @@ -31,9 +31,9 @@ class socket_accessor return strand_; } - const asio::socket& get_socket() const NOEXCEPT + const transport& get_transport() const NOEXCEPT { - return socket_; + return transport_; } const config::endpoint& get_endpoint() const NOEXCEPT @@ -58,9 +58,11 @@ BOOST_AUTO_TEST_CASE(socket__construct__default__closed_not_stopped_expected) threadpool pool(1); constexpr auto maximum = 42u; const auto instance = std::make_shared(log, pool.service(), maximum); + const auto& transport = instance->get_transport(); BOOST_REQUIRE(!instance->stranded()); - BOOST_REQUIRE(!instance->get_socket().is_open()); + BOOST_REQUIRE(std::holds_alternative(transport)); + BOOST_REQUIRE(!std::get(transport).is_open()); BOOST_REQUIRE(&instance->get_strand() == &instance->strand()); BOOST_REQUIRE(instance->get_endpoint() == instance->endpoint()); BOOST_REQUIRE(!instance->get_endpoint().is_address());