From 6d421d53f04bb5f713447010fa4b74e20dde812e Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Wed, 13 May 2026 23:39:10 -0400 Subject: [PATCH 1/8] Add graceful shutdown & server runtime --- components/spider-storage/Cargo.toml | 2 +- components/spider-storage/src/state.rs | 2 + components/spider-storage/src/state/server.rs | 174 +++++++++ .../spider-storage/src/task_instance_pool.rs | 343 ++++++++++++++---- 4 files changed, 454 insertions(+), 67 deletions(-) create mode 100644 components/spider-storage/src/state/server.rs diff --git a/components/spider-storage/Cargo.toml b/components/spider-storage/Cargo.toml index c4fce35c..c8381ce2 100644 --- a/components/spider-storage/Cargo.toml +++ b/components/spider-storage/Cargo.toml @@ -31,6 +31,7 @@ tokio = { version = "1.50.0", features = [ "sync", "time" ] } +tokio-util = { version = "0.7.18", features = ["rt"] } tracing = { version = "0.1.44", features = ["attributes"] } uuid = { version = "1.19.0", features = ["serde"] } @@ -41,5 +42,4 @@ rand = "0.9.1" serial_test = { version = "3.2.0", features = ["file_locks"] } tabled = "0.20.0" tokio = { version = "1.50.0", features = ["macros", "rt-multi-thread", "sync"] } -tokio-util = { version = "0.7", features = ["rt"] } uuid = { version = "1.19.0", features = ["v4"] } diff --git a/components/spider-storage/src/state.rs b/components/spider-storage/src/state.rs index 4d573778..1afe82fa 100644 --- a/components/spider-storage/src/state.rs +++ b/components/spider-storage/src/state.rs @@ -1,9 +1,11 @@ pub mod error; pub mod job_cache; +pub mod server; pub mod service; pub use error::StorageServerError; pub use job_cache::JobCache; +pub use server::ServerRuntime; pub use service::ServiceState; #[cfg(test)] diff --git a/components/spider-storage/src/state/server.rs b/components/spider-storage/src/state/server.rs new file mode 100644 index 00000000..dcc93e7d --- /dev/null +++ b/components/spider-storage/src/state/server.rs @@ -0,0 +1,174 @@ +use std::time::Duration; + +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; + +use crate::{ + cache::error::{CacheError, InternalError}, + config::DatabaseConfig, + db::{MariaDbStorageConnector, SessionManagement}, + ready_queue::{ReadyQueueConfig, ReadyQueueSenderHandle, create_ready_queue}, + state::{JobCache, ServiceState, StorageServerError}, + task_instance_pool::{ + TaskInstancePoolConfig, + TaskInstancePoolHandle, + create_task_instance_pool, + }, +}; + +/// Production per-process storage server runtime. +pub struct ServerRuntime { + service_state: + ServiceState, + cancellation_token: CancellationToken, + task_instance_pool_task: JoinHandle>, +} + +impl ServerRuntime { + /// Creates a storage server runtime from the database configuration. + /// + /// # Returns + /// + /// A newly created [`ServerRuntime`] on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * Forwards [`MariaDbStorageConnector::connect`]'s return values on failure. + /// * Forwards [`create_ready_queue`]'s return values on failure. + pub async fn create(db_config: &DatabaseConfig) -> Result { + let cancellation_token = CancellationToken::new(); + let db = MariaDbStorageConnector::connect(db_config).await?; + let session_id = db.session_id(); + let (ready_queue_sender, ready_queue_receiver) = + create_ready_queue(ReadyQueueConfig::default()).map_err(CacheError::from)?; + let (task_instance_pool_connector, task_instance_pool_task) = create_task_instance_pool( + ready_queue_sender.clone(), + db.clone(), + cancellation_token.clone(), + TaskInstancePoolConfig::default(), + ); + let service_state = ServiceState::new( + db, + session_id, + JobCache::new(), + ready_queue_sender, + ready_queue_receiver, + task_instance_pool_connector, + ); + + Ok(Self { + service_state, + cancellation_token, + task_instance_pool_task, + }) + } + + /// # Returns + /// + /// A clone of the runtime's [`ServiceState`]. + #[must_use] + pub fn service_state( + &self, + ) -> ServiceState { + self.service_state.clone() + } + + /// Stops background tasks owned by the runtime. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`StorageServerError::Stopping`] if the task instance pool does not stop before timeout. + /// * [`StorageServerError::Cache`] if the task instance pool task fails or cannot be joined. + pub async fn stop_background_tasks(self) -> Result<(), StorageServerError> { + stop_background_task( + self.cancellation_token, + self.task_instance_pool_task, + STOP_BACKGROUND_TASKS_TIMEOUT, + ) + .await + } +} + +const STOP_BACKGROUND_TASKS_TIMEOUT: Duration = Duration::from_secs(30); + +/// Stops a single cancellation-token-controlled background task. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * [`StorageServerError::Stopping`] if the task does not stop before `timeout`. +/// * [`StorageServerError::Cache`] if the task fails or cannot be joined. +async fn stop_background_task( + cancellation_token: CancellationToken, + task: JoinHandle>, + timeout: Duration, +) -> Result<(), StorageServerError> { + cancellation_token.cancel(); + let join_result = tokio::time::timeout(timeout, task).await.map_err(|_| { + StorageServerError::Stopping("task instance pool stop timed out".to_owned()) + })?; + let pool_result = join_result.map_err(|e| { + StorageServerError::Cache(CacheError::Internal( + InternalError::TaskInstancePoolCorrupted(format!("task join error: {e}")), + )) + })?; + pool_result.map_err(|e| StorageServerError::Cache(CacheError::Internal(e))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn stop_background_task_cancels_and_joins_task() -> anyhow::Result<()> { + let cancellation_token = CancellationToken::new(); + let task_cancellation_token = cancellation_token.clone(); + let task = tokio::spawn(async move { + task_cancellation_token.cancelled().await; + Ok(()) + }); + + stop_background_task(cancellation_token, task, Duration::from_secs(1)).await?; + Ok(()) + } + + #[tokio::test] + async fn stop_background_task_returns_stopping_on_timeout() -> anyhow::Result<()> { + let cancellation_token = CancellationToken::new(); + let task = tokio::spawn(async move { + tokio::time::sleep(Duration::from_mins(1)).await; + Ok(()) + }); + + let result = stop_background_task(cancellation_token, task, Duration::from_millis(1)).await; + + assert!( + matches!(result, Err(StorageServerError::Stopping(_))), + "timeout should return Stopping" + ); + Ok(()) + } + + #[tokio::test] + async fn stop_background_task_returns_cache_error_on_pool_error() -> anyhow::Result<()> { + let cancellation_token = CancellationToken::new(); + let task = tokio::spawn(async move { + Err(InternalError::TaskInstancePoolCorrupted( + "test failure".to_owned(), + )) + }); + + let result = stop_background_task(cancellation_token, task, Duration::from_secs(1)).await; + + assert!( + matches!(result, Err(StorageServerError::Cache(_))), + "pool task failure should return Cache error" + ); + Ok(()) + } +} diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index ace45ce6..e994a0e6 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -24,7 +24,11 @@ use std::{ use async_trait::async_trait; use spider_core::types::id::{ExecutionManagerId, JobId, ResourceGroupId, TaskInstanceId}; -use tokio::sync::mpsc; +use tokio::{ + sync::{mpsc, mpsc::error::TryRecvError}, + task::JoinHandle, +}; +use tokio_util::sync::CancellationToken; use crate::{ cache::{ @@ -32,6 +36,7 @@ use crate::{ error::InternalError, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, }, + db::{ExecutionManagerLivenessManagement, MariaDbStorageConnector}, ready_queue::ReadyQueueSender, }; @@ -72,16 +77,16 @@ pub trait ExecutionManagerLivenessStore: Clone + Send + Sync { id: &ExecutionManagerId, ) -> Result; - /// Returns the IDs of execution managers whose last heartbeat is before `stale_before`, after - /// marking them dead. + /// Returns the IDs of execution managers whose last heartbeat is older than + /// `stale_after_sec`, after marking them dead. /// /// This operation is atomic: once an execution manager is returned by this method, it will not /// be returned again in subsequent calls. /// /// # Parameters /// - /// * `stale_before` - The cutoff time; execution managers with no heartbeat after this time are - /// considered dead. + /// * `stale_after_sec` - The seconds after the last heartbeat which makes an execution manager + /// stale. /// /// # Returns /// @@ -94,7 +99,7 @@ pub trait ExecutionManagerLivenessStore: Clone + Send + Sync { /// * Forwards the underlying store's return values on failure. async fn get_dead_execution_managers( &self, - stale_before: SystemTime, + stale_after_sec: u64, ) -> Result, InternalError>; } @@ -167,53 +172,96 @@ pub struct TaskInstancePoolHandle { sender: mpsc::Sender, } -impl TaskInstancePoolHandle { - /// Creates a new task instance pool and returns a handle to it. - /// - /// # Type Parameters - /// - /// * `ReadyQueueSenderType` - The ready queue sender implementation for re-enqueue operations. - /// * `LivenessStoreType` - The execution manager liveness store implementation. - /// - /// # Returns - /// - /// A [`TaskInstancePoolHandle`] connected to the newly spawned pool coroutine. - #[must_use] - pub fn create< - ReadyQueueSenderType: ReadyQueueSender + 'static, - LivenessStoreType: ExecutionManagerLivenessStore + 'static, - >( - ready_queue_sender: ReadyQueueSenderType, - execution_manager_liveness_store: LivenessStoreType, - execution_manager_stale_cutoff: Duration, - gc_interval: Duration, - channel_size: usize, - ) -> Self { - let next_task_instance_id = Arc::new(AtomicU64::new(1)); - let (sender, receiver) = mpsc::channel(channel_size); - - let pool = TaskInstancePool { - ready_queue_sender, - execution_manager_liveness_store, - execution_manager_stale_cutoff, - instances: Vec::new(), - execution_manager_pool: HashSet::new(), - receiver, - }; - tokio::spawn(async move { - match pool.run(gc_interval).await { - Ok(()) => {} - Err(_e) => todo!("log this error and terminate the storage service"), - } - }); +/// Configuration for a task instance pool actor. +#[derive(Debug, Clone, Copy)] +pub struct TaskInstancePoolConfig { + pub execution_manager_stale_after_sec: u64, + pub gc_interval: Duration, + pub channel_size: usize, +} +impl Default for TaskInstancePoolConfig { + fn default() -> Self { Self { - next_task_instance_id, - sender, + execution_manager_stale_after_sec: 60, + gc_interval: Duration::from_secs(30), + channel_size: 128, } } } +/// Creates a task instance pool and returns the handle plus the spawned actor task. +/// +/// # Type Parameters +/// +/// * `ReadyQueueSenderType` - The ready queue sender implementation for re-enqueue operations. +/// * `LivenessStoreType` - The execution manager liveness store implementation. +/// +/// # Returns +/// +/// A [`TaskInstancePoolHandle`] and the spawned actor's [`JoinHandle`]. +pub fn create_task_instance_pool< + ReadyQueueSenderType: ReadyQueueSender + 'static, + LivenessStoreType: ExecutionManagerLivenessStore + 'static, +>( + ready_queue_sender: ReadyQueueSenderType, + execution_manager_liveness_store: LivenessStoreType, + cancellation_token: CancellationToken, + config: TaskInstancePoolConfig, +) -> ( + TaskInstancePoolHandle, + JoinHandle>, +) { + let next_task_instance_id = Arc::new(AtomicU64::new(1)); + let (sender, receiver) = mpsc::channel(config.channel_size); + + let pool = TaskInstancePool { + ready_queue_sender, + execution_manager_liveness_store, + execution_manager_stale_after_sec: config.execution_manager_stale_after_sec, + instances: Vec::new(), + execution_manager_pool: HashSet::new(), + receiver, + }; + let pool_task = + tokio::spawn(async move { pool.run(cancellation_token, config.gc_interval).await }); + let handle = TaskInstancePoolHandle { + next_task_instance_id, + sender, + }; + + (handle, pool_task) +} + +#[async_trait] +impl ExecutionManagerLivenessStore for MariaDbStorageConnector { + async fn is_execution_manager_alive( + &self, + id: &ExecutionManagerId, + ) -> Result { + ExecutionManagerLivenessManagement::is_execution_manager_alive(self, *id) + .await + .map_err(|e| { + InternalError::TaskInstancePoolCorrupted(format!( + "failed to check execution manager liveness: {e}" + )) + }) + } + + async fn get_dead_execution_managers( + &self, + stale_after_sec: u64, + ) -> Result, InternalError> { + ExecutionManagerLivenessManagement::get_dead_execution_managers(self, stale_after_sec) + .await + .map_err(|e| { + InternalError::TaskInstancePoolCorrupted(format!( + "failed to get dead execution managers: {e}" + )) + }) + } +} + #[async_trait] impl TaskInstancePoolConnector for TaskInstancePoolHandle { fn get_next_available_task_instance_id(&self) -> TaskInstanceId { @@ -319,7 +367,7 @@ struct TaskInstancePool< ready_queue_sender: ReadyQueueSenderType, execution_manager_liveness_store: LivenessStoreType, execution_manager_pool: HashSet, - execution_manager_stale_cutoff: Duration, + execution_manager_stale_after_sec: u64, instances: Vec, receiver: mpsc::Receiver, } @@ -335,16 +383,23 @@ impl Result<(), InternalError> { + async fn run( + mut self, + cancellation_token: CancellationToken, + gc_interval: Duration, + ) -> Result<(), InternalError> { let mut gc_interval = tokio::time::interval(gc_interval); // The first tick completes immediately; skip it so we don't GC right at startup. gc_interval.tick().await; loop { tokio::select! { + () = cancellation_token.cancelled() => { + self.drain_received_messages().await?; + return Ok(()); + } message = self.receiver.recv() => { let Some(message) = message else { - // TODO: log this exit return Ok(()); }; self.handle_message(message).await?; @@ -356,6 +411,22 @@ impl Result<(), InternalError> { + loop { + match self.receiver.try_recv() { + Ok(message) => self.handle_message(message).await?, + Err(TryRecvError::Empty | TryRecvError::Disconnected) => return Ok(()), + } + } + } + /// Handles a single pool message. /// /// # Errors @@ -423,11 +494,7 @@ impl Result<(), InternalError> { let dead_em_ids: Vec = self .execution_manager_liveness_store - .get_dead_execution_managers( - gc_started_at - .checked_sub(self.execution_manager_stale_cutoff) - .unwrap_or(SystemTime::UNIX_EPOCH), - ) + .get_dead_execution_managers(self.execution_manager_stale_after_sec) .await?; for execution_manager_id in &dead_em_ids { @@ -565,7 +632,7 @@ mod tests { async fn get_dead_execution_managers( &self, - _stale_before: SystemTime, + _stale_after_sec: u64, ) -> Result, InternalError> { Ok(self.dead_execution_managers.lock().await.clone()) } @@ -640,7 +707,7 @@ mod tests { async fn get_dead_execution_managers( &self, - _stale_before: SystemTime, + _stale_after_sec: u64, ) -> Result, InternalError> { Ok(Vec::new()) } @@ -709,7 +776,7 @@ mod tests { ready_queue_sender, execution_manager_liveness_store: liveness_store, execution_manager_pool: HashSet::new(), - execution_manager_stale_cutoff, + execution_manager_stale_after_sec: execution_manager_stale_cutoff.as_secs(), instances: Vec::new(), receiver, } @@ -752,12 +819,16 @@ mod tests { #[tokio::test] async fn dead_execution_manager_registration_triggers_recovery() { let ready_queue_sender = MockReadyQueueSender::default(); - let pool = TaskInstancePoolHandle::create( + let cancellation_token = CancellationToken::new(); + let (pool, pool_task) = create_task_instance_pool( ready_queue_sender.clone(), RejectAllLivenessStore, - Duration::from_mins(1), - Duration::from_mins(1), - DEFAULT_CHANNEL_SIZE, + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: Duration::from_mins(1), + channel_size: DEFAULT_CHANNEL_SIZE, + }, ); let tcb = build_single_task_tcb().await; let task_instance_id = 1; @@ -785,18 +856,28 @@ mod tests { messages.contains(&ReadyMessage::Task(job_id, 0)), "task should be re-enqueued for dead EM, got: {messages:?}" ); + cancellation_token.cancel(); + tokio::time::timeout(Duration::from_secs(1), pool_task) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); } #[tokio::test] async fn valid_em_is_cached_and_subsequent_registrations_skip_verify() { let ready_queue_sender = MockReadyQueueSender::default(); let liveness_store = MockExecutionManagerLivenessStore::default(); - let pool = TaskInstancePoolHandle::create( + let cancellation_token = CancellationToken::new(); + let (pool, pool_task) = create_task_instance_pool( ready_queue_sender, liveness_store.clone(), - Duration::from_mins(1), - Duration::from_mins(1), - DEFAULT_CHANNEL_SIZE, + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: Duration::from_mins(1), + channel_size: DEFAULT_CHANNEL_SIZE, + }, ); let execution_manager_id = ExecutionManagerId::new(); @@ -828,6 +909,136 @@ mod tests { 1, "liveness store should be called exactly once for two registrations with the same EM" ); + cancellation_token.cancel(); + tokio::time::timeout(Duration::from_secs(1), pool_task) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); + } + + #[tokio::test] + async fn spawned_pool_exits_when_cancelled() { + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let (_pool, pool_task) = create_task_instance_pool( + MockReadyQueueSender::default(), + MockExecutionManagerLivenessStore::default(), + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: Duration::from_mins(1), + channel_size: DEFAULT_CHANNEL_SIZE, + }, + ); + + cancellation_token.cancel(); + + tokio::time::timeout(Duration::from_secs(1), pool_task) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); + } + + #[tokio::test] + async fn spawned_pool_processes_registration_before_shutdown() { + let ready_queue_sender = MockReadyQueueSender::default(); + let cancellation_token = tokio_util::sync::CancellationToken::new(); + let (pool, pool_task) = create_task_instance_pool( + ready_queue_sender.clone(), + RejectAllLivenessStore, + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: Duration::from_mins(1), + channel_size: DEFAULT_CHANNEL_SIZE, + }, + ); + let tcb = build_single_task_tcb().await; + let task_instance_id = 1; + let _ = tcb + .register_task_instance(task_instance_id) + .await + .expect("TCB registration should succeed"); + let metadata = make_task_instance_metadata( + TaskId::Index(0), + task_instance_id, + ExecutionManagerId::new(), + SystemTime::now(), + ); + let job_id = metadata.job_id; + + pool.register_task_instance(tcb, metadata) + .await + .expect("registration should be sent"); + tokio::time::sleep(Duration::from_millis(100)).await; + cancellation_token.cancel(); + tokio::time::timeout(Duration::from_secs(1), pool_task) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); + + let messages = ready_queue_sender.sent_messages.lock().await.clone(); + assert!( + messages.contains(&ReadyMessage::Task(job_id, 0)), + "registration should be processed before shutdown, got: {messages:?}" + ); + } + + #[tokio::test] + async fn run_drains_queued_registrations_when_already_cancelled() { + let ready_queue_sender = MockReadyQueueSender::default(); + let cancellation_token = CancellationToken::new(); + cancellation_token.cancel(); + let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_SIZE); + let mut expected_messages: Vec = Vec::new(); + + for task_index in 0..3 { + let tcb = build_single_task_tcb().await; + let task_instance_id = task_index as TaskInstanceId + 1; + let _ = tcb + .register_task_instance(task_instance_id) + .await + .expect("TCB registration should succeed"); + let metadata = make_task_instance_metadata( + TaskId::Index(task_index), + task_instance_id, + ExecutionManagerId::new(), + SystemTime::now(), + ); + expected_messages.push(ReadyMessage::Task(metadata.job_id, task_index)); + sender + .send(PoolMessage::Register { + tcb: Tcb::Task(tcb), + metadata, + }) + .await + .expect("pool message should be queued"); + } + + drop(sender); + let pool = TaskInstancePool { + ready_queue_sender: ready_queue_sender.clone(), + execution_manager_liveness_store: RejectAllLivenessStore, + execution_manager_pool: HashSet::new(), + execution_manager_stale_after_sec: 60, + instances: Vec::new(), + receiver, + }; + + pool.run(cancellation_token, Duration::from_mins(1)) + .await + .expect("pool should stop cleanly"); + + let messages = ready_queue_sender.sent_messages.lock().await.clone(); + assert_eq!(messages.len(), expected_messages.len(), "got: {messages:?}"); + for expected in &expected_messages { + assert!( + messages.contains(expected), + "missing drained registration {expected:?}, got: {messages:?}" + ); + } } #[tokio::test] From fc6e11e9dc2129f4d2054250f7cfd22fb4e92e47 Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Thu, 14 May 2026 12:16:18 -0400 Subject: [PATCH 2/8] Merge two em liveness trait into one --- components/spider-storage/src/cache/error.rs | 3 + .../spider-storage/src/task_instance_pool.rs | 187 +++++++----------- 2 files changed, 74 insertions(+), 116 deletions(-) diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs index 8bfcadf3..af1ae54c 100644 --- a/components/spider-storage/src/cache/error.rs +++ b/components/spider-storage/src/cache/error.rs @@ -83,6 +83,9 @@ pub enum InternalError { #[error(transparent)] WireError(#[from] WireError), + + #[error(transparent)] + Db(#[from] crate::db::DbError), } /// Enums for all errors representing operations that are rejected due to stale cache state. diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index e994a0e6..e8407810 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -7,7 +7,7 @@ //! the task so a new instance can be scheduled, while the original instance remains live until it //! completes or is force-removed. //! * **Dead-execution-manager recovery**: During each GC cycle, the pool queries the -//! [`ExecutionManagerLivenessStore`] to detect dead execution managers, force-removes their +//! [`ExecutionManagerLivenessManagement`] to detect dead execution managers, force-removes their //! instances from the task control blocks, and re-enqueues the corresponding tasks. //! //! Internally, the pool runs as a single-owner coroutine: a tokio task owns the mutable state @@ -36,7 +36,7 @@ use crate::{ error::InternalError, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, }, - db::{ExecutionManagerLivenessManagement, MariaDbStorageConnector}, + db::ExecutionManagerLivenessManagement, ready_queue::ReadyQueueSender, }; @@ -51,58 +51,6 @@ pub struct TaskInstanceMetadata { pub soft_timeout_ddl: Option, } -/// Store for tracking execution manager liveness state. -/// -/// Implementations persist execution manager heartbeat state durably and provide an atomic -/// operation to detect and mark disconnected execution managers as dead. -#[async_trait] -pub trait ExecutionManagerLivenessStore: Clone + Send + Sync { - /// Checks whether the execution manager with the given ID is alive. - /// - /// # Parameters - /// - /// * `id` - The execution manager ID to check. - /// - /// # Returns - /// - /// Whether the execution manager is alive on success. - /// - /// # Errors - /// - /// Returns an error if: - /// - /// * Forwards the underlying store's return values on failure. - async fn is_execution_manager_alive( - &self, - id: &ExecutionManagerId, - ) -> Result; - - /// Returns the IDs of execution managers whose last heartbeat is older than - /// `stale_after_sec`, after marking them dead. - /// - /// This operation is atomic: once an execution manager is returned by this method, it will not - /// be returned again in subsequent calls. - /// - /// # Parameters - /// - /// * `stale_after_sec` - The seconds after the last heartbeat which makes an execution manager - /// stale. - /// - /// # Returns - /// - /// A vector of dead execution manager IDs on success. - /// - /// # Errors - /// - /// Returns an error if: - /// - /// * Forwards the underlying store's return values on failure. - async fn get_dead_execution_managers( - &self, - stale_after_sec: u64, - ) -> Result, InternalError>; -} - /// Connector for creating and registering task instances in the task instance pool. /// /// This trait is invoked by the cache layer to allocate task instance IDs and register newly @@ -173,10 +121,15 @@ pub struct TaskInstancePoolHandle { } /// Configuration for a task instance pool actor. +/// +/// Controls GC timing, channel buffering, and execution manager staleness detection. #[derive(Debug, Clone, Copy)] pub struct TaskInstancePoolConfig { + /// Seconds without a heartbeat after which an execution manager is considered stale. pub execution_manager_stale_after_sec: u64, + /// Interval between GC cycles that check for dead execution managers. pub gc_interval: Duration, + /// Maximum number of pending registration messages in the pool channel. pub channel_size: usize, } @@ -202,7 +155,7 @@ impl Default for TaskInstancePoolConfig { /// A [`TaskInstancePoolHandle`] and the spawned actor's [`JoinHandle`]. pub fn create_task_instance_pool< ReadyQueueSenderType: ReadyQueueSender + 'static, - LivenessStoreType: ExecutionManagerLivenessStore + 'static, + LivenessStoreType: ExecutionManagerLivenessManagement + 'static, >( ready_queue_sender: ReadyQueueSenderType, execution_manager_liveness_store: LivenessStoreType, @@ -233,35 +186,6 @@ pub fn create_task_instance_pool< (handle, pool_task) } -#[async_trait] -impl ExecutionManagerLivenessStore for MariaDbStorageConnector { - async fn is_execution_manager_alive( - &self, - id: &ExecutionManagerId, - ) -> Result { - ExecutionManagerLivenessManagement::is_execution_manager_alive(self, *id) - .await - .map_err(|e| { - InternalError::TaskInstancePoolCorrupted(format!( - "failed to check execution manager liveness: {e}" - )) - }) - } - - async fn get_dead_execution_managers( - &self, - stale_after_sec: u64, - ) -> Result, InternalError> { - ExecutionManagerLivenessManagement::get_dead_execution_managers(self, stale_after_sec) - .await - .map_err(|e| { - InternalError::TaskInstancePoolCorrupted(format!( - "failed to get dead execution managers: {e}" - )) - }) - } -} - #[async_trait] impl TaskInstancePoolConnector for TaskInstancePoolHandle { fn get_next_available_task_instance_id(&self) -> TaskInstanceId { @@ -362,7 +286,7 @@ enum PoolMessage { /// * `LivenessStoreType` - The execution manager liveness store implementation. struct TaskInstancePool< ReadyQueueSenderType: ReadyQueueSender, - LivenessStoreType: ExecutionManagerLivenessStore, + LivenessStoreType: ExecutionManagerLivenessManagement, > { ready_queue_sender: ReadyQueueSenderType, execution_manager_liveness_store: LivenessStoreType, @@ -372,7 +296,7 @@ struct TaskInstancePool< receiver: mpsc::Receiver, } -impl +impl TaskInstancePool { /// Runs the coroutine loop, processing messages and GC timer ticks. @@ -433,15 +357,15 @@ impl Result<(), InternalError> { match message { PoolMessage::Register { tcb, metadata } => { - let em_id = &metadata.execution_manager_id; - if !self.execution_manager_pool.contains(em_id) { + let em_id = metadata.execution_manager_id; + if !self.execution_manager_pool.contains(&em_id) { if !self .execution_manager_liveness_store .is_execution_manager_alive(em_id) @@ -461,7 +385,7 @@ impl Result<(), InternalError> { let dead_em_ids: Vec = self @@ -587,7 +511,10 @@ impl>>, alive_call_count: Arc, } #[async_trait] - impl ExecutionManagerLivenessStore for MockExecutionManagerLivenessStore { + impl ExecutionManagerLivenessManagement for MockExecutionManagerLivenessManagement { + async fn register_execution_manager( + &self, + _ip_address: IpAddr, + ) -> Result { + unimplemented!("not needed by pool tests") + } + + async fn update_execution_manager_heartbeat( + &self, + _execution_manager_id: ExecutionManagerId, + ) -> Result<(), DbError> { + unimplemented!("not needed by pool tests") + } + async fn is_execution_manager_alive( &self, - _id: &ExecutionManagerId, - ) -> Result { + _execution_manager_id: ExecutionManagerId, + ) -> Result { self.alive_call_count .fetch_add(1, std::sync::atomic::Ordering::Relaxed); Ok(true) @@ -633,7 +574,7 @@ mod tests { async fn get_dead_execution_managers( &self, _stale_after_sec: u64, - ) -> Result, InternalError> { + ) -> Result, DbError> { Ok(self.dead_execution_managers.lock().await.clone()) } } @@ -692,23 +633,37 @@ mod tests { } } - /// A [`ExecutionManagerLivenessStore`] where all EMs are reported as dead. + /// A [`ExecutionManagerLivenessManagement`] where all EMs are reported as dead. #[derive(Clone, Default)] struct RejectAllLivenessStore; #[async_trait] - impl ExecutionManagerLivenessStore for RejectAllLivenessStore { + impl ExecutionManagerLivenessManagement for RejectAllLivenessStore { + async fn register_execution_manager( + &self, + _ip_address: IpAddr, + ) -> Result { + unimplemented!("not needed by pool tests") + } + + async fn update_execution_manager_heartbeat( + &self, + _execution_manager_id: ExecutionManagerId, + ) -> Result<(), DbError> { + unimplemented!("not needed by pool tests") + } + async fn is_execution_manager_alive( &self, - _id: &ExecutionManagerId, - ) -> Result { + _execution_manager_id: ExecutionManagerId, + ) -> Result { Ok(false) } async fn get_dead_execution_managers( &self, _stale_after_sec: u64, - ) -> Result, InternalError> { + ) -> Result, DbError> { Ok(Vec::new()) } } @@ -768,9 +723,9 @@ mod tests { /// sender is dropped immediately. fn build_test_pool( ready_queue_sender: MockReadyQueueSender, - liveness_store: MockExecutionManagerLivenessStore, + liveness_store: MockExecutionManagerLivenessManagement, execution_manager_stale_cutoff: Duration, - ) -> TaskInstancePool { + ) -> TaskInstancePool { let (_sender, receiver) = mpsc::channel(1); TaskInstancePool { ready_queue_sender, @@ -789,7 +744,7 @@ mod tests { /// /// The job ID assigned to the task, so callers can match it against re-enqueue messages. async fn register_task_in_pool( - pool: &mut TaskInstancePool, + pool: &mut TaskInstancePool, tcb: &SharedTaskControlBlock, task_id: TaskId, task_instance_id: TaskInstanceId, @@ -867,7 +822,7 @@ mod tests { #[tokio::test] async fn valid_em_is_cached_and_subsequent_registrations_skip_verify() { let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let cancellation_token = CancellationToken::new(); let (pool, pool_task) = create_task_instance_pool( ready_queue_sender, @@ -922,7 +877,7 @@ mod tests { let cancellation_token = tokio_util::sync::CancellationToken::new(); let (_pool, pool_task) = create_task_instance_pool( MockReadyQueueSender::default(), - MockExecutionManagerLivenessStore::default(), + MockExecutionManagerLivenessManagement::default(), cancellation_token.clone(), TaskInstancePoolConfig { execution_manager_stale_after_sec: 60, @@ -1046,7 +1001,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store, @@ -1090,7 +1045,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store, @@ -1148,7 +1103,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store.clone(), @@ -1206,7 +1161,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store.clone(), @@ -1264,7 +1219,7 @@ mod tests { // index 3: dead EM, terminated -> removed, no re-enqueue (terminal wins) // index 4: dead EM, on-going -> removed, re-enqueued let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store.clone(), From 6a3be1e5a6fd46705b8b646721b3ebab09086a21 Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Thu, 14 May 2026 14:15:04 -0400 Subject: [PATCH 3/8] Use u64 instead of duration --- .../spider-storage/src/task_instance_pool.rs | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index e8407810..dc251d57 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -127,8 +127,8 @@ pub struct TaskInstancePoolHandle { pub struct TaskInstancePoolConfig { /// Seconds without a heartbeat after which an execution manager is considered stale. pub execution_manager_stale_after_sec: u64, - /// Interval between GC cycles that check for dead execution managers. - pub gc_interval: Duration, + /// Interval in seconds between GC cycles that check for dead execution managers. + pub gc_interval: u64, /// Maximum number of pending registration messages in the pool channel. pub channel_size: usize, } @@ -137,7 +137,7 @@ impl Default for TaskInstancePoolConfig { fn default() -> Self { Self { execution_manager_stale_after_sec: 60, - gc_interval: Duration::from_secs(30), + gc_interval: 30, channel_size: 128, } } @@ -310,9 +310,9 @@ impl Result<(), InternalError> { - let mut gc_interval = tokio::time::interval(gc_interval); + let mut gc_interval = tokio::time::interval(Duration::from_secs(gc_interval)); // The first tick completes immediately; skip it so we don't GC right at startup. gc_interval.tick().await; @@ -324,6 +324,7 @@ impl { let Some(message) = message else { + // TODO: log this exit return Ok(()); }; self.handle_message(message).await?; @@ -781,7 +782,7 @@ mod tests { cancellation_token.clone(), TaskInstancePoolConfig { execution_manager_stale_after_sec: 60, - gc_interval: Duration::from_mins(1), + gc_interval: 60, channel_size: DEFAULT_CHANNEL_SIZE, }, ); @@ -830,7 +831,7 @@ mod tests { cancellation_token.clone(), TaskInstancePoolConfig { execution_manager_stale_after_sec: 60, - gc_interval: Duration::from_mins(1), + gc_interval: 60, channel_size: DEFAULT_CHANNEL_SIZE, }, ); @@ -881,7 +882,7 @@ mod tests { cancellation_token.clone(), TaskInstancePoolConfig { execution_manager_stale_after_sec: 60, - gc_interval: Duration::from_mins(1), + gc_interval: 60, channel_size: DEFAULT_CHANNEL_SIZE, }, ); @@ -905,7 +906,7 @@ mod tests { cancellation_token.clone(), TaskInstancePoolConfig { execution_manager_stale_after_sec: 60, - gc_interval: Duration::from_mins(1), + gc_interval: 60, channel_size: DEFAULT_CHANNEL_SIZE, }, ); @@ -982,7 +983,7 @@ mod tests { receiver, }; - pool.run(cancellation_token, Duration::from_mins(1)) + pool.run(cancellation_token, 60) .await .expect("pool should stop cleanly"); From 6bba976368322d12f9a42b44a1995491622e0e20 Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Thu, 14 May 2026 14:28:54 -0400 Subject: [PATCH 4/8] Fix test --- .../spider-storage/src/task_instance_pool.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index dc251d57..14dc0828 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -802,7 +802,7 @@ mod tests { pool.register_task_instance(tcb.clone(), metadata) .await - .unwrap(); + .expect("registration should be sent"); // Give the pool coroutine time to process the message. tokio::time::sleep(Duration::from_millis(100)).await; @@ -844,7 +844,9 @@ mod tests { execution_manager_id, SystemTime::now(), ); - pool.register_task_instance(tcb1, metadata1).await.unwrap(); + pool.register_task_instance(tcb1, metadata1) + .await + .expect("first registration should succeed"); let tcb2 = build_single_task_tcb().await; let metadata2 = make_task_instance_metadata( @@ -853,7 +855,9 @@ mod tests { execution_manager_id, SystemTime::now(), ); - pool.register_task_instance(tcb2, metadata2).await.unwrap(); + pool.register_task_instance(tcb2, metadata2) + .await + .expect("second registration should succeed"); // Give the pool coroutine time to process both messages. tokio::time::sleep(Duration::from_millis(100)).await; @@ -875,7 +879,7 @@ mod tests { #[tokio::test] async fn spawned_pool_exits_when_cancelled() { - let cancellation_token = tokio_util::sync::CancellationToken::new(); + let cancellation_token = CancellationToken::new(); let (_pool, pool_task) = create_task_instance_pool( MockReadyQueueSender::default(), MockExecutionManagerLivenessManagement::default(), @@ -899,7 +903,7 @@ mod tests { #[tokio::test] async fn spawned_pool_processes_registration_before_shutdown() { let ready_queue_sender = MockReadyQueueSender::default(); - let cancellation_token = tokio_util::sync::CancellationToken::new(); + let cancellation_token = CancellationToken::new(); let (pool, pool_task) = create_task_instance_pool( ready_queue_sender.clone(), RejectAllLivenessStore, From 5ccfa8de211f87b719a2921ca654d956e5315675 Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Thu, 14 May 2026 14:48:06 -0400 Subject: [PATCH 5/8] Rename variable --- components/spider-storage/src/state/server.rs | 19 +++++++++--------- .../spider-storage/src/task_instance_pool.rs | 20 +++++++++---------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/components/spider-storage/src/state/server.rs b/components/spider-storage/src/state/server.rs index dcc93e7d..492a4347 100644 --- a/components/spider-storage/src/state/server.rs +++ b/components/spider-storage/src/state/server.rs @@ -21,7 +21,7 @@ pub struct ServerRuntime { service_state: ServiceState, cancellation_token: CancellationToken, - task_instance_pool_task: JoinHandle>, + task_instance_pool_join_handle: JoinHandle>, } impl ServerRuntime { @@ -43,12 +43,13 @@ impl ServerRuntime { let session_id = db.session_id(); let (ready_queue_sender, ready_queue_receiver) = create_ready_queue(ReadyQueueConfig::default()).map_err(CacheError::from)?; - let (task_instance_pool_connector, task_instance_pool_task) = create_task_instance_pool( - ready_queue_sender.clone(), - db.clone(), - cancellation_token.clone(), - TaskInstancePoolConfig::default(), - ); + let (task_instance_pool_connector, task_instance_pool_join_handle) = + create_task_instance_pool( + ready_queue_sender.clone(), + db.clone(), + cancellation_token.clone(), + TaskInstancePoolConfig::default(), + ); let service_state = ServiceState::new( db, session_id, @@ -61,7 +62,7 @@ impl ServerRuntime { Ok(Self { service_state, cancellation_token, - task_instance_pool_task, + task_instance_pool_join_handle, }) } @@ -86,7 +87,7 @@ impl ServerRuntime { pub async fn stop_background_tasks(self) -> Result<(), StorageServerError> { stop_background_task( self.cancellation_token, - self.task_instance_pool_task, + self.task_instance_pool_join_handle, STOP_BACKGROUND_TASKS_TIMEOUT, ) .await diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index 14dc0828..4560d83f 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -176,14 +176,14 @@ pub fn create_task_instance_pool< execution_manager_pool: HashSet::new(), receiver, }; - let pool_task = + let pool_join_handle = tokio::spawn(async move { pool.run(cancellation_token, config.gc_interval).await }); let handle = TaskInstancePoolHandle { next_task_instance_id, sender, }; - (handle, pool_task) + (handle, pool_join_handle) } #[async_trait] @@ -776,7 +776,7 @@ mod tests { async fn dead_execution_manager_registration_triggers_recovery() { let ready_queue_sender = MockReadyQueueSender::default(); let cancellation_token = CancellationToken::new(); - let (pool, pool_task) = create_task_instance_pool( + let (pool, pool_join_handle) = create_task_instance_pool( ready_queue_sender.clone(), RejectAllLivenessStore, cancellation_token.clone(), @@ -813,7 +813,7 @@ mod tests { "task should be re-enqueued for dead EM, got: {messages:?}" ); cancellation_token.cancel(); - tokio::time::timeout(Duration::from_secs(1), pool_task) + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) .await .expect("pool task should exit before timeout") .expect("pool task should join successfully") @@ -825,7 +825,7 @@ mod tests { let ready_queue_sender = MockReadyQueueSender::default(); let liveness_store = MockExecutionManagerLivenessManagement::default(); let cancellation_token = CancellationToken::new(); - let (pool, pool_task) = create_task_instance_pool( + let (pool, pool_join_handle) = create_task_instance_pool( ready_queue_sender, liveness_store.clone(), cancellation_token.clone(), @@ -870,7 +870,7 @@ mod tests { "liveness store should be called exactly once for two registrations with the same EM" ); cancellation_token.cancel(); - tokio::time::timeout(Duration::from_secs(1), pool_task) + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) .await .expect("pool task should exit before timeout") .expect("pool task should join successfully") @@ -880,7 +880,7 @@ mod tests { #[tokio::test] async fn spawned_pool_exits_when_cancelled() { let cancellation_token = CancellationToken::new(); - let (_pool, pool_task) = create_task_instance_pool( + let (_pool, pool_join_handle) = create_task_instance_pool( MockReadyQueueSender::default(), MockExecutionManagerLivenessManagement::default(), cancellation_token.clone(), @@ -893,7 +893,7 @@ mod tests { cancellation_token.cancel(); - tokio::time::timeout(Duration::from_secs(1), pool_task) + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) .await .expect("pool task should exit before timeout") .expect("pool task should join successfully") @@ -904,7 +904,7 @@ mod tests { async fn spawned_pool_processes_registration_before_shutdown() { let ready_queue_sender = MockReadyQueueSender::default(); let cancellation_token = CancellationToken::new(); - let (pool, pool_task) = create_task_instance_pool( + let (pool, pool_join_handle) = create_task_instance_pool( ready_queue_sender.clone(), RejectAllLivenessStore, cancellation_token.clone(), @@ -933,7 +933,7 @@ mod tests { .expect("registration should be sent"); tokio::time::sleep(Duration::from_millis(100)).await; cancellation_token.cancel(); - tokio::time::timeout(Duration::from_secs(1), pool_task) + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) .await .expect("pool task should exit before timeout") .expect("pool task should join successfully") From e772967069c39803a14a2fea526fc35de4da536d Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Thu, 14 May 2026 15:42:25 -0400 Subject: [PATCH 6/8] Refactor server runtime --- components/spider-storage/src/state.rs | 2 +- components/spider-storage/src/state/server.rs | 258 +++++++++++------- 2 files changed, 162 insertions(+), 98 deletions(-) diff --git a/components/spider-storage/src/state.rs b/components/spider-storage/src/state.rs index 1afe82fa..7610878c 100644 --- a/components/spider-storage/src/state.rs +++ b/components/spider-storage/src/state.rs @@ -5,7 +5,7 @@ pub mod service; pub use error::StorageServerError; pub use job_cache::JobCache; -pub use server::ServerRuntime; +pub use server::{ServerRuntime, create_server_runtime}; pub use service::ServiceState; #[cfg(test)] diff --git a/components/spider-storage/src/state/server.rs b/components/spider-storage/src/state/server.rs index 492a4347..6ad6e2ae 100644 --- a/components/spider-storage/src/state/server.rs +++ b/components/spider-storage/src/state/server.rs @@ -1,69 +1,116 @@ -use std::time::Duration; - use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use crate::{ cache::error::{CacheError, InternalError}, config::DatabaseConfig, - db::{MariaDbStorageConnector, SessionManagement}, - ready_queue::{ReadyQueueConfig, ReadyQueueSenderHandle, create_ready_queue}, + db::{DbStorage, MariaDbStorageConnector, SessionManagement}, + ready_queue::{ReadyQueueConfig, ReadyQueueSender, ReadyQueueSenderHandle, create_ready_queue}, state::{JobCache, ServiceState, StorageServerError}, task_instance_pool::{ TaskInstancePoolConfig, + TaskInstancePoolConnector, TaskInstancePoolHandle, create_task_instance_pool, }, }; -/// Production per-process storage server runtime. -pub struct ServerRuntime { +/// Per-process storage server runtime. +/// +/// # Type Parameters +/// +/// * `ReadyQueueSenderType` - The ready queue sender type. +/// * `DbConnectorType` - The database connector type. +/// * `TaskInstancePoolConnectorType` - The task instance pool connector type. +pub struct ServerRuntime< + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +> { service_state: - ServiceState, + ServiceState, cancellation_token: CancellationToken, task_instance_pool_join_handle: JoinHandle>, + stop_timeout_sec: u64, } -impl ServerRuntime { - /// Creates a storage server runtime from the database configuration. - /// - /// # Returns - /// - /// A newly created [`ServerRuntime`] on success. +/// Creates a storage server runtime from the database configuration. +/// +/// # Returns +/// +/// A newly created [`ServerRuntime`] on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`MariaDbStorageConnector::connect`]'s return values on failure. +/// * Forwards [`create_ready_queue`]'s return values on failure. +pub async fn create_server_runtime( + db_config: &DatabaseConfig, +) -> Result< + ServerRuntime, + StorageServerError, +> { + let cancellation_token = CancellationToken::new(); + let db = MariaDbStorageConnector::connect(db_config).await?; + let session_id = db.session_id(); + let (ready_queue_sender, ready_queue_receiver) = + create_ready_queue(ReadyQueueConfig::default()).map_err(CacheError::from)?; + let (task_instance_pool_connector, task_instance_pool_join_handle) = create_task_instance_pool( + ready_queue_sender.clone(), + db.clone(), + cancellation_token.clone(), + TaskInstancePoolConfig::default(), + ); + let service_state = ServiceState::new( + db, + session_id, + JobCache::new(), + ready_queue_sender, + ready_queue_receiver, + task_instance_pool_connector, + ); + + Ok(ServerRuntime { + service_state, + cancellation_token, + task_instance_pool_join_handle, + stop_timeout_sec: STOP_BACKGROUND_TASKS_TIMEOUT_SEC, + }) +} + +impl + ServerRuntime +where + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +{ + /// Stops background tasks owned by the runtime. /// /// # Errors /// /// Returns an error if: /// - /// * Forwards [`MariaDbStorageConnector::connect`]'s return values on failure. - /// * Forwards [`create_ready_queue`]'s return values on failure. - pub async fn create(db_config: &DatabaseConfig) -> Result { - let cancellation_token = CancellationToken::new(); - let db = MariaDbStorageConnector::connect(db_config).await?; - let session_id = db.session_id(); - let (ready_queue_sender, ready_queue_receiver) = - create_ready_queue(ReadyQueueConfig::default()).map_err(CacheError::from)?; - let (task_instance_pool_connector, task_instance_pool_join_handle) = - create_task_instance_pool( - ready_queue_sender.clone(), - db.clone(), - cancellation_token.clone(), - TaskInstancePoolConfig::default(), - ); - let service_state = ServiceState::new( - db, - session_id, - JobCache::new(), - ready_queue_sender, - ready_queue_receiver, - task_instance_pool_connector, - ); - - Ok(Self { - service_state, - cancellation_token, - task_instance_pool_join_handle, - }) + /// * [`StorageServerError::Stopping`] if the task instance pool does not stop before timeout. + /// * [`StorageServerError::Cache`] if the task instance pool task fails or cannot be joined. + pub async fn stop_background_tasks(self) -> Result<(), StorageServerError> { + self.cancellation_token.cancel(); + let join_result = tokio::time::timeout( + std::time::Duration::from_secs(self.stop_timeout_sec), + self.task_instance_pool_join_handle, + ) + .await + .map_err(|_| { + StorageServerError::Stopping("task instance pool stop timed out".to_owned()) + })?; + let pool_result = join_result.map_err(|e| { + StorageServerError::Cache(CacheError::Internal( + InternalError::TaskInstancePoolCorrupted(format!("task join error: {e}")), + )) + })?; + pool_result.map_err(|e| StorageServerError::Cache(CacheError::Internal(e))) } /// # Returns @@ -72,81 +119,93 @@ impl ServerRuntime { #[must_use] pub fn service_state( &self, - ) -> ServiceState { + ) -> ServiceState { self.service_state.clone() } - - /// Stops background tasks owned by the runtime. - /// - /// # Errors - /// - /// Returns an error if: - /// - /// * [`StorageServerError::Stopping`] if the task instance pool does not stop before timeout. - /// * [`StorageServerError::Cache`] if the task instance pool task fails or cannot be joined. - pub async fn stop_background_tasks(self) -> Result<(), StorageServerError> { - stop_background_task( - self.cancellation_token, - self.task_instance_pool_join_handle, - STOP_BACKGROUND_TASKS_TIMEOUT, - ) - .await - } } -const STOP_BACKGROUND_TASKS_TIMEOUT: Duration = Duration::from_secs(30); - -/// Stops a single cancellation-token-controlled background task. -/// -/// # Errors -/// -/// Returns an error if: -/// -/// * [`StorageServerError::Stopping`] if the task does not stop before `timeout`. -/// * [`StorageServerError::Cache`] if the task fails or cannot be joined. -async fn stop_background_task( - cancellation_token: CancellationToken, - task: JoinHandle>, - timeout: Duration, -) -> Result<(), StorageServerError> { - cancellation_token.cancel(); - let join_result = tokio::time::timeout(timeout, task).await.map_err(|_| { - StorageServerError::Stopping("task instance pool stop timed out".to_owned()) - })?; - let pool_result = join_result.map_err(|e| { - StorageServerError::Cache(CacheError::Internal( - InternalError::TaskInstancePoolCorrupted(format!("task join error: {e}")), - )) - })?; - pool_result.map_err(|e| StorageServerError::Cache(CacheError::Internal(e))) -} +const STOP_BACKGROUND_TASKS_TIMEOUT_SEC: u64 = 30; #[cfg(test)] mod tests { - use super::*; + use std::time::Duration; + + use tokio::task::JoinHandle; + use tokio_util::sync::CancellationToken; + + use super::ServerRuntime; + use crate::{ + cache::error::InternalError, + db::SessionManagement, + ready_queue::{ReadyQueueConfig, ReadyQueueSenderHandle, create_ready_queue}, + state::{ + JobCache, + ServiceState, + StorageServerError, + test_utils::{MockDbConnector, MockTaskInstancePoolConnector}, + }, + }; + + type TestServerRuntime = + ServerRuntime; + + fn create_test_server_runtime( + cancellation_token: CancellationToken, + task: JoinHandle>, + stop_timeout_sec: u64, + ) -> TestServerRuntime { + let db = MockDbConnector::default(); + let session_id = db.session_id(); + let (sender, receiver) = + create_ready_queue(ReadyQueueConfig::default()).expect("ready queue creation"); + let service_state = ServiceState::new( + db, + session_id, + JobCache::new(), + sender, + receiver, + MockTaskInstancePoolConnector, + ); + + ServerRuntime { + service_state, + cancellation_token, + task_instance_pool_join_handle: task, + stop_timeout_sec, + } + } #[tokio::test] - async fn stop_background_task_cancels_and_joins_task() -> anyhow::Result<()> { + async fn stop_background_tasks_cancels_and_joins_task() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); let task_cancellation_token = cancellation_token.clone(); - let task = tokio::spawn(async move { + let task: JoinHandle> = tokio::spawn(async move { task_cancellation_token.cancelled().await; Ok(()) }); - stop_background_task(cancellation_token, task, Duration::from_secs(1)).await?; + let runtime = create_test_server_runtime( + cancellation_token, + task, + super::STOP_BACKGROUND_TASKS_TIMEOUT_SEC, + ); + runtime + .stop_background_tasks() + .await + .expect("stop_background_tasks should succeed"); Ok(()) } #[tokio::test] - async fn stop_background_task_returns_stopping_on_timeout() -> anyhow::Result<()> { + async fn stop_background_tasks_returns_stopping_on_timeout() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); - let task = tokio::spawn(async move { - tokio::time::sleep(Duration::from_mins(1)).await; + let task: JoinHandle> = tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(10)).await; Ok(()) }); - let result = stop_background_task(cancellation_token, task, Duration::from_millis(1)).await; + let runtime = create_test_server_runtime(cancellation_token, task, 0); + let result = runtime.stop_background_tasks().await; assert!( matches!(result, Err(StorageServerError::Stopping(_))), @@ -156,15 +215,20 @@ mod tests { } #[tokio::test] - async fn stop_background_task_returns_cache_error_on_pool_error() -> anyhow::Result<()> { + async fn stop_background_tasks_returns_cache_error_on_pool_error() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); - let task = tokio::spawn(async move { + let task: JoinHandle> = tokio::spawn(async move { Err(InternalError::TaskInstancePoolCorrupted( "test failure".to_owned(), )) }); - let result = stop_background_task(cancellation_token, task, Duration::from_secs(1)).await; + let runtime = create_test_server_runtime( + cancellation_token, + task, + super::STOP_BACKGROUND_TASKS_TIMEOUT_SEC, + ); + let result = runtime.stop_background_tasks().await; assert!( matches!(result, Err(StorageServerError::Cache(_))), From 5b8eccf988d58e6f3c01fe99da1ddea06c0866c0 Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Thu, 14 May 2026 15:52:27 -0400 Subject: [PATCH 7/8] Fix import --- components/spider-storage/src/state/server.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/components/spider-storage/src/state/server.rs b/components/spider-storage/src/state/server.rs index 6ad6e2ae..deebb68d 100644 --- a/components/spider-storage/src/state/server.rs +++ b/components/spider-storage/src/state/server.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; @@ -98,7 +100,7 @@ where pub async fn stop_background_tasks(self) -> Result<(), StorageServerError> { self.cancellation_token.cancel(); let join_result = tokio::time::timeout( - std::time::Duration::from_secs(self.stop_timeout_sec), + Duration::from_secs(self.stop_timeout_sec), self.task_instance_pool_join_handle, ) .await From 76dafa0f7dc6c98b6fb709bbc0a024e820ded2f2 Mon Sep 17 00:00:00 2001 From: Sitao Wang Date: Thu, 14 May 2026 16:20:00 -0400 Subject: [PATCH 8/8] Address comment --- components/spider-storage/src/cache/error.rs | 3 ++ components/spider-storage/src/state/server.rs | 21 +++++----- .../spider-storage/src/task_instance_pool.rs | 38 +++++++++++++++++++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs index af1ae54c..5f2f2fb2 100644 --- a/components/spider-storage/src/cache/error.rs +++ b/components/spider-storage/src/cache/error.rs @@ -78,6 +78,9 @@ pub enum InternalError { #[error("invalid config: {0}")] ReadyQueueInvalidConfig(&'static str), + #[error("invalid config: {0}")] + TaskInstancePoolInvalidConfig(&'static str), + #[error("ready queue channel is closed")] ReadyQueueChannelClosed, diff --git a/components/spider-storage/src/state/server.rs b/components/spider-storage/src/state/server.rs index deebb68d..bc9aa829 100644 --- a/components/spider-storage/src/state/server.rs +++ b/components/spider-storage/src/state/server.rs @@ -97,17 +97,18 @@ where /// /// * [`StorageServerError::Stopping`] if the task instance pool does not stop before timeout. /// * [`StorageServerError::Cache`] if the task instance pool task fails or cannot be joined. - pub async fn stop_background_tasks(self) -> Result<(), StorageServerError> { + pub async fn stop_background_tasks(mut self) -> Result<(), StorageServerError> { self.cancellation_token.cancel(); - let join_result = tokio::time::timeout( - Duration::from_secs(self.stop_timeout_sec), - self.task_instance_pool_join_handle, - ) - .await - .map_err(|_| { - StorageServerError::Stopping("task instance pool stop timed out".to_owned()) - })?; - let pool_result = join_result.map_err(|e| { + let result = tokio::select! { + result = &mut self.task_instance_pool_join_handle => result, + () = tokio::time::sleep(Duration::from_secs(self.stop_timeout_sec)) => { + self.task_instance_pool_join_handle.abort(); + return Err(StorageServerError::Stopping( + "task instance pool stop timed out".to_owned(), + )); + } + }; + let pool_result = result.map_err(|e| { StorageServerError::Cache(CacheError::Internal( InternalError::TaskInstancePoolCorrupted(format!("task join error: {e}")), )) diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index 4560d83f..24c914b7 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -133,6 +133,44 @@ pub struct TaskInstancePoolConfig { pub channel_size: usize, } +impl TaskInstancePoolConfig { + /// Creates a new [`TaskInstancePoolConfig`] with validation. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * `execution_manager_stale_after_sec` is zero. + /// * `gc_interval` is zero. + /// * `channel_size` is zero. + pub const fn new( + execution_manager_stale_after_sec: u64, + gc_interval: u64, + channel_size: usize, + ) -> Result { + if execution_manager_stale_after_sec == 0 { + return Err(InternalError::TaskInstancePoolInvalidConfig( + "execution_manager_stale_after_sec must be greater than zero", + )); + } + if gc_interval == 0 { + return Err(InternalError::TaskInstancePoolInvalidConfig( + "gc_interval must be greater than zero", + )); + } + if channel_size == 0 { + return Err(InternalError::TaskInstancePoolInvalidConfig( + "channel_size must be greater than zero", + )); + } + Ok(Self { + execution_manager_stale_after_sec, + gc_interval, + channel_size, + }) + } +} + impl Default for TaskInstancePoolConfig { fn default() -> Self { Self {