diff --git a/example/timeouts.cpp b/example/timeouts.cpp index 7c913d428..f12b86185 100644 --- a/example/timeouts.cpp +++ b/example/timeouts.cpp @@ -12,32 +12,22 @@ #include #include #include -#include -#include #include +#include #include #include #include #include -#include -#include #include #include #include -#include -#if defined(BOOST_ASIO_HAS_CO_AWAIT) && !defined(BOOST_ASIO_USE_TS_EXECUTOR_AS_DEFAULT) +#if defined(BOOST_ASIO_HAS_CO_AWAIT) -#include - -using namespace boost::asio::experimental::awaitable_operators; -using boost::asio::use_awaitable; using boost::mysql::error_code; -constexpr std::chrono::milliseconds TIMEOUT(8000); - void print_employee(boost::mysql::row_view employee) { std::cout << "Employee '" << employee.at(0) << " " // first_name (string) @@ -46,62 +36,11 @@ void print_employee(boost::mysql::row_view employee) } /** - * Helper functions to check whether an async operation, launched in parallel with - * a timer, was successful, resulted in an error or timed out. The timer is always the first operation. - * If the variant holds the first alternative, the timer fired before - * the async operation completed, which means a timeout. We'll be using as_tuple with use_awaitable to be able - * to use boost::mysql::throw_on_error and include server diagnostics in the thrown exceptions. - */ -template -T check_error( - std::variant>&& op_result, - const boost::mysql::diagnostics& diag = {} -) -{ - if (op_result.index() == 0) - { - throw std::runtime_error("Operation timed out"); - } - auto [ec, res] = std::get<1>(std::move(op_result)); - boost::mysql::throw_on_error(ec, diag); - return res; -} - -void check_error( - const std::variant>& op_result, - const boost::mysql::diagnostics& diag -) -{ - if (op_result.index() == 0) - { - throw std::runtime_error("Operation timed out"); - } - auto [ec] = std::get<1>(op_result); - boost::mysql::throw_on_error(ec, diag); -} - -// Using this completion token instead of plain use_awaitable prevents -// co_await from throwing exceptions. Instead, co_await will return a std::tuple -// with a non-zero code on error. We will then use boost::mysql::throw_on_error -// to throw exceptions with embedded diagnostics, if available. If you -// employ plain use_awaitable, you will get boost::system::system_error exceptions -// instead of boost::mysql::error_with_diagnostics exceptions. This is a limitation of use_awaitable. -constexpr auto tuple_awaitable = boost::asio::as_tuple(boost::asio::use_awaitable); - -/** - * We use Boost.Asio's cancellation capabilities to implement timeouts for our - * asynchronous operations. This is not something specific to Boost.MySQL, and + * We use Boost.Asio's cancel_after completion token to cancel operations + * after a certain time has elapsed. This is not something specific to Boost.MySQL, and * can be used with any other asynchronous operation that follows Asio's model. - * - * Each time we invoke an asynchronous operation, we also call timer_type::async_wait. - * We then use Asio's overload for operator || to run the timer wait and the async operation - * in parallel. Once the first of them finishes, the other operation is cancelled - * (the behavior is similar to JavaScripts's Promise.race). - * If we co_await the awaitable returned by operator ||, we get a std::variant, - * where T is the async operation's result type. If the timer wait finishes first (we have a - * timeout), the variant will hold the std::monostate at index 0; otherwise, it will have the async - * operation's result at index 1. The function check_error throws an exception in the case of - * timeout and extracts the operation's result otherwise. + * If the operation times out, it will fail with a boost::asio::error::operation_aborted + * error code. * * If any of the MySQL specific operations result in a timeout, the connection is left * in an unspecified state. You should close it and re-open it to get it working again. @@ -109,49 +48,34 @@ constexpr auto tuple_awaitable = boost::asio::as_tuple(boost::asio::use_awaitabl boost::asio::awaitable coro_main( boost::mysql::tcp_ssl_connection& conn, boost::asio::ip::tcp::resolver& resolver, - boost::asio::steady_timer& timer, const boost::mysql::handshake_params& params, const char* hostname, const char* company_id ) { - boost::mysql::diagnostics diag; + using boost::asio::cancel_after; + constexpr std::chrono::seconds timeout(8); + + // TODO: thrown exceptions don't contain diagnostics. + // Should be solved by https://github.com/boostorg/mysql/issues/329 // Resolve hostname - timer.expires_after(TIMEOUT); - auto endpoints = check_error(co_await ( - timer.async_wait(use_awaitable) || - resolver.async_resolve(hostname, boost::mysql::default_port_string, tuple_awaitable) - )); - - // Connect to server. Note that we need to reset the timer before using it again. - timer.expires_after(TIMEOUT); - auto op_result = co_await ( - timer.async_wait(use_awaitable) || - conn.async_connect(*endpoints.begin(), params, diag, tuple_awaitable) - ); - check_error(op_result, diag); + auto endpoints = co_await resolver + .async_resolve(hostname, boost::mysql::default_port_string, cancel_after(timeout)); + + // Connect to server + co_await conn.async_connect(*endpoints.begin(), params, cancel_after(timeout)); // We will be using company_id, which is untrusted user input, so we will use a prepared // statement. - auto stmt_op_result = co_await ( - timer.async_wait(use_awaitable) || - conn.async_prepare_statement( - "SELECT first_name, last_name, salary FROM employee WHERE company_id = ?", - diag, - tuple_awaitable - ) + boost::mysql::statement stmt = co_await conn.async_prepare_statement( + "SELECT first_name, last_name, salary FROM employee WHERE company_id = ?", + cancel_after(timeout) ); - boost::mysql::statement stmt = check_error(std::move(stmt_op_result), diag); // Execute the statement boost::mysql::results result; - timer.expires_after(TIMEOUT); - op_result = co_await ( - timer.async_wait(use_awaitable) || - conn.async_execute(stmt.bind(company_id), result, diag, tuple_awaitable) - ); - check_error(op_result, diag); + co_await conn.async_execute(stmt.bind(company_id), result, cancel_after(timeout)); // Print all the obtained rows for (boost::mysql::row_view employee : result.rows()) @@ -160,8 +84,7 @@ boost::asio::awaitable coro_main( } // Notify the MySQL server we want to quit, then close the underlying connection. - op_result = co_await (timer.async_wait(use_awaitable) || conn.async_close(diag, tuple_awaitable)); - check_error(op_result, diag); + co_await conn.async_close(cancel_after(timeout)); } void main_impl(int argc, char** argv) @@ -182,7 +105,6 @@ void main_impl(int argc, char** argv) boost::asio::io_context ctx; boost::asio::ssl::context ssl_ctx(boost::asio::ssl::context::tls_client); boost::mysql::tcp_ssl_connection conn(ctx, ssl_ctx); - boost::asio::steady_timer timer(ctx.get_executor()); // Connection parameters boost::mysql::handshake_params params( @@ -197,8 +119,8 @@ void main_impl(int argc, char** argv) // The entry point. We pass in a function returning a boost::asio::awaitable, as required. boost::asio::co_spawn( ctx.get_executor(), - [&conn, &resolver, &timer, params, hostname, company_id] { - return coro_main(conn, resolver, timer, params, hostname, company_id); + [&conn, &resolver, params, hostname, company_id] { + return coro_main(conn, resolver, params, hostname, company_id); }, // If any exception is thrown in the coroutine body, rethrow it. [](std::exception_ptr ptr) { diff --git a/include/boost/mysql/connection_pool.hpp b/include/boost/mysql/connection_pool.hpp index 50372ba05..0999f47aa 100644 --- a/include/boost/mysql/connection_pool.hpp +++ b/include/boost/mysql/connection_pool.hpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -265,8 +266,10 @@ class connection_pool return std::chrono::seconds(30); } - struct initiate_run + struct initiate_run : detail::initiation_base { + using detail::initiation_base::initiation_base; + template void operator()(Handler&& h, std::shared_ptr self) { @@ -280,8 +283,10 @@ class connection_pool asio::any_completion_handler handler ); - struct initiate_get_connection + struct initiate_get_connection : detail::initiation_base { + using detail::initiation_base::initiation_base; + template void operator()( Handler&& h, @@ -309,7 +314,7 @@ class connection_pool CompletionToken&& token ) -> decltype(asio::async_initiate( - initiate_get_connection{}, + std::declval(), token, diag, impl_, @@ -318,7 +323,7 @@ class connection_pool { BOOST_ASSERT(valid()); return asio::async_initiate( - initiate_get_connection{}, + initiate_get_connection{get_executor()}, token, diag, impl_, @@ -538,12 +543,19 @@ class connection_pool * `~pooled_connection` and \ref pooled_connection::return_without_reset. */ template - auto async_run(CompletionToken&& token) BOOST_MYSQL_RETURN_TYPE( - decltype(asio::async_initiate(initiate_run{}, token, impl_)) - ) + auto async_run(CompletionToken&& token) + BOOST_MYSQL_RETURN_TYPE(decltype(asio::async_initiate( + std::declval(), + token, + impl_ + ))) { BOOST_ASSERT(valid()); - return asio::async_initiate(initiate_run{}, token, impl_); + return asio::async_initiate( + initiate_run{get_executor()}, + token, + impl_ + ); } /// \copydoc async_get_connection(diagnostics&,CompletionToken&&) diff --git a/include/boost/mysql/detail/async_helpers.hpp b/include/boost/mysql/detail/async_helpers.hpp new file mode 100644 index 000000000..cb35bf3d4 --- /dev/null +++ b/include/boost/mysql/detail/async_helpers.hpp @@ -0,0 +1,33 @@ +// +// Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// + +#ifndef BOOST_MYSQL_DETAIL_ASYNC_HELPERS_HPP +#define BOOST_MYSQL_DETAIL_ASYNC_HELPERS_HPP + +#include + +namespace boost { +namespace mysql { +namespace detail { + +// Base class for initiation objects. Includes a bound executor, so they're compatible +// with asio::cancel_after and similar +struct initiation_base +{ + asio::any_io_executor ex; + + initiation_base(asio::any_io_executor ex) noexcept : ex(std::move(ex)) {} + + using executor_type = asio::any_io_executor; + const executor_type& get_executor() const noexcept { return ex; } +}; + +} // namespace detail +} // namespace mysql +} // namespace boost + +#endif diff --git a/include/boost/mysql/detail/connection_impl.hpp b/include/boost/mysql/detail/connection_impl.hpp index 6281d1474..376dcc80e 100644 --- a/include/boost/mysql/detail/connection_impl.hpp +++ b/include/boost/mysql/detail/connection_impl.hpp @@ -22,11 +22,13 @@ #include #include +#include #include #include #include #include +#include #include #include @@ -125,6 +127,8 @@ class connection_impl std::unique_ptr engine_; std::unique_ptr st_; + asio::any_io_executor get_executor() const { return engine_->get_executor(); } + // Helper for execution requests template static auto make_request(T&& input, connection_state& st) @@ -202,8 +206,10 @@ class connection_impl async_run_impl(eng, st, params, diag, std::forward(handler), has_void_result{}); } - struct run_algo_initiation + struct run_algo_initiation : initiation_base { + using initiation_base::initiation_base; + template void operator()( Handler&& handler, @@ -232,8 +238,10 @@ class connection_impl } template - struct initiate_connect + struct initiate_connect : initiation_base { + using initiation_base::initiation_base; + template void operator()( Handler&& handler, @@ -249,8 +257,10 @@ class connection_impl } }; - struct initiate_connect_v2 + struct initiate_connect_v2 : initiation_base { + using initiation_base::initiation_base; + template void operator()( Handler&& handler, @@ -266,8 +276,10 @@ class connection_impl }; // execute - struct initiate_execute + struct initiate_execute : initiation_base { + using initiation_base::initiation_base; + template void operator()( Handler&& handler, @@ -289,8 +301,10 @@ class connection_impl }; // start execution - struct initiate_start_execution + struct initiate_start_execution : initiation_base { + using initiation_base::initiation_base; + template void operator()( Handler&& handler, @@ -347,7 +361,7 @@ class connection_impl template auto async_run(AlgoParams params, diagnostics& diag, CompletionToken&& token) -> decltype(asio::async_initiate>( - run_algo_initiation(), + run_algo_initiation(get_executor()), token, &diag, engine_.get(), @@ -356,7 +370,7 @@ class connection_impl )) { return asio::async_initiate>( - run_algo_initiation(), + run_algo_initiation(get_executor()), token, &diag, engine_.get(), @@ -392,7 +406,7 @@ class connection_impl CompletionToken&& token ) -> decltype(asio::async_initiate( - initiate_connect(), + initiate_connect(get_executor()), token, &diag, engine_.get(), @@ -402,7 +416,7 @@ class connection_impl )) { return asio::async_initiate( - initiate_connect(), + initiate_connect(get_executor()), token, &diag, engine_.get(), @@ -415,7 +429,7 @@ class connection_impl template auto async_connect_v2(const connect_params& params, diagnostics& diag, CompletionToken&& token) -> decltype(asio::async_initiate( - initiate_connect_v2(), + initiate_connect_v2(get_executor()), token, &diag, engine_.get(), @@ -424,7 +438,7 @@ class connection_impl )) { return asio::async_initiate( - initiate_connect_v2(), + initiate_connect_v2(get_executor()), token, &diag, engine_.get(), @@ -461,7 +475,7 @@ class connection_impl CompletionToken&& token ) -> decltype(asio::async_initiate( - initiate_execute(), + initiate_execute(get_executor()), token, &diag, engine_.get(), @@ -471,7 +485,7 @@ class connection_impl )) { return asio::async_initiate( - initiate_execute(), + initiate_execute(get_executor()), token, &diag, engine_.get(), @@ -508,7 +522,7 @@ class connection_impl CompletionToken&& token ) -> decltype(asio::async_initiate( - initiate_start_execution(), + initiate_start_execution(get_executor()), token, &diag, engine_.get(), @@ -518,7 +532,7 @@ class connection_impl )) { return asio::async_initiate( - initiate_start_execution(), + initiate_start_execution(get_executor()), token, &diag, engine_.get(), diff --git a/test/common/include/test_common/network_result.hpp b/test/common/include/test_common/network_result.hpp index 8610eb9cc..8064c7677 100644 --- a/test/common/include/test_common/network_result.hpp +++ b/test/common/include/test_common/network_result.hpp @@ -19,10 +19,13 @@ #include #include #include +#include +#include #include #include #include +#include #include #include @@ -308,11 +311,31 @@ class async_result template static return_type initiate(Initiation&& initiation, mysql::test::as_netresult_t token, Args&&... args) { - return do_initiate(std::move(initiation), token.slot, std::move(args)...); + using types = mp11::mp_list; + using diag_pos = mp11::mp_find; + constexpr std::size_t actual_pos = diag_pos::value == sizeof...(Args) ? 0u : diag_pos::value; + return do_initiate( + std::move(initiation), + token.slot, + std::get(std::tuple{args...}), + std::forward(args)... + ); + } + + // Common case optimization: diagnostics* is first + template + static return_type initiate( + Initiation&& initiation, + mysql::test::as_netresult_t token, + mysql::diagnostics* diag, + Args&&... args + ) + { + return do_initiate_impl(std::move(initiation), token.slot, diag, diag, std::forward(args)...); } private: - // initiate() is not allowed to inspect individual arguments + // A diagnostics* was found template static return_type do_initiate( Initiation&& initiation, @@ -321,29 +344,28 @@ class async_result Args&&... args ) { - // Verify that we correctly set diagnostics in all cases - *diag = mysql::test::create_server_diag("Diagnostics not cleared properly"); - - // Create the return type - mysql::test::runnable_network_result netres; - - // Record that we're initiating - mysql::test::initiation_guard guard; - - // Actually call the initiation function - std::move(initiation)( - mysql::test::test_detail::as_netres_handler(netres.impl->netres, diag, slot), - diag, - std::move(args)... - ); + return do_initiate_impl(std::move(initiation), slot, diag, std::forward(args)...); + } - return netres; + // No diagnostics* was found + template + static return_type do_initiate(Initiation&& initiation, asio::cancellation_slot slot, T&&, Args&&... args) + { + return do_initiate_impl(std::move(initiation), slot, nullptr, std::forward(args)...); } - // For functions without diagnostics template - static return_type do_initiate(Initiation&& initiation, asio::cancellation_slot slot, Args&&... args) + static return_type do_initiate_impl( + Initiation&& initiation, + asio::cancellation_slot slot, + mysql::diagnostics* diag, + Args&&... args + ) { + // Verify that we correctly set diagnostics in all cases + if (diag) + *diag = mysql::test::create_server_diag("Diagnostics not cleared properly"); + // Create the return type mysql::test::runnable_network_result netres; @@ -352,7 +374,7 @@ class async_result // Actually call the initiation function std::move(initiation)( - mysql::test::test_detail::as_netres_handler(netres.impl->netres, nullptr, slot), + mysql::test::test_detail::as_netres_handler(netres.impl->netres, diag, slot), std::move(args)... ); diff --git a/test/integration/test/any_connection.cpp b/test/integration/test/any_connection.cpp index 0e3421675..874bd8559 100644 --- a/test/integration/test/any_connection.cpp +++ b/test/integration/test/any_connection.cpp @@ -20,6 +20,7 @@ #include +#include #include #include #include @@ -199,4 +200,48 @@ BOOST_FIXTURE_TEST_CASE(using_non_connected_connection, any_connection_fixture) conn.async_ping(as_netresult).validate_any_error(); } +// Spotcheck: we can use cancel_after and other tokens +// that require initiations to have an associated executor +BOOST_FIXTURE_TEST_CASE(cancel_after, any_connection_fixture) +{ + // The token to use + const auto token = asio::cancel_after(std::chrono::seconds(10), asio::deferred); + + // Connect + conn.async_connect(connect_params_builder().build(), token)(as_netresult).validate_no_error(); + + // Execute + results result; + conn.async_execute("SELECT 'abc'", result, token)(as_netresult).validate_no_error(); + BOOST_TEST(result.rows() == makerows(1, "abc"), per_element()); + + // Start execution + execution_state st; + conn.async_start_execution("SELECT 'abc'", st, token)(as_netresult).validate_no_error(); + auto rws = conn.async_read_some_rows(st, token)(as_netresult).get(); + BOOST_TEST(rws == makerows(1, "abc"), per_element()); + conn.async_read_resultset_head(st, token)(as_netresult).validate_no_error(); + +#ifdef BOOST_MYSQL_CXX14 + // Start execution (static, for read_some_rows) + using tup_t = std::tuple; + static_execution_state st2; + std::array storage; + conn.async_start_execution("SELECT 'abc'", st2, token)(as_netresult).validate_no_error(); + std::size_t sz = conn.async_read_some_rows(st2, boost::span(storage), token)(as_netresult).get(); + BOOST_TEST(sz == 1u); +#endif + + // Prepare & close statement + auto stmt = conn.async_prepare_statement("SELECT ?", token)(as_netresult).get(); + conn.async_close_statement(stmt, token)(as_netresult).validate_no_error(); + + // Reset connection & ping + conn.async_reset_connection(token)(as_netresult).validate_no_error(); + conn.async_ping(token)(as_netresult).validate_no_error(); + + // Close + conn.async_close(token)(as_netresult).validate_no_error(); +} + BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/test/integration/test/connection_pool.cpp b/test/integration/test/connection_pool.cpp index 0e84597f2..2f1ef5068 100644 --- a/test/integration/test/connection_pool.cpp +++ b/test/integration/test/connection_pool.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -617,6 +618,28 @@ BOOST_FIXTURE_TEST_CASE(zero_timeuts, fixture) }); } +// Spotcheck: we can use completion tokens that require +// initiations to have a bound executor, like cancel_after +BOOST_FIXTURE_TEST_CASE(cancel_after, fixture) +{ + run_stackful_coro([&](asio::yield_context yield) { + constexpr std::chrono::seconds timeout(10); + + connection_pool pool(yield.get_executor(), create_pool_params()); + pool_guard grd(&pool); + pool.async_run(asio::cancel_after(timeout, check_err)); + + // Get a connection + auto conn = pool.async_get_connection(diag, asio::cancel_after(timeout, yield[ec])); + check_success(); + conn->async_ping(yield); + + // The overload with a timeout also works + conn = pool.async_get_connection(timeout, diag, asio::cancel_after(timeout, yield[ec])); + conn->async_ping(yield); + }); +} + // Spotcheck: constructing a connection_pool with invalid params throws BOOST_AUTO_TEST_CASE(invalid_params) {