diff --git a/src/spider/CMakeLists.txt b/src/spider/CMakeLists.txt index da1e3f23..cd1195e8 100644 --- a/src/spider/CMakeLists.txt +++ b/src/spider/CMakeLists.txt @@ -68,6 +68,8 @@ set(SPIDER_WORKER_SOURCES worker/ChildPid.cpp worker/DllLoader.hpp worker/DllLoader.cpp + worker/ExecutorHandle.hpp + worker/ExecutorHandle.cpp worker/Process.hpp worker/Process.cpp worker/TaskExecutor.hpp diff --git a/src/spider/client/Job.hpp b/src/spider/client/Job.hpp index 36df1739..a3bf4526 100644 --- a/src/spider/client/Job.hpp +++ b/src/spider/client/Job.hpp @@ -84,7 +84,20 @@ class Job { * * @throw spider::ConnectionException */ - auto cancel(); + auto cancel() -> void { + std::variant, core::StorageErr> conn_result + = m_storage_factory->provide_storage_connection(); + if (std::holds_alternative(conn_result)) { + throw ConnectionException(std::get(conn_result).description); + } + auto conn = std::move(std::get>(conn_result)); + + core::StorageErr const err + = m_metadata_storage->cancel_job_by_user(*conn, m_id, "Job cancelled by user."); + if (!err.success()) { + throw ConnectionException(err.description); + } + } /** * @return Status of the job. @@ -146,16 +159,41 @@ class Job { } /** - * NOTE: It is undefined behavior to call this method for a job that is not in the `Failed` + * NOTE: It is undefined behavior to call this method for a job that is not in the `Cancelled` * state. * * @return A pair: - * - the name of the task function that failed. - * - the error message sent from the task through `TaskContext::abort` or from Spider. + * - the name of the task function that called `TaskContext::abort`, or "user" if job is + * cancelled by calling `Job::cancel`. + * - the error message sent from the task through `TaskContext::abort` or "Job cancelled by + * user." if job is cancelled through `Job::cancel`. * @throw spider::ConnectionException */ auto get_error() -> std::pair { - throw ConnectionException{"Not implemented"}; + if (nullptr == m_conn) { + std::variant, core::StorageErr> conn_result + = m_storage_factory->provide_storage_connection(); + if (std::holds_alternative(conn_result)) { + throw ConnectionException(std::get(conn_result).description); + } + auto conn = std::move(std::get>(conn_result)); + + std::pair res; + core::StorageErr const err + = m_metadata_storage->get_error_message(*conn, m_id, &res.first, &res.second); + if (false == err.success()) { + throw ConnectionException{err.description}; + } + return res; + } + + std::pair res; + core::StorageErr const err + = m_metadata_storage->get_error_message(*m_conn, m_id, &res.first, &res.second); + if (false == err.success()) { + throw ConnectionException{err.description}; + } + return res; } private: diff --git a/src/spider/client/TaskContext.cpp b/src/spider/client/TaskContext.cpp index 64f1d788..81fa140a 100644 --- a/src/spider/client/TaskContext.cpp +++ b/src/spider/client/TaskContext.cpp @@ -1,5 +1,6 @@ #include "TaskContext.hpp" +#include #include #include #include @@ -69,4 +70,19 @@ auto TaskContext::get_jobs() -> std::vector { } return job_ids; } + +auto TaskContext::abort(std::string const& message) -> void { + std::variant, core::StorageErr> conn_result + = m_storage_factory->provide_storage_connection(); + if (std::holds_alternative(conn_result)) { + throw ConnectionException(std::get(conn_result).description); + } + auto conn = std::move(std::get>(conn_result)); + + core::StorageErr const err = m_metadata_store->cancel_job_by_task(*conn, m_task_id, message); + if (!err.success()) { + throw ConnectionException(err.description); + } + std::quick_exit(1); +} } // namespace spider diff --git a/src/spider/storage/MetadataStorage.hpp b/src/spider/storage/MetadataStorage.hpp index 39a3e437..4fa33fa0 100644 --- a/src/spider/storage/MetadataStorage.hpp +++ b/src/spider/storage/MetadataStorage.hpp @@ -75,6 +75,47 @@ class MetadataStorage { std::vector* job_ids ) -> StorageErr = 0; + /** + * Cancel a job by user. Set the job state to CANCEL and set all tasks that have not finished + * or started to CANCEL. Set the error message of the job and offender to "user". + * @param conn + * @param id The job id. + * @param message The error message of the cancellation. + * @return The error code. + */ + virtual auto + cancel_job_by_user(StorageConnection& conn, boost::uuids::uuid id, std::string const& message) + -> StorageErr + = 0; + /** + * Cancel the job from the task. Set the job state to CANCEL and set all tasks that have not + * finished or started to CANCEL. Se the error message of the job and offender to the function + * name of the task. + * @param conn + * @param id The task id. + * @param message The error message of the cancellation. + * @return The error code. + */ + virtual auto + cancel_job_by_task(StorageConnection& conn, boost::uuids::uuid id, std::string const& message) + -> StorageErr + = 0; + /** + * Get the error message of a cancelled job. + * @param conn + * @param id The job id. + * @param offender The function name of the cancelling task if job cancelled by task, "user" if + * the job is cancelled by user. + * @param message The error message of the cancellation. + * @return The error code. + */ + virtual auto get_error_message( + StorageConnection& conn, + boost::uuids::uuid id, + std::string* offender, + std::string* message + ) -> StorageErr + = 0; virtual auto remove_job(StorageConnection& conn, boost::uuids::uuid id) noexcept -> StorageErr = 0; virtual auto reset_job(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr = 0; @@ -96,6 +137,9 @@ class MetadataStorage { virtual auto set_task_state(StorageConnection& conn, boost::uuids::uuid id, TaskState state) -> StorageErr = 0; + virtual auto get_task_state(StorageConnection& conn, boost::uuids::uuid id, TaskState* state) + -> StorageErr + = 0; virtual auto set_task_running(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr = 0; virtual auto add_task_instance(StorageConnection& conn, TaskInstance const& instance) -> StorageErr diff --git a/src/spider/storage/mysql/MySqlStorage.cpp b/src/spider/storage/mysql/MySqlStorage.cpp index a7c321d9..c8aaf227 100644 --- a/src/spider/storage/mysql/MySqlStorage.cpp +++ b/src/spider/storage/mysql/MySqlStorage.cpp @@ -1076,6 +1076,138 @@ auto MySqlMetadataStorage::get_jobs_by_client_id( return StorageErr{}; } +auto MySqlMetadataStorage::cancel_job_by_user( + StorageConnection& conn, + boost::uuids::uuid const id, + std::string const& message +) -> StorageErr { + try { + // Set all pending/ready/running tasks from the job to cancelled + std::unique_ptr task_statement( + static_cast(conn)->prepareStatement( + "UPDATE `tasks` SET `state` = 'cancel' WHERE `job_id` = ? AND " + "`state` IN ('pending', 'ready', 'running')" + ) + ); + sql::bytes id_bytes = uuid_get_bytes(id); + task_statement->setBytes(1, &id_bytes); + task_statement->executeUpdate(); + // Set job state to cancelled + std::unique_ptr job_statement( + static_cast(conn)->prepareStatement( + "UPDATE `jobs` SET `state` = 'cancel' WHERE `id` = ?" + ) + ); + job_statement->setBytes(1, &id_bytes); + job_statement->executeUpdate(); + // Set the cancel message + std::unique_ptr message_statement( + static_cast(conn)->prepareStatement( + "INSERT INTO `job_errors` (`job_id`, `offender`, `message`) VALUES (?, ?, " + "?) " + ) + ); + message_statement->setBytes(1, &id_bytes); + message_statement->setString(2, "user"); + message_statement->setString(3, message); + message_statement->executeUpdate(); + } catch (sql::SQLException& e) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + static_cast(conn)->commit(); + return StorageErr{}; +} + +auto MySqlMetadataStorage::cancel_job_by_task( + StorageConnection& conn, + boost::uuids::uuid id, + std::string const& message +) -> StorageErr { + try { + // Get job id + sql::bytes task_id_bytes = uuid_get_bytes(id); + std::unique_ptr statement( + static_cast(conn)->prepareStatement( + "SELECT `job_id`, `func_name` FROM `tasks` WHERE `id` = ?" + ) + ); + statement->setBytes(1, &task_id_bytes); + std::unique_ptr const res{statement->executeQuery()}; + if (res->rowsCount() == 0) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::KeyNotFoundErr, "No task with id"}; + } + res->next(); + boost::uuids::uuid const job_id = read_id(res->getBinaryStream("job_id")); + sql::bytes job_id_bytes = uuid_get_bytes(job_id); + std::string const function_name = get_sql_string(res->getString("func_name")); + // Set all pending/ready/running tasks from the job to cancelled + std::unique_ptr task_statement( + static_cast(conn)->prepareStatement( + "UPDATE `tasks` SET `state` = 'cancel' WHERE `job_id` = ? AND " + "`state` IN ('pending', 'ready', 'running')" + ) + ); + task_statement->setBytes(1, &job_id_bytes); + task_statement->executeUpdate(); + // Set job state to cancelled + std::unique_ptr job_statement( + static_cast(conn)->prepareStatement( + "UPDATE `jobs` SET `state` = 'cancel' WHERE `id` = ?" + ) + ); + job_statement->setBytes(1, &job_id_bytes); + job_statement->executeUpdate(); + // Set the cancel message + std::unique_ptr message_statement( + static_cast(conn)->prepareStatement( + "INSERT INTO `job_errors` (`job_id`, `offender`, `message`) VALUES (?, ?, " + "?) " + ) + ); + message_statement->setBytes(1, &job_id_bytes); + message_statement->setString(2, function_name); + message_statement->setString(3, message); + message_statement->executeUpdate(); + } catch (sql::SQLException& e) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + static_cast(conn)->commit(); + return StorageErr{}; +} + +auto MySqlMetadataStorage::get_error_message( + StorageConnection& conn, + boost::uuids::uuid const id, + std::string* offender, + std::string* message +) -> StorageErr { + try { + std::unique_ptr statement{ + static_cast(conn)->prepareStatement( + "SELECT `offender`, `message` FROM `job_errors` WHERE `job_id` = ?" + ) + }; + sql::bytes id_bytes = uuid_get_bytes(id); + statement->setBytes(1, &id_bytes); + std::unique_ptr const res{statement->executeQuery()}; + if (res->rowsCount() == 0) { + static_cast(conn)->commit(); + return StorageErr{StorageErrType::KeyNotFoundErr, "No messages found"}; + } + res->next(); + *offender = get_sql_string(res->getString("offender")); + *message = get_sql_string(res->getString("message")); + } catch (sql::SQLException& e) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + static_cast(conn)->commit(); + return StorageErr{}; +} + auto MySqlMetadataStorage::remove_job(StorageConnection& conn, boost::uuids::uuid id) noexcept -> StorageErr { try { @@ -1422,6 +1554,39 @@ auto MySqlMetadataStorage::set_task_state( return StorageErr{}; } +auto MySqlMetadataStorage::get_task_state( + StorageConnection& conn, + boost::uuids::uuid const id, + TaskState* state +) -> StorageErr { + try { + // Get the state of the task + std::unique_ptr statement( + static_cast(conn)->prepareStatement( + "SELECT `state` FROM `tasks` WHERE `id` = ?" + ) + ); + sql::bytes id_bytes = uuid_get_bytes(id); + statement->setBytes(1, &id_bytes); + std::unique_ptr const res(statement->executeQuery()); + if (res->rowsCount() == 0) { + static_cast(conn)->commit(); + return StorageErr{ + StorageErrType::KeyNotFoundErr, + fmt::format("No task with id {} ", boost::uuids::to_string(id)) + }; + } + res->next(); + std::string const state_str = get_sql_string(res->getString("state")); + *state = string_to_task_state(state_str); + } catch (sql::SQLException& e) { + static_cast(conn)->rollback(); + return StorageErr{StorageErrType::OtherErr, e.what()}; + } + static_cast(conn)->commit(); + return StorageErr{}; +} + auto MySqlMetadataStorage::set_task_running(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr { try { @@ -1705,11 +1870,16 @@ auto MySqlMetadataStorage::task_fail( // Set the task fail if the last task instance fails std::unique_ptr const task_statement( static_cast(conn)->prepareStatement( - "UPDATE `tasks` SET `state` = 'fail' WHERE `id` = ?" + "UPDATE `tasks` SET `state` = 'fail' WHERE `id` = ? AND `state` = " + "'running'" ) ); task_statement->setBytes(1, &task_id_bytes); - task_statement->executeUpdate(); + int32_t const task_count = task_statement->executeUpdate(); + if (task_count == 0) { + static_cast(conn)->commit(); + return StorageErr{}; + } // Set the job fails std::unique_ptr const job_statement( static_cast(conn)->prepareStatement( diff --git a/src/spider/storage/mysql/MySqlStorage.hpp b/src/spider/storage/mysql/MySqlStorage.hpp index 90808dac..0987b6b5 100644 --- a/src/spider/storage/mysql/MySqlStorage.hpp +++ b/src/spider/storage/mysql/MySqlStorage.hpp @@ -73,6 +73,18 @@ class MySqlMetadataStorage : public MetadataStorage { boost::uuids::uuid client_id, std::vector* job_ids ) -> StorageErr override; + auto + cancel_job_by_user(StorageConnection& conn, boost::uuids::uuid id, std::string const& message) + -> StorageErr override; + auto + cancel_job_by_task(StorageConnection& conn, boost::uuids::uuid id, std::string const& message) + -> StorageErr override; + auto get_error_message( + StorageConnection& conn, + boost::uuids::uuid id, + std::string* offender, + std::string* message + ) -> StorageErr override; auto remove_job(StorageConnection& conn, boost::uuids::uuid id) noexcept -> StorageErr override; auto reset_job(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr override; auto add_child(StorageConnection& conn, boost::uuids::uuid parent_id, Task const& child) @@ -88,6 +100,8 @@ class MySqlMetadataStorage : public MetadataStorage { ) -> StorageErr override; auto set_task_state(StorageConnection& conn, boost::uuids::uuid id, TaskState state) -> StorageErr override; + auto get_task_state(StorageConnection& conn, boost::uuids::uuid id, TaskState* state) + -> StorageErr override; auto set_task_running(StorageConnection& conn, boost::uuids::uuid id) -> StorageErr override; auto add_task_instance(StorageConnection& conn, TaskInstance const& instance) -> StorageErr override; diff --git a/src/spider/storage/mysql/mysql_stmt.hpp b/src/spider/storage/mysql/mysql_stmt.hpp index f699bf51..5ab528bc 100644 --- a/src/spider/storage/mysql/mysql_stmt.hpp +++ b/src/spider/storage/mysql/mysql_stmt.hpp @@ -32,6 +32,14 @@ std::string const cCreateJobTable = R"(CREATE TABLE IF NOT EXISTS jobs ( PRIMARY KEY (`id`) ))"; +std::string const cCreateJobErrorTable = R"(CREATE TABLE IF NOT EXISTS `job_errors` ( + `job_id` BINARY(16) NOT NULL, + `offender` VARCHAR(64) NOT NULL, + `message` VARCHAR(999) NOT NULL, + CONSTRAINT `job_error_job_id` FOREIGN KEY (`job_id`) REFERENCES `jobs` (`id`) ON UPDATE NO ACTION ON DELETE CASCADE, + PRIMARY KEY (`job_id`) +))"; + std::string const cCreateTaskTable = R"(CREATE TABLE IF NOT EXISTS tasks ( `id` BINARY(16) NOT NULL, `job_id` BINARY(16) NOT NULL, @@ -167,10 +175,11 @@ std::string const cCreateTaskKVDataTable = R"(CREATE TABLE IF NOT EXISTS `task_k CONSTRAINT `kv_data_task_id` FOREIGN KEY (`task_id`) REFERENCES `tasks` (`id`) ON UPDATE NO ACTION ON DELETE CASCADE ))"; -std::array const cCreateStorage = { +std::array const cCreateStorage = { cCreateDriverTable, // drivers table must be created before data_ref_driver cCreateSchedulerTable, - cCreateJobTable, // jobs table must be created before task + cCreateJobTable, // jobs table must be created before task and job_error + cCreateJobErrorTable, cCreateTaskTable, // tasks table must be created before data_ref_task cCreateDataTable, // data table must be created before task_outputs cCreateDataLocalityTable, diff --git a/src/spider/worker/ExecutorHandle.cpp b/src/spider/worker/ExecutorHandle.cpp new file mode 100644 index 00000000..62dbbe80 --- /dev/null +++ b/src/spider/worker/ExecutorHandle.cpp @@ -0,0 +1,38 @@ +#include "ExecutorHandle.hpp" + +#include +#include + +#include + +#include "TaskExecutor.hpp" + +namespace spider::worker { +auto ExecutorHandle::get_task_id() -> std::optional { + std::lock_guard const lock_guard{m_mutex}; + if (nullptr != m_executor) { + return m_executor->get_task_id(); + } + return std::nullopt; +} + +auto ExecutorHandle::cancel_executor() -> void { + std::lock_guard const lock_guard{m_mutex}; + if (nullptr != m_executor) { + m_executor->cancel(); + } +} + +auto ExecutorHandle::set(TaskExecutor* executor) -> void { + std::lock_guard const lock_guard{m_mutex}; + m_executor = executor; +} + +auto ExecutorHandle::clear() -> void { + std::lock_guard const lock_guard{m_mutex}; + m_executor = nullptr; +} + +TaskExecutor* ExecutorHandle::m_executor = nullptr; +std::mutex ExecutorHandle::m_mutex; +} // namespace spider::worker diff --git a/src/spider/worker/ExecutorHandle.hpp b/src/spider/worker/ExecutorHandle.hpp new file mode 100644 index 00000000..279f124c --- /dev/null +++ b/src/spider/worker/ExecutorHandle.hpp @@ -0,0 +1,43 @@ +#ifndef SPIDER_WORKER_EXECUTORHANDLE_HPP +#define SPIDER_WORKER_EXECUTORHANDLE_HPP + +#include +#include + +#include + +#include "TaskExecutor.hpp" + +namespace spider::worker { +/** + * This singleton class acts as a handle for thread-safe access to the task executor and task id. + * It maintains a weak reference to the task executor to prevent multiple destructor calls and + * ensures that access remains valid only while the executor itself is valid. + */ +class ExecutorHandle { +public: + [[nodiscard]] static auto get_task_id() -> std::optional; + static auto cancel_executor() -> void; + static auto set(TaskExecutor* executor) -> void; + static auto clear() -> void; + + // Delete default constructor + ExecutorHandle() = delete; + // Delete copy constructor and assignment operator + ExecutorHandle(ExecutorHandle const&) = delete; + auto operator=(ExecutorHandle const&) -> ExecutorHandle& = delete; + // Delete move constructor and assignment operator + ExecutorHandle(ExecutorHandle&&) = delete; + auto operator=(ExecutorHandle&&) -> ExecutorHandle& = delete; + // Default destructor + ~ExecutorHandle() = default; + +private: + // Do not use std::shared_ptr to avoid calling destructor twice. + static TaskExecutor* m_executor; + + static std::mutex m_mutex; +}; +} // namespace spider::worker + +#endif diff --git a/src/spider/worker/TaskExecutor.cpp b/src/spider/worker/TaskExecutor.cpp index 88c787c2..0d422270 100644 --- a/src/spider/worker/TaskExecutor.cpp +++ b/src/spider/worker/TaskExecutor.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include // IWYU pragma: keep @@ -17,6 +18,10 @@ #include namespace spider::worker { +auto TaskExecutor::get_task_id() const -> boost::uuids::uuid { + return m_task_id; +} + auto TaskExecutor::get_pid() const -> pid_t { return m_process->get_pid(); } @@ -32,16 +37,21 @@ auto TaskExecutor::waiting() -> bool { return TaskExecutorState::Waiting == m_state; } -auto TaskExecutor::succeed() -> bool { +auto TaskExecutor::succeeded() -> bool { std::lock_guard const lock(m_state_mutex); return TaskExecutorState::Succeed == m_state; } -auto TaskExecutor::error() -> bool { +auto TaskExecutor::errored() -> bool { std::lock_guard const lock(m_state_mutex); return TaskExecutorState::Error == m_state; } +auto TaskExecutor::cancelled() -> bool { + std::lock_guard const lock(m_state_mutex); + return TaskExecutorState::Cancelled == m_state; +} + void TaskExecutor::wait() { int const exit_code = m_process->wait(); if (exit_code != 0) { @@ -77,6 +87,12 @@ auto TaskExecutor::process_output_handler() -> boost::asio::awaitable { while (true) { std::optional const response_option = co_await receive_message_async(m_read_pipe); + { + std::lock_guard const lock(m_state_mutex); + if (m_state != TaskExecutorState::Waiting && m_state != TaskExecutorState::Running) { + co_return; + } + } if (!response_option.has_value()) { std::lock_guard const lock(m_state_mutex); m_state = TaskExecutorState::Error; diff --git a/src/spider/worker/TaskExecutor.hpp b/src/spider/worker/TaskExecutor.hpp index 3fc19caf..36a3d9e9 100644 --- a/src/spider/worker/TaskExecutor.hpp +++ b/src/spider/worker/TaskExecutor.hpp @@ -50,7 +50,8 @@ class TaskExecutor { boost::process::v2::environment::value> const& environment, Args&&... args ) - : m_read_pipe(context), + : m_task_id(task_id), + m_read_pipe(context), m_write_pipe(context) { std::vector process_args{ "--func", @@ -103,7 +104,8 @@ class TaskExecutor { boost::process::v2::environment::value> const& environment, std::vector const& args_buffers ) - : m_read_pipe(context), + : m_task_id(task_id), + m_read_pipe(context), m_write_pipe(context) { std::vector process_args{ "--func", @@ -157,13 +159,16 @@ class TaskExecutor { auto completed() -> bool; auto waiting() -> bool; - auto succeed() -> bool; - auto error() -> bool; + auto succeeded() -> bool; + auto errored() -> bool; + auto cancelled() -> bool; void wait(); void cancel(); + [[nodiscard]] auto get_task_id() const -> boost::uuids::uuid; + template auto get_result() const -> std::optional { return core::response_get_result(m_result_buffer); @@ -176,6 +181,8 @@ class TaskExecutor { private: auto process_output_handler() -> boost::asio::awaitable; + boost::uuids::uuid m_task_id; + std::mutex m_state_mutex; std::condition_variable m_complete_cv; TaskExecutorState m_state = TaskExecutorState::Running; diff --git a/src/spider/worker/worker.cpp b/src/spider/worker/worker.cpp index ade6ce6b..7245e0d2 100644 --- a/src/spider/worker/worker.cpp +++ b/src/spider/worker/worker.cpp @@ -47,6 +47,7 @@ #include #include #include +#include #include #include @@ -123,6 +124,36 @@ auto get_environment_variable() -> absl::flat_hash_map< return environment_variables; } +/** + * Checks if the task is cancelled. If the task state is set to cancelled, cancels the running task. + * @param conn The storage connection to use. + * @param metadata_store + */ +auto check_task_cancel( + std::shared_ptr const& conn, + std::shared_ptr const& metadata_store +) -> void { + std::optional const optional_task_id + = spider::worker::ExecutorHandle::get_task_id(); + if (!optional_task_id.has_value()) { + return; + } + boost::uuids::uuid const task_id = optional_task_id.value(); + + spider::core::TaskState task_state = spider::core::TaskState::Running; + spider::core::StorageErr err = metadata_store->get_task_state(*conn, task_id, &task_state); + if (false == err.success()) { + spdlog::error("Failed to get task state: {}", err.description); + return; + } + + if (spider::core::TaskState::Canceled != task_state) { + return; + } + + spider::worker::ExecutorHandle::cancel_executor(); +} + auto heartbeat_loop( std::shared_ptr const& storage_factory, std::shared_ptr const& metadata_store, @@ -131,7 +162,8 @@ auto heartbeat_loop( int fail_count = 0; while (!spider::core::StopFlag::is_stop_requested()) { std::this_thread::sleep_for(std::chrono::seconds(1)); - spdlog::debug("Updating heartbeat"); + + // Getting a storage connection std::variant, spider::core::StorageErr> conn_result = storage_factory->provide_storage_connection(); if (std::holds_alternative(conn_result)) { @@ -142,10 +174,13 @@ auto heartbeat_loop( fail_count++; continue; } - auto conn = std::move( + std::shared_ptr const conn = std::move( std::get>(conn_result) ); + check_task_cancel(conn, metadata_store); + + spdlog::debug("Updating heartbeat"); spider::core::StorageErr const err = metadata_store->update_heartbeat(*conn, driver.get_id()); if (!err.success()) { @@ -283,7 +318,7 @@ auto handle_executor_result( } auto conn = std::move(std::get>(conn_result)); - if (!executor.succeed()) { + if (!executor.succeeded()) { spdlog::warn("Task {} failed", task.get_function_name()); metadata_store->task_fail( *conn, @@ -386,6 +421,8 @@ auto task_loop( arg_buffers }; + spider::worker::ExecutorHandle::set(&executor); + pid_t const pid = executor.get_pid(); spider::core::ChildPid::set_pid(pid); // Double check if stop token is set to avoid any missing signal @@ -397,8 +434,18 @@ auto task_loop( context.run(); executor.wait(); + spider::worker::ExecutorHandle::clear(); spider::core::ChildPid::set_pid(0); + if (executor.cancelled()) { + // If the task is cancelled by the user or other tasks, the states have already been + // updated in the storage, so there's no need to do anything. + // If the task is cancelled by calling `TaskContext::abort`, the storage has also been + // updated, so again, no further action is needed. + spdlog::debug("Task {} was cancelled", task.get_function_name()); + continue; + } + if (handle_executor_result(storage_factory, metadata_store, instance, task, executor)) { fail_task_id = std::nullopt; } else { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a04b1c58..da25d866 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -86,6 +86,18 @@ target_link_libraries( spider_client ) +add_executable(cancel_test) +target_sources(cancel_test PRIVATE client/cancel-test.cpp) +target_link_libraries( + cancel_test + PRIVATE + spider_core + spider_client + worker_test + Boost::program_options + spdlog::spdlog +) + add_custom_target(integrationTest ALL) add_custom_command( TARGET integrationTest @@ -99,4 +111,5 @@ add_dependencies( worker_test client_test signal_test + cancel_test ) diff --git a/tests/client/cancel-test.cpp b/tests/client/cancel-test.cpp new file mode 100644 index 00000000..7207cbe1 --- /dev/null +++ b/tests/client/cancel-test.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: keep +#include + +#include +#include +#include + +namespace { +auto parse_args(int const argc, char** argv) -> boost::program_options::variables_map { + boost::program_options::options_description desc; + desc.add_options()("help", "spider client test"); + desc.add_options()( + "storage_url", + boost::program_options::value(), + "storage server url" + ); + + boost::program_options::variables_map variables; + boost::program_options::store( + // NOLINTNEXTLINE(misc-include-cleaner) + boost::program_options::parse_command_line(argc, argv, desc), + variables + ); + boost::program_options::notify(variables); + return variables; +} + +constexpr int cCmdArgParseErr = 1; +constexpr int cJobNotCancelled = 2; +constexpr int cWrongErrorMessage = 3; +constexpr int cException = 4; + +auto test_user_cancel(spider::Driver& driver) -> int { + spider::Job sleep_job = driver.start(&sleep_test, 3); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + sleep_job.cancel(); + + sleep_job.wait_complete(); + if (spider::JobStatus::Cancelled != sleep_job.get_status()) { + spdlog::error("Sleep job status is not cancelled"); + return cJobNotCancelled; + } + + std::pair const job_errors = sleep_job.get_error(); + if ("user" != job_errors.first) { + spdlog::error("User job cancel wrong name"); + return cWrongErrorMessage; + } + if ("Job cancelled by user." != job_errors.second) { + spdlog::error("User job cancel wrong error message"); + return cWrongErrorMessage; + } + + return 0; +} + +auto test_task_cancel(spider::Driver& driver) -> int { + spider::Job abort_job = driver.start(&abort_test, 0); + abort_job.wait_complete(); + if (spider::JobStatus::Cancelled != abort_job.get_status()) { + spdlog::error("Abort job status is not cancelled"); + return cJobNotCancelled; + } + std::pair const job_errors = abort_job.get_error(); + if ("abort_test" != job_errors.first) { + spdlog::error("Cancelled task wrong function name"); + return cWrongErrorMessage; + } + if ("Abort test" != job_errors.second) { + spdlog::error("Cancelled task wrong error message"); + return cWrongErrorMessage; + } + + return 0; +} +} // namespace + +auto main(int argc, char** argv) -> int { + spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] [spider.client] %v"); +#ifndef NDEBUG + spdlog::set_level(spdlog::level::trace); +#endif + + boost::program_options::variables_map const args = parse_args(argc, argv); + + std::string storage_url; + try { + if (!args.contains("storage_url")) { + spdlog::error("storage_url is required"); + return cCmdArgParseErr; + } + storage_url = args["storage_url"].as(); + } catch (boost::bad_any_cast& e) { + return cCmdArgParseErr; + } catch (boost::program_options::error& e) { + return cCmdArgParseErr; + } + + try { + spider::Driver driver{storage_url}; + spdlog::debug("Driver created"); + + int result = test_user_cancel(driver); + if (0 != result) { + return result; + } + + result = test_task_cancel(driver); + if (0 != result) { + return result; + } + } catch (std::exception& e) { + spdlog::error("Exception: {}", e.what()); + return cException; + } + return 0; +} diff --git a/tests/integration/test_cancel.py b/tests/integration/test_cancel.py new file mode 100644 index 00000000..01d428b5 --- /dev/null +++ b/tests/integration/test_cancel.py @@ -0,0 +1,85 @@ +import subprocess +import time +import uuid +from pathlib import Path +from typing import Tuple + +import msgpack +import pytest + +from .client import ( + add_driver, + add_driver_data, + Data, + Driver, + g_storage_url, + get_task_outputs, + get_task_state, + remove_data, + remove_job, + storage, + submit_job, + Task, + TaskGraph, + TaskInput, + TaskOutput, +) +from .utils import g_scheduler_port + + +def start_scheduler_worker( + storage_url: str, scheduler_port: int +) -> Tuple[subprocess.Popen, subprocess.Popen]: + # Start the scheduler + dir_path = Path(__file__).resolve().parent + dir_path = dir_path / ".." / ".." / "src" / "spider" + scheduler_cmds = [ + str(dir_path / "spider_scheduler"), + "--host", + "127.0.0.1", + "--port", + str(scheduler_port), + "--storage_url", + storage_url, + ] + scheduler_process = subprocess.Popen(scheduler_cmds) + worker_cmds = [ + str(dir_path / "spider_worker"), + "--host", + "127.0.0.1", + "--storage_url", + storage_url, + "--libs", + "tests/libworker_test.so", + ] + worker_process = subprocess.Popen(worker_cmds) + return scheduler_process, worker_process + + +@pytest.fixture(scope="class") +def scheduler_worker(storage): + scheduler_process, worker_process = start_scheduler_worker( + storage_url=g_storage_url, scheduler_port=g_scheduler_port + ) + # Wait for 5 second to make sure the scheduler and worker are started + time.sleep(5) + yield + scheduler_process.kill() + worker_process.kill() + + +class TestCancel: + + # Test that the task can be cancelled by user and from the task. + # Execute the cancel_test client, which includes cancelling a running task + # and executing a task that cancels itself. + def test_task_cancel(self, scheduler_worker): + dir_path = Path(__file__).resolve().parent + dir_path = dir_path / ".." + client_cmds = [ + str(dir_path / "cancel_test"), + "--storage_url", + g_storage_url, + ] + p = subprocess.run(client_cmds, timeout=20) + assert p.returncode == 0 diff --git a/tests/storage/test-MetadataStorage.cpp b/tests/storage/test-MetadataStorage.cpp index 40ce7472..1da37e61 100644 --- a/tests/storage/test-MetadataStorage.cpp +++ b/tests/storage/test-MetadataStorage.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -448,6 +449,143 @@ TEMPLATE_LIST_TEST_CASE("Task finish", "[storage]", spider::test::StorageFactory REQUIRE(storage->remove_job(*conn, job_id).success()); } +/** + * Creates a test job with a task dependency graph for cancellation tests. + * + * Task graph structure: + * parent_1 (p1) ──┐ + * ├──> child_task + * parent_2 (p2) ──┘ + * + * Task states after setup: + * - parent_1: Succeeded (with output "1.1") + * - parent_2: Ready + * - child_task: Pending (waiting for parent_2 to complete) + * + * @param storage + * @param conn + * @return A tuple containing (job_id, parent_1_id, parent_2_id, and child_task_id). + */ +auto job_cancel_setup( + std::unique_ptr& storage, + std::unique_ptr& conn +) -> std::tuple { + boost::uuids::random_generator gen; + boost::uuids::uuid const job_id = gen(); + + spider::core::Task child_task{"child"}; + spider::core::Task parent_1{"p1"}; + spider::core::Task parent_2{"p2"}; + parent_1.add_input(spider::core::TaskInput{"1", "float"}); + parent_1.add_input(spider::core::TaskInput{"2", "float"}); + parent_2.add_input(spider::core::TaskInput{"3", "int"}); + parent_2.add_input(spider::core::TaskInput{"4", "int"}); + parent_1.add_output(spider::core::TaskOutput{"float"}); + parent_2.add_output(spider::core::TaskOutput{"int"}); + child_task.add_input(spider::core::TaskInput{parent_1.get_id(), 0, "float"}); + child_task.add_input(spider::core::TaskInput{parent_2.get_id(), 0, "int"}); + child_task.add_output(spider::core::TaskOutput{"float"}); + parent_1.set_max_retries(1); + parent_2.set_max_retries(1); + child_task.set_max_retries(1); + spider::core::TaskGraph graph; + graph.add_task(child_task); + graph.add_task(parent_1); + graph.add_task(parent_2); + graph.add_dependency(parent_2.get_id(), child_task.get_id()); + graph.add_dependency(parent_1.get_id(), child_task.get_id()); + graph.add_input_task(parent_1.get_id()); + graph.add_input_task(parent_2.get_id()); + graph.add_output_task(child_task.get_id()); + REQUIRE(storage->add_job(*conn, job_id, gen(), graph).success()); + + spider::core::TaskInstance const parent_1_instance{gen(), parent_1.get_id()}; + REQUIRE(storage->set_task_state(*conn, parent_1.get_id(), spider::core::TaskState::Running) + .success()); + REQUIRE(storage->task_finish( + *conn, + parent_1_instance, + {spider::core::TaskOutput{"1.1", "float"}} + ) + .success()); + + return std::make_tuple(job_id, parent_1.get_id(), parent_2.get_id(), child_task.get_id()); +} + +TEMPLATE_LIST_TEST_CASE("Job cancel by user", "[storage]", spider::test::StorageFactoryTypeList) { + std::unique_ptr storage_factory + = spider::test::create_storage_factory(); + std::unique_ptr storage + = storage_factory->provide_metadata_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + auto conn = std::move(std::get>(conn_result)); + + auto [job_id, parent_1_id, parent_2_id, child_id] = job_cancel_setup(storage, conn); + + std::string const error_message = "Job cancelled by user."; + REQUIRE(storage->cancel_job_by_user(*conn, job_id, error_message).success()); + spider::core::JobStatus job_status = spider::core::JobStatus::Running; + REQUIRE(storage->get_job_status(*conn, job_id, &job_status).success()); + REQUIRE(spider::core::JobStatus::Cancelled == job_status); + std::string error_task_res; + std::string error_message_res; + REQUIRE( + storage->get_error_message(*conn, job_id, &error_task_res, &error_message_res).success() + ); + REQUIRE("user" == error_task_res); + REQUIRE(error_message == error_message_res); + + spider::core::TaskState task_state = spider::core::TaskState::Running; + REQUIRE(storage->get_task_state(*conn, parent_1_id, &task_state).success()); + REQUIRE(spider::core::TaskState::Succeed == task_state); + REQUIRE(storage->get_task_state(*conn, parent_2_id, &task_state).success()); + REQUIRE(spider::core::TaskState::Canceled == task_state); + REQUIRE(storage->get_task_state(*conn, child_id, &task_state).success()); + REQUIRE(spider::core::TaskState::Canceled == task_state); + + REQUIRE(storage->remove_job(*conn, job_id).success()); +} + +TEMPLATE_LIST_TEST_CASE("Job cancel by task", "[storage]", spider::test::StorageFactoryTypeList) { + std::unique_ptr storage_factory + = spider::test::create_storage_factory(); + std::unique_ptr storage + = storage_factory->provide_metadata_storage(); + + std::variant, spider::core::StorageErr> + conn_result = storage_factory->provide_storage_connection(); + REQUIRE(std::holds_alternative>(conn_result)); + auto conn = std::move(std::get>(conn_result)); + + auto [job_id, parent_1_id, parent_2_id, child_id] = job_cancel_setup(storage, conn); + + std::string const error_message = "test error message"; + REQUIRE(storage->cancel_job_by_task(*conn, parent_2_id, error_message).success()); + spider::core::JobStatus job_status = spider::core::JobStatus::Running; + REQUIRE(storage->get_job_status(*conn, job_id, &job_status).success()); + REQUIRE(spider::core::JobStatus::Cancelled == job_status); + std::string error_task_res; + std::string error_message_res; + REQUIRE( + storage->get_error_message(*conn, job_id, &error_task_res, &error_message_res).success() + ); + REQUIRE("p2" == error_task_res); + REQUIRE(error_message == error_message_res); + + spider::core::TaskState task_state = spider::core::TaskState::Running; + REQUIRE(storage->get_task_state(*conn, parent_1_id, &task_state).success()); + REQUIRE(spider::core::TaskState::Succeed == task_state); + REQUIRE(storage->get_task_state(*conn, parent_2_id, &task_state).success()); + REQUIRE(spider::core::TaskState::Canceled == task_state); + REQUIRE(storage->get_task_state(*conn, child_id, &task_state).success()); + REQUIRE(spider::core::TaskState::Canceled == task_state); + + REQUIRE(storage->remove_job(*conn, job_id).success()); +} + TEMPLATE_LIST_TEST_CASE("Job reset", "[storage]", spider::test::StorageFactoryTypeList) { std::unique_ptr storage_factory = spider::test::create_storage_factory(); diff --git a/tests/worker/test-TaskExecutor.cpp b/tests/worker/test-TaskExecutor.cpp index d7d27939..4245a544 100644 --- a/tests/worker/test-TaskExecutor.cpp +++ b/tests/worker/test-TaskExecutor.cpp @@ -88,7 +88,7 @@ TEMPLATE_LIST_TEST_CASE( }; context.run(); executor.wait(); - REQUIRE(executor.succeed()); + REQUIRE(executor.succeeded()); std::optional const result_option = executor.get_result(); REQUIRE(result_option.has_value()); REQUIRE(5 == result_option.value_or(0)); @@ -119,7 +119,7 @@ TEMPLATE_LIST_TEST_CASE( }; context.run(); executor.wait(); - REQUIRE(executor.error()); + REQUIRE(executor.errored()); std::tuple error = executor.get_error(); REQUIRE(spider::core::FunctionInvokeError::WrongNumberOfArguments == std::get<0>(error)); } @@ -149,7 +149,7 @@ TEMPLATE_LIST_TEST_CASE( }; context.run(); executor.wait(); - REQUIRE(executor.error()); + REQUIRE(executor.errored()); std::tuple error = executor.get_error(); REQUIRE(spider::core::FunctionInvokeError::FunctionExecutionError == std::get<0>(error)); } @@ -211,7 +211,7 @@ TEMPLATE_LIST_TEST_CASE( }; context.run(); executor.wait(); - REQUIRE(executor.succeed()); + REQUIRE(executor.succeeded()); std::optional const optional_result = executor.get_result(); REQUIRE(optional_result.has_value()); if (optional_result.has_value()) { @@ -255,7 +255,7 @@ TEMPLATE_LIST_TEST_CASE( }; context.run(); executor.wait(); - REQUIRE(executor.succeed()); + REQUIRE(executor.succeeded()); std::optional const result_option = executor.get_result(); REQUIRE(result_option.has_value()); REQUIRE(input_1 + input_2 == result_option.value_or("")); diff --git a/tests/worker/worker-test.cpp b/tests/worker/worker-test.cpp index 1daf725f..21a96291 100644 --- a/tests/worker/worker-test.cpp +++ b/tests/worker/worker-test.cpp @@ -1,9 +1,11 @@ #include "worker-test.hpp" +#include #include #include #include #include +#include #include #include @@ -70,6 +72,16 @@ auto join_string_test( return input_1 + input_2; } +auto sleep_test([[maybe_unused]] spider::TaskContext& context, int milliseconds) -> int { + std::this_thread::sleep_for(std::chrono::milliseconds{milliseconds}); + return milliseconds; +} + +auto abort_test(spider::TaskContext& context, [[maybe_unused]] int x) -> int { + context.abort("Abort test"); + return 0; +} + // NOLINTBEGIN(cert-err58-cpp) SPIDER_REGISTER_TASK(sum_test); SPIDER_REGISTER_TASK(swap_test); @@ -79,4 +91,6 @@ SPIDER_REGISTER_TASK(random_fail_test); SPIDER_REGISTER_TASK(create_data_test); SPIDER_REGISTER_TASK(create_task_test); SPIDER_REGISTER_TASK(join_string_test); +SPIDER_REGISTER_TASK(sleep_test); +SPIDER_REGISTER_TASK(abort_test); // NOLINTEND(cert-err58-cpp) diff --git a/tests/worker/worker-test.hpp b/tests/worker/worker-test.hpp index f1624586..d50f03f6 100644 --- a/tests/worker/worker-test.hpp +++ b/tests/worker/worker-test.hpp @@ -27,4 +27,8 @@ auto join_string_test( std::string const& input_2 ) -> std::string; +auto sleep_test(spider::TaskContext& context, int milliseconds) -> int; + +auto abort_test(spider::TaskContext& context, int x) -> int; + #endif diff --git a/tools/scripts/storage/init_db.sql b/tools/scripts/storage/init_db.sql index d56f49c9..3575f56d 100644 --- a/tools/scripts/storage/init_db.sql +++ b/tools/scripts/storage/init_db.sql @@ -23,6 +23,13 @@ CREATE TABLE IF NOT EXISTS jobs INDEX idx_jobs_state (`state`), PRIMARY KEY (`id`) ); +CREATE TABLE IF NOT EXISTS `job_errors` ( + `job_id` BINARY(16) NOT NULL, + `offender` VARCHAR(64) NOT NULL, + `message` VARCHAR(999) NOT NULL, + CONSTRAINT `job_error_job_id` FOREIGN KEY (`job_id`) REFERENCES `jobs` (`id`) ON UPDATE NO ACTION ON DELETE CASCADE, + PRIMARY KEY (`job_id`) +); CREATE TABLE IF NOT EXISTS tasks ( `id` BINARY(16) NOT NULL,