Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/bitcoin/network/net/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include <atomic>
#include <memory>
#include <optional>
#include <variant>
#include <bitcoin/network/async/async.hpp>
#include <bitcoin/network/config/config.hpp>
#include <bitcoin/network/define.hpp>
Expand Down Expand Up @@ -291,6 +291,8 @@ class BCT_API socket
const count_handler& handler) NOEXCEPT;

protected:
using transport = std::variant<asio::socket, ws::websocket>;

socket(const logger& log, asio::io_context& service,
size_t maximum_request, const config::address& address,
const config::endpoint& endpoint, bool proxied, bool inbound) NOEXCEPT;
Expand All @@ -304,10 +306,9 @@ class BCT_API socket
std::atomic_bool stopped_{};

// These are protected by strand (see also handle_accept).
asio::socket socket_;
config::address address_;
config::endpoint endpoint_;
std::optional<ws::websocket> websocket_{};
transport transport_;
};

typedef std::function<void(const code&, const socket::ptr&)> socket_handler;
Expand Down
2 changes: 1 addition & 1 deletion src/messages/http_body.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void body::reader::finish(boost_code& ec) NOEXCEPT

void body::writer::init(boost_code& ec) NOEXCEPT
{
return std::visit(overload
std::visit(overload
{
[&] (std::monostate&) NOEXCEPT
{
Expand Down
85 changes: 58 additions & 27 deletions src/net/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <memory>
#include <utility>
#include <variant>
#include <bitcoin/network/async/async.hpp>
#include <bitcoin/network/config/config.hpp>
#include <bitcoin/network/define.hpp>
Expand Down Expand Up @@ -69,9 +70,9 @@ socket::socket(const logger& log, asio::io_context& service,
maximum_(maximum_request),
strand_(service.get_executor()),
service_(service),
socket_(strand_),
address_(address),
endpoint_(endpoint),
transport_(std::in_place_type<asio::socket>, strand_),
reporter(log),
tracker<socket>(log)
{
Expand Down Expand Up @@ -107,7 +108,8 @@ void socket::do_stop() NOEXCEPT
BC_ASSERT(stranded());

// Release the callback closure before shutdown/close.
if (websocket()) websocket_->control_callback();
if (std::holds_alternative<ws::websocket>(transport_))
std::get<ws::websocket>(transport_).control_callback();

boost_code ignore{};
auto& socket = get_transport();
Expand Down Expand Up @@ -140,7 +142,7 @@ void socket::do_async_stop() NOEXCEPT
{
BC_ASSERT(stranded());

if (!websocket())
if (!std::holds_alternative<ws::websocket>(transport_))
{
do_stop();
return;
Expand All @@ -152,7 +154,8 @@ void socket::do_async_stop() NOEXCEPT
// This will repost to the strand, but the iocontext is alive because this
// is not initiated by session callback invoking stop(). Any subsequent
// stop() call will terminate this listener by invoking socket.shutdown().
websocket_->async_close(beast::websocket::close_code::normal,
std::get<ws::websocket>(transport_).async_close(
beast::websocket::close_code::normal,
std::bind(&socket::do_stop, shared_from_this()));
}

Expand Down Expand Up @@ -221,7 +224,7 @@ void socket::do_cancel(const result_handler& handler) NOEXCEPT
{
// Causes connect, send, and receive calls to quit with
// asio::error::operation_aborted passed to handlers.
socket_.cancel();
get_transport().cancel();
}
catch (const std::exception& LOG_ONLY(e))
{
Expand All @@ -238,7 +241,8 @@ void socket::do_cancel(const result_handler& handler) NOEXCEPT
void socket::accept(asio::acceptor& acceptor,
result_handler&& handler) NOEXCEPT
{
BC_ASSERT_MSG(!socket_.is_open(), "accept on open socket");
BC_ASSERT_MSG(!std::get<asio::socket>(transport_).is_open(),
"accept on open socket");

// Closure of the acceptor, not the socket, releases this handler.
// The socket is not guarded during async_accept. This is required so the
Expand All @@ -252,7 +256,7 @@ void socket::accept(asio::acceptor& acceptor,
{
// Dispatches on the acceptor's strand (which should be network).
// Cannot move handler due to catch block invocation.
acceptor.async_accept(socket_,
acceptor.async_accept(std::get<asio::socket>(transport_),
std::bind(&socket::handle_accept,
shared_from_this(), _1, handler));
}
Expand Down Expand Up @@ -368,12 +372,13 @@ void socket::do_connect(const asio::endpoints& range,
{
BC_ASSERT(stranded());
BC_ASSERT_MSG(!websocket(), "socket is upgraded");
BC_ASSERT_MSG(!socket_.is_open(), "connect on open socket");
BC_ASSERT_MSG(!std::get<asio::socket>(transport_).is_open(),
"connect on open socket");

try
{
// Establishes a socket connection by trying each endpoint in sequence.
boost::asio::async_connect(socket_, range,
boost::asio::async_connect(std::get<asio::socket>(transport_), range,
std::bind(&socket::handle_connect,
shared_from_this(), _1, _2, handler));
}
Expand All @@ -395,7 +400,7 @@ void socket::do_read(const asio::mutable_buffer& out,
try
{
// This composed operation posts all intermediate handlers to strand.
boost::asio::async_read(socket_, out,
boost::asio::async_read(std::get<asio::socket>(transport_), out,
std::bind(&socket::handle_tcp,
shared_from_this(), _1, _2, handler));
}
Expand All @@ -414,7 +419,7 @@ void socket::do_write(const asio::const_buffer& in,
try
{
// This composed operation posts all intermediate handlers to strand.
boost::asio::async_write(socket_, in,
boost::asio::async_write(std::get<asio::socket>(transport_), in,
std::bind(&socket::handle_tcp,
shared_from_this(), _1, _2, handler));
}
Expand Down Expand Up @@ -490,7 +495,9 @@ void socket::do_ws_read(std::reference_wrapper<http::flat_buffer> out,

try
{
websocket_->async_read(out.get(),
auto& socket = std::get<ws::websocket>(transport_);

socket.async_read(out.get(),
std::bind(&socket::handle_ws_read,
shared_from_this(), _1, _2, handler));
}
Expand All @@ -509,12 +516,14 @@ void socket::do_ws_write(const asio::const_buffer& in, bool binary,

try
{
auto& socket = std::get<ws::websocket>(transport_);

if (binary)
websocket_->binary(true);
socket.binary(true);
else
websocket_->text(true);
socket.text(true);

websocket_->async_write(in,
socket.async_write(in,
std::bind(&socket::handle_ws_write,
shared_from_this(), _1, _2, handler));
}
Expand Down Expand Up @@ -565,8 +574,10 @@ void socket::do_http_read(std::reference_wrapper<http::flat_buffer> buffer,
// Causes http::error::header_limit on completion.
parser->header_limit(limit<uint32_t>(maximum_));

auto& socket = std::get<asio::socket>(transport_);

// This operation posts handler to the strand.
beast::http::async_read(socket_, buffer.get(), *parser,
beast::http::async_read(socket, buffer.get(), *parser,
std::bind(&socket::handle_http_read,
shared_from_this(), _1, _2, request, parser, handler));
}
Expand All @@ -591,8 +602,10 @@ void socket::do_http_write(

try
{
auto& socket = std::get<asio::socket>(transport_);

// This operation posts handler to the strand.
beast::http::async_write(socket_, response.get(),
beast::http::async_write(socket, response.get(),
std::bind(&socket::handle_http_write,
shared_from_this(), _1, _2, handler));
}
Expand All @@ -613,7 +626,10 @@ void socket::handle_accept(boost_code ec,
// This is running in the acceptor (not socket) execution context.
// socket_ and endpoint_ are not guarded here, see comments on accept.
if (!ec)
endpoint_ = { socket_.remote_endpoint(ec) };
{
const auto& socket = std::get<asio::socket>(transport_);
endpoint_ = { socket.remote_endpoint(ec) };
}

if (error::asio_is_canceled(ec))
{
Expand Down Expand Up @@ -777,7 +793,8 @@ void socket::handle_ws_event(ws::frame_type kind,
LOGX("WS pong [" << endpoint() << "] size: " << data.size());
break;
case ws::frame_type::close:
LOGX("WS close [" << endpoint() << "] " << websocket_->reason());
const auto& socket = std::get<ws::websocket>(transport_);
LOGX("WS close [" << endpoint() << "] " << socket.reason());
break;
}
}
Expand Down Expand Up @@ -877,7 +894,7 @@ asio::io_context& socket::service() const NOEXCEPT
bool socket::websocket() const NOEXCEPT
{
BC_ASSERT(stranded());
return websocket_.has_value();
return std::holds_alternative<ws::websocket>(transport_);
}

code socket::set_websocket(const http::request& request) NOEXCEPT
Expand All @@ -887,11 +904,14 @@ code socket::set_websocket(const http::request& request) NOEXCEPT

try
{
websocket_.emplace(std::move(socket_));
transport_.emplace<ws::websocket>(
std::move(std::get<asio::socket>(transport_)));

auto& socket = std::get<ws::websocket>(transport_);

// Causes websocket::error::message_too_big on completion.
websocket_->read_message_max(maximum_);
websocket_->set_option(ws::decorator
socket.read_message_max(maximum_);
socket.set_option(ws::decorator
{
[](http::fields& header) NOEXCEPT
{
Expand All @@ -901,11 +921,11 @@ code socket::set_websocket(const http::request& request) NOEXCEPT
});

// Handle ping, pong, close - must be cleared on stop.
websocket_->control_callback(std::bind(&socket::do_ws_event,
socket.control_callback(std::bind(&socket::do_ws_event,
shared_from_this(), _1, _2));

websocket_->binary(true);
websocket_->accept(request);
socket.binary(true);
socket.accept(request);
return error::upgraded;
}
catch (const std::exception& LOG_ONLY(e))
Expand All @@ -922,7 +942,18 @@ asio::socket& socket::get_transport() NOEXCEPT
{
BC_ASSERT(stranded());

return websocket() ? beast::get_lowest_layer(*websocket_) : socket_;
// Explicit returns required to prevent reference stripping.
return std::visit(overload
{
[&](asio::socket& arg) NOEXCEPT -> asio::socket&
{
return arg;
},
[&](ws::websocket& arg) NOEXCEPT -> asio::socket&
{
return beast::get_lowest_layer(arg);
}
}, transport_);
}

void socket::logx(const std::string& context,
Expand Down
8 changes: 5 additions & 3 deletions test/net/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class socket_accessor
return strand_;
}

const asio::socket& get_socket() const NOEXCEPT
const transport& get_transport() const NOEXCEPT
{
return socket_;
return transport_;
}

const config::endpoint& get_endpoint() const NOEXCEPT
Expand All @@ -58,9 +58,11 @@ BOOST_AUTO_TEST_CASE(socket__construct__default__closed_not_stopped_expected)
threadpool pool(1);
constexpr auto maximum = 42u;
const auto instance = std::make_shared<socket_accessor>(log, pool.service(), maximum);
const auto& transport = instance->get_transport();

BOOST_REQUIRE(!instance->stranded());
BOOST_REQUIRE(!instance->get_socket().is_open());
BOOST_REQUIRE(std::holds_alternative<asio::socket>(transport));
BOOST_REQUIRE(!std::get<asio::socket>(transport).is_open());
BOOST_REQUIRE(&instance->get_strand() == &instance->strand());
BOOST_REQUIRE(instance->get_endpoint() == instance->endpoint());
BOOST_REQUIRE(!instance->get_endpoint().is_address());
Expand Down
Loading