From 92470aeb18d45c5653071c30c211712a677b8185 Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Thu, 12 Mar 2026 20:50:37 -0400 Subject: [PATCH 1/8] WIP --- Cargo.lock | 18 ++ components/spider-core/src/job.rs | 1 + components/spider-core/src/types/id.rs | 6 +- components/spider-core/src/types/io.rs | 11 +- components/spider-storage/Cargo.toml | 3 + components/spider-storage/src/cache.rs | 1 + components/spider-storage/src/cache/task.rs | 243 ++++++++++++++++++++ components/spider-storage/src/lib.rs | 1 + components/spider-storage/src/protocol.rs | 155 ------------- 9 files changed, 277 insertions(+), 162 deletions(-) create mode 100644 components/spider-storage/src/cache.rs create mode 100644 components/spider-storage/src/cache/task.rs diff --git a/Cargo.lock b/Cargo.lock index 7f976ea1..3b982aa4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,6 +113,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + [[package]] name = "proc-macro2" version = "1.0.106" @@ -231,7 +237,10 @@ name = "spider-storage" version = "0.1.0" dependencies = [ "async-trait", + "serde", "spider-core", + "thiserror", + "tokio", ] [[package]] @@ -294,6 +303,15 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "pin-project-lite", +] + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/components/spider-core/src/job.rs b/components/spider-core/src/job.rs index a8c74dfb..9b093d45 100644 --- a/components/spider-core/src/job.rs +++ b/components/spider-core/src/job.rs @@ -2,6 +2,7 @@ pub struct Job {} /// Enum for all possible states of a job. +#[derive(Debug)] pub enum JobState { Running, PendingRetry, diff --git a/components/spider-core/src/types/id.rs b/components/spider-core/src/types/id.rs index a91099c0..8e8b115b 100644 --- a/components/spider-core/src/types/id.rs +++ b/components/spider-core/src/types/id.rs @@ -66,9 +66,7 @@ pub type WorkerId = Id; pub enum SchedulerIdMarker {} pub type SchedulerId = Id; -#[derive(Debug, PartialEq, Eq)] -pub enum TaskInstanceIdMarker {} -pub type TaskInstanceId = Id; +pub type TaskInstanceId = u64; /// Represents a signed ID. /// @@ -122,8 +120,6 @@ pub type SignedJobId = SignedId; pub type SignedTaskId = SignedId; -pub type SignedTaskInstanceId = SignedId; - #[cfg(test)] mod tests { use std::any::TypeId; diff --git a/components/spider-core/src/types/io.rs b/components/spider-core/src/types/io.rs index 37ab3f0f..1ddd70c8 100644 --- a/components/spider-core/src/types/io.rs +++ b/components/spider-core/src/types/io.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; + /// Represents a value object. pub struct Value {} @@ -5,7 +7,12 @@ pub struct Value {} pub struct Data {} /// Represents an input of a task. -pub struct TaskInput {} +#[derive(Serialize, Deserialize, Debug)] +pub enum TaskInput { + ValuePayload(Vec), +} /// Represents an output of a task. -pub struct TaskOutput {} +pub enum TaskOutput { + ValuePayload(Vec), +} diff --git a/components/spider-storage/Cargo.toml b/components/spider-storage/Cargo.toml index f086a955..528e88f7 100644 --- a/components/spider-storage/Cargo.toml +++ b/components/spider-storage/Cargo.toml @@ -10,3 +10,6 @@ path = "src/lib.rs" [dependencies] async-trait = "0.1.89" spider-core = { path = "../spider-core" } +thiserror = "2.0.18" +tokio = { version = "1.49.0", features = ["rt-multi-thread", "sync"] } +serde = { version = "1.0.228", features = ["derive"] } \ No newline at end of file diff --git a/components/spider-storage/src/cache.rs b/components/spider-storage/src/cache.rs new file mode 100644 index 00000000..124563c5 --- /dev/null +++ b/components/spider-storage/src/cache.rs @@ -0,0 +1 @@ +mod task; \ No newline at end of file diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs new file mode 100644 index 00000000..ecdfa6e1 --- /dev/null +++ b/components/spider-storage/src/cache/task.rs @@ -0,0 +1,243 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use serde::Serialize; +use spider_core::job::JobState; +use spider_core::task::{DataflowDependencyIndex, Task, TaskIndex}; +use spider_core::types::id::{JobId, TaskInstanceId}; +use spider_core::types::io::{TaskInput, TaskOutput}; + +/// Enum for all possible states of a task. +#[derive(Eq, PartialEq, Debug, Clone)] +pub enum TaskState { + Pending, + Ready, + Running, + Succeeded, + Failed(String), + Cancelled, +} + +impl TaskState { + pub fn is_terminal(&self) -> bool { + matches!(self, TaskState::Succeeded | TaskState::Failed(_) | TaskState::Cancelled) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("invalid task output")] + InvalidTaskOutput, + + #[error("task output already written")] + TaskOutputDuplicateWrite, + + #[error("task input not ready")] + TaskInputNotReady, + + #[error("task outputs length mismatch: expected {0}, got {1}")] + TaskOutputsLengthMismatch(usize, usize), + + #[error("task index {0} is out of bounds")] + TaskIndexOutOfBound(TaskIndex), + + #[error("task is already in a terminal state: {0:?}")] + TaskAlreadyTerminal(TaskState), + + #[error("job is already in a terminal state: {0:?}")] + JobAlreadyTerminal(JobState), + + #[error("task is still pending")] + TaskStillPending, + + #[error("task instance {0} is not registered")] + TaskInstanceNotRegistered(TaskInstanceId), + + #[error("failed to send ready task to the queue: {0}")] + TokioSendError(#[from] tokio::sync::mpsc::error::SendError<(JobId, TaskIndex)>), +} + +#[derive(Serialize, Clone)] +pub struct TdlContext { + package: String, + func: String, +} + +#[derive(Serialize)] +pub struct ExecutionContext { + pub task_instance_id: TaskInstanceId, + pub tdl_context: TdlContext, + pub inputs: Vec, +} + +/// Internal representation of a data dependency. +enum Data { + Value(Option>), + Channel, +} + +/// A shareable reference to a data object, allowing multiple tasks to read/write the same data +/// concurrently. +struct DataRef { + data: Arc>, +} + +impl DataRef { + fn new_value(value: Vec) -> Self { + Self { + data: Arc::new(std::sync::RwLock::new(Data::Value(Some(value)))), + } + } + + fn new_null_value() -> Self { + Self { + data: Arc::new(std::sync::RwLock::new(Data::Value(None))), + } + } + + fn write_task_output(&self, task_output: TaskOutput) -> Result<(), Error> { + match task_output { + TaskOutput::ValuePayload(payload) => { + match &mut *self.data.write().expect("rw lock poisoned") { + Data::Value(optional_value) => { + if optional_value.is_some() { + return Err(Error::TaskOutputDuplicateWrite); + } + *optional_value = Some(payload); + } + Data::Channel => { + return Err(Error::InvalidTaskOutput); + } + } + } + } + Ok(()) + } + + fn as_task_input(&self) -> Result { + match &*self.data.read().expect("rw lock poisoned") { + Data::Value(optional_value) => { + Ok(TaskInput::ValuePayload(optional_value.clone().ok_or(Error::TaskInputNotReady)?)) + } + Data::Channel => { + Err(Error::InvalidTaskOutput) + } + } + } +} + +struct TaskMetadata { + state: TaskState, + tdl_context: TdlContext, + registered_instances: HashSet, + num_unfinished_parents: usize, + inputs: Vec, + outputs: Vec, + children: Vec, +} + +impl TaskMetadata { + fn register(&mut self, task_instance_id: TaskInstanceId) -> Result { + if self.state.is_terminal() { + return Err(Error::TaskAlreadyTerminal(self.state.clone())); + } + if self.state != TaskState::Ready || self.state != TaskState::Running { + return Err(Error::TaskStillPending); + } + self.state = TaskState::Running; + self.registered_instances.insert(task_instance_id); + Ok(ExecutionContext { + task_instance_id, + tdl_context: self.tdl_context.clone(), + inputs: self.fetch_inputs()?, + }) + } + + fn complete(&mut self, task_instance_id: TaskInstanceId, task_outputs: Vec) -> Result<(), Error> { + if !self.registered_instances.contains(&task_instance_id) { + return Err(Error::TaskInstanceNotRegistered(task_instance_id)); + } + if self.state.is_terminal() { + return Err(Error::TaskAlreadyTerminal(self.state.clone())); + } + self.write_outputs(task_outputs)?; + self.state = TaskState::Succeeded; + Ok(()) + } + + fn write_outputs(&self, task_outputs: Vec) -> Result<(), Error> { + if task_outputs.len() != self.outputs.len() { + return Err(Error::TaskOutputsLengthMismatch(self.outputs.len(), task_outputs.len())); + } + for (output_ref, output) in self.outputs.iter().zip(task_outputs.into_iter()) { + output_ref.write_task_output(output)?; + } + Ok(()) + } + + fn fetch_inputs(&self) -> Result, Error> { + self.inputs.iter().map(|input_ref| input_ref.as_task_input()).collect() + } +} + +struct TaskGraph { + tasks: Vec>, +} + +struct JobMetadata { + state: JobState, + task_graph: TaskGraph, + num_unfinished_tasks: AtomicUsize, +} + +pub struct Job { + id: JobId, + metadata: std::sync::RwLock, + ready_queue_sender: tokio::sync::mpsc::Sender<(JobId, TaskIndex)>, +} + +impl Job { + pub fn register_task_instance(&self, task_instance_id: TaskInstanceId, task_index: TaskIndex) -> Result { + let job_metadata = self.metadata.read().expect("rw lock poisoned"); + let mut task_metadata = job_metadata.task_graph.tasks.get(task_index) + .ok_or(Error::TaskIndexOutOfBound(task_index))? + .lock() + .expect("mutex poisoned"); + task_metadata.register(task_instance_id) + } + + pub async fn complete_task_instance(&self, task_instance_id: TaskInstanceId, task_index: TaskIndex, task_outputs: Vec) -> Result<(), Error> { + let job_metadata = self.metadata.read().expect("rw lock poisoned"); + + // Update the task metadata + let mut task_metadata = job_metadata.task_graph.tasks.get(task_index) + .ok_or(Error::TaskIndexOutOfBound(task_index))? + .lock() + .expect("mutex poisoned"); + task_metadata.complete(task_instance_id, task_outputs)?; + for child_idx in &task_metadata.children { + let mut child_metadata = job_metadata.task_graph.tasks.get(*child_idx) + .ok_or(Error::TaskIndexOutOfBound(*child_idx))? + .lock() + .expect("mutex poisoned"); + child_metadata.num_unfinished_parents -= 1; + if child_metadata.num_unfinished_parents == 0 { + child_metadata.state = TaskState::Ready; + self.ready_queue_sender.send((self.id, *child_idx)).await?; + } + } + let num_unfinished_tasks = job_metadata.num_unfinished_tasks.fetch_sub(1, std::sync::atomic::Ordering::SeqCst) - 1; + drop(task_metadata); + drop(job_metadata); + + if num_unfinished_tasks > 0 { + return Ok(()); + } + + // Atomic decrement guarantees that only one thread's control flow can reach here. + let job_metadata = self.metadata.write().expect("rw lock poisoned"); + + + Ok(()) + } +} diff --git a/components/spider-storage/src/lib.rs b/components/spider-storage/src/lib.rs index 595d39d6..1962873a 100644 --- a/components/spider-storage/src/lib.rs +++ b/components/spider-storage/src/lib.rs @@ -1,4 +1,5 @@ mod error; pub mod protocol; +pub mod cache; pub use error::StorageError; diff --git a/components/spider-storage/src/protocol.rs b/components/spider-storage/src/protocol.rs index 1d936fbc..0e726352 100644 --- a/components/spider-storage/src/protocol.rs +++ b/components/spider-storage/src/protocol.rs @@ -194,161 +194,6 @@ pub trait JobOrchestration { async fn delete_job(&self, signed_id: SignedJobId) -> Result<(), StorageError>; } -/// Defines the storage interface for task orchestration. -/// -/// In the Spider scheduling framework, every task is associated with a resource group ID. -/// Orchestration operations may only be performed when the provided resource group ID matches the -/// one associated with the target task. -/// -/// # NOTE -/// -/// All operations defined by this trait **must be transactional**. Implementations are required to -/// guarantee atomicity and consistency for each operation. -#[async_trait] -pub trait TaskOrchestration { - /// Retrieves the input data for a task. - /// - /// # Parameters - /// - /// * `signed_id` - The signed ID of the target task. - /// - /// # Returns - /// - /// A vector of task inputs on success. - /// - /// # Errors - /// - /// Returns a [`StorageError`] instance indicating the failures. - /// - /// Implementations **must document** the specific error variants they may return and the - /// conditions under which those errors occur. - async fn get_task_inputs( - &self, - signed_id: SignedTaskId, - ) -> Result, StorageError>; - - /// Retrieves the output data for a task. - /// - /// # Parameters - /// - /// * `signed_id` - The signed ID of the target task. - /// - /// # Returns - /// - /// A vector of task outputs on success. - /// - /// # Errors - /// - /// Returns a [`StorageError`] instance indicating the failures. - /// - /// Implementations **must document** the specific error variants they may return and the - /// conditions under which those errors occur. - async fn get_task_outputs( - &self, - signed_id: SignedTaskId, - ) -> Result, StorageError>; - - /// Creates a new task instance for execution. - /// - /// This method is typically invoked by the scheduler when a task is ready to be executed. If - /// the task is in [`TaskState::Ready`], this method will transition it to - /// [`TaskState::Running`]. - /// - /// # Parameters - /// - /// * `signed_id` - The signed ID of the target task. - /// - /// # Returns - /// - /// The ID of the created task instance on success. - /// - /// # Errors - /// - /// Returns a [`StorageError`] instance indicating the failures. - /// - /// Implementations **must document** the specific error variants they may return and the - /// conditions under which those errors occur. - async fn create_task_instance( - &self, - signed_id: SignedTaskId, - ) -> Result; - - /// Completes a task instance and uploads its outputs. - /// - /// On success, this method will transition the task instance to [`TaskState::Succeeded`]. - /// - /// # Parameters - /// - /// * `signed_id` - The signed ID of the target task instance. - /// * `outputs` - A vector of task outputs produced by the completed task instance. - /// - /// # Returns - /// - /// `Ok(())` on success. - /// - /// # Errors - /// - /// Returns a [`StorageError`] instance indicating the failures. - /// - /// Implementations **must document** the specific error variants they may return and the - /// conditions under which those errors occur. - async fn complete_task_instance( - &self, - signed_id: SignedTaskInstanceId, - outputs: Vec, - ) -> Result<(), StorageError>; - - /// Cancels a task instance. - /// - /// If the cancelled instance is the only task instance associated with the task, this method - /// will also transition the task to [`TaskState::Cancelled`]. - /// - /// # Parameters - /// - /// * `signed_id` - The signed ID of the target task instance. - /// - /// # Returns - /// - /// `Ok(())` on success. - /// - /// # Errors - /// - /// Returns a [`StorageError`] instance indicating the failures. - /// - /// Implementations **must document** the specific error variants they may return and the - /// conditions under which those errors occur. - async fn cancel_task_instance( - &self, - signed_id: SignedTaskInstanceId, - ) -> Result<(), StorageError>; - - /// Marks a task instance as failed and records the error message. - /// - /// If the failed instance is the only task instance associated with the task, this method - /// will also transition the task to [`TaskState::Failed`]. - /// - /// # Parameters - /// - /// * `signed_id` - The signed ID of the target task instance. - /// * `error_message` - A description of the error that caused the task instance to fail. - /// - /// # Returns - /// - /// `Ok(())` on success. - /// - /// # Errors - /// - /// Returns a [`StorageError`] instance indicating the failures. - /// - /// Implementations **must document** the specific error variants they may return and the - /// conditions under which those errors occur. - async fn fail_task_instance( - &self, - signed_id: SignedTaskInstanceId, - error_message: String, - ) -> Result<(), StorageError>; -} - /// Defines the storage interface for data management. /// /// In the Spider scheduling framework, a data object is a shareable value holder that can be shared From b622746dab6f58e158d4c5f5fdc15fbc7558a59d Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Sat, 14 Mar 2026 22:38:28 -0400 Subject: [PATCH 2/8] WIP. --- components/spider-core/src/task.rs | 12 +- components/spider-core/src/types/io.rs | 4 +- components/spider-storage/src/cache.rs | 3 + components/spider-storage/src/cache/error.rs | 67 +++ components/spider-storage/src/cache/job.rs | 96 ++++ components/spider-storage/src/cache/task.rs | 493 +++++++++++-------- components/spider-storage/src/cache/types.rs | 60 +++ 7 files changed, 537 insertions(+), 198 deletions(-) create mode 100644 components/spider-storage/src/cache/error.rs create mode 100644 components/spider-storage/src/cache/job.rs create mode 100644 components/spider-storage/src/cache/types.rs diff --git a/components/spider-core/src/task.rs b/components/spider-core/src/task.rs index 41a58be0..6eaf94d9 100644 --- a/components/spider-core/src/task.rs +++ b/components/spider-core/src/task.rs @@ -24,8 +24,9 @@ pub enum Error { } /// Enum for all possible states of a task. +#[derive(Eq, PartialEq, Debug, Clone)] pub enum TaskState { - PENDING, + Pending, Ready, Running, Succeeded, @@ -33,5 +34,14 @@ pub enum TaskState { Cancelled, } +impl TaskState { + pub fn is_terminal(&self) -> bool { + matches!( + self, + TaskState::Succeeded | TaskState::Failed(_) | TaskState::Cancelled + ) + } +} + /// Represents metadata associated with a task. pub struct TaskMetadata {} diff --git a/components/spider-core/src/types/io.rs b/components/spider-core/src/types/io.rs index 1ddd70c8..4df423f9 100644 --- a/components/spider-core/src/types/io.rs +++ b/components/spider-core/src/types/io.rs @@ -13,6 +13,4 @@ pub enum TaskInput { } /// Represents an output of a task. -pub enum TaskOutput { - ValuePayload(Vec), -} +pub type TaskOutput = Vec; diff --git a/components/spider-storage/src/cache.rs b/components/spider-storage/src/cache.rs index 297403ba..5a2872f5 100644 --- a/components/spider-storage/src/cache.rs +++ b/components/spider-storage/src/cache.rs @@ -1 +1,4 @@ +pub mod error; +mod job; mod task; +mod types; diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs new file mode 100644 index 00000000..3a5afbdb --- /dev/null +++ b/components/spider-storage/src/cache/error.rs @@ -0,0 +1,67 @@ +use spider_core::{ + task::{TaskIndex, TaskState}, + types::id::JobId, +}; + +/// Enums for all possible errors that can happen in the cache. +pub enum CacheError { + Internal(InternalError), + Rejection(RejectionError), +} + +/// Enums for all internal errors. When these error happens, it is considered that the system is in +/// an inconsistent state and cannot continue to service requests. A restart is needed to recover +/// the cache from the storage. +#[derive(thiserror::Error, Debug)] +pub enum InternalError { + #[error("task output already written by a previous successful task instance")] + TaskOutputDuplicateWrite, + + #[error("task input not ready when attempting to register a task instance")] + TaskInputNotReady, + + #[error("out-of-bound task access detected")] + TaskIndexOutOfBound, + + #[error("task not ready when attempting to register a task instance")] + TaskNotReady, + + #[error("task graph corrupted: {0}")] + TaskGraphCorrupted(String), + + #[error("failed to send scheduling context into the channel")] + TokioSendError(#[from] tokio::sync::mpsc::error::SendError<(JobId, TaskIndex)>), + + #[error("task outputs length mismatch: expected {0}, got {1}")] + TaskOutputsLengthMismatch(usize, usize), +} + +impl From for CacheError { + fn from(e: InternalError) -> Self { + CacheError::Internal(e) + } +} + +/// Enums for all rejection errors. When these error happens, it is considered that the request is +/// valid, but cannot be processed due to the current state of the cache. These errors should be +/// forwarded to the client for notification. +#[derive(thiserror::Error, Debug)] +pub enum RejectionError { + #[error("task instance ID is not registered")] + InvalidTaskInstanceId, + + #[error("task is already in a terminal state: {0:?}")] + TaskAlreadyTerminated(TaskState), + + #[error("the number of living task instances has reached the upper limit")] + TaskInstanceLimitExceeded, + + #[error("task output not ready")] + TaskOutputNotReady, +} + +impl From for CacheError { + fn from(e: RejectionError) -> Self { + CacheError::Rejection(e) + } +} diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs new file mode 100644 index 00000000..c522b5f5 --- /dev/null +++ b/components/spider-storage/src/cache/job.rs @@ -0,0 +1,96 @@ +use std::sync::atomic::AtomicUsize; + +use spider_core::{ + job::JobState, + task::TaskIndex, + types::{ + id::{JobId, ResourceGroupId, TaskInstanceId}, + io::TaskOutput, + }, +}; + +use crate::{ + cache::{ + error::CacheError, + task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock, TaskGraph}, + types::{ExecutionContext, TaskId}, + }, + db::DbStorage, +}; + +pub struct JobControlBlock< + ReadyQueueSenderType: ReadyQueueConnector, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +> { + id: JobId, + owner_id: ResourceGroupId, + job: tokio::sync::RwLock, + ready_queue_connector: ReadyQueueSenderType, + db_connector: DbConnectorType, + task_instance_pool_connector: TaskInstancePoolConnectorType, +} + +impl< + ReadyQueueSenderType: ReadyQueueConnector, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +> JobControlBlock +{ + pub async fn create_task_instance( + &self, + task_id: TaskId, + ) -> Result { + todo!("Implement this!") + } + + pub async fn complete_task_instance( + &self, + task_instance_id: TaskInstanceId, + task_id: TaskId, + task_outputs: Vec, + ) -> Result { + todo!("Implement this!") + } + + pub async fn fail_task_instance( + &self, + task_instance_id: TaskInstanceId, + task_id: TaskId, + ) -> Result { + todo!("Implement this!") + } +} + +struct Job { + state: JobState, + task_graph: TaskGraph, + num_unfinished_tasks: AtomicUsize, +} + +#[async_trait::async_trait] +pub trait ReadyQueueConnector { + async fn send_task_ready(&self, job_id: JobId, task_ids: Vec) + -> Result<(), CacheError>; + + async fn send_commit_ready(&self, job_id: JobId) -> Result<(), CacheError>; + + async fn send_cleanup_ready(&self, job_id: JobId) -> Result<(), CacheError>; +} + +#[async_trait::async_trait] +pub trait TaskInstancePoolConnector { + fn get_next_available_task_instance_id(&self) -> TaskInstanceId; + + async fn register_task_instance( + &self, + task_instance_id: TaskInstanceId, + task: SharedTaskControlBlock, + ) -> Result<(), CacheError>; + + async fn register_termination_task_instance( + &self, + task_instance_id: TaskInstanceId, + termination_task: SharedTerminationTaskControlBlock, + ) -> Result<(), CacheError>; +} diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs index 14e74b80..8c47f8ad 100644 --- a/components/spider-storage/src/cache/task.rs +++ b/components/spider-storage/src/cache/task.rs @@ -1,280 +1,385 @@ use std::{ collections::{HashMap, HashSet}, + future::Ready, sync::{Arc, atomic::AtomicUsize}, }; use serde::Serialize; use spider_core::{ job::JobState, - task::{DataflowDependencyIndex, Task, TaskIndex}, + task::{DataflowDependencyIndex, Task, TaskIndex, TaskState}, types::{ id::{JobId, TaskInstanceId}, io::{TaskInput, TaskOutput}, }, }; -/// Enum for all possible states of a task. -#[derive(Eq, PartialEq, Debug, Clone)] -pub enum TaskState { - Pending, - Ready, - Running, - Succeeded, - Failed(String), - Cancelled, +use crate::cache::{ + error::{CacheError, CacheError::Internal, InternalError, RejectionError}, + types::{ExecutionContext, Reader, TdlContext, Writer}, +}; + +pub struct TaskGraph { + tasks: Vec, + outputs: Vec, + commit_task: Option, + cleanup_task: Option, } -impl TaskState { - pub fn is_terminal(&self) -> bool { - matches!( - self, - TaskState::Succeeded | TaskState::Failed(_) | TaskState::Cancelled - ) +impl TaskGraph { + pub fn get_task(&self, task_index: TaskIndex) -> Option { + self.tasks.get(task_index).cloned() + } + + pub async fn get_outputs(&self) -> Result, RejectionError> { + let mut outputs = Vec::with_capacity(self.outputs.len()); + for output_reader in &self.outputs { + let output_guard = output_reader.read().await; + if let Some(output) = &*output_guard { + outputs.push(output.clone()); + } else { + return Err(RejectionError::TaskOutputNotReady.into()); + } + } + Ok(outputs) } -} -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("invalid task output")] - InvalidTaskOutput, + pub fn get_commit_task(&self) -> Option { + self.commit_task.clone() + } - #[error("task output already written")] - TaskOutputDuplicateWrite, + pub fn get_cleanup_task(&self) -> Option { + self.cleanup_task.clone() + } +} - #[error("task input not ready")] - TaskInputNotReady, +#[derive(Clone)] +pub struct SharedTaskControlBlock { + inner: Arc>, +} - #[error("task outputs length mismatch: expected {0}, got {1}")] - TaskOutputsLengthMismatch(usize, usize), +impl SharedTaskControlBlock { + pub async fn register_task_instance( + &self, + task_instance_id: TaskInstanceId, + ) -> Result { + let mut tcb = self.inner.lock().await; + tcb.base.register_task_instance(task_instance_id)?; + + // NOTE: The following execution can only fail due to internal errors. + let result: Result<_, InternalError> = { + let inputs = tcb.fetch_inputs().await?; + let execution_context = ExecutionContext { + task_instance_id, + tdl_context: tcb.base.tdl_context.clone(), + inputs, + }; + Ok(execution_context) + }; + result.map_err(CacheError::from) + } - #[error("task index {0} is out of bounds")] - TaskIndexOutOfBound(TaskIndex), + pub async fn complete_task_instance( + &self, + task_instance_id: TaskInstanceId, + task_outputs: Vec, + ) -> Result, CacheError> { + let mut tcb = self.inner.lock().await; + tcb.base.complete_task_instance(task_instance_id)?; + + // NOTE: The following execution can only fail due to internal errors. + let result: Result<_, InternalError> = { + tcb.write_outputs(task_outputs).await?; + let mut ready_child_indices = Vec::new(); + for child in &tcb.children { + let mut child_tcb = child.inner.lock().await; + if child_tcb.num_parents == 0 { + return Err(InternalError::TaskGraphCorrupted( + "the child has no unfinished parent, but it is still updated as if one of \ + its parent just completed." + .to_owned(), + ) + .into()); + } + child_tcb.num_unfinished_parents -= 1; + if child_tcb.num_unfinished_parents != 0 { + continue; + } - #[error("task is already in a terminal state: {0:?}")] - TaskAlreadyTerminal(TaskState), + // In practice, this update is guarded by a read lock on the task graph, which + // guarantees that the child tasks shouldn't be terminated, as the parent is + // not. + if child_tcb.base.state.is_terminal() { + return Err(InternalError::TaskGraphCorrupted( + "a child task is in a terminal state, but it is still updated as if one \ + of its parent just completed." + .to_owned(), + ) + .into()); + } + child_tcb.base.state = TaskState::Ready; + ready_child_indices.push(child_tcb.index); + } - #[error("job is already in a terminal state: {0:?}")] - JobAlreadyTerminal(JobState), + Ok(ready_child_indices) + }; + result.map_err(CacheError::from) + } - #[error("task is still pending")] - TaskStillPending, + pub async fn fail_task_instance( + &self, + task_instance_id: TaskInstanceId, + error_message: String, + ) -> Result { + let mut tcb = self.inner.lock().await; + tcb.base + .fail_task_instance(task_instance_id, error_message) + .map_err(CacheError::from) + } - #[error("task instance {0} is not registered")] - TaskInstanceNotRegistered(TaskInstanceId), + pub async fn reset(&self) { + let mut tcb = self.inner.lock().await; + tcb.base.instance_ids.clear(); - #[error("failed to send ready task to the queue: {0}")] - TokioSendError(#[from] tokio::sync::mpsc::error::SendError<(JobId, TaskIndex)>), -} + // Reset outputs + for output_writer in &tcb.outputs { + let mut output = output_writer.write().await; + *output = None; + } -#[derive(Serialize, Clone)] -pub struct TdlContext { - package: String, - func: String, -} + tcb.base.retry_counter.reset(); -#[derive(Serialize)] -pub struct ExecutionContext { - pub task_instance_id: TaskInstanceId, - pub tdl_context: TdlContext, - pub inputs: Vec, -} + tcb.num_unfinished_parents = tcb.num_parents; + tcb.base.state = if tcb.num_unfinished_parents == 0 { + TaskState::Ready + } else { + TaskState::Pending + }; + } -/// Internal representation of a data dependency. -enum Data { - Value(Option>), - Channel, + pub async fn force_remove_task_instance(&self, task_instance_id: TaskInstanceId) -> bool { + let mut tcb = self.inner.lock().await; + tcb.base.force_remove_task_instance(task_instance_id) + } } -/// A shareable reference to a data object, allowing multiple tasks to read/write the same data -/// concurrently. -struct DataRef { - data: Arc>, +#[derive(Clone)] +pub struct SharedTerminationTaskControlBlock { + inner: Arc>, } -impl DataRef { - fn new_value(value: Vec) -> Self { - Self { - data: Arc::new(std::sync::RwLock::new(Data::Value(Some(value)))), - } +impl SharedTerminationTaskControlBlock { + pub fn register_termination_task_instance( + &self, + task_instance_id: TaskInstanceId, + ) -> Result { + let mut tcb = self.inner.blocking_lock(); + tcb.base.register_task_instance(task_instance_id)?; + Ok(tcb.base.tdl_context.clone()) } - fn new_null_value() -> Self { - Self { - data: Arc::new(std::sync::RwLock::new(Data::Value(None))), - } + pub fn complete_termination_task_instance( + &self, + task_instance_id: TaskInstanceId, + ) -> Result<(), CacheError> { + let mut tcb = self.inner.blocking_lock(); + tcb.base.complete_task_instance(task_instance_id) } - fn write_task_output(&self, task_output: TaskOutput) -> Result<(), Error> { - match task_output { - TaskOutput::ValuePayload(payload) => { - match &mut *self.data.write().expect("rw lock poisoned") { - Data::Value(optional_value) => { - if optional_value.is_some() { - return Err(Error::TaskOutputDuplicateWrite); - } - *optional_value = Some(payload); - } - Data::Channel => { - return Err(Error::InvalidTaskOutput); - } - } - } - } - Ok(()) + pub fn fail_termination_task_instance( + &self, + task_instance_id: TaskInstanceId, + error_message: String, + ) -> Result { + let mut tcb = self.inner.blocking_lock(); + tcb.base + .fail_task_instance(task_instance_id, error_message) + .map_err(CacheError::from) } - fn as_task_input(&self) -> Result { - match &*self.data.read().expect("rw lock poisoned") { - Data::Value(optional_value) => Ok(TaskInput::ValuePayload( - optional_value.clone().ok_or(Error::TaskInputNotReady)?, - )), - Data::Channel => Err(Error::InvalidTaskOutput), - } + pub async fn force_remove_task_instance(&self, task_instance_id: TaskInstanceId) -> bool { + let mut tcb = self.inner.lock().await; + tcb.base.force_remove_task_instance(task_instance_id) } } -struct TaskMetadata { +struct BaseTaskControlBlock { state: TaskState, tdl_context: TdlContext, - registered_instances: HashSet, - num_unfinished_parents: usize, - inputs: Vec, - outputs: Vec, - children: Vec, + instance_ids: HashSet, + max_num_instances: usize, + retry_counter: RetryCounter, } -impl TaskMetadata { - fn register(&mut self, task_instance_id: TaskInstanceId) -> Result { +impl BaseTaskControlBlock { + fn register_task_instance( + &mut self, + task_instance_id: TaskInstanceId, + ) -> Result<(), CacheError> { if self.state.is_terminal() { - return Err(Error::TaskAlreadyTerminal(self.state.clone())); + return Err(RejectionError::TaskAlreadyTerminated(self.state.clone()).into()); + } + if !matches!(self.state, TaskState::Ready | TaskState::Running) { + return Err(InternalError::TaskNotReady.into()); } - if self.state != TaskState::Ready || self.state != TaskState::Running { - return Err(Error::TaskStillPending); + if self.instance_ids.len() >= self.max_num_instances { + return Err(RejectionError::TaskInstanceLimitExceeded.into()); } + self.instance_ids.insert(task_instance_id); self.state = TaskState::Running; - self.registered_instances.insert(task_instance_id); - Ok(ExecutionContext { - task_instance_id, - tdl_context: self.tdl_context.clone(), - inputs: self.fetch_inputs()?, - }) + Ok(()) } - fn complete( + fn complete_task_instance( &mut self, task_instance_id: TaskInstanceId, - task_outputs: Vec, - ) -> Result<(), Error> { - if !self.registered_instances.contains(&task_instance_id) { - return Err(Error::TaskInstanceNotRegistered(task_instance_id)); + ) -> Result<(), CacheError> { + if !self.instance_ids.remove(&task_instance_id) { + return Err(RejectionError::InvalidTaskInstanceId.into()); } if self.state.is_terminal() { - return Err(Error::TaskAlreadyTerminal(self.state.clone())); + return Err(RejectionError::TaskAlreadyTerminated(self.state.clone()).into()); } - self.write_outputs(task_outputs)?; self.state = TaskState::Succeeded; Ok(()) } - fn write_outputs(&self, task_outputs: Vec) -> Result<(), Error> { + fn fail_task_instance( + &mut self, + task_instance_id: TaskInstanceId, + error_message: String, + ) -> Result { + if !self.instance_ids.remove(&task_instance_id) { + return Err(RejectionError::InvalidTaskInstanceId.into()); + } + if self.state.is_terminal() { + return Err(RejectionError::TaskAlreadyTerminated(self.state.clone()).into()); + } + + if self.retry_counter.retry() == 0 { + self.state = if self.instance_ids.len() == 0 { + TaskState::Running + } else { + TaskState::Ready + }; + } else { + self.state = TaskState::Failed(error_message); + } + Ok(self.state.clone()) + } + + fn force_remove_task_instance(&mut self, task_instance_id: TaskInstanceId) -> bool { + let existed = self.instance_ids.remove(&task_instance_id); + if existed && self.state == TaskState::Running { + self.state = TaskState::Ready; + } + existed + } +} + +struct TaskControlBlock { + base: BaseTaskControlBlock, + index: TaskIndex, + num_parents: usize, + num_unfinished_parents: usize, + inputs: Vec, + outputs: Vec, + children: Vec, +} + +impl TaskControlBlock { + async fn write_outputs(&self, task_outputs: Vec) -> Result<(), InternalError> { if task_outputs.len() != self.outputs.len() { - return Err(Error::TaskOutputsLengthMismatch( + return Err(InternalError::TaskOutputsLengthMismatch( self.outputs.len(), task_outputs.len(), )); } - for (output_ref, output) in self.outputs.iter().zip(task_outputs.into_iter()) { - output_ref.write_task_output(output)?; + + // Write task outputs + // NOTE: Currently, there is only one possible task output type (value payload) and thus we + // do not need to validate the type. In the future, when more task output types are + // supported, type validation should be done before any writes happens to avoid partial + // writes. + for (output_writer, task_output) in self.outputs.iter().zip(task_outputs.into_iter()) { + let mut output = output_writer.write().await; + if output.is_some() { + return Err(InternalError::TaskOutputDuplicateWrite); + } + *output = Some(task_output); } + Ok(()) } - fn fetch_inputs(&self) -> Result, Error> { - self.inputs - .iter() - .map(|input_ref| input_ref.as_task_input()) - .collect() + async fn fetch_inputs(&self) -> Result, CacheError> { + let mut inputs = Vec::with_capacity(self.inputs.len()); + for input_reader in &self.inputs { + inputs.push(input_reader.read_as_task_input().await?); + } + Ok(inputs) } } -struct TaskGraph { - tasks: Vec>, +struct TerminationTaskControlBlock { + base: BaseTaskControlBlock, } -struct JobMetadata { - state: JobState, - task_graph: TaskGraph, - num_unfinished_tasks: AtomicUsize, -} +type ValuePayload = Option>; -pub struct Job { - id: JobId, - metadata: std::sync::RwLock, - ready_queue_sender: tokio::sync::mpsc::Sender<(JobId, TaskIndex)>, -} +#[derive(Clone)] +struct Channel {} -impl Job { - pub fn register_task_instance( - &self, - task_instance_id: TaskInstanceId, - task_index: TaskIndex, - ) -> Result { - let job_metadata = self.metadata.read().expect("rw lock poisoned"); - let mut task_metadata = job_metadata - .task_graph - .tasks - .get(task_index) - .ok_or(Error::TaskIndexOutOfBound(task_index))? - .lock() - .expect("mutex poisoned"); - task_metadata.register(task_instance_id) - } +enum InputReader { + Value(Reader), + Channel(Channel), +} - pub async fn complete_task_instance( - &self, - task_instance_id: TaskInstanceId, - task_index: TaskIndex, - task_outputs: Vec, - ) -> Result<(), Error> { - let job_metadata = self.metadata.read().expect("rw lock poisoned"); - - // Update the task metadata - let mut task_metadata = job_metadata - .task_graph - .tasks - .get(task_index) - .ok_or(Error::TaskIndexOutOfBound(task_index))? - .lock() - .expect("mutex poisoned"); - task_metadata.complete(task_instance_id, task_outputs)?; - for child_idx in &task_metadata.children { - let mut child_metadata = job_metadata - .task_graph - .tasks - .get(*child_idx) - .ok_or(Error::TaskIndexOutOfBound(*child_idx))? - .lock() - .expect("mutex poisoned"); - child_metadata.num_unfinished_parents -= 1; - if child_metadata.num_unfinished_parents == 0 { - child_metadata.state = TaskState::Ready; - self.ready_queue_sender.send((self.id, *child_idx)).await?; +impl InputReader { + async fn read_as_task_input(&self) -> Result { + match self { + InputReader::Value(value_payload) => { + let value_guard = value_payload.read().await; + if let Some(value) = &*value_guard { + Ok(TaskInput::ValuePayload(value.clone())) + } else { + Err(InternalError::TaskInputNotReady.into()) + } } + InputReader::Channel(_) => unimplemented!("channel input is not supported yet"), } - let num_unfinished_tasks = job_metadata - .num_unfinished_tasks - .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) - - 1; - drop(task_metadata); - drop(job_metadata); - - if num_unfinished_tasks > 0 { - return Ok(()); + } +} + +type OutputReader = Reader; + +type OutputWriter = Writer; + +struct RetryCounter { + max_num_retries_allowed: usize, + retry_count: usize, +} + +impl RetryCounter { + fn new(max_num_retries_allowed: usize) -> Self { + Self { + max_num_retries_allowed, + retry_count: max_num_retries_allowed, } + } - // Atomic decrement guarantees that only one thread's control flow can reach here. - let job_metadata = self.metadata.write().expect("rw lock poisoned"); + fn retry(&mut self) -> usize { + if self.retry_count == 0 { + // In practice, this is possible if the total number of task instances creates are + // greater than the number of retries allowed. + return 0; + } + let num_retries_left = self.retry_count; + self.retry_count -= 1; + num_retries_left + } - Ok(()) + fn reset(&mut self) { + self.retry_count = self.max_num_retries_allowed; } } diff --git a/components/spider-storage/src/cache/types.rs b/components/spider-storage/src/cache/types.rs new file mode 100644 index 00000000..f2a393df --- /dev/null +++ b/components/spider-storage/src/cache/types.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use serde::Serialize; +use spider_core::{ + task::TaskIndex, + types::{id::TaskInstanceId, io::TaskInput}, +}; +use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +pub type Shared = Arc>; + +#[derive(Clone)] +pub struct Reader { + inner: Shared, +} + +impl Reader { + pub fn new(inner: Shared) -> Reader { + Reader { inner } + } + + pub async fn read(&self) -> RwLockReadGuard<'_, Type> { + self.inner.read().await + } +} + +#[derive(Clone)] +pub struct Writer { + inner: Shared, +} + +impl Writer { + pub fn new(inner: Shared) -> Writer { + Writer { inner } + } + + pub async fn write(&self) -> RwLockWriteGuard<'_, Type> { + self.inner.write().await + } +} + +#[derive(Serialize, Clone)] +pub struct TdlContext { + package: String, + func: String, +} + +#[derive(Serialize)] +pub struct ExecutionContext { + pub task_instance_id: TaskInstanceId, + pub tdl_context: TdlContext, + pub inputs: Vec, +} + +#[derive(Serialize, Clone)] +pub enum TaskId { + TaskIndex(TaskIndex), + Commit, + Cleanup, +} From 18d10e90579b49ac2bdd4c30f579ebc446fd02a3 Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Tue, 17 Mar 2026 10:48:11 -0400 Subject: [PATCH 3/8] WIP. --- components/spider-core/src/job.rs | 7 ++ components/spider-storage/src/cache/error.rs | 28 +++++ components/spider-storage/src/cache/job.rs | 118 +++++++++++++++++-- components/spider-storage/src/cache/task.rs | 14 +-- components/spider-storage/src/cache/types.rs | 2 +- 5 files changed, 150 insertions(+), 19 deletions(-) diff --git a/components/spider-core/src/job.rs b/components/spider-core/src/job.rs index 7eaf94b9..bd57b457 100644 --- a/components/spider-core/src/job.rs +++ b/components/spider-core/src/job.rs @@ -27,6 +27,13 @@ impl JobState { matches!(self, Self::Succeeded | Self::Failed | Self::Cancelled) } + /// # Returns + /// + /// Whether the job is in [`JobState::Running`] state. + pub const fn is_running(&self) -> bool { + return matches!(self, Self::Running); + } + /// # Returns /// /// Whether the state transition `from` -> `to` is valid. diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs index 3a5afbdb..8e4caaed 100644 --- a/components/spider-storage/src/cache/error.rs +++ b/components/spider-storage/src/cache/error.rs @@ -1,4 +1,5 @@ use spider_core::{ + job::JobState, task::{TaskIndex, TaskState}, types::id::JobId, }; @@ -29,6 +30,24 @@ pub enum InternalError { #[error("task graph corrupted: {0}")] TaskGraphCorrupted(String), + #[error("job has no been started")] + JobNotStarted, + + #[error("job does not have a commit task")] + JobNoCommit, + + #[error("job does not have a cleanup task")] + JobNoCleanup, + + #[error("unexpected job state: current {current}, expected {expected}")] + UnexpectedJobState { + current: JobState, + expected: JobState, + }, + + #[error("job outputs are not ready")] + JobOutputsNotReady, + #[error("failed to send scheduling context into the channel")] TokioSendError(#[from] tokio::sync::mpsc::error::SendError<(JobId, TaskIndex)>), @@ -53,6 +72,15 @@ pub enum RejectionError { #[error("task is already in a terminal state: {0:?}")] TaskAlreadyTerminated(TaskState), + #[error("job is no longer in the running state: {0}")] + JobNoLongerRunning(JobState), + + #[error("job is no longer in the commit-ready state: {0}")] + JobNoLongerCommitReady(JobState), + + #[error("job is no longer in the cleanup-ready state: {0}")] + JobNoLongerCleanupReady(JobState), + #[error("the number of living task instances has reached the upper limit")] TaskInstanceLimitExceeded, diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs index c522b5f5..25d8dc4a 100644 --- a/components/spider-storage/src/cache/job.rs +++ b/components/spider-storage/src/cache/job.rs @@ -11,16 +11,21 @@ use spider_core::{ use crate::{ cache::{ - error::CacheError, + error::{ + CacheError, + InternalError, + RejectionError, + RejectionError::{JobNoLongerCleanupReady, JobNoLongerCommitReady, JobNoLongerRunning}, + }, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock, TaskGraph}, types::{ExecutionContext, TaskId}, }, - db::DbStorage, + db::InternalJobOrchestration, }; pub struct JobControlBlock< ReadyQueueSenderType: ReadyQueueConnector, - DbConnectorType: DbStorage, + DbConnectorType: InternalJobOrchestration, TaskInstancePoolConnectorType: TaskInstancePoolConnector, > { id: JobId, @@ -33,7 +38,7 @@ pub struct JobControlBlock< impl< ReadyQueueSenderType: ReadyQueueConnector, - DbConnectorType: DbStorage, + DbConnectorType: InternalJobOrchestration, TaskInstancePoolConnectorType: TaskInstancePoolConnector, > JobControlBlock { @@ -41,7 +46,95 @@ impl< &self, task_id: TaskId, ) -> Result { - todo!("Implement this!") + let job = self.job.read().await; + + let execution_context = match task_id { + TaskId::TaskIndex(task_index) => { + if job.state == JobState::Ready { + return Err(InternalError::JobNotStarted.into()); + } + if !job.state.is_running() { + return Err(RejectionError::JobNoLongerRunning(job.state).into()); + } + let tcb = job + .task_graph + .get_task(task_index) + .ok_or(InternalError::TaskIndexOutOfBound)?; + let task_instance_id = self + .task_instance_pool_connector + .get_next_available_task_instance_id(); + let execution_context = tcb.register_task_instance(task_instance_id).await?; + self.task_instance_pool_connector + .register_task_instance(task_instance_id, tcb) + .await?; + execution_context + } + + TaskId::Commit => { + if job.state.is_terminal() || job.state == JobState::CleanupReady { + return Err(JobNoLongerCommitReady(job.state).into()); + } + if job.state != JobState::CommitReady { + return Err(InternalError::UnexpectedJobState { + expected: JobState::CommitReady, + current: job.state, + } + .into()); + } + let commit_tcb = job + .task_graph + .get_commit_task() + .ok_or(InternalError::JobNoCommit)?; + let task_instance_id = self + .task_instance_pool_connector + .get_next_available_task_instance_id(); + let tdl_context = commit_tcb + .register_termination_task_instance(task_instance_id) + .await?; + self.task_instance_pool_connector + .register_termination_task_instance(task_instance_id, commit_tcb) + .await?; + ExecutionContext { + task_instance_id, + tdl_context, + // TODO: Question, what's the input for the commit task? + inputs: None, + } + } + + TaskId::Cleanup => { + if job.state.is_terminal() { + return Err(JobNoLongerCleanupReady(job.state).into()); + } + if job.state != JobState::CleanupReady { + return Err(InternalError::UnexpectedJobState { + expected: JobState::CleanupReady, + current: job.state, + } + .into()); + } + let commit_tcb = job + .task_graph + .get_commit_task() + .ok_or(InternalError::JobNoCommit)?; + let task_instance_id = self + .task_instance_pool_connector + .get_next_available_task_instance_id(); + let tdl_context = commit_tcb + .register_termination_task_instance(task_instance_id) + .await?; + self.task_instance_pool_connector + .register_termination_task_instance(task_instance_id, commit_tcb) + .await?; + ExecutionContext { + task_instance_id, + tdl_context, + inputs: None, + } + } + }; + + Ok(execution_context) } pub async fn complete_task_instance( @@ -70,12 +163,15 @@ struct Job { #[async_trait::async_trait] pub trait ReadyQueueConnector { - async fn send_task_ready(&self, job_id: JobId, task_ids: Vec) - -> Result<(), CacheError>; + async fn send_task_ready( + &self, + job_id: JobId, + task_ids: Vec, + ) -> Result<(), InternalError>; - async fn send_commit_ready(&self, job_id: JobId) -> Result<(), CacheError>; + async fn send_commit_ready(&self, job_id: JobId) -> Result<(), InternalError>; - async fn send_cleanup_ready(&self, job_id: JobId) -> Result<(), CacheError>; + async fn send_cleanup_ready(&self, job_id: JobId) -> Result<(), InternalError>; } #[async_trait::async_trait] @@ -86,11 +182,11 @@ pub trait TaskInstancePoolConnector { &self, task_instance_id: TaskInstanceId, task: SharedTaskControlBlock, - ) -> Result<(), CacheError>; + ) -> Result<(), InternalError>; async fn register_termination_task_instance( &self, task_instance_id: TaskInstanceId, termination_task: SharedTerminationTaskControlBlock, - ) -> Result<(), CacheError>; + ) -> Result<(), InternalError>; } diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs index 8c47f8ad..42b3d4ee 100644 --- a/components/spider-storage/src/cache/task.rs +++ b/components/spider-storage/src/cache/task.rs @@ -72,7 +72,7 @@ impl SharedTaskControlBlock { let execution_context = ExecutionContext { task_instance_id, tdl_context: tcb.base.tdl_context.clone(), - inputs, + inputs: Some(inputs), }; Ok(execution_context) }; @@ -169,29 +169,29 @@ pub struct SharedTerminationTaskControlBlock { } impl SharedTerminationTaskControlBlock { - pub fn register_termination_task_instance( + pub async fn register_termination_task_instance( &self, task_instance_id: TaskInstanceId, ) -> Result { - let mut tcb = self.inner.blocking_lock(); + let mut tcb = self.inner.lock().await; tcb.base.register_task_instance(task_instance_id)?; Ok(tcb.base.tdl_context.clone()) } - pub fn complete_termination_task_instance( + pub async fn complete_termination_task_instance( &self, task_instance_id: TaskInstanceId, ) -> Result<(), CacheError> { - let mut tcb = self.inner.blocking_lock(); + let mut tcb = self.inner.lock().await; tcb.base.complete_task_instance(task_instance_id) } - pub fn fail_termination_task_instance( + pub async fn fail_termination_task_instance( &self, task_instance_id: TaskInstanceId, error_message: String, ) -> Result { - let mut tcb = self.inner.blocking_lock(); + let mut tcb = self.inner.lock().await; tcb.base .fail_task_instance(task_instance_id, error_message) .map_err(CacheError::from) diff --git a/components/spider-storage/src/cache/types.rs b/components/spider-storage/src/cache/types.rs index f2a393df..8ecae1f1 100644 --- a/components/spider-storage/src/cache/types.rs +++ b/components/spider-storage/src/cache/types.rs @@ -49,7 +49,7 @@ pub struct TdlContext { pub struct ExecutionContext { pub task_instance_id: TaskInstanceId, pub tdl_context: TdlContext, - pub inputs: Vec, + pub inputs: Option>, } #[derive(Serialize, Clone)] From a00185dde31f8ebc986c4d9bc1763cfab50659df Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Tue, 17 Mar 2026 13:20:23 -0400 Subject: [PATCH 4/8] Still WIP... --- components/spider-core/src/types/id.rs | 2 +- components/spider-storage/src/cache/job.rs | 66 ++++++++++++++++++++-- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/components/spider-core/src/types/id.rs b/components/spider-core/src/types/id.rs index 8e8b115b..cda5d7b7 100644 --- a/components/spider-core/src/types/id.rs +++ b/components/spider-core/src/types/id.rs @@ -50,7 +50,7 @@ pub type ResourceGroupId = Id; pub enum TaskIdMarker {} pub type TaskId = Id; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum JobIdMarker {} pub type JobId = Id; diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs index 25d8dc4a..eeab0e49 100644 --- a/components/spider-storage/src/cache/job.rs +++ b/components/spider-storage/src/cache/job.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::{AtomicUsize, Ordering}; use spider_core::{ job::JobState, @@ -15,7 +15,7 @@ use crate::{ CacheError, InternalError, RejectionError, - RejectionError::{JobNoLongerCleanupReady, JobNoLongerCommitReady, JobNoLongerRunning}, + RejectionError::{JobNoLongerCleanupReady, JobNoLongerCommitReady}, }, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock, TaskGraph}, types::{ExecutionContext, TaskId}, @@ -140,8 +140,66 @@ impl< pub async fn complete_task_instance( &self, task_instance_id: TaskInstanceId, - task_id: TaskId, + task_index: TaskIndex, task_outputs: Vec, + ) -> Result { + let job = self.job.read().await; + if job.state == JobState::Ready { + return Err(InternalError::JobNotStarted.into()); + } + if !job.state.is_running() { + return Err(RejectionError::JobNoLongerRunning(job.state).into()); + } + let tcb = job + .task_graph + .get_task(task_index) + .ok_or(InternalError::TaskIndexOutOfBound)?; + let ready_task_ids = tcb.complete_task_instance(task_instance_id, task_outputs).await?; + let num_incompleted_task = job.num_incompleted_tasks.fetch_sub(1, Ordering::Relaxed); + if !ready_task_ids.is_empty() { + if num_incompleted_task == 0 { + return Err( + InternalError::TaskGraphCorrupted( + "no incompleted tasks while new ready task IDs are generated".to_owned() + ).into()); + } + self.ready_queue_connector + .send_task_ready(self.id.clone(), ready_task_ids) + .await?; + return Ok(job.state); + } + if num_incompleted_task != 0 { + return Ok(job.state); + } + drop(job); + + let job_state = self.commit_job_outputs().await?; + if matches!(job_state, JobState::CommitReady) { + self.ready_queue_connector.send_commit_ready(self.id.clone()).await?; + } + Ok(job_state) + } + + pub async fn commit_job_outputs(&self) -> Result { + let mut job = self.job.write().await; + if !job.state.is_running() { + return Err(RejectionError::JobNoLongerRunning(job.state).into()); + } + let outputs = job.task_graph.get_outputs().await.map_err(|_| InternalError::JobOutputsNotReady)?; + + Ok(job.state) + } + + pub async fn complete_commit_task_instance( + &self, + task_instance_id: TaskInstanceId, + ) -> Result { + todo!("Implement this!") + } + + pub async fn complete_cleanup_task_instance( + &self, + task_instance_id: TaskInstanceId, ) -> Result { todo!("Implement this!") } @@ -158,7 +216,7 @@ impl< struct Job { state: JobState, task_graph: TaskGraph, - num_unfinished_tasks: AtomicUsize, + num_incompleted_tasks: AtomicUsize, } #[async_trait::async_trait] From 61b59789e61fdf9aa011b20166bb84333ab5023d Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Tue, 17 Mar 2026 17:32:06 -0400 Subject: [PATCH 5/8] Still WIP... --- components/spider-storage/src/cache/error.rs | 13 + components/spider-storage/src/cache/job.rs | 309 +++++++++++++++---- components/spider-storage/src/cache/task.rs | 10 +- 3 files changed, 272 insertions(+), 60 deletions(-) diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs index 8e4caaed..4043c273 100644 --- a/components/spider-storage/src/cache/error.rs +++ b/components/spider-storage/src/cache/error.rs @@ -8,6 +8,7 @@ use spider_core::{ pub enum CacheError { Internal(InternalError), Rejection(RejectionError), + DbError(crate::db::DbError), } /// Enums for all internal errors. When these error happens, it is considered that the system is in @@ -48,6 +49,9 @@ pub enum InternalError { #[error("job outputs are not ready")] JobOutputsNotReady, + #[error("job terminated unexpectedly")] + JobTerminatedUnexpectedly, + #[error("failed to send scheduling context into the channel")] TokioSendError(#[from] tokio::sync::mpsc::error::SendError<(JobId, TaskIndex)>), @@ -81,6 +85,9 @@ pub enum RejectionError { #[error("job is no longer in the cleanup-ready state: {0}")] JobNoLongerCleanupReady(JobState), + #[error("job is already in a terminal state: {0}")] + JobAlreadyTerminated(JobState), + #[error("the number of living task instances has reached the upper limit")] TaskInstanceLimitExceeded, @@ -93,3 +100,9 @@ impl From for CacheError { CacheError::Rejection(e) } } + +impl From for CacheError { + fn from(e: crate::db::DbError) -> Self { + CacheError::DbError(e) + } +} diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs index eeab0e49..1258808a 100644 --- a/components/spider-storage/src/cache/job.rs +++ b/components/spider-storage/src/cache/job.rs @@ -2,17 +2,19 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use spider_core::{ job::JobState, - task::TaskIndex, + task::{TaskIndex, TaskState}, types::{ id::{JobId, ResourceGroupId, TaskInstanceId}, io::TaskOutput, }, }; +use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use crate::{ cache::{ error::{ CacheError, + CacheError::Internal, InternalError, RejectionError, RejectionError::{JobNoLongerCleanupReady, JobNoLongerCommitReady}, @@ -30,7 +32,7 @@ pub struct JobControlBlock< > { id: JobId, owner_id: ResourceGroupId, - job: tokio::sync::RwLock, + job: RwJob, ready_queue_connector: ReadyQueueSenderType, db_connector: DbConnectorType, task_instance_pool_connector: TaskInstancePoolConnectorType, @@ -46,16 +48,9 @@ impl< &self, task_id: TaskId, ) -> Result { - let job = self.job.read().await; - let execution_context = match task_id { TaskId::TaskIndex(task_index) => { - if job.state == JobState::Ready { - return Err(InternalError::JobNotStarted.into()); - } - if !job.state.is_running() { - return Err(RejectionError::JobNoLongerRunning(job.state).into()); - } + let job = self.job.read_if_running().await?; let tcb = job .task_graph .get_task(task_index) @@ -71,16 +66,7 @@ impl< } TaskId::Commit => { - if job.state.is_terminal() || job.state == JobState::CleanupReady { - return Err(JobNoLongerCommitReady(job.state).into()); - } - if job.state != JobState::CommitReady { - return Err(InternalError::UnexpectedJobState { - expected: JobState::CommitReady, - current: job.state, - } - .into()); - } + let job = self.job.read_if_commit_ready().await?; let commit_tcb = job .task_graph .get_commit_task() @@ -103,16 +89,7 @@ impl< } TaskId::Cleanup => { - if job.state.is_terminal() { - return Err(JobNoLongerCleanupReady(job.state).into()); - } - if job.state != JobState::CleanupReady { - return Err(InternalError::UnexpectedJobState { - expected: JobState::CleanupReady, - current: job.state, - } - .into()); - } + let job = self.job.read_if_cleanup_ready().await?; let commit_tcb = job .task_graph .get_commit_task() @@ -143,73 +120,180 @@ impl< task_index: TaskIndex, task_outputs: Vec, ) -> Result { - let job = self.job.read().await; - if job.state == JobState::Ready { - return Err(InternalError::JobNotStarted.into()); - } - if !job.state.is_running() { - return Err(RejectionError::JobNoLongerRunning(job.state).into()); - } + let job = self.job.read_if_running().await?; let tcb = job .task_graph .get_task(task_index) .ok_or(InternalError::TaskIndexOutOfBound)?; - let ready_task_ids = tcb.complete_task_instance(task_instance_id, task_outputs).await?; + let ready_task_ids = tcb + .complete_task_instance(task_instance_id, task_outputs) + .await?; let num_incompleted_task = job.num_incompleted_tasks.fetch_sub(1, Ordering::Relaxed); + if !ready_task_ids.is_empty() { if num_incompleted_task == 0 { - return Err( - InternalError::TaskGraphCorrupted( - "no incompleted tasks while new ready task IDs are generated".to_owned() - ).into()); + return Err(InternalError::TaskGraphCorrupted( + "no incompleted tasks while new ready task IDs are generated".to_owned(), + ) + .into()); } self.ready_queue_connector .send_task_ready(self.id.clone(), ready_task_ids) .await?; return Ok(job.state); } + if num_incompleted_task != 0 { return Ok(job.state); } - drop(job); - let job_state = self.commit_job_outputs().await?; - if matches!(job_state, JobState::CommitReady) { - self.ready_queue_connector.send_commit_ready(self.id.clone()).await?; + drop(job); + let job_state = self.commit_outputs().await?; + match job_state { + JobState::CommitReady => { + if !self.job.has_commit_task().await { + return Err(InternalError::JobNoCommit.into()); + } + self.ready_queue_connector + .send_commit_ready(self.id.clone()) + .await?; + } + JobState::Succeeded => {} + other => unreachable!( + "unexpected job state after committing job outputs: {:?}", + other + ), } Ok(job_state) } - pub async fn commit_job_outputs(&self) -> Result { - let mut job = self.job.write().await; - if !job.state.is_running() { - return Err(RejectionError::JobNoLongerRunning(job.state).into()); - } - let outputs = job.task_graph.get_outputs().await.map_err(|_| InternalError::JobOutputsNotReady)?; - - Ok(job.state) - } - pub async fn complete_commit_task_instance( &self, task_instance_id: TaskInstanceId, ) -> Result { - todo!("Implement this!") + let mut job = self.job.write_if_commit_ready().await?; + job.task_graph + .get_commit_task() + .ok_or(InternalError::JobNoCommit)? + .complete_termination_task_instance(task_instance_id) + .await?; + self.db_connector + .set_state(self.id.clone(), JobState::Succeeded) + .await?; + job.state = JobState::Succeeded; + Ok(JobState::Succeeded) } pub async fn complete_cleanup_task_instance( &self, task_instance_id: TaskInstanceId, ) -> Result { - todo!("Implement this!") + let mut job = self.job.write_if_cleanup_ready().await?; + job.task_graph + .get_cleanup_task() + .ok_or(InternalError::JobNoCleanup)? + .complete_termination_task_instance(task_instance_id) + .await?; + self.db_connector + .set_state(self.id.clone(), JobState::Cancelled) + .await?; + job.state = JobState::Cancelled; + Ok(JobState::Cancelled) } pub async fn fail_task_instance( &self, task_instance_id: TaskInstanceId, task_id: TaskId, + error_message: String, ) -> Result { - todo!("Implement this!") + match task_id { + TaskId::TaskIndex(task_index) => { + let job = self.job.read_if_running().await?; + let task_state = job + .task_graph + .get_task(task_index) + .ok_or(InternalError::TaskIndexOutOfBound)? + .fail_task_instance(task_instance_id, error_message.clone()) + .await?; + if matches!(task_state, TaskState::Ready | TaskState::Running) { + self.ready_queue_connector + .send_task_ready(self.id.clone(), vec![task_index]) + .await?; + return Ok(job.state); + } + } + TaskId::Commit => { + let job = self.job.read_if_commit_ready().await?; + let task_state = job + .task_graph + .get_commit_task() + .ok_or(InternalError::JobNoCommit)? + .fail_termination_task_instance(task_instance_id, error_message.clone()) + .await?; + if matches!(task_state, TaskState::Ready | TaskState::Running) { + self.ready_queue_connector + .send_commit_ready(self.id.clone()) + .await?; + return Ok(job.state); + } + } + TaskId::Cleanup => { + let job = self.job.read_if_cleanup_ready().await?; + let task_state = job + .task_graph + .get_cleanup_task() + .ok_or(InternalError::JobNoCleanup)? + .fail_termination_task_instance(task_instance_id, error_message.clone()) + .await?; + if matches!(task_state, TaskState::Ready | TaskState::Running) { + self.ready_queue_connector + .send_cleanup_ready(self.id.clone()) + .await?; + return Ok(job.state); + } + } + }; + + let mut job = self.job.write_if_non_terminated().await.map_err(|e| { + match &e { + CacheError::Rejection(RejectionError::JobAlreadyTerminated(state)) => { + if *state == JobState::Failed { + // Already failed by others + return e; + } + InternalError::JobTerminatedUnexpectedly.into() + } + _ => InternalError::JobTerminatedUnexpectedly.into(), + } + })?; + self.db_connector + .fail(self.id.clone(), error_message) + .await?; + job.state = JobState::Failed; + Ok(JobState::Failed) + } + + async fn commit_outputs(&self) -> Result { + let mut job = self.job.write_if_running().await?; + let outputs = job + .task_graph + .get_outputs() + .await + .map_err(|_| InternalError::JobOutputsNotReady)?; + job.state = self + .db_connector + .commit_outputs(self.id.clone(), outputs) + .await?; + Ok(job.state) + } + + async fn cancel(&self) -> Result { + todo!( + "Implement this. The job table must be locked for write, and the state of all tasks \ + must be checked to ensure if any of them are failed already, the cancellation \ + shouldn't go through." + ) } } @@ -219,6 +303,113 @@ struct Job { num_incompleted_tasks: AtomicUsize, } +struct RwJob { + inner: RwLock, +} + +impl RwJob { + async fn read_checked( + &self, + check: fn(&Job) -> Result<(), CacheError>, + ) -> Result, CacheError> { + let guard = self.inner.read().await; + check(&*guard)?; + Ok(guard) + } + + async fn write_checked( + &self, + check: fn(&Job) -> Result<(), CacheError>, + ) -> Result, CacheError> { + let guard = self.inner.write().await; + check(&*guard)?; + Ok(guard) + } + + pub async fn read_if_running(&self) -> Result, CacheError> { + self.read_checked(Job::assumed_running).await + } + + pub async fn write_if_running(&self) -> Result, CacheError> { + self.write_checked(Job::assumed_running).await + } + + pub async fn read_if_commit_ready(&self) -> Result, CacheError> { + self.read_checked(Job::assumed_commit_ready).await + } + + pub async fn write_if_commit_ready(&self) -> Result, CacheError> { + self.write_checked(Job::assumed_commit_ready).await + } + + pub async fn read_if_cleanup_ready(&self) -> Result, CacheError> { + self.read_checked(Job::assumed_cleanup_ready).await + } + + pub async fn write_if_cleanup_ready(&self) -> Result, CacheError> { + self.write_checked(Job::assumed_cleanup_ready).await + } + + pub async fn write_if_non_terminated(&self) -> Result, CacheError> { + self.write_checked(Job::assumed_non_terminated).await + } + + pub async fn has_commit_task(&self) -> bool { + self.inner.read().await.task_graph.has_commit_task() + } + + pub async fn has_cleanup_task(&self) -> bool { + self.inner.read().await.task_graph.has_cleanup_task() + } +} + +impl Job { + fn assumed_running(&self) -> Result<(), CacheError> { + if !self.state.is_running() { + if self.state == JobState::Ready { + return Err(InternalError::JobNotStarted.into()); + } + return Err(RejectionError::JobNoLongerRunning(self.state).into()); + } + Ok(()) + } + + fn assumed_commit_ready(&self) -> Result<(), CacheError> { + if self.state != JobState::CommitReady { + if self.state.is_terminal() || self.state == JobState::CleanupReady { + return Err(JobNoLongerCommitReady(self.state).into()); + } + return Err(InternalError::UnexpectedJobState { + expected: JobState::CommitReady, + current: self.state, + } + .into()); + } + Ok(()) + } + + fn assumed_cleanup_ready(&self) -> Result<(), CacheError> { + if self.state != JobState::CleanupReady { + if self.state.is_terminal() { + return Err(JobNoLongerCleanupReady(self.state).into()); + } + return Err(InternalError::UnexpectedJobState { + expected: JobState::CleanupReady, + current: self.state, + } + .into()); + } + Ok(()) + } + + fn assumed_non_terminated(&self) -> Result<(), CacheError> { + if self.state.is_terminal() { + return Err(RejectionError::JobNoLongerRunning(self.state).into()); + } + Ok(()) + } +} + #[async_trait::async_trait] pub trait ReadyQueueConnector { async fn send_task_ready( diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs index 42b3d4ee..f1b7968f 100644 --- a/components/spider-storage/src/cache/task.rs +++ b/components/spider-storage/src/cache/task.rs @@ -32,7 +32,7 @@ impl TaskGraph { } pub async fn get_outputs(&self) -> Result, RejectionError> { - let mut outputs = Vec::with_capacity(self.outputs.len()); + let mut outputs: Vec = Vec::with_capacity(self.outputs.len()); for output_reader in &self.outputs { let output_guard = output_reader.read().await; if let Some(output) = &*output_guard { @@ -44,6 +44,14 @@ impl TaskGraph { Ok(outputs) } + pub fn has_commit_task(&self) -> bool { + self.commit_task.is_some() + } + + pub fn has_cleanup_task(&self) -> bool { + self.cleanup_task.is_some() + } + pub fn get_commit_task(&self) -> Option { self.commit_task.clone() } From b9da8570837d07ffc2d5c6b7a33d1c43be62dbfd Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Tue, 17 Mar 2026 21:26:16 -0400 Subject: [PATCH 6/8] Claude was working, lol --- Cargo.lock | 12 + components/spider-core/src/job.rs | 3 +- components/spider-core/src/task.rs | 28 +- components/spider-core/src/task/task_graph.rs | 142 ++++- components/spider-core/src/types/io.rs | 2 +- .../tests/test_task_graph_serde.rs | 60 +- components/spider-storage/Cargo.toml | 3 + components/spider-storage/src/cache.rs | 43 ++ components/spider-storage/src/cache/error.rs | 22 +- .../spider-storage/src/cache/factory.rs | 246 +++++++++ components/spider-storage/src/cache/job.rs | 69 ++- components/spider-storage/src/cache/task.rs | 110 ++-- components/spider-storage/src/cache/tests.rs | 517 ++++++++++++++++++ components/spider-storage/src/cache/types.rs | 12 +- components/spider-storage/src/protocol.rs | 12 +- 15 files changed, 1170 insertions(+), 111 deletions(-) create mode 100644 components/spider-storage/src/cache/factory.rs create mode 100644 components/spider-storage/src/cache/tests.rs diff --git a/Cargo.lock b/Cargo.lock index cd38cfd3..8c3c3463 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1431,6 +1431,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", ] [[package]] diff --git a/components/spider-core/src/job.rs b/components/spider-core/src/job.rs index bd57b457..3b9eadf2 100644 --- a/components/spider-core/src/job.rs +++ b/components/spider-core/src/job.rs @@ -30,8 +30,9 @@ impl JobState { /// # Returns /// /// Whether the job is in [`JobState::Running`] state. + #[must_use] pub const fn is_running(&self) -> bool { - return matches!(self, Self::Running); + matches!(self, Self::Running) } /// # Returns diff --git a/components/spider-core/src/task.rs b/components/spider-core/src/task.rs index 6eaf94d9..0295d01d 100644 --- a/components/spider-core/src/task.rs +++ b/components/spider-core/src/task.rs @@ -1,6 +1,7 @@ mod task_graph; mod type_descriptor; +use serde::{Deserialize, Serialize}; pub use task_graph::*; use thiserror::Error; pub use type_descriptor::*; @@ -35,13 +36,30 @@ pub enum TaskState { } impl TaskState { - pub fn is_terminal(&self) -> bool { - matches!( - self, - TaskState::Succeeded | TaskState::Failed(_) | TaskState::Cancelled - ) + #[must_use] + pub const fn is_terminal(&self) -> bool { + matches!(self, Self::Succeeded | Self::Failed(_) | Self::Cancelled) } } /// Represents metadata associated with a task. pub struct TaskMetadata {} + +/// Execution policy for a task, controlling concurrency and retry behavior. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ExecutionPolicy { + /// The maximum number of concurrent instances allowed for this task. + pub max_num_instances: usize, + + /// The maximum number of retries allowed for this task on failure. + pub max_num_retries: usize, +} + +impl Default for ExecutionPolicy { + fn default() -> Self { + Self { + max_num_instances: 1, + max_num_retries: 0, + } + } +} diff --git a/components/spider-core/src/task/task_graph.rs b/components/spider-core/src/task/task_graph.rs index f9513676..5ed1b912 100644 --- a/components/spider-core/src/task/task_graph.rs +++ b/components/spider-core/src/task/task_graph.rs @@ -5,10 +5,10 @@ use serde::{ de::{self, MapAccess, Visitor}, ser::{SerializeMap, SerializeSeq, Serializer}, }; -use strum::{EnumCount, IntoEnumIterator}; -use strum_macros::{EnumCount, EnumIter}; +use strum::IntoEnumIterator; +use strum_macros::EnumIter; -use crate::task::{DataTypeDescriptor, Error}; +use crate::task::{DataTypeDescriptor, Error, ExecutionPolicy}; /// A unique identifier for a task within a task graph, assigned based on insertion order. /// @@ -51,6 +51,7 @@ pub struct Task { child_indices: Vec, input_dep_indices: Vec, output_dep_indices: Vec, + execution_policy: ExecutionPolicy, } impl Task { @@ -117,6 +118,11 @@ impl Task { self.tdl_function.as_str() } + #[must_use] + pub const fn get_execution_policy(&self) -> &ExecutionPolicy { + &self.execution_policy + } + const fn new( idx: TaskIndex, tdl_package: String, @@ -124,6 +130,7 @@ impl Task { input_dep_indices: Vec, output_dep_indices: Vec, parent_indices: Vec, + execution_policy: ExecutionPolicy, ) -> Self { Self { idx, @@ -133,6 +140,7 @@ impl Task { child_indices: Vec::new(), input_dep_indices, output_dep_indices, + execution_policy, } } @@ -232,6 +240,22 @@ pub struct TaskDescriptor { /// * `None`: All inputs are graph inputs (i.e., external inputs with no source tasks). This /// indicates the task is an input task to the graph. pub input_sources: Option>, + + /// The execution policy for this task. + pub execution_policy: ExecutionPolicy, +} + +/// A descriptor for a termination task (commit or cleanup) that runs after the main task graph. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TerminationTaskDescriptor { + /// The TDL package containing the termination task function. + pub tdl_package: String, + + /// The TDL function name for the termination task. + pub tdl_function: String, + + /// The execution policy for this termination task. + pub execution_policy: ExecutionPolicy, } /// An in-memory representation of a directed acyclic graph (DAG) of tasks and their dependencies. @@ -239,6 +263,8 @@ pub struct TaskDescriptor { pub struct TaskGraph { dataflow_deps: Vec, tasks: Vec, + commit_task: Option, + cleanup_task: Option, } impl TaskGraph { @@ -341,6 +367,7 @@ impl TaskGraph { input_dep_indices, output_dep_indices, parent_indices, + task_descriptor.execution_policy, )); Ok(task_idx) } @@ -391,6 +418,39 @@ impl TaskGraph { self.tasks.len() } + #[must_use] + pub fn get_tasks(&self) -> &[Task] { + &self.tasks + } + + #[must_use] + pub fn get_dataflow_dep(&self, index: DataflowDependencyIndex) -> Option<&DataflowDependency> { + self.dataflow_deps.get(index) + } + + #[must_use] + pub const fn get_num_dataflow_deps(&self) -> usize { + self.dataflow_deps.len() + } + + #[must_use] + pub const fn get_commit_task_descriptor(&self) -> Option<&TerminationTaskDescriptor> { + self.commit_task.as_ref() + } + + #[must_use] + pub const fn get_cleanup_task_descriptor(&self) -> Option<&TerminationTaskDescriptor> { + self.cleanup_task.as_ref() + } + + pub fn set_commit_task(&mut self, descriptor: TerminationTaskDescriptor) { + self.commit_task = Some(descriptor); + } + + pub fn set_cleanup_task(&mut self, descriptor: TerminationTaskDescriptor) { + self.cleanup_task = Some(descriptor); + } + /// Computes the input data-flow dependencies and parent task indices for a task based on its /// inputs. /// @@ -598,6 +658,7 @@ impl TaskGraph { inputs, outputs, input_sources, + execution_policy: task.execution_policy.clone(), }; sequence.serialize_element(&task_descriptor)?; } @@ -610,7 +671,15 @@ impl Serialize for TaskGraph { &self, serializer: SerializerImpl, ) -> Result { - let mut map = serializer.serialize_map(Some(SerializableTaskGraphField::COUNT))?; + // Count required fields (schema_version + tasks) plus optional fields. + let mut num_fields = 2; + if self.commit_task.is_some() { + num_fields += 1; + } + if self.cleanup_task.is_some() { + num_fields += 1; + } + let mut map = serializer.serialize_map(Some(num_fields))?; // Iterate the field enum to ensure all fields are serialized and only once. for field in SerializableTaskGraphField::iter() { match field { @@ -622,6 +691,22 @@ impl Serialize for TaskGraph { SerializableTaskGraphField::SchemaVersion.as_str(), TASK_GRAPH_SCHEMA_VERSION, )?, + SerializableTaskGraphField::CommitTask => { + if let Some(commit_task) = &self.commit_task { + map.serialize_entry( + SerializableTaskGraphField::CommitTask.as_str(), + commit_task, + )?; + } + } + SerializableTaskGraphField::CleanupTask => { + if let Some(cleanup_task) = &self.cleanup_task { + map.serialize_entry( + SerializableTaskGraphField::CleanupTask.as_str(), + cleanup_task, + )?; + } + } } } @@ -650,11 +735,13 @@ static TASK_GRAPH_SCHEMA_COMPATIBLE_VERSION_REQUIREMENT: std::sync::LazyLock "schema_version", Self::Tasks => "tasks", + Self::CommitTask => "commit_task", + Self::CleanupTask => "cleanup_task", } } } @@ -685,6 +774,8 @@ impl<'deserializer_lifetime> Visitor<'deserializer_lifetime> for TaskGraphVisito ) -> Result { let mut schema_version_raw: Option = None; let mut tasks_result: Option, _>> = None; + let mut commit_task: Option = None; + let mut cleanup_task: Option = None; while let Some(key) = map.next_key::()? { match key { @@ -706,6 +797,22 @@ impl<'deserializer_lifetime> Visitor<'deserializer_lifetime> for TaskGraphVisito // but defer the dispatching. tasks_result = Some(map.next_value()); } + SerializableTaskGraphField::CommitTask => { + if commit_task.is_some() { + return Err(de::Error::duplicate_field( + SerializableTaskGraphField::CommitTask.as_str(), + )); + } + commit_task = Some(map.next_value()?); + } + SerializableTaskGraphField::CleanupTask => { + if cleanup_task.is_some() { + return Err(de::Error::duplicate_field( + SerializableTaskGraphField::CleanupTask.as_str(), + )); + } + cleanup_task = Some(map.next_value()?); + } } } @@ -743,6 +850,9 @@ impl<'deserializer_lifetime> Visitor<'deserializer_lifetime> for TaskGraphVisito } } + graph.commit_task = commit_task; + graph.cleanup_task = cleanup_task; + Ok(graph) } } @@ -912,6 +1022,7 @@ mod tests { inputs: vec![int32_type.clone(), float64_type.clone()], outputs: vec![int64_type.clone(), bool_type.clone()], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_0 insertion should succeed"); @@ -924,6 +1035,7 @@ mod tests { inputs: vec![bytes_type.clone()], outputs: vec![list_int32_type.clone(), bytes_type.clone()], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_1 insertion should succeed"); @@ -939,6 +1051,7 @@ mod tests { task_idx: 0, position: 0, }]), + execution_policy: ExecutionPolicy::default(), }) .expect("task_2 insertion should succeed"); @@ -960,6 +1073,7 @@ mod tests { position: 1, }, ]), + execution_policy: ExecutionPolicy::default(), }) .expect("task_3 insertion should succeed"); @@ -981,6 +1095,7 @@ mod tests { position: 0, }, ]), + execution_policy: ExecutionPolicy::default(), }) .expect("task_4 insertion should succeed"); @@ -996,6 +1111,7 @@ mod tests { task_idx: 3, position: 0, }]), + execution_policy: ExecutionPolicy::default(), }) .expect("task_5 insertion should succeed"); @@ -1030,6 +1146,7 @@ mod tests { position: 1, }, ]), + execution_policy: ExecutionPolicy::default(), }) .expect("task_6 insertion should succeed"); @@ -1045,6 +1162,7 @@ mod tests { task_idx: 5, position: 1, }]), + execution_policy: ExecutionPolicy::default(), }) .expect("task_7 insertion should succeed"); @@ -1066,6 +1184,7 @@ mod tests { position: 0, }, ]), + execution_policy: ExecutionPolicy::default(), }) .expect("task_8 insertion should succeed"); @@ -1078,6 +1197,7 @@ mod tests { inputs: vec![], outputs: vec![], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_9 insertion should succeed"); @@ -1867,6 +1987,7 @@ mod tests { inputs: vec![int32_type.clone()], outputs: vec![float64_type.clone(), bool_type.clone()], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_0 insertion should succeed"); @@ -1888,6 +2009,7 @@ mod tests { position: 1, }, ]), + execution_policy: ExecutionPolicy::default(), })); // Attempt to create task_1 with 1 input but 0 input sources (mismatched count) @@ -1897,6 +2019,7 @@ mod tests { inputs: vec![float64_type], outputs: vec![int32_type.clone()], input_sources: Some(vec![]), + execution_policy: ExecutionPolicy::default(), })); // Attempt to create task_1 with 0 input but 1 input sources (mismatched count) @@ -1909,6 +2032,7 @@ mod tests { task_idx: 0, position: 0, }]), + execution_policy: ExecutionPolicy::default(), })); // Verify graph state is unchanged @@ -1936,6 +2060,7 @@ mod tests { inputs: vec![], outputs: vec![int32_type.clone()], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_0 insertion should succeed"); @@ -1948,6 +2073,7 @@ mod tests { inputs: vec![], outputs: vec![int32_type], input_sources: Some(vec![]), + execution_policy: ExecutionPolicy::default(), })); } @@ -1969,6 +2095,7 @@ mod tests { inputs: vec![int32_type.clone()], outputs: vec![float64_type, bool_type], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_0 insertion should succeed"); @@ -1984,6 +2111,7 @@ mod tests { task_idx: 0, position: 0, }]), + execution_policy: ExecutionPolicy::default(), })); } @@ -2003,6 +2131,7 @@ mod tests { inputs: vec![int32_type.clone()], outputs: vec![float64_type.clone()], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_0 insertion should succeed"); @@ -2018,6 +2147,7 @@ mod tests { task_idx: 5, position: 0, }]), + execution_policy: ExecutionPolicy::default(), })); } @@ -2038,6 +2168,7 @@ mod tests { inputs: vec![int32_type.clone()], outputs: vec![float64_type.clone(), bool_type], input_sources: None, + execution_policy: ExecutionPolicy::default(), }) .expect("task_0 insertion should succeed"); @@ -2053,6 +2184,7 @@ mod tests { task_idx: 0, position: 2, }]), + execution_policy: ExecutionPolicy::default(), })); } diff --git a/components/spider-core/src/types/io.rs b/components/spider-core/src/types/io.rs index 4df423f9..26e860f0 100644 --- a/components/spider-core/src/types/io.rs +++ b/components/spider-core/src/types/io.rs @@ -7,7 +7,7 @@ pub struct Value {} pub struct Data {} /// Represents an input of a task. -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub enum TaskInput { ValuePayload(Vec), } diff --git a/components/spider-core/tests/test_task_graph_serde.rs b/components/spider-core/tests/test_task_graph_serde.rs index 1d47c04c..fe8087f4 100644 --- a/components/spider-core/tests/test_task_graph_serde.rs +++ b/components/spider-core/tests/test_task_graph_serde.rs @@ -141,7 +141,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ } } ], - "input_sources": null + "input_sources": null, + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -169,7 +173,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ } } ], - "input_sources": null + "input_sources": null, + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -209,7 +217,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ "task_idx": 0, "position": 0 } - ] + ], + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -253,7 +265,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ "task_idx": 0, "position": 1 } - ] + ], + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -306,7 +322,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ "task_idx": 1, "position": 0 } - ] + ], + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -339,7 +359,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ "task_idx": 3, "position": 0 } - ] + ], + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -396,7 +420,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ "task_idx": 4, "position": 1 } - ] + ], + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -424,7 +452,11 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ "task_idx": 5, "position": 1 } - ] + ], + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", @@ -463,14 +495,22 @@ const TASK_GRAPH_IN_JSON: &str = r#"{ "task_idx": 1, "position": 0 } - ] + ], + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } }, { "tdl_package": "test_pkg", "tdl_function": "fn_10", "inputs": [], "outputs": [], - "input_sources": null + "input_sources": null, + "execution_policy": { + "max_num_instances": 1, + "max_num_retries": 0 + } } ] }"#; diff --git a/components/spider-storage/Cargo.toml b/components/spider-storage/Cargo.toml index 54e29513..ded28e3c 100644 --- a/components/spider-storage/Cargo.toml +++ b/components/spider-storage/Cargo.toml @@ -14,3 +14,6 @@ sqlx = "0.8.6" spider-core = { path = "../spider-core" } thiserror = "2.0.18" tokio = { version = "1.49.0", features = ["rt-multi-thread", "sync"] } + +[dev-dependencies] +tokio = { version = "1.49.0", features = ["rt-multi-thread", "sync", "macros"] } diff --git a/components/spider-storage/src/cache.rs b/components/spider-storage/src/cache.rs index 5a2872f5..f9928cc6 100644 --- a/components/spider-storage/src/cache.rs +++ b/components/spider-storage/src/cache.rs @@ -1,4 +1,47 @@ +// TODO(spider-storage): Address these clippy lints when stabilizing the cache layer. +#[allow( + clippy::future_not_send, + clippy::significant_drop_tightening, + clippy::option_if_let_else, + clippy::missing_errors_doc, + clippy::missing_panics_doc +)] pub mod error; +#[allow( + clippy::future_not_send, + clippy::significant_drop_tightening, + clippy::option_if_let_else, + clippy::missing_errors_doc, + clippy::missing_panics_doc +)] +mod factory; +#[allow( + clippy::future_not_send, + clippy::significant_drop_tightening, + clippy::option_if_let_else, + clippy::missing_errors_doc, + clippy::missing_panics_doc +)] mod job; +#[allow( + clippy::future_not_send, + clippy::significant_drop_tightening, + clippy::option_if_let_else, + clippy::missing_errors_doc, + clippy::missing_panics_doc +)] mod task; +#[allow( + clippy::future_not_send, + clippy::significant_drop_tightening, + clippy::option_if_let_else, + clippy::missing_errors_doc, + clippy::missing_panics_doc +)] mod types; + +pub use factory::*; +pub use job::{JobControlBlock, ReadyQueueConnector, TaskInstancePoolConnector}; + +#[cfg(test)] +mod tests; diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs index 4043c273..65e71041 100644 --- a/components/spider-storage/src/cache/error.rs +++ b/components/spider-storage/src/cache/error.rs @@ -5,15 +5,17 @@ use spider_core::{ }; /// Enums for all possible errors that can happen in the cache. +#[derive(Debug)] pub enum CacheError { Internal(InternalError), Rejection(RejectionError), DbError(crate::db::DbError), } -/// Enums for all internal errors. When these error happens, it is considered that the system is in -/// an inconsistent state and cannot continue to service requests. A restart is needed to recover -/// the cache from the storage. +/// Enums for all internal errors. +/// +/// When these error happens, it is considered that the system is in an inconsistent state and +/// cannot continue to service requests. A restart is needed to recover the cache from the storage. #[derive(thiserror::Error, Debug)] pub enum InternalError { #[error("task output already written by a previous successful task instance")] @@ -61,13 +63,15 @@ pub enum InternalError { impl From for CacheError { fn from(e: InternalError) -> Self { - CacheError::Internal(e) + Self::Internal(e) } } -/// Enums for all rejection errors. When these error happens, it is considered that the request is -/// valid, but cannot be processed due to the current state of the cache. These errors should be -/// forwarded to the client for notification. +/// Enums for all rejection errors. +/// +/// When these error happens, it is considered that the request is valid, but cannot be processed +/// due to the current state of the cache. These errors should be forwarded to the client for +/// notification. #[derive(thiserror::Error, Debug)] pub enum RejectionError { #[error("task instance ID is not registered")] @@ -97,12 +101,12 @@ pub enum RejectionError { impl From for CacheError { fn from(e: RejectionError) -> Self { - CacheError::Rejection(e) + Self::Rejection(e) } } impl From for CacheError { fn from(e: crate::db::DbError) -> Self { - CacheError::DbError(e) + Self::DbError(e) } } diff --git a/components/spider-storage/src/cache/factory.rs b/components/spider-storage/src/cache/factory.rs new file mode 100644 index 00000000..c3900dd1 --- /dev/null +++ b/components/spider-storage/src/cache/factory.rs @@ -0,0 +1,246 @@ +use std::{collections::HashSet, sync::Arc}; + +use spider_core::{ + task::{self as core_task, TaskIndex, TaskState}, + types::{ + id::{JobId, ResourceGroupId}, + io::TaskInput, + }, +}; +use tokio::sync::RwLock; + +use crate::{ + cache::{ + error::{CacheError, InternalError}, + job::{Job, JobControlBlock, ReadyQueueConnector, RwJob, TaskInstancePoolConnector}, + task::{ + BaseTaskControlBlock, + InputReader, + OutputReader, + RetryCounter, + SharedTaskControlBlock, + SharedTerminationTaskControlBlock, + TaskControlBlock, + TaskGraph, + TerminationTaskControlBlock, + ValuePayload, + }, + types::{Reader, Shared, TdlContext, Writer}, + }, + db::InternalJobOrchestration, +}; + +/// The result type of [`build_job`]. +type BuildJobResult = + Result<(JobControlBlock, Vec), CacheError>; + +/// Builds a [`JobControlBlock`] from a user-facing [`core_task::TaskGraph`] and job inputs. +/// +/// Returns the job control block and a list of initially ready task indices (input tasks). +/// +/// # Errors +/// +/// Returns [`CacheError`] if the task graph is corrupted or the job inputs do not match the +/// expected graph inputs. +/// +/// # Panics +/// +/// Panics if a task control block mutex cannot be acquired during construction (indicates a bug). +pub fn build_job< + ReadyQueueConnectorType: ReadyQueueConnector, + DbConnectorType: InternalJobOrchestration, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +>( + job_id: JobId, + owner_id: ResourceGroupId, + core_graph: &core_task::TaskGraph, + job_inputs: Vec, + ready_queue: ReadyQueueConnectorType, + db: DbConnectorType, + pool: TaskInstancePoolConnectorType, +) -> BuildJobResult { + let data_buffers = create_data_buffers(core_graph, job_inputs)?; + let cache_tcbs = build_task_control_blocks(core_graph, &data_buffers); + populate_children(core_graph, &cache_tcbs); + let output_readers = collect_job_outputs(core_graph, &data_buffers); + + let commit_task = core_graph + .get_commit_task_descriptor() + .map(build_termination_tcb); + let cleanup_task = core_graph + .get_cleanup_task_descriptor() + .map(build_termination_tcb); + + let cache_task_graph = TaskGraph { + tasks: cache_tcbs, + outputs: output_readers, + commit_task, + cleanup_task, + }; + + let num_tasks = core_graph.get_num_tasks(); + let job = Job::new( + spider_core::job::JobState::Running, + cache_task_graph, + num_tasks, + ); + let rw_job = RwJob::new(job); + let jcb = JobControlBlock::new(job_id, owner_id, rw_job, ready_queue, db, pool); + + let ready_indices: Vec = core_graph + .get_tasks() + .iter() + .filter(|t| t.is_input_task()) + .map(spider_core::task::Task::get_index) + .collect(); + + Ok((jcb, ready_indices)) +} + +/// Creates shared data buffers for all dataflow dependencies and pre-populates job inputs. +fn create_data_buffers( + core_graph: &core_task::TaskGraph, + job_inputs: Vec, +) -> Result>, CacheError> { + let num_deps = core_graph.get_num_dataflow_deps(); + let mut data_buffers: Vec> = + (0..num_deps).map(|_| Arc::new(RwLock::new(None))).collect(); + + let mut graph_input_dep_indices: Vec = Vec::new(); + for dep_idx in 0..num_deps { + let dep = core_graph.get_dataflow_dep(dep_idx).ok_or_else(|| { + InternalError::TaskGraphCorrupted("dataflow dep index out of bounds".to_owned()) + })?; + if dep.get_src().is_none() { + graph_input_dep_indices.push(dep_idx); + } + } + + if graph_input_dep_indices.len() != job_inputs.len() { + return Err(InternalError::TaskGraphCorrupted(format!( + "expected {} graph inputs, got {} job inputs", + graph_input_dep_indices.len(), + job_inputs.len() + )) + .into()); + } + + for (dep_idx, job_input) in graph_input_dep_indices.iter().zip(job_inputs.into_iter()) { + let TaskInput::ValuePayload(payload) = job_input; + data_buffers[*dep_idx] = Arc::new(RwLock::new(Some(payload))); + } + + Ok(data_buffers) +} + +/// Builds `SharedTaskControlBlock`s for each task (without children populated). +fn build_task_control_blocks( + core_graph: &core_task::TaskGraph, + data_buffers: &[Shared], +) -> Vec { + let core_tasks = core_graph.get_tasks(); + let mut cache_tcbs: Vec = Vec::with_capacity(core_tasks.len()); + + for core_task in core_tasks { + let inputs: Vec = core_task + .get_input_dep_indices() + .iter() + .map(|&dep_idx| InputReader::Value(Reader::new(data_buffers[dep_idx].clone()))) + .collect(); + + let outputs: Vec<_> = core_task + .get_output_dep_indices() + .iter() + .map(|&dep_idx| Writer::new(data_buffers[dep_idx].clone())) + .collect(); + + let num_parents = core_task.get_num_parents(); + let state = if num_parents == 0 { + TaskState::Ready + } else { + TaskState::Pending + }; + + let execution_policy = core_task.get_execution_policy(); + + let tcb = TaskControlBlock { + base: BaseTaskControlBlock { + state, + tdl_context: TdlContext { + package: core_task.get_tdl_package().to_owned(), + func: core_task.get_tdl_function().to_owned(), + }, + instance_ids: HashSet::new(), + max_num_instances: execution_policy.max_num_instances, + retry_counter: RetryCounter::new(execution_policy.max_num_retries), + }, + index: core_task.get_index(), + num_parents, + num_unfinished_parents: num_parents, + inputs, + outputs, + children: Vec::new(), + }; + + cache_tcbs.push(SharedTaskControlBlock::new(tcb)); + } + + cache_tcbs +} + +/// Populates child references for each task control block (second pass). +/// +/// # Panics +/// +/// Panics if a mutex cannot be acquired (should be impossible during single-threaded construction). +fn populate_children(core_graph: &core_task::TaskGraph, cache_tcbs: &[SharedTaskControlBlock]) { + for core_task in core_graph.get_tasks() { + let children: Vec = core_task + .get_child_indices() + .iter() + .map(|&child_idx| cache_tcbs[child_idx].clone()) + .collect(); + + if !children.is_empty() { + let mut tcb_guard = cache_tcbs[core_task.get_index()] + .try_lock_for_construction() + .expect("lock should not be contended during construction"); + tcb_guard.children = children; + } + } +} + +/// Collects job-level outputs (dangling dataflow outputs not consumed by any task). +fn collect_job_outputs( + core_graph: &core_task::TaskGraph, + data_buffers: &[Shared], +) -> Vec { + let mut output_readers: Vec = Vec::new(); + for (dep_idx, buffer) in data_buffers.iter().enumerate() { + if let Some(dep) = core_graph.get_dataflow_dep(dep_idx) + && dep.get_src().is_some() + && dep.get_dst().is_empty() + { + output_readers.push(Reader::new(buffer.clone())); + } + } + output_readers +} + +fn build_termination_tcb( + desc: &core_task::TerminationTaskDescriptor, +) -> SharedTerminationTaskControlBlock { + let tcb = TerminationTaskControlBlock { + base: BaseTaskControlBlock { + state: TaskState::Ready, + tdl_context: TdlContext { + package: desc.tdl_package.clone(), + func: desc.tdl_function.clone(), + }, + instance_ids: HashSet::new(), + max_num_instances: desc.execution_policy.max_num_instances, + retry_counter: RetryCounter::new(desc.execution_policy.max_num_retries), + }, + }; + SharedTerminationTaskControlBlock::new(tcb) +} diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs index 1258808a..84cafdff 100644 --- a/components/spider-storage/src/cache/job.rs +++ b/components/spider-storage/src/cache/job.rs @@ -14,7 +14,6 @@ use crate::{ cache::{ error::{ CacheError, - CacheError::Internal, InternalError, RejectionError, RejectionError::{JobNoLongerCleanupReady, JobNoLongerCommitReady}, @@ -25,6 +24,7 @@ use crate::{ db::InternalJobOrchestration, }; +#[allow(dead_code)] pub struct JobControlBlock< ReadyQueueSenderType: ReadyQueueConnector, DbConnectorType: InternalJobOrchestration, @@ -44,6 +44,24 @@ impl< TaskInstancePoolConnectorType: TaskInstancePoolConnector, > JobControlBlock { + pub(super) const fn new( + id: JobId, + owner_id: ResourceGroupId, + job: RwJob, + ready_queue_connector: ReadyQueueSenderType, + db_connector: DbConnectorType, + task_instance_pool_connector: TaskInstancePoolConnectorType, + ) -> Self { + Self { + id, + owner_id, + job, + ready_queue_connector, + db_connector, + task_instance_pool_connector, + } + } + pub async fn create_task_instance( &self, task_id: TaskId, @@ -83,25 +101,24 @@ impl< ExecutionContext { task_instance_id, tdl_context, - // TODO: Question, what's the input for the commit task? inputs: None, } } TaskId::Cleanup => { let job = self.job.read_if_cleanup_ready().await?; - let commit_tcb = job + let cleanup_tcb = job .task_graph - .get_commit_task() - .ok_or(InternalError::JobNoCommit)?; + .get_cleanup_task() + .ok_or(InternalError::JobNoCleanup)?; let task_instance_id = self .task_instance_pool_connector .get_next_available_task_instance_id(); - let tdl_context = commit_tcb + let tdl_context = cleanup_tcb .register_termination_task_instance(task_instance_id) .await?; self.task_instance_pool_connector - .register_termination_task_instance(task_instance_id, commit_tcb) + .register_termination_task_instance(task_instance_id, cleanup_tcb) .await?; ExecutionContext { task_instance_id, @@ -128,8 +145,10 @@ impl< let ready_task_ids = tcb .complete_task_instance(task_instance_id, task_outputs) .await?; - let num_incompleted_task = job.num_incompleted_tasks.fetch_sub(1, Ordering::Relaxed); + let num_incompleted_task = job.num_incompleted_tasks.fetch_sub(1, Ordering::Relaxed) - 1; + // NOTE: `fetch_sub` returns the previous value, so `num_incompleted_task` is the count + // *before* decrementing. The new count is `num_incompleted_task - 1`. if !ready_task_ids.is_empty() { if num_incompleted_task == 0 { return Err(InternalError::TaskGraphCorrupted( @@ -253,7 +272,7 @@ impl< return Ok(job.state); } } - }; + } let mut job = self.job.write_if_non_terminated().await.map_err(|e| { match &e { @@ -288,6 +307,7 @@ impl< Ok(job.state) } + #[allow(clippy::unused_async, dead_code)] async fn cancel(&self) -> Result { todo!( "Implement this. The job table must be locked for write, and the state of all tasks \ @@ -297,23 +317,39 @@ impl< } } -struct Job { - state: JobState, - task_graph: TaskGraph, - num_incompleted_tasks: AtomicUsize, +pub(super) struct Job { + pub(super) state: JobState, + pub(super) task_graph: TaskGraph, + pub(super) num_incompleted_tasks: AtomicUsize, } -struct RwJob { +impl Job { + pub(super) const fn new(state: JobState, task_graph: TaskGraph, num_tasks: usize) -> Self { + Self { + state, + task_graph, + num_incompleted_tasks: AtomicUsize::new(num_tasks), + } + } +} + +pub(super) struct RwJob { inner: RwLock, } impl RwJob { + pub(super) fn new(job: Job) -> Self { + Self { + inner: RwLock::new(job), + } + } + async fn read_checked( &self, check: fn(&Job) -> Result<(), CacheError>, ) -> Result, CacheError> { let guard = self.inner.read().await; - check(&*guard)?; + check(&guard)?; Ok(guard) } @@ -322,7 +358,7 @@ impl RwJob { check: fn(&Job) -> Result<(), CacheError>, ) -> Result, CacheError> { let guard = self.inner.write().await; - check(&*guard)?; + check(&guard)?; Ok(guard) } @@ -358,6 +394,7 @@ impl RwJob { self.inner.read().await.task_graph.has_commit_task() } + #[allow(dead_code)] pub async fn has_cleanup_task(&self) -> bool { self.inner.read().await.task_graph.has_cleanup_task() } diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs index f1b7968f..ea320244 100644 --- a/components/spider-storage/src/cache/task.rs +++ b/components/spider-storage/src/cache/task.rs @@ -1,29 +1,23 @@ -use std::{ - collections::{HashMap, HashSet}, - future::Ready, - sync::{Arc, atomic::AtomicUsize}, -}; +use std::{collections::HashSet, sync::Arc}; -use serde::Serialize; use spider_core::{ - job::JobState, - task::{DataflowDependencyIndex, Task, TaskIndex, TaskState}, + task::{TaskIndex, TaskState}, types::{ - id::{JobId, TaskInstanceId}, + id::TaskInstanceId, io::{TaskInput, TaskOutput}, }, }; use crate::cache::{ - error::{CacheError, CacheError::Internal, InternalError, RejectionError}, + error::{CacheError, InternalError, RejectionError}, types::{ExecutionContext, Reader, TdlContext, Writer}, }; pub struct TaskGraph { - tasks: Vec, - outputs: Vec, - commit_task: Option, - cleanup_task: Option, + pub(super) tasks: Vec, + pub(super) outputs: Vec, + pub(super) commit_task: Option, + pub(super) cleanup_task: Option, } impl TaskGraph { @@ -38,17 +32,18 @@ impl TaskGraph { if let Some(output) = &*output_guard { outputs.push(output.clone()); } else { - return Err(RejectionError::TaskOutputNotReady.into()); + return Err(RejectionError::TaskOutputNotReady); } } Ok(outputs) } - pub fn has_commit_task(&self) -> bool { + pub const fn has_commit_task(&self) -> bool { self.commit_task.is_some() } - pub fn has_cleanup_task(&self) -> bool { + #[allow(dead_code)] + pub const fn has_cleanup_task(&self) -> bool { self.cleanup_task.is_some() } @@ -67,6 +62,20 @@ pub struct SharedTaskControlBlock { } impl SharedTaskControlBlock { + pub(super) fn new(inner: TaskControlBlock) -> Self { + Self { + inner: Arc::new(tokio::sync::Mutex::new(inner)), + } + } + + /// Attempts to lock the inner mutex without blocking. Only intended for use during + /// construction when no contention is possible. + pub(super) fn try_lock_for_construction( + &self, + ) -> Result, ()> { + self.inner.try_lock().map_err(|_| ()) + } + pub async fn register_task_instance( &self, task_instance_id: TaskInstanceId, @@ -177,6 +186,12 @@ pub struct SharedTerminationTaskControlBlock { } impl SharedTerminationTaskControlBlock { + pub(super) fn new(inner: TerminationTaskControlBlock) -> Self { + Self { + inner: Arc::new(tokio::sync::Mutex::new(inner)), + } + } + pub async fn register_termination_task_instance( &self, task_instance_id: TaskInstanceId, @@ -211,12 +226,12 @@ impl SharedTerminationTaskControlBlock { } } -struct BaseTaskControlBlock { - state: TaskState, - tdl_context: TdlContext, - instance_ids: HashSet, - max_num_instances: usize, - retry_counter: RetryCounter, +pub(super) struct BaseTaskControlBlock { + pub(super) state: TaskState, + pub(super) tdl_context: TdlContext, + pub(super) instance_ids: HashSet, + pub(super) max_num_instances: usize, + pub(super) retry_counter: RetryCounter, } impl BaseTaskControlBlock { @@ -258,14 +273,14 @@ impl BaseTaskControlBlock { error_message: String, ) -> Result { if !self.instance_ids.remove(&task_instance_id) { - return Err(RejectionError::InvalidTaskInstanceId.into()); + return Err(RejectionError::InvalidTaskInstanceId); } if self.state.is_terminal() { - return Err(RejectionError::TaskAlreadyTerminated(self.state.clone()).into()); + return Err(RejectionError::TaskAlreadyTerminated(self.state.clone())); } if self.retry_counter.retry() == 0 { - self.state = if self.instance_ids.len() == 0 { + self.state = if self.instance_ids.is_empty() { TaskState::Running } else { TaskState::Ready @@ -285,14 +300,14 @@ impl BaseTaskControlBlock { } } -struct TaskControlBlock { - base: BaseTaskControlBlock, - index: TaskIndex, - num_parents: usize, - num_unfinished_parents: usize, - inputs: Vec, - outputs: Vec, - children: Vec, +pub(super) struct TaskControlBlock { + pub(super) base: BaseTaskControlBlock, + pub(super) index: TaskIndex, + pub(super) num_parents: usize, + pub(super) num_unfinished_parents: usize, + pub(super) inputs: Vec, + pub(super) outputs: Vec, + pub(super) children: Vec, } impl TaskControlBlock { @@ -329,16 +344,17 @@ impl TaskControlBlock { } } -struct TerminationTaskControlBlock { - base: BaseTaskControlBlock, +pub(super) struct TerminationTaskControlBlock { + pub(super) base: BaseTaskControlBlock, } -type ValuePayload = Option>; +pub(super) type ValuePayload = Option>; #[derive(Clone)] -struct Channel {} +pub(super) struct Channel {} -enum InputReader { +#[allow(dead_code)] +pub(super) enum InputReader { Value(Reader), Channel(Channel), } @@ -346,7 +362,7 @@ enum InputReader { impl InputReader { async fn read_as_task_input(&self) -> Result { match self { - InputReader::Value(value_payload) => { + Self::Value(value_payload) => { let value_guard = value_payload.read().await; if let Some(value) = &*value_guard { Ok(TaskInput::ValuePayload(value.clone())) @@ -354,29 +370,29 @@ impl InputReader { Err(InternalError::TaskInputNotReady.into()) } } - InputReader::Channel(_) => unimplemented!("channel input is not supported yet"), + Self::Channel(_) => unimplemented!("channel input is not supported yet"), } } } -type OutputReader = Reader; +pub(super) type OutputReader = Reader; -type OutputWriter = Writer; +pub(super) type OutputWriter = Writer; -struct RetryCounter { +pub(super) struct RetryCounter { max_num_retries_allowed: usize, retry_count: usize, } impl RetryCounter { - fn new(max_num_retries_allowed: usize) -> Self { + pub(super) const fn new(max_num_retries_allowed: usize) -> Self { Self { max_num_retries_allowed, retry_count: max_num_retries_allowed, } } - fn retry(&mut self) -> usize { + const fn retry(&mut self) -> usize { if self.retry_count == 0 { // In practice, this is possible if the total number of task instances creates are // greater than the number of retries allowed. @@ -387,7 +403,7 @@ impl RetryCounter { num_retries_left } - fn reset(&mut self) { + const fn reset(&mut self) { self.retry_count = self.max_num_retries_allowed; } } diff --git a/components/spider-storage/src/cache/tests.rs b/components/spider-storage/src/cache/tests.rs new file mode 100644 index 00000000..90891135 --- /dev/null +++ b/components/spider-storage/src/cache/tests.rs @@ -0,0 +1,517 @@ +use std::sync::{ + Arc, + atomic::{AtomicU64, Ordering}, +}; + +use async_trait::async_trait; +use spider_core::{ + job::JobState, + task::{ + BytesTypeDescriptor, + DataTypeDescriptor, + ExecutionPolicy, + TaskDescriptor, + TaskGraph as CoreTaskGraph, + TaskIndex, + TerminationTaskDescriptor, + ValueTypeDescriptor, + }, + types::{ + id::{JobId, ResourceGroupId, TaskInstanceId}, + io::{TaskInput, TaskOutput}, + }, +}; +use tokio::sync::Mutex; + +use crate::{ + cache::{ + build_job, + error::InternalError, + job::{ReadyQueueConnector, TaskInstancePoolConnector}, + task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, + types::TaskId, + }, + db::{DbError, InternalJobOrchestration}, +}; + +// --- Mock implementations --- + +type ReadyTaskList = Arc)>>>; + +struct MockReadyQueue { + ready_tasks: ReadyTaskList, + commit_ready_count: Arc, + cleanup_ready_count: Arc, +} + +impl MockReadyQueue { + fn new() -> Self { + Self { + ready_tasks: Arc::new(Mutex::new(Vec::new())), + commit_ready_count: Arc::new(AtomicU64::new(0)), + cleanup_ready_count: Arc::new(AtomicU64::new(0)), + } + } +} + +#[async_trait] +impl ReadyQueueConnector for MockReadyQueue { + async fn send_task_ready( + &self, + job_id: JobId, + task_ids: Vec, + ) -> Result<(), InternalError> { + self.ready_tasks.lock().await.push((job_id, task_ids)); + Ok(()) + } + + async fn send_commit_ready(&self, _job_id: JobId) -> Result<(), InternalError> { + self.commit_ready_count.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + + async fn send_cleanup_ready(&self, _job_id: JobId) -> Result<(), InternalError> { + self.cleanup_ready_count.fetch_add(1, Ordering::Relaxed); + Ok(()) + } +} + +struct MockDb { + has_commit_task: bool, +} + +impl MockDb { + fn new(has_commit_task: bool) -> Self { + Self { has_commit_task } + } +} + +#[async_trait] +impl InternalJobOrchestration for MockDb { + async fn set_state(&self, _job_id: JobId, _state: JobState) -> Result<(), DbError> { + Ok(()) + } + + async fn commit_outputs( + &self, + _job_id: JobId, + _job_outputs: Vec, + ) -> Result { + if self.has_commit_task { + Ok(JobState::CommitReady) + } else { + Ok(JobState::Succeeded) + } + } + + async fn cancel(&self, _job_id: JobId) -> Result { + Ok(JobState::Cancelled) + } + + async fn fail(&self, _job_id: JobId, _error_message: String) -> Result<(), DbError> { + Ok(()) + } + + async fn delete_expired_terminated_jobs( + &self, + _expire_after: std::time::Duration, + ) -> Result, DbError> { + Ok(Vec::new()) + } +} + +struct MockInstancePool { + next_id: AtomicU64, +} + +impl MockInstancePool { + fn new() -> Self { + Self { + next_id: AtomicU64::new(1), + } + } +} + +#[async_trait] +impl TaskInstancePoolConnector for MockInstancePool { + fn get_next_available_task_instance_id(&self) -> TaskInstanceId { + self.next_id.fetch_add(1, Ordering::Relaxed) + } + + async fn register_task_instance( + &self, + _task_instance_id: TaskInstanceId, + _task: SharedTaskControlBlock, + ) -> Result<(), InternalError> { + Ok(()) + } + + async fn register_termination_task_instance( + &self, + _task_instance_id: TaskInstanceId, + _termination_task: SharedTerminationTaskControlBlock, + ) -> Result<(), InternalError> { + Ok(()) + } +} + +// --- Helper: simple byte type descriptor --- + +fn bytes_type() -> DataTypeDescriptor { + DataTypeDescriptor::Value(ValueTypeDescriptor::Bytes(BytesTypeDescriptor {})) +} + +// --- Tests --- + +/// Tests the factory and end-to-end execution with a simple linear chain: A -> B -> C. +/// +/// Graph topology: +/// ```text +/// [job_input] -> A -> B -> C -> [job_output] +/// ``` +/// +/// Verifies: +/// - `build_job` correctly identifies only A as initially ready (B and C depend on predecessors). +/// - Job inputs are pre-populated: A receives the original `b"hello"` bytes. +/// - Dataflow wiring works across the chain: A's output `b"world"` is delivered as B's input, and +/// B's output `b"done"` is delivered as C's input. +/// - The job remains in `Running` state while tasks are still incomplete. +/// - The job transitions to `Succeeded` once the final task (C) completes, since there is no commit +/// task configured. +#[tokio::test] +async fn test_factory_linear_chain() { + let mut graph = CoreTaskGraph::default(); + let task_a = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_a".into(), + inputs: vec![bytes_type()], + outputs: vec![bytes_type()], + input_sources: None, + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + let task_b = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_b".into(), + inputs: vec![bytes_type()], + outputs: vec![bytes_type()], + input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + task_idx: task_a, + position: 0, + }]), + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + let task_c = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_c".into(), + inputs: vec![bytes_type()], + outputs: vec![bytes_type()], + input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + task_idx: task_b, + position: 0, + }]), + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + + let job_inputs = vec![TaskInput::ValuePayload(b"hello".to_vec())]; + + let (jcb, ready_indices) = build_job( + JobId::new(), + ResourceGroupId::new(), + &graph, + job_inputs, + MockReadyQueue::new(), + MockDb::new(false), + MockInstancePool::new(), + ) + .unwrap(); + + // Only task A should be ready initially. + assert_eq!(ready_indices, vec![task_a]); + + // Execute task A. + let ctx_a = jcb + .create_task_instance(TaskId::TaskIndex(task_a)) + .await + .unwrap(); + assert!(ctx_a.inputs.is_some()); + let inputs_a = ctx_a.inputs.unwrap(); + assert_eq!(inputs_a.len(), 1); + assert_eq!( + inputs_a[0], + TaskInput::ValuePayload(b"hello".to_vec()), + "task A should receive the job input" + ); + + let state = jcb + .complete_task_instance(ctx_a.task_instance_id, task_a, vec![b"world".to_vec()]) + .await + .unwrap(); + assert_eq!(state, JobState::Running); + + // Execute task B. + let ctx_b = jcb + .create_task_instance(TaskId::TaskIndex(task_b)) + .await + .unwrap(); + let inputs_b = ctx_b.inputs.unwrap(); + assert_eq!( + inputs_b[0], + TaskInput::ValuePayload(b"world".to_vec()), + "task B should receive task A's output" + ); + + let state = jcb + .complete_task_instance(ctx_b.task_instance_id, task_b, vec![b"done".to_vec()]) + .await + .unwrap(); + assert_eq!(state, JobState::Running); + + // Execute task C — the last task. + let ctx_c = jcb + .create_task_instance(TaskId::TaskIndex(task_c)) + .await + .unwrap(); + let inputs_c = ctx_c.inputs.unwrap(); + assert_eq!(inputs_c[0], TaskInput::ValuePayload(b"done".to_vec())); + + let state = jcb + .complete_task_instance(ctx_c.task_instance_id, task_c, vec![b"final".to_vec()]) + .await + .unwrap(); + assert_eq!( + state, + JobState::Succeeded, + "job should succeed after all tasks complete" + ); +} + +/// Tests the factory and end-to-end execution with a diamond DAG that exercises fan-out and +/// fan-in. +/// +/// Graph topology: +/// ```text +/// ┌─> B ─┐ +/// [job_input] -> A -> D -> [job_output] +/// └─> C ─┘ +/// ``` +/// +/// A has two outputs; B consumes output 0, C consumes output 1. D consumes both B's and C's +/// outputs (fan-in from two parents). +/// +/// Verifies: +/// - Only A is initially ready. +/// - Completing A unblocks both B and C simultaneously — both appear in a single `send_task_ready` +/// call to the ready queue. +/// - D is not unblocked until *both* B and C have completed (fan-in gate). +/// - The job transitions to `Succeeded` once the sink task (D) completes. +#[tokio::test] +#[allow(clippy::too_many_lines)] +async fn test_factory_diamond_dag() { + let mut graph = CoreTaskGraph::default(); + let task_a = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_a".into(), + inputs: vec![bytes_type()], + outputs: vec![bytes_type(), bytes_type()], + input_sources: None, + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + let task_b = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_b".into(), + inputs: vec![bytes_type()], + outputs: vec![bytes_type()], + input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + task_idx: task_a, + position: 0, + }]), + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + let task_c = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_c".into(), + inputs: vec![bytes_type()], + outputs: vec![bytes_type()], + input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + task_idx: task_a, + position: 1, + }]), + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + let task_d = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_d".into(), + inputs: vec![bytes_type(), bytes_type()], + outputs: vec![bytes_type()], + input_sources: Some(vec![ + spider_core::task::TaskInputOutputIndex { + task_idx: task_b, + position: 0, + }, + spider_core::task::TaskInputOutputIndex { + task_idx: task_c, + position: 0, + }, + ]), + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + + let job_inputs = vec![TaskInput::ValuePayload(b"input".to_vec())]; + let ready_queue = MockReadyQueue::new(); + let ready_tasks_ref = ready_queue.ready_tasks.clone(); + + let (jcb, ready_indices) = build_job( + JobId::new(), + ResourceGroupId::new(), + &graph, + job_inputs, + ready_queue, + MockDb::new(false), + MockInstancePool::new(), + ) + .unwrap(); + + assert_eq!(ready_indices, vec![task_a]); + + // Complete A. + let ctx_a = jcb + .create_task_instance(TaskId::TaskIndex(task_a)) + .await + .unwrap(); + let state = jcb + .complete_task_instance( + ctx_a.task_instance_id, + task_a, + vec![b"out_b".to_vec(), b"out_c".to_vec()], + ) + .await + .unwrap(); + assert_eq!(state, JobState::Running); + + // Check that B and C were enqueued as ready. + let queued = ready_tasks_ref.lock().await; + assert_eq!(queued.len(), 1); + let (_, ref task_ids) = queued[0]; + assert!(task_ids.contains(&task_b)); + assert!(task_ids.contains(&task_c)); + drop(queued); + + // Complete B and C. + let ctx_b = jcb + .create_task_instance(TaskId::TaskIndex(task_b)) + .await + .unwrap(); + jcb.complete_task_instance(ctx_b.task_instance_id, task_b, vec![b"b_out".to_vec()]) + .await + .unwrap(); + + let ctx_c = jcb + .create_task_instance(TaskId::TaskIndex(task_c)) + .await + .unwrap(); + jcb.complete_task_instance(ctx_c.task_instance_id, task_c, vec![b"c_out".to_vec()]) + .await + .unwrap(); + + // D should now be ready. Complete it. + let ctx_d = jcb + .create_task_instance(TaskId::TaskIndex(task_d)) + .await + .unwrap(); + let state = jcb + .complete_task_instance(ctx_d.task_instance_id, task_d, vec![b"final".to_vec()]) + .await + .unwrap(); + assert_eq!(state, JobState::Succeeded); +} + +/// Tests the commit task lifecycle: job transitions through `CommitReady` before `Succeeded`. +/// +/// Graph topology: +/// ```text +/// A -> [job_output] +/// (commit task: commit_fn) +/// ``` +/// +/// A single task (A) with no inputs and one dangling output. A `TerminationTaskDescriptor` is +/// attached as the commit task. The mock DB is configured to return `CommitReady` on +/// `commit_outputs` (simulating a job that has a commit task). +/// +/// Verifies: +/// - After A completes and outputs are committed, the job transitions to `CommitReady` (not +/// directly to `Succeeded`). +/// - The ready queue receives exactly one `send_commit_ready` notification. +/// - The commit task can be registered via `TaskId::Commit` and returns no inputs (`inputs: None`), +/// since termination tasks do not consume dataflow outputs. +/// - After the commit task instance completes, the job transitions to `Succeeded`. +#[tokio::test] +async fn test_factory_with_commit_task() { + let mut graph = CoreTaskGraph::default(); + let task_a = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "fn_a".into(), + inputs: vec![], + outputs: vec![bytes_type()], + input_sources: None, + execution_policy: ExecutionPolicy::default(), + }) + .unwrap(); + + graph.set_commit_task(TerminationTaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: "commit_fn".into(), + execution_policy: ExecutionPolicy::default(), + }); + + let ready_queue = MockReadyQueue::new(); + let commit_count = ready_queue.commit_ready_count.clone(); + + let (jcb, ready_indices) = build_job( + JobId::new(), + ResourceGroupId::new(), + &graph, + vec![], + ready_queue, + MockDb::new(true), + MockInstancePool::new(), + ) + .unwrap(); + + assert_eq!(ready_indices, vec![task_a]); + + // Complete task A. + let ctx_a = jcb + .create_task_instance(TaskId::TaskIndex(task_a)) + .await + .unwrap(); + let state = jcb + .complete_task_instance(ctx_a.task_instance_id, task_a, vec![b"output".to_vec()]) + .await + .unwrap(); + assert_eq!(state, JobState::CommitReady); + assert_eq!(commit_count.load(Ordering::Relaxed), 1); + + // Execute commit task. + let ctx_commit = jcb.create_task_instance(TaskId::Commit).await.unwrap(); + assert!(ctx_commit.inputs.is_none()); + let state = jcb + .complete_commit_task_instance(ctx_commit.task_instance_id) + .await + .unwrap(); + assert_eq!(state, JobState::Succeeded); +} diff --git a/components/spider-storage/src/cache/types.rs b/components/spider-storage/src/cache/types.rs index 8ecae1f1..23b5cb82 100644 --- a/components/spider-storage/src/cache/types.rs +++ b/components/spider-storage/src/cache/types.rs @@ -15,8 +15,8 @@ pub struct Reader { } impl Reader { - pub fn new(inner: Shared) -> Reader { - Reader { inner } + pub const fn new(inner: Shared) -> Self { + Self { inner } } pub async fn read(&self) -> RwLockReadGuard<'_, Type> { @@ -30,8 +30,8 @@ pub struct Writer { } impl Writer { - pub fn new(inner: Shared) -> Writer { - Writer { inner } + pub const fn new(inner: Shared) -> Self { + Self { inner } } pub async fn write(&self) -> RwLockWriteGuard<'_, Type> { @@ -41,8 +41,8 @@ impl Writer { #[derive(Serialize, Clone)] pub struct TdlContext { - package: String, - func: String, + pub(super) package: String, + pub(super) func: String, } #[derive(Serialize)] diff --git a/components/spider-storage/src/protocol.rs b/components/spider-storage/src/protocol.rs index 392cab7f..ef6bcbc8 100644 --- a/components/spider-storage/src/protocol.rs +++ b/components/spider-storage/src/protocol.rs @@ -5,17 +5,7 @@ use spider_core::{ job::JobState, task::{TaskGraph, TaskMetadata}, types::{ - id::{ - DataId, - JobId, - ResourceGroupId, - SchedulerId, - SignedJobId, - SignedTaskId, - TaskId, - TaskInstanceId, - WorkerId, - }, + id::{DataId, JobId, ResourceGroupId, SchedulerId, SignedJobId, TaskId, WorkerId}, io::{Data, TaskInput, TaskOutput}, }, }; From 1947b08b198383b9a6006e63c4615d827034d008 Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Tue, 17 Mar 2026 23:01:09 -0400 Subject: [PATCH 7/8] Test framework implemented by claude. --- components/spider-storage/src/cache/task.rs | 2 +- components/spider-storage/src/cache/tests.rs | 1256 ++++++++++++++++-- 2 files changed, 1172 insertions(+), 86 deletions(-) diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs index ea320244..929bd14c 100644 --- a/components/spider-storage/src/cache/task.rs +++ b/components/spider-storage/src/cache/task.rs @@ -279,7 +279,7 @@ impl BaseTaskControlBlock { return Err(RejectionError::TaskAlreadyTerminated(self.state.clone())); } - if self.retry_counter.retry() == 0 { + if self.retry_counter.retry() != 0 { self.state = if self.instance_ids.is_empty() { TaskState::Running } else { diff --git a/components/spider-storage/src/cache/tests.rs b/components/spider-storage/src/cache/tests.rs index 90891135..d102d2e4 100644 --- a/components/spider-storage/src/cache/tests.rs +++ b/components/spider-storage/src/cache/tests.rs @@ -1,19 +1,18 @@ -use std::sync::{ - Arc, - atomic::{AtomicU64, Ordering}, +use std::{ + collections::HashSet, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + time::{Duration, Instant}, }; use async_trait::async_trait; use spider_core::{ job::JobState, task::{ - BytesTypeDescriptor, - DataTypeDescriptor, - ExecutionPolicy, - TaskDescriptor, - TaskGraph as CoreTaskGraph, - TaskIndex, - TerminationTaskDescriptor, + BytesTypeDescriptor, DataTypeDescriptor, ExecutionPolicy, TaskDescriptor, + TaskGraph as CoreTaskGraph, TaskIndex, TaskInputOutputIndex, TerminationTaskDescriptor, ValueTypeDescriptor, }, types::{ @@ -26,22 +25,30 @@ use tokio::sync::Mutex; use crate::{ cache::{ build_job, - error::InternalError, + error::{CacheError, InternalError, RejectionError}, job::{ReadyQueueConnector, TaskInstancePoolConnector}, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, - types::TaskId, + types::{ExecutionContext, TaskId}, }, db::{DbError, InternalJobOrchestration}, }; -// --- Mock implementations --- +// ============================================================================= +// Mock implementations +// ============================================================================= type ReadyTaskList = Arc)>>>; +/// A mock ready queue that records all ready-task notifications. +/// +/// When `worker_txs` is set, it round-robin dispatches newly-ready task indices across +/// per-worker channels so that workers can pick them up without contention on a shared receiver. struct MockReadyQueue { ready_tasks: ReadyTaskList, commit_ready_count: Arc, cleanup_ready_count: Arc, + worker_txs: Option>>, + round_robin_counter: AtomicU64, } impl MockReadyQueue { @@ -50,6 +57,18 @@ impl MockReadyQueue { ready_tasks: Arc::new(Mutex::new(Vec::new())), commit_ready_count: Arc::new(AtomicU64::new(0)), cleanup_ready_count: Arc::new(AtomicU64::new(0)), + worker_txs: None, + round_robin_counter: AtomicU64::new(0), + } + } + + fn with_worker_channels(txs: Vec>) -> Self { + Self { + ready_tasks: Arc::new(Mutex::new(Vec::new())), + commit_ready_count: Arc::new(AtomicU64::new(0)), + cleanup_ready_count: Arc::new(AtomicU64::new(0)), + worker_txs: Some(txs), + round_robin_counter: AtomicU64::new(0), } } } @@ -61,6 +80,14 @@ impl ReadyQueueConnector for MockReadyQueue { job_id: JobId, task_ids: Vec, ) -> Result<(), InternalError> { + if let Some(txs) = &self.worker_txs { + let num_workers = txs.len(); + for &idx in &task_ids { + let worker = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize + % num_workers; + let _ = txs[worker].send(idx); + } + } self.ready_tasks.lock().await.push((job_id, task_ids)); Ok(()) } @@ -155,22 +182,654 @@ impl TaskInstancePoolConnector for MockInstancePool { } } -// --- Helper: simple byte type descriptor --- +// ============================================================================= +// Test execution framework +// ============================================================================= + +/// Pluggable per-task behavior during test execution. +/// +/// Implementations control what happens when a worker picks up a task: they can complete it +/// immediately, inject failures, add delays, or register multiple instances. +#[async_trait] +trait TaskHandler: Send + Sync { + /// Returns the number of concurrent instances to register per task. + fn num_instances(&self) -> usize { + 1 + } + + /// Called for each task instance. Returns outputs to submit on success, or an error message + /// to fail the instance. + async fn handle_instance( + &self, + task_index: TaskIndex, + instance_index: usize, + ctx: &ExecutionContext, + ) -> Result, String>; +} + +/// Default handler: immediately completes each task with 1KB outputs (1 instance per task). +struct ImmediateCompletionHandler { + num_outputs_per_task: usize, +} + +#[async_trait] +impl TaskHandler for ImmediateCompletionHandler { + async fn handle_instance( + &self, + _task_index: TaskIndex, + _instance_index: usize, + _ctx: &ExecutionContext, + ) -> Result, String> { + Ok((0..self.num_outputs_per_task) + .map(|_| make_1kb_payload()) + .collect()) + } +} + +/// Handler that runs 3 instances per task: 2 succeed and 1 fails. +/// +/// The last instance always fails; the first two succeed. Since instances run concurrently, the +/// first completion wins and subsequent completions/failures receive rejection errors (e.g. +/// `TaskAlreadyTerminated`), which are handled gracefully by the worker loop. +struct MultiInstancePartialFailHandler { + num_outputs_per_task: usize, + num_instances: usize, +} + +#[async_trait] +impl TaskHandler for MultiInstancePartialFailHandler { + fn num_instances(&self) -> usize { + self.num_instances + } + + async fn handle_instance( + &self, + _task_index: TaskIndex, + instance_index: usize, + _ctx: &ExecutionContext, + ) -> Result, String> { + if instance_index < self.num_instances - 1 { + Ok((0..self.num_outputs_per_task) + .map(|_| make_1kb_payload()) + .collect()) + } else { + Err(format!("simulated failure for instance {instance_index}")) + } + } +} + +/// Handler where every instance always fails. Used to test retry exhaustion. +struct AlwaysFailHandler; + +#[async_trait] +impl TaskHandler for AlwaysFailHandler { + async fn handle_instance( + &self, + task_index: TaskIndex, + instance_index: usize, + _ctx: &ExecutionContext, + ) -> Result, String> { + Err(format!( + "permanent failure for task {task_index} instance {instance_index}" + )) + } +} + +/// Returns true if the error is a rejection (expected under concurrency), false if internal. +/// Internal errors are unexpected and should be propagated. +fn is_rejection(err: &CacheError) -> bool { + matches!(err, CacheError::Rejection(_)) +} + +/// A single instance's timing: which task, which instance, how long from register to complete/fail. +#[derive(Debug)] +struct InstanceLatency { + task_index: TaskIndex, + duration: Duration, +} + +/// Collected results from a test run. +struct TestResult { + total_execution_time: Duration, + /// Per-instance latencies (one entry per `create_task_instance` → `complete/fail` cycle). + /// For single-instance tests, there is one entry per task. For multi-instance tests, there + /// are multiple entries per task. + instance_latencies: Vec, + /// Number of unique tasks that were dispatched to workers. + tasks_dispatched: usize, + final_state: JobState, + ready_queue_call_count: usize, + total_tasks_reported_ready: usize, +} + +impl TestResult { + fn report( + &self, + test_name: &str, + num_workers: usize, + graph_construction_time: Duration, + build_job_time: Duration, + ) { + let mut sorted: Vec = self + .instance_latencies + .iter() + .map(|l| l.duration.as_secs_f64() * 1000.0) + .collect(); + sorted.sort_by(|a, b| a.partial_cmp(b).expect("latencies should be comparable")); + + let num_instances = sorted.len(); + let avg = if num_instances > 0 { + sorted.iter().sum::() / num_instances as f64 + } else { + 0.0 + }; + let p50 = percentile(&sorted, 50.0); + let p95 = percentile(&sorted, 95.0); + let p99 = percentile(&sorted, 99.0); + + eprintln!(); + eprintln!("=== {test_name} ({num_workers} workers) ==="); + eprintln!( + " graph_construction: {:>10.2} ms", + graph_construction_time.as_secs_f64() * 1000.0 + ); + eprintln!( + " build_job: {:>10.2} ms", + build_job_time.as_secs_f64() * 1000.0 + ); + eprintln!( + " total_execution: {:>10.2} ms", + self.total_execution_time.as_secs_f64() * 1000.0 + ); + eprintln!(" tasks_dispatched: {:>10}", self.tasks_dispatched); + eprintln!(" instances_measured: {:>10}", num_instances); + eprintln!(" avg_per_instance_latency: {avg:>10.3} ms"); + eprintln!(" p50_per_instance_latency: {p50:>10.3} ms"); + eprintln!(" p95_per_instance_latency: {p95:>10.3} ms"); + eprintln!(" p99_per_instance_latency: {p99:>10.3} ms"); + eprintln!( + " ready_queue_calls: {:>10}", + self.ready_queue_call_count + ); + eprintln!( + " total_tasks_reported_ready: {:>10}", + self.total_tasks_reported_ready + ); + eprintln!(); + } +} + +fn percentile(sorted: &[f64], pct: f64) -> f64 { + if sorted.is_empty() { + return 0.0; + } + let idx = (pct / 100.0 * (sorted.len() - 1) as f64).round() as usize; + sorted[idx.min(sorted.len() - 1)] +} + +/// Full entry point for a scheduled test: builds the job, runs workers, returns results. +/// +/// Each worker gets its own dedicated channel. The `MockReadyQueue` round-robins newly-ready +/// task indices across worker channels, eliminating contention on a shared receiver. +/// +/// When `task_handler.num_instances() > 1`, the worker dispatches multiple concurrent instances +/// per task. Each instance calls `create_task_instance` independently; one succeeds, the rest +/// may get rejection errors (e.g. `TaskAlreadyTerminated`) which are handled gracefully. +#[allow(clippy::too_many_lines)] +async fn run_scheduled_test( + graph: &CoreTaskGraph, + job_inputs: Vec, + num_workers: usize, + task_handler: Arc, +) -> (TestResult, Duration) { + // Create per-worker channels. + let mut worker_txs = Vec::with_capacity(num_workers); + let mut worker_rxs = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + worker_txs.push(tx); + worker_rxs.push(rx); + } + + let ready_queue = MockReadyQueue::with_worker_channels(worker_txs.clone()); + let ready_tasks_ref = ready_queue.ready_tasks.clone(); + + let build_start = Instant::now(); + let (jcb, initial_ready) = build_job( + JobId::new(), + ResourceGroupId::new(), + graph, + job_inputs, + ready_queue, + MockDb::new(false), + MockInstancePool::new(), + ) + .expect("build_job should succeed"); + let build_job_time = build_start.elapsed(); + + // Seed initial ready tasks round-robin across worker channels. + for (i, &idx) in initial_ready.iter().enumerate() { + worker_txs[i % num_workers] + .send(idx) + .expect("worker channel should be open during seeding"); + } + + let jcb = Arc::new(jcb); + let latencies: Arc>> = Arc::new(Mutex::new(Vec::new())); + let tasks_dispatched = Arc::new(AtomicU64::new(0)); + let done = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let num_instances_per_task = task_handler.num_instances(); + let num_tasks = graph.get_num_tasks(); + + let exec_start = Instant::now(); + + // Spawn workers, each with its own receiver. + let mut worker_handles = Vec::with_capacity(num_workers); + for mut rx in worker_rxs { + let jcb = Arc::clone(&jcb); + let handler = Arc::clone(&task_handler); + let latencies = Arc::clone(&latencies); + let tasks_dispatched = Arc::clone(&tasks_dispatched); + let done = Arc::clone(&done); + + worker_handles.push(tokio::spawn(async move { + loop { + if done.load(Ordering::Relaxed) { + break; + } + + let task_idx = match rx.try_recv() { + Ok(idx) => idx, + Err(_) => { + tokio::task::yield_now().await; + continue; + } + }; + + tasks_dispatched.fetch_add(1, Ordering::Relaxed); + + let (terminal, mut instance_lats) = if num_instances_per_task == 1 { + execute_single_instance(&jcb, &*handler, task_idx).await + } else { + execute_multi_instance( + &jcb, &*handler, task_idx, num_instances_per_task, + ).await + }; + + latencies.lock().await.append(&mut instance_lats); + + if terminal { + done.store(true, Ordering::Relaxed); + break; + } + } + })); + } + + drop(worker_txs); + + for handle in worker_handles { + handle.await.expect("worker task should not panic"); + } + + let total_execution_time = exec_start.elapsed(); + + let ready_queue_snapshot = ready_tasks_ref.lock().await; + let ready_queue_call_count = ready_queue_snapshot.len(); + let total_tasks_reported_ready: usize = ready_queue_snapshot + .iter() + .map(|(_, ids)| ids.len()) + .sum(); + drop(ready_queue_snapshot); + + let instance_latencies = Arc::try_unwrap(latencies) + .expect("all workers should have finished by now") + .into_inner(); + + let tasks_dispatched = tasks_dispatched.load(Ordering::Relaxed) as usize; + let final_state = if tasks_dispatched == num_tasks { + JobState::Succeeded + } else { + JobState::Failed + }; + + let result = TestResult { + total_execution_time, + instance_latencies, + tasks_dispatched, + final_state, + ready_queue_call_count, + total_tasks_reported_ready, + }; + + (result, build_job_time) +} + +type JcbType = crate::cache::job::JobControlBlock; + +/// Executes a single instance for the given task. +/// Returns `(is_terminal, instance_latencies)`. +async fn execute_single_instance( + jcb: &Arc, + handler: &dyn TaskHandler, + task_idx: TaskIndex, +) -> (bool, Vec) { + let inst_start = Instant::now(); + + let ctx = match jcb.create_task_instance(TaskId::TaskIndex(task_idx)).await { + Ok(ctx) => ctx, + Err(e) => { + assert!( + is_rejection(&e), + "create_task_instance for task {task_idx} returned unexpected error: {e:?}" + ); + return (false, Vec::new()); + } + }; + + let terminal = match handler.handle_instance(task_idx, 0, &ctx).await { + Ok(outputs) => { + match jcb + .complete_task_instance(ctx.task_instance_id, task_idx, outputs) + .await + { + Ok(state) => state.is_terminal(), + Err(e) => { + assert!( + is_rejection(&e), + "complete_task_instance for task {task_idx} returned unexpected error: {e:?}" + ); + false + } + } + } + Err(error_message) => { + match jcb + .fail_task_instance( + ctx.task_instance_id, + TaskId::TaskIndex(task_idx), + error_message, + ) + .await + { + Ok(state) => state.is_terminal(), + Err(e) => { + assert!( + is_rejection(&e), + "fail_task_instance for task {task_idx} returned unexpected error: {e:?}" + ); + false + } + } + } + }; + + let lat = InstanceLatency { + task_index: task_idx, + duration: inst_start.elapsed(), + }; + (terminal, vec![lat]) +} + +/// Executes multiple concurrent instances for the given task. +/// Returns `(is_terminal, instance_latencies)` with one latency entry per instance. +/// Each instance's timing covers the full cycle: `create_task_instance` → handler → +/// `complete/fail_task_instance`. +async fn execute_multi_instance( + jcb: &Arc, + handler: &dyn TaskHandler, + task_idx: TaskIndex, + num_instances: usize, +) -> (bool, Vec) { + // Pre-compute each instance's outcome so we can move it into the spawned task. + // We use a dummy ExecutionContext for the handler since the real one is created inside + // the coroutine. + let dummy_ctx = ExecutionContext { + task_instance_id: 0, + tdl_context: crate::cache::types::TdlContext { + package: String::new(), + func: String::new(), + }, + inputs: None, + }; + let mut outcomes: Vec, String>> = Vec::with_capacity(num_instances); + for i in 0..num_instances { + outcomes.push(handler.handle_instance(task_idx, i, &dummy_ctx).await); + } + + // Spawn one coroutine per instance. Each coroutine does the full cycle: + // create_task_instance → complete/fail → record latency. + let mut handles = Vec::with_capacity(num_instances); + for (instance_index, outcome) in outcomes.into_iter().enumerate() { + let jcb = Arc::clone(jcb); + + handles.push(tokio::spawn(async move { + let inst_start = Instant::now(); + + let ctx = match jcb.create_task_instance(TaskId::TaskIndex(task_idx)).await { + Ok(ctx) => ctx, + Err(e) => { + assert!( + is_rejection(&e), + "create_task_instance for task {task_idx} instance {instance_index} \ + returned unexpected error: {e:?}" + ); + let lat = InstanceLatency { + task_index: task_idx, + duration: inst_start.elapsed(), + }; + return (false, lat); + } + }; + + let terminal = match outcome { + Ok(outputs) => { + match jcb + .complete_task_instance(ctx.task_instance_id, task_idx, outputs) + .await + { + Ok(state) => state.is_terminal(), + Err(e) => { + assert!( + is_rejection(&e), + "complete_task_instance for task {task_idx} instance \ + {instance_index} returned unexpected error: {e:?}" + ); + false + } + } + } + Err(error_message) => { + match jcb + .fail_task_instance( + ctx.task_instance_id, + TaskId::TaskIndex(task_idx), + error_message, + ) + .await + { + Ok(state) => state.is_terminal(), + Err(e) => { + assert!( + is_rejection(&e), + "fail_task_instance for task {task_idx} instance \ + {instance_index} returned unexpected error: {e:?}" + ); + false + } + } + } + }; + + let lat = InstanceLatency { + task_index: task_idx, + duration: inst_start.elapsed(), + }; + (terminal, lat) + })); + } + + let mut terminal = false; + let mut lats = Vec::with_capacity(handles.len()); + for handle in handles { + let (t, lat) = handle.await.expect("instance task should not panic"); + if t { + terminal = true; + } + lats.push(lat); + } + (terminal, lats) +} + +// ============================================================================= +// Graph builders +// ============================================================================= fn bytes_type() -> DataTypeDescriptor { DataTypeDescriptor::Value(ValueTypeDescriptor::Bytes(BytesTypeDescriptor {})) } -// --- Tests --- +fn make_1kb_payload() -> Vec { + vec![0xAB_u8; 1024] +} + +/// Builds a flat graph of `num_tasks` independent tasks, each with `num_inputs` graph-level +/// inputs and `num_outputs` outputs. +fn build_flat_graph( + num_tasks: usize, + num_inputs_per_task: usize, + num_outputs_per_task: usize, +) -> (CoreTaskGraph, Vec) { + build_flat_graph_with_policy( + num_tasks, + num_inputs_per_task, + num_outputs_per_task, + ExecutionPolicy::default(), + ) +} + +fn build_flat_graph_with_policy( + num_tasks: usize, + num_inputs_per_task: usize, + num_outputs_per_task: usize, + policy: ExecutionPolicy, +) -> (CoreTaskGraph, Vec) { + let mut graph = CoreTaskGraph::default(); + for i in 0..num_tasks { + graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: format!("fn_{i}"), + inputs: vec![bytes_type(); num_inputs_per_task], + outputs: vec![bytes_type(); num_outputs_per_task], + input_sources: None, + execution_policy: policy.clone(), + }) + .expect("flat graph task insertion should succeed"); + } + let job_inputs: Vec = (0..num_tasks * num_inputs_per_task) + .map(|_| TaskInput::ValuePayload(make_1kb_payload())) + .collect(); + (graph, job_inputs) +} + +/// Builds a layered neural-network-style graph. +/// +/// Returns `(graph, job_inputs, layers)` where `layers[i]` contains the task indices for layer +/// `i`. Layer 0 tasks are input tasks with `fan_in` graph-level inputs each. Tasks in +/// subsequent layers receive outputs from `fan_in` tasks in the previous layer using circular +/// connectivity: task at position `p` in layer `L` receives outputs from positions +/// `(p - fan_in/2) % width .. (p - fan_in/2 + fan_in - 1) % width` in layer `L-1`. +fn build_neural_net_graph( + num_layers: usize, + width: usize, + fan_in: usize, +) -> (CoreTaskGraph, Vec, Vec>) { + build_neural_net_graph_with_policy( + num_layers, + width, + fan_in, + ExecutionPolicy::default(), + ) +} + +fn build_neural_net_graph_with_policy( + num_layers: usize, + width: usize, + fan_in: usize, + policy: ExecutionPolicy, +) -> (CoreTaskGraph, Vec, Vec>) { + let mut graph = CoreTaskGraph::default(); + let mut layers: Vec> = Vec::with_capacity(num_layers); + + let mut layer_0 = Vec::with_capacity(width); + for i in 0..width { + let idx = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: format!("L0_{i}"), + inputs: vec![bytes_type(); fan_in], + outputs: vec![bytes_type()], + input_sources: None, + execution_policy: policy.clone(), + }) + .expect("neural net layer 0 task insertion should succeed"); + layer_0.push(idx); + } + layers.push(layer_0); + + let half = fan_in / 2; + for layer_idx in 1..num_layers { + let prev_layer = &layers[layer_idx - 1]; + let mut current_layer = Vec::with_capacity(width); + + for p in 0..width { + let input_sources: Vec = (0..fan_in) + .map(|k| { + let src_pos = (p + width - half + k) % width; + TaskInputOutputIndex { + task_idx: prev_layer[src_pos], + position: 0, + } + }) + .collect(); + + let idx = graph + .insert_task(TaskDescriptor { + tdl_package: "pkg".into(), + tdl_function: format!("L{layer_idx}_{p}"), + inputs: vec![bytes_type(); fan_in], + outputs: vec![bytes_type()], + input_sources: Some(input_sources), + execution_policy: policy.clone(), + }) + .expect("neural net layer task insertion should succeed"); + current_layer.push(idx); + } + layers.push(current_layer); + } + + let job_inputs: Vec = (0..width * fan_in) + .map(|_| TaskInput::ValuePayload(make_1kb_payload())) + .collect(); + + (graph, job_inputs, layers) +} + +// ============================================================================= +// Stage 1 tests (existing) +// ============================================================================= /// Tests the factory and end-to-end execution with a simple linear chain: A -> B -> C. /// -/// Graph topology: +/// # Graph topology +/// /// ```text /// [job_input] -> A -> B -> C -> [job_output] /// ``` /// -/// Verifies: +/// # Verifies +/// /// - `build_job` correctly identifies only A as initially ready (B and C depend on predecessors). /// - Job inputs are pre-populated: A receives the original `b"hello"` bytes. /// - Dataflow wiring works across the chain: A's output `b"world"` is delivered as B's input, and @@ -190,33 +849,33 @@ async fn test_factory_linear_chain() { input_sources: None, execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task A insertion should succeed"); let task_b = graph .insert_task(TaskDescriptor { tdl_package: "pkg".into(), tdl_function: "fn_b".into(), inputs: vec![bytes_type()], outputs: vec![bytes_type()], - input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + input_sources: Some(vec![TaskInputOutputIndex { task_idx: task_a, position: 0, }]), execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task B insertion should succeed"); let task_c = graph .insert_task(TaskDescriptor { tdl_package: "pkg".into(), tdl_function: "fn_c".into(), inputs: vec![bytes_type()], outputs: vec![bytes_type()], - input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + input_sources: Some(vec![TaskInputOutputIndex { task_idx: task_b, position: 0, }]), execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task C insertion should succeed"); let job_inputs = vec![TaskInput::ValuePayload(b"hello".to_vec())]; @@ -229,18 +888,15 @@ async fn test_factory_linear_chain() { MockDb::new(false), MockInstancePool::new(), ) - .unwrap(); + .expect("build_job should succeed for linear chain"); - // Only task A should be ready initially. assert_eq!(ready_indices, vec![task_a]); - // Execute task A. let ctx_a = jcb .create_task_instance(TaskId::TaskIndex(task_a)) .await - .unwrap(); - assert!(ctx_a.inputs.is_some()); - let inputs_a = ctx_a.inputs.unwrap(); + .expect("create instance for task A should succeed"); + let inputs_a = ctx_a.inputs.expect("task A should have inputs"); assert_eq!(inputs_a.len(), 1); assert_eq!( inputs_a[0], @@ -251,15 +907,14 @@ async fn test_factory_linear_chain() { let state = jcb .complete_task_instance(ctx_a.task_instance_id, task_a, vec![b"world".to_vec()]) .await - .unwrap(); + .expect("complete task A should succeed"); assert_eq!(state, JobState::Running); - // Execute task B. let ctx_b = jcb .create_task_instance(TaskId::TaskIndex(task_b)) .await - .unwrap(); - let inputs_b = ctx_b.inputs.unwrap(); + .expect("create instance for task B should succeed"); + let inputs_b = ctx_b.inputs.expect("task B should have inputs"); assert_eq!( inputs_b[0], TaskInput::ValuePayload(b"world".to_vec()), @@ -269,21 +924,20 @@ async fn test_factory_linear_chain() { let state = jcb .complete_task_instance(ctx_b.task_instance_id, task_b, vec![b"done".to_vec()]) .await - .unwrap(); + .expect("complete task B should succeed"); assert_eq!(state, JobState::Running); - // Execute task C — the last task. let ctx_c = jcb .create_task_instance(TaskId::TaskIndex(task_c)) .await - .unwrap(); - let inputs_c = ctx_c.inputs.unwrap(); + .expect("create instance for task C should succeed"); + let inputs_c = ctx_c.inputs.expect("task C should have inputs"); assert_eq!(inputs_c[0], TaskInput::ValuePayload(b"done".to_vec())); let state = jcb .complete_task_instance(ctx_c.task_instance_id, task_c, vec![b"final".to_vec()]) .await - .unwrap(); + .expect("complete task C should succeed"); assert_eq!( state, JobState::Succeeded, @@ -294,22 +948,20 @@ async fn test_factory_linear_chain() { /// Tests the factory and end-to-end execution with a diamond DAG that exercises fan-out and /// fan-in. /// -/// Graph topology: +/// # Graph topology +/// /// ```text /// ┌─> B ─┐ /// [job_input] -> A -> D -> [job_output] /// └─> C ─┘ /// ``` /// -/// A has two outputs; B consumes output 0, C consumes output 1. D consumes both B's and C's -/// outputs (fan-in from two parents). +/// # Verifies /// -/// Verifies: /// - Only A is initially ready. -/// - Completing A unblocks both B and C simultaneously — both appear in a single `send_task_ready` -/// call to the ready queue. +/// - Completing A unblocks both B and C simultaneously. /// - D is not unblocked until *both* B and C have completed (fan-in gate). -/// - The job transitions to `Succeeded` once the sink task (D) completes. +/// - The job transitions to `Succeeded` once D completes. #[tokio::test] #[allow(clippy::too_many_lines)] async fn test_factory_diamond_dag() { @@ -323,33 +975,33 @@ async fn test_factory_diamond_dag() { input_sources: None, execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task A insertion should succeed"); let task_b = graph .insert_task(TaskDescriptor { tdl_package: "pkg".into(), tdl_function: "fn_b".into(), inputs: vec![bytes_type()], outputs: vec![bytes_type()], - input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + input_sources: Some(vec![TaskInputOutputIndex { task_idx: task_a, position: 0, }]), execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task B insertion should succeed"); let task_c = graph .insert_task(TaskDescriptor { tdl_package: "pkg".into(), tdl_function: "fn_c".into(), inputs: vec![bytes_type()], outputs: vec![bytes_type()], - input_sources: Some(vec![spider_core::task::TaskInputOutputIndex { + input_sources: Some(vec![TaskInputOutputIndex { task_idx: task_a, position: 1, }]), execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task C insertion should succeed"); let task_d = graph .insert_task(TaskDescriptor { tdl_package: "pkg".into(), @@ -357,18 +1009,18 @@ async fn test_factory_diamond_dag() { inputs: vec![bytes_type(), bytes_type()], outputs: vec![bytes_type()], input_sources: Some(vec![ - spider_core::task::TaskInputOutputIndex { + TaskInputOutputIndex { task_idx: task_b, position: 0, }, - spider_core::task::TaskInputOutputIndex { + TaskInputOutputIndex { task_idx: task_c, position: 0, }, ]), execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task D insertion should succeed"); let job_inputs = vec![TaskInput::ValuePayload(b"input".to_vec())]; let ready_queue = MockReadyQueue::new(); @@ -383,15 +1035,14 @@ async fn test_factory_diamond_dag() { MockDb::new(false), MockInstancePool::new(), ) - .unwrap(); + .expect("build_job should succeed for diamond DAG"); assert_eq!(ready_indices, vec![task_a]); - // Complete A. let ctx_a = jcb .create_task_instance(TaskId::TaskIndex(task_a)) .await - .unwrap(); + .expect("create instance for task A should succeed"); let state = jcb .complete_task_instance( ctx_a.task_instance_id, @@ -399,10 +1050,9 @@ async fn test_factory_diamond_dag() { vec![b"out_b".to_vec(), b"out_c".to_vec()], ) .await - .unwrap(); + .expect("complete task A should succeed"); assert_eq!(state, JobState::Running); - // Check that B and C were enqueued as ready. let queued = ready_tasks_ref.lock().await; assert_eq!(queued.len(), 1); let (_, ref task_ids) = queued[0]; @@ -410,54 +1060,46 @@ async fn test_factory_diamond_dag() { assert!(task_ids.contains(&task_c)); drop(queued); - // Complete B and C. let ctx_b = jcb .create_task_instance(TaskId::TaskIndex(task_b)) .await - .unwrap(); + .expect("create instance for task B should succeed"); jcb.complete_task_instance(ctx_b.task_instance_id, task_b, vec![b"b_out".to_vec()]) .await - .unwrap(); + .expect("complete task B should succeed"); let ctx_c = jcb .create_task_instance(TaskId::TaskIndex(task_c)) .await - .unwrap(); + .expect("create instance for task C should succeed"); jcb.complete_task_instance(ctx_c.task_instance_id, task_c, vec![b"c_out".to_vec()]) .await - .unwrap(); + .expect("complete task C should succeed"); - // D should now be ready. Complete it. let ctx_d = jcb .create_task_instance(TaskId::TaskIndex(task_d)) .await - .unwrap(); + .expect("create instance for task D should succeed"); let state = jcb .complete_task_instance(ctx_d.task_instance_id, task_d, vec![b"final".to_vec()]) .await - .unwrap(); + .expect("complete task D should succeed"); assert_eq!(state, JobState::Succeeded); } /// Tests the commit task lifecycle: job transitions through `CommitReady` before `Succeeded`. /// -/// Graph topology: +/// # Graph topology +/// /// ```text /// A -> [job_output] /// (commit task: commit_fn) /// ``` /// -/// A single task (A) with no inputs and one dangling output. A `TerminationTaskDescriptor` is -/// attached as the commit task. The mock DB is configured to return `CommitReady` on -/// `commit_outputs` (simulating a job that has a commit task). -/// -/// Verifies: -/// - After A completes and outputs are committed, the job transitions to `CommitReady` (not -/// directly to `Succeeded`). -/// - The ready queue receives exactly one `send_commit_ready` notification. -/// - The commit task can be registered via `TaskId::Commit` and returns no inputs (`inputs: None`), -/// since termination tasks do not consume dataflow outputs. -/// - After the commit task instance completes, the job transitions to `Succeeded`. +/// # Verifies +/// +/// - After A completes, the job transitions to `CommitReady`. +/// - The commit task can be registered and completed, transitioning to `Succeeded`. #[tokio::test] async fn test_factory_with_commit_task() { let mut graph = CoreTaskGraph::default(); @@ -470,7 +1112,7 @@ async fn test_factory_with_commit_task() { input_sources: None, execution_policy: ExecutionPolicy::default(), }) - .unwrap(); + .expect("task A insertion should succeed"); graph.set_commit_task(TerminationTaskDescriptor { tdl_package: "pkg".into(), @@ -490,28 +1132,472 @@ async fn test_factory_with_commit_task() { MockDb::new(true), MockInstancePool::new(), ) - .unwrap(); + .expect("build_job should succeed for commit task test"); assert_eq!(ready_indices, vec![task_a]); - // Complete task A. let ctx_a = jcb .create_task_instance(TaskId::TaskIndex(task_a)) .await - .unwrap(); + .expect("create instance for task A should succeed"); let state = jcb .complete_task_instance(ctx_a.task_instance_id, task_a, vec![b"output".to_vec()]) .await - .unwrap(); + .expect("complete task A should succeed"); assert_eq!(state, JobState::CommitReady); assert_eq!(commit_count.load(Ordering::Relaxed), 1); - // Execute commit task. - let ctx_commit = jcb.create_task_instance(TaskId::Commit).await.unwrap(); + let ctx_commit = jcb + .create_task_instance(TaskId::Commit) + .await + .expect("create commit instance should succeed"); assert!(ctx_commit.inputs.is_none()); let state = jcb .complete_commit_task_instance(ctx_commit.task_instance_id) .await - .unwrap(); + .expect("complete commit task should succeed"); assert_eq!(state, JobState::Succeeded); } + +// ============================================================================= +// Stage 2 tests: scheduler smoke tests +// ============================================================================= + +/// Smoke test: validates the test scheduler with a small independent graph (10 tasks, 2 workers). +/// +/// # Purpose +/// +/// Ensures the `run_scheduled_test` infrastructure works correctly before scaling to 10k tasks. +/// +/// # Verifies +/// +/// - All 10 tasks complete. +/// - Final job state is `Succeeded`. +/// - No ready-queue propagation (no dependencies). +#[tokio::test] +async fn test_scheduler_smoke_independent() { + let (graph, job_inputs) = build_flat_graph(10, 1, 1); + let handler: Arc = Arc::new(ImmediateCompletionHandler { + num_outputs_per_task: 1, + }); + let (result, _) = run_scheduled_test(&graph, job_inputs, 2, handler).await; + + assert_eq!(result.tasks_dispatched, 10, "all 10 tasks should complete"); + assert_eq!(result.final_state, JobState::Succeeded); + assert_eq!(result.ready_queue_call_count, 0); +} + +/// Smoke test: validates the test scheduler with a small layered graph (3 layers × 5 tasks, +/// fan-in=2, 2 workers). +/// +/// # Verifies +/// +/// - All 15 tasks complete. +/// - Final job state is `Succeeded`. +/// - Tasks in layers 1 and 2 (10 total) are reported ready via the ready queue. +#[tokio::test] +async fn test_scheduler_smoke_layered() { + let (graph, job_inputs, _layers) = build_neural_net_graph(3, 5, 2); + assert_eq!(graph.get_num_tasks(), 15); + + let handler: Arc = Arc::new(ImmediateCompletionHandler { + num_outputs_per_task: 1, + }); + let (result, _) = run_scheduled_test(&graph, job_inputs, 2, handler).await; + + assert_eq!(result.tasks_dispatched, 15, "all 15 tasks should complete"); + assert_eq!(result.final_state, JobState::Succeeded); + assert_eq!( + result.total_tasks_reported_ready, 10, + "layers 1 and 2 (5+5 tasks) should be reported ready" + ); +} + +/// Smoke test: validates multi-instance execution with a small graph (5 tasks, 2 workers, +/// 3 instances per task where 2 succeed and 1 fails). +/// +/// # Purpose +/// +/// Ensures the multi-instance path handles concurrent instance registration, completion, and +/// rejection errors (e.g. `TaskAlreadyTerminated`) gracefully before scaling up. +/// +/// # Verifies +/// +/// - All 5 tasks complete despite 1 out of 3 instances failing per task. +/// - Final job state is `Succeeded`. +#[tokio::test] +async fn test_scheduler_smoke_multi_instance() { + let policy = ExecutionPolicy { + max_num_instances: 3, + max_num_retries: 2, + }; + let (graph, job_inputs) = build_flat_graph_with_policy(5, 1, 1, policy); + let handler: Arc = Arc::new(MultiInstancePartialFailHandler { + num_outputs_per_task: 1, + num_instances: 3, + }); + let (result, _) = run_scheduled_test(&graph, job_inputs, 2, handler).await; + + assert_eq!(result.tasks_dispatched, 5, "all 5 tasks should complete"); + assert_eq!(result.final_state, JobState::Succeeded); +} + +// ============================================================================= +// Stage 2 tests: large-scale performance +// ============================================================================= + +/// Large-scale performance baseline: 10,000 independent tasks with zero dependencies. +/// +/// # Purpose +/// +/// Establishes a baseline for cache-layer throughput with no dependency overhead. +/// +/// # Graph topology +/// +/// ```text +/// [20,000 job_inputs (1KB each)] +/// T_0 (2 in, 1 out) -> [job_output_0] +/// ... +/// T_9999 (2 in, 1 out) -> [job_output_9999] +/// ``` +/// +/// # Metrics captured +/// +/// Graph construction, `build_job`, total execution (128 workers), per-task latency (avg/p50/p95/p99). +/// +/// # How to interpret results +/// +/// Compare per-task latency against the neural-network test. The difference isolates +/// dependency-tracking overhead. +/// +/// Run with `cargo test test_scale_10k_independent -- --nocapture` to see timing output. +#[tokio::test] +async fn test_scale_10k_independent() { + const NUM_TASKS: usize = 10_000; + const NUM_WORKERS: usize = 128; + + let graph_start = Instant::now(); + let (graph, job_inputs) = build_flat_graph(NUM_TASKS, 2, 1); + let graph_time = graph_start.elapsed(); + + let handler: Arc = Arc::new(ImmediateCompletionHandler { + num_outputs_per_task: 1, + }); + let (result, build_job_time) = + run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; + + assert_eq!( + result.tasks_dispatched, + NUM_TASKS, + "all tasks should have completed" + ); + assert_eq!(result.final_state, JobState::Succeeded); + assert_eq!( + result.ready_queue_call_count, 0, + "no ready-queue propagation for independent tasks" + ); + + result.report("10k Independent Tasks", NUM_WORKERS, graph_time, build_job_time); +} + +/// Large-scale dependency test: 10,000 tasks in a 10-layer × 1000-wide neural-network topology. +/// +/// # Purpose +/// +/// Measures dependency-tracking overhead with fan-in=10, fan-out=10. +/// +/// # Graph topology +/// +/// ```text +/// Layer 0: T_0..T_999 [10 graph inputs, 1KB] -> 1 output +/// ... +/// Layer 9: T_9000..T_9999 [10 inputs from layer 8] -> [job outputs] +/// Circular connectivity: (p-5)%1000 .. (p+4)%1000. +/// ``` +/// +/// # How to interpret results +/// +/// Compare against 10k independent test. Difference = dependency overhead. +/// +/// Run with `cargo test test_scale_10k_neural_net -- --nocapture` to see timing output. +#[tokio::test] +async fn test_scale_10k_neural_net() { + const NUM_LAYERS: usize = 10; + const WIDTH: usize = 1000; + const FAN_IN: usize = 10; + const NUM_WORKERS: usize = 128; + + let graph_start = Instant::now(); + let (graph, job_inputs, _layers) = build_neural_net_graph(NUM_LAYERS, WIDTH, FAN_IN); + let graph_time = graph_start.elapsed(); + + assert_eq!(graph.get_num_tasks(), NUM_LAYERS * WIDTH); + + let handler: Arc = Arc::new(ImmediateCompletionHandler { + num_outputs_per_task: 1, + }); + let (result, build_job_time) = + run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; + + assert_eq!( + result.tasks_dispatched, + NUM_LAYERS * WIDTH, + "all tasks should have completed" + ); + assert_eq!(result.final_state, JobState::Succeeded); + assert_eq!( + result.total_tasks_reported_ready, + (NUM_LAYERS - 1) * WIDTH, + "9 layers × 1000 tasks should be reported ready" + ); + + result.report( + "10k Neural Net (10x1000, fan=10)", + NUM_WORKERS, + graph_time, + build_job_time, + ); +} + +/// Large-scale multi-instance test: 10,000 independent tasks, 3 instances per task (2 succeed, +/// 1 fails), 128 workers. +/// +/// # Purpose +/// +/// Measures the overhead of concurrent multi-instance registration and the rejection-error path +/// when instances race to complete/fail the same task. +/// +/// # How to interpret results +/// +/// Compare against the single-instance 10k independent test. The difference shows the cost of +/// instance contention, registration overhead, and rejection handling. +/// +/// Run with `cargo test test_scale_10k_multi_instance -- --nocapture` to see timing output. +#[tokio::test] +async fn test_scale_10k_multi_instance() { + const NUM_TASKS: usize = 10_000; + const NUM_WORKERS: usize = 128; + + let policy = ExecutionPolicy { + max_num_instances: 3, + max_num_retries: 2, + }; + let graph_start = Instant::now(); + let (graph, job_inputs) = build_flat_graph_with_policy(NUM_TASKS, 2, 1, policy); + let graph_time = graph_start.elapsed(); + + let handler: Arc = Arc::new(MultiInstancePartialFailHandler { + num_outputs_per_task: 1, + num_instances: 3, + }); + let (result, build_job_time) = + run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; + + assert_eq!( + result.tasks_dispatched, + NUM_TASKS, + "all tasks should have completed" + ); + assert_eq!(result.final_state, JobState::Succeeded); + + result.report( + "10k Independent (3 instances, 2 success + 1 fail)", + NUM_WORKERS, + graph_time, + build_job_time, + ); +} + +/// Large-scale multi-instance dependency test: 10,000 tasks in a 10-layer × 1000-wide +/// neural-network topology, 3 instances per task (2 succeed, 1 fails), 128 workers. +/// +/// # Purpose +/// +/// Combines dependency-tracking overhead with multi-instance contention. Each task completion +/// acquires locks on 10 child TCBs while also racing against sibling instances that may +/// complete or fail concurrently. +/// +/// # Graph topology +/// +/// Same as [`test_scale_10k_neural_net`] (10 layers × 1000, fan-in=10, circular connectivity), +/// but with `max_num_instances=3` and `max_num_retries=2`. +/// +/// # How to interpret results +/// +/// Compare against the single-instance neural-net test to isolate multi-instance overhead in a +/// dependency-heavy graph. +/// +/// Run with `cargo test test_scale_10k_neural_net_multi_instance -- --nocapture` to see timing. +#[tokio::test] +async fn test_scale_10k_neural_net_multi_instance() { + const NUM_LAYERS: usize = 10; + const WIDTH: usize = 1000; + const FAN_IN: usize = 10; + const NUM_WORKERS: usize = 128; + + let policy = ExecutionPolicy { + max_num_instances: 3, + max_num_retries: 2, + }; + let graph_start = Instant::now(); + let (graph, job_inputs, _layers) = + build_neural_net_graph_with_policy(NUM_LAYERS, WIDTH, FAN_IN, policy); + let graph_time = graph_start.elapsed(); + + assert_eq!(graph.get_num_tasks(), NUM_LAYERS * WIDTH); + + let handler: Arc = Arc::new(MultiInstancePartialFailHandler { + num_outputs_per_task: 1, + num_instances: 3, + }); + let (result, build_job_time) = + run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; + + assert_eq!( + result.tasks_dispatched, + NUM_LAYERS * WIDTH, + "all tasks should have completed" + ); + assert_eq!(result.final_state, JobState::Succeeded); + assert_eq!( + result.total_tasks_reported_ready, + (NUM_LAYERS - 1) * WIDTH, + "9 layers x 1000 tasks should be reported ready" + ); + + result.report( + "10k Neural Net (10x1000, fan=10, 3 instances, 2 success + 1 fail)", + NUM_WORKERS, + graph_time, + build_job_time, + ); +} + +/// Tests that a job fails when all task instances always fail and retries are exhausted. +/// +/// # Purpose +/// +/// Validates the retry-exhaustion → job failure path. Every instance of every task fails, +/// consuming all retries. The first task to exhaust its retries causes the entire job to fail. +/// +/// # Graph topology +/// +/// ```text +/// T_0..T_9 (1 input, 1 output, all independent) +/// ExecutionPolicy: max_num_instances=1, max_num_retries=2 +/// ``` +/// +/// # Verifies +/// +/// - The job reaches `Failed` state (not `Succeeded`). +/// - Not all tasks complete (the job fails early once one task exhausts retries). +#[tokio::test] +async fn test_always_fail_exhausts_retries() { + const NUM_TASKS: usize = 10; + const NUM_WORKERS: usize = 4; + + let policy = ExecutionPolicy { + max_num_instances: 1, + max_num_retries: 2, + }; + let (graph, job_inputs) = build_flat_graph_with_policy(NUM_TASKS, 1, 1, policy); + + let handler: Arc = Arc::new(AlwaysFailHandler); + let (result, _) = run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; + + assert_eq!( + result.final_state, + JobState::Failed, + "job should fail when all instances always fail and retries are exhausted" + ); +} + +// ============================================================================= +// Stage 2 tests: ready-queue correctness +// ============================================================================= + +/// Correctness test for the neural-network ready-queue propagation. +/// +/// # Purpose +/// +/// Verifies that the ready queue receives exactly the right task indices after each layer +/// completes. Uses sequential execution for determinism. +/// +/// # Verifies +/// +/// - After completing layer L, exactly the 1000 tasks in layer L+1 are reported ready. +/// - No duplicates. +/// - The final layer triggers no further ready notifications. +#[tokio::test] +#[allow(clippy::too_many_lines)] +async fn test_neural_net_ready_queue_correctness() { + const NUM_LAYERS: usize = 10; + const WIDTH: usize = 1000; + const FAN_IN: usize = 10; + + let (graph, job_inputs, layers) = build_neural_net_graph(NUM_LAYERS, WIDTH, FAN_IN); + + let ready_queue = MockReadyQueue::new(); + let ready_tasks_ref = ready_queue.ready_tasks.clone(); + + let (jcb, initial_ready) = build_job( + JobId::new(), + ResourceGroupId::new(), + &graph, + job_inputs, + ready_queue, + MockDb::new(false), + MockInstancePool::new(), + ) + .expect("build_job should succeed for correctness test"); + + let initial_set: HashSet = initial_ready.into_iter().collect(); + let expected_layer_0: HashSet = layers[0].iter().copied().collect(); + assert_eq!(initial_set, expected_layer_0, "only layer 0 should be initially ready"); + + for (layer_idx, layer) in layers.iter().enumerate() { + ready_tasks_ref.lock().await.clear(); + + for &task_idx in layer { + let ctx = jcb + .create_task_instance(TaskId::TaskIndex(task_idx)) + .await + .expect("create instance should succeed in correctness test"); + jcb.complete_task_instance(ctx.task_instance_id, task_idx, vec![make_1kb_payload()]) + .await + .expect("complete task should succeed in correctness test"); + } + + if layer_idx < NUM_LAYERS - 1 { + let snapshot = ready_tasks_ref.lock().await; + let mut reported_ready: Vec = snapshot + .iter() + .flat_map(|(_, ids)| ids.iter().copied()) + .collect(); + drop(snapshot); + + let unique: HashSet = reported_ready.iter().copied().collect(); + assert_eq!( + unique.len(), + reported_ready.len(), + "layer {layer_idx}: no task should be reported ready more than once" + ); + + reported_ready.sort_unstable(); + let mut expected: Vec = layers[layer_idx + 1].clone(); + expected.sort_unstable(); + assert_eq!( + reported_ready, expected, + "layer {layer_idx}: reported ready tasks should match layer {}", + layer_idx + 1 + ); + } else { + let snapshot = ready_tasks_ref.lock().await; + let reported_count: usize = snapshot.iter().map(|(_, ids)| ids.len()).sum(); + assert_eq!( + reported_count, 0, + "last layer should not trigger any ready notifications" + ); + } + } +} From 3bfee795bdfe82aa07cb49ccebcc4f58ee119f26 Mon Sep 17 00:00:00 2001 From: LinZhihao-723 Date: Wed, 18 Mar 2026 15:09:35 -0400 Subject: [PATCH 8/8] Remove multi-instance testing; Improve perf instrument. --- components/spider-storage/src/cache.rs | 14 + components/spider-storage/src/cache/tests.rs | 508 +++++-------------- 2 files changed, 152 insertions(+), 370 deletions(-) diff --git a/components/spider-storage/src/cache.rs b/components/spider-storage/src/cache.rs index f9928cc6..73379607 100644 --- a/components/spider-storage/src/cache.rs +++ b/components/spider-storage/src/cache.rs @@ -44,4 +44,18 @@ pub use factory::*; pub use job::{JobControlBlock, ReadyQueueConnector, TaskInstancePoolConnector}; #[cfg(test)] +#[allow( + clippy::future_not_send, + clippy::significant_drop_tightening, + clippy::option_if_let_else, + clippy::missing_errors_doc, + clippy::missing_panics_doc, + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_sign_loss, + clippy::similar_names, + clippy::needless_pass_by_value, + clippy::too_many_lines, + clippy::manual_let_else +)] mod tests; diff --git a/components/spider-storage/src/cache/tests.rs b/components/spider-storage/src/cache/tests.rs index d102d2e4..c0a8467a 100644 --- a/components/spider-storage/src/cache/tests.rs +++ b/components/spider-storage/src/cache/tests.rs @@ -11,8 +11,14 @@ use async_trait::async_trait; use spider_core::{ job::JobState, task::{ - BytesTypeDescriptor, DataTypeDescriptor, ExecutionPolicy, TaskDescriptor, - TaskGraph as CoreTaskGraph, TaskIndex, TaskInputOutputIndex, TerminationTaskDescriptor, + BytesTypeDescriptor, + DataTypeDescriptor, + ExecutionPolicy, + TaskDescriptor, + TaskGraph as CoreTaskGraph, + TaskIndex, + TaskInputOutputIndex, + TerminationTaskDescriptor, ValueTypeDescriptor, }, types::{ @@ -25,7 +31,7 @@ use tokio::sync::Mutex; use crate::{ cache::{ build_job, - error::{CacheError, InternalError, RejectionError}, + error::{CacheError, InternalError}, job::{ReadyQueueConnector, TaskInstancePoolConnector}, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, types::{ExecutionContext, TaskId}, @@ -83,8 +89,8 @@ impl ReadyQueueConnector for MockReadyQueue { if let Some(txs) = &self.worker_txs { let num_workers = txs.len(); for &idx in &task_ids { - let worker = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize - % num_workers; + let worker = + self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize % num_workers; let _ = txs[worker].send(idx); } } @@ -189,35 +195,28 @@ impl TaskInstancePoolConnector for MockInstancePool { /// Pluggable per-task behavior during test execution. /// /// Implementations control what happens when a worker picks up a task: they can complete it -/// immediately, inject failures, add delays, or register multiple instances. +/// immediately, inject failures, or add delays. #[async_trait] trait TaskHandler: Send + Sync { - /// Returns the number of concurrent instances to register per task. - fn num_instances(&self) -> usize { - 1 - } - - /// Called for each task instance. Returns outputs to submit on success, or an error message - /// to fail the instance. - async fn handle_instance( + /// Called for each task. Returns outputs to submit on success, or an error message to fail + /// the instance. + async fn handle_task( &self, task_index: TaskIndex, - instance_index: usize, ctx: &ExecutionContext, ) -> Result, String>; } -/// Default handler: immediately completes each task with 1KB outputs (1 instance per task). +/// Default handler: immediately completes each task with 1KB outputs. struct ImmediateCompletionHandler { num_outputs_per_task: usize, } #[async_trait] impl TaskHandler for ImmediateCompletionHandler { - async fn handle_instance( + async fn handle_task( &self, _task_index: TaskIndex, - _instance_index: usize, _ctx: &ExecutionContext, ) -> Result, String> { Ok((0..self.num_outputs_per_task) @@ -226,52 +225,17 @@ impl TaskHandler for ImmediateCompletionHandler { } } -/// Handler that runs 3 instances per task: 2 succeed and 1 fails. -/// -/// The last instance always fails; the first two succeed. Since instances run concurrently, the -/// first completion wins and subsequent completions/failures receive rejection errors (e.g. -/// `TaskAlreadyTerminated`), which are handled gracefully by the worker loop. -struct MultiInstancePartialFailHandler { - num_outputs_per_task: usize, - num_instances: usize, -} - -#[async_trait] -impl TaskHandler for MultiInstancePartialFailHandler { - fn num_instances(&self) -> usize { - self.num_instances - } - - async fn handle_instance( - &self, - _task_index: TaskIndex, - instance_index: usize, - _ctx: &ExecutionContext, - ) -> Result, String> { - if instance_index < self.num_instances - 1 { - Ok((0..self.num_outputs_per_task) - .map(|_| make_1kb_payload()) - .collect()) - } else { - Err(format!("simulated failure for instance {instance_index}")) - } - } -} - -/// Handler where every instance always fails. Used to test retry exhaustion. +/// Handler where every task always fails. Used to test retry exhaustion. struct AlwaysFailHandler; #[async_trait] impl TaskHandler for AlwaysFailHandler { - async fn handle_instance( + async fn handle_task( &self, task_index: TaskIndex, - instance_index: usize, _ctx: &ExecutionContext, ) -> Result, String> { - Err(format!( - "permanent failure for task {task_index} instance {instance_index}" - )) + Err(format!("permanent failure for task {task_index}")) } } @@ -281,20 +245,23 @@ fn is_rejection(err: &CacheError) -> bool { matches!(err, CacheError::Rejection(_)) } -/// A single instance's timing: which task, which instance, how long from register to complete/fail. +/// Per-task timing breakdown: time spent in `create_task_instance` (pre-execution) and +/// `complete/fail_task_instance` (post-execution). #[derive(Debug)] -struct InstanceLatency { +#[allow(dead_code)] +struct TaskLatency { task_index: TaskIndex, - duration: Duration, + /// Duration of `create_task_instance` call. + create_duration: Duration, + /// Duration of `complete_task_instance` or `fail_task_instance` call. + complete_duration: Duration, } /// Collected results from a test run. struct TestResult { total_execution_time: Duration, - /// Per-instance latencies (one entry per `create_task_instance` → `complete/fail` cycle). - /// For single-instance tests, there is one entry per task. For multi-instance tests, there - /// are multiple entries per task. - instance_latencies: Vec, + /// Per-task latencies with create/complete breakdown. One entry per dispatched task. + task_latencies: Vec, /// Number of unique tasks that were dispatched to workers. tasks_dispatched: usize, final_state: JobState, @@ -310,43 +277,71 @@ impl TestResult { graph_construction_time: Duration, build_job_time: Duration, ) { - let mut sorted: Vec = self - .instance_latencies + let to_ms = |d: &Duration| d.as_secs_f64() * 1000.0; + + let mut create_ms: Vec = self + .task_latencies .iter() - .map(|l| l.duration.as_secs_f64() * 1000.0) + .map(|l| to_ms(&l.create_duration)) .collect(); - sorted.sort_by(|a, b| a.partial_cmp(b).expect("latencies should be comparable")); + let mut complete_ms: Vec = self + .task_latencies + .iter() + .map(|l| to_ms(&l.complete_duration)) + .collect(); + create_ms.sort_by(|a, b| a.partial_cmp(b).expect("latencies should be comparable")); + complete_ms.sort_by(|a, b| a.partial_cmp(b).expect("latencies should be comparable")); - let num_instances = sorted.len(); - let avg = if num_instances > 0 { - sorted.iter().sum::() / num_instances as f64 - } else { - 0.0 - }; - let p50 = percentile(&sorted, 50.0); - let p95 = percentile(&sorted, 95.0); - let p99 = percentile(&sorted, 99.0); + let _count = self.task_latencies.len(); + let avg_create = avg_of(&create_ms); + let avg_complete = avg_of(&complete_ms); eprintln!(); eprintln!("=== {test_name} ({num_workers} workers) ==="); eprintln!( " graph_construction: {:>10.2} ms", - graph_construction_time.as_secs_f64() * 1000.0 + to_ms(&graph_construction_time) ); eprintln!( " build_job: {:>10.2} ms", - build_job_time.as_secs_f64() * 1000.0 + to_ms(&build_job_time) ); eprintln!( " total_execution: {:>10.2} ms", - self.total_execution_time.as_secs_f64() * 1000.0 + to_ms(&self.total_execution_time) + ); + eprintln!( + " tasks_dispatched: {:>10}", + self.tasks_dispatched + ); + eprintln!(" --- create_task_instance ---"); + eprintln!(" avg: {avg_create:>10.3} ms"); + eprintln!( + " p50: {:>10.3} ms", + percentile(&create_ms, 50.0) + ); + eprintln!( + " p95: {:>10.3} ms", + percentile(&create_ms, 95.0) + ); + eprintln!( + " p99: {:>10.3} ms", + percentile(&create_ms, 99.0) + ); + eprintln!(" --- complete/fail_task_instance ---"); + eprintln!(" avg: {avg_complete:>10.3} ms"); + eprintln!( + " p50: {:>10.3} ms", + percentile(&complete_ms, 50.0) + ); + eprintln!( + " p95: {:>10.3} ms", + percentile(&complete_ms, 95.0) + ); + eprintln!( + " p99: {:>10.3} ms", + percentile(&complete_ms, 99.0) ); - eprintln!(" tasks_dispatched: {:>10}", self.tasks_dispatched); - eprintln!(" instances_measured: {:>10}", num_instances); - eprintln!(" avg_per_instance_latency: {avg:>10.3} ms"); - eprintln!(" p50_per_instance_latency: {p50:>10.3} ms"); - eprintln!(" p95_per_instance_latency: {p95:>10.3} ms"); - eprintln!(" p99_per_instance_latency: {p99:>10.3} ms"); eprintln!( " ready_queue_calls: {:>10}", self.ready_queue_call_count @@ -367,6 +362,13 @@ fn percentile(sorted: &[f64], pct: f64) -> f64 { sorted[idx.min(sorted.len() - 1)] } +fn avg_of(values: &[f64]) -> f64 { + if values.is_empty() { + return 0.0; + } + values.iter().sum::() / values.len() as f64 +} + /// Full entry point for a scheduled test: builds the job, runs workers, returns results. /// /// Each worker gets its own dedicated channel. The `MockReadyQueue` round-robins newly-ready @@ -415,10 +417,9 @@ async fn run_scheduled_test( } let jcb = Arc::new(jcb); - let latencies: Arc>> = Arc::new(Mutex::new(Vec::new())); + let latencies: Arc>> = Arc::new(Mutex::new(Vec::new())); let tasks_dispatched = Arc::new(AtomicU64::new(0)); let done = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let num_instances_per_task = task_handler.num_instances(); let num_tasks = graph.get_num_tasks(); let exec_start = Instant::now(); @@ -438,25 +439,20 @@ async fn run_scheduled_test( break; } - let task_idx = match rx.try_recv() { - Ok(idx) => idx, - Err(_) => { - tokio::task::yield_now().await; - continue; - } + let task_idx = if let Ok(idx) = rx.try_recv() { + idx + } else { + tokio::task::yield_now().await; + continue; }; tasks_dispatched.fetch_add(1, Ordering::Relaxed); - let (terminal, mut instance_lats) = if num_instances_per_task == 1 { - execute_single_instance(&jcb, &*handler, task_idx).await - } else { - execute_multi_instance( - &jcb, &*handler, task_idx, num_instances_per_task, - ).await - }; + let (terminal, lat) = execute_single_instance(&jcb, &*handler, task_idx).await; - latencies.lock().await.append(&mut instance_lats); + if let Some(lat) = lat { + latencies.lock().await.push(lat); + } if terminal { done.store(true, Ordering::Relaxed); @@ -476,13 +472,11 @@ async fn run_scheduled_test( let ready_queue_snapshot = ready_tasks_ref.lock().await; let ready_queue_call_count = ready_queue_snapshot.len(); - let total_tasks_reported_ready: usize = ready_queue_snapshot - .iter() - .map(|(_, ids)| ids.len()) - .sum(); + let total_tasks_reported_ready: usize = + ready_queue_snapshot.iter().map(|(_, ids)| ids.len()).sum(); drop(ready_queue_snapshot); - let instance_latencies = Arc::try_unwrap(latencies) + let task_latencies = Arc::try_unwrap(latencies) .expect("all workers should have finished by now") .into_inner(); @@ -495,7 +489,7 @@ async fn run_scheduled_test( let result = TestResult { total_execution_time, - instance_latencies, + task_latencies, tasks_dispatched, final_state, ready_queue_call_count, @@ -508,14 +502,15 @@ async fn run_scheduled_test( type JcbType = crate::cache::job::JobControlBlock; /// Executes a single instance for the given task. -/// Returns `(is_terminal, instance_latencies)`. +/// Returns `(is_terminal, Option)`. The latency is `None` if registration was +/// rejected (e.g. task already terminated by another worker). async fn execute_single_instance( jcb: &Arc, handler: &dyn TaskHandler, task_idx: TaskIndex, -) -> (bool, Vec) { - let inst_start = Instant::now(); - +) -> (bool, Option) { + // Time the create_task_instance call. + let create_start = Instant::now(); let ctx = match jcb.create_task_instance(TaskId::TaskIndex(task_idx)).await { Ok(ctx) => ctx, Err(e) => { @@ -523,11 +518,17 @@ async fn execute_single_instance( is_rejection(&e), "create_task_instance for task {task_idx} returned unexpected error: {e:?}" ); - return (false, Vec::new()); + return (false, None); } }; + let create_duration = create_start.elapsed(); + + // Handler decides outputs or failure (not timed — simulates external execution). + let outcome = handler.handle_task(task_idx, &ctx).await; - let terminal = match handler.handle_instance(task_idx, 0, &ctx).await { + // Time the complete/fail call. + let complete_start = Instant::now(); + let terminal = match outcome { Ok(outputs) => { match jcb .complete_task_instance(ctx.task_instance_id, task_idx, outputs) @@ -537,7 +538,8 @@ async fn execute_single_instance( Err(e) => { assert!( is_rejection(&e), - "complete_task_instance for task {task_idx} returned unexpected error: {e:?}" + "complete_task_instance for task {task_idx} returned unexpected error: \ + {e:?}" ); false } @@ -563,122 +565,14 @@ async fn execute_single_instance( } } }; + let complete_duration = complete_start.elapsed(); - let lat = InstanceLatency { + let lat = TaskLatency { task_index: task_idx, - duration: inst_start.elapsed(), + create_duration, + complete_duration, }; - (terminal, vec![lat]) -} - -/// Executes multiple concurrent instances for the given task. -/// Returns `(is_terminal, instance_latencies)` with one latency entry per instance. -/// Each instance's timing covers the full cycle: `create_task_instance` → handler → -/// `complete/fail_task_instance`. -async fn execute_multi_instance( - jcb: &Arc, - handler: &dyn TaskHandler, - task_idx: TaskIndex, - num_instances: usize, -) -> (bool, Vec) { - // Pre-compute each instance's outcome so we can move it into the spawned task. - // We use a dummy ExecutionContext for the handler since the real one is created inside - // the coroutine. - let dummy_ctx = ExecutionContext { - task_instance_id: 0, - tdl_context: crate::cache::types::TdlContext { - package: String::new(), - func: String::new(), - }, - inputs: None, - }; - let mut outcomes: Vec, String>> = Vec::with_capacity(num_instances); - for i in 0..num_instances { - outcomes.push(handler.handle_instance(task_idx, i, &dummy_ctx).await); - } - - // Spawn one coroutine per instance. Each coroutine does the full cycle: - // create_task_instance → complete/fail → record latency. - let mut handles = Vec::with_capacity(num_instances); - for (instance_index, outcome) in outcomes.into_iter().enumerate() { - let jcb = Arc::clone(jcb); - - handles.push(tokio::spawn(async move { - let inst_start = Instant::now(); - - let ctx = match jcb.create_task_instance(TaskId::TaskIndex(task_idx)).await { - Ok(ctx) => ctx, - Err(e) => { - assert!( - is_rejection(&e), - "create_task_instance for task {task_idx} instance {instance_index} \ - returned unexpected error: {e:?}" - ); - let lat = InstanceLatency { - task_index: task_idx, - duration: inst_start.elapsed(), - }; - return (false, lat); - } - }; - - let terminal = match outcome { - Ok(outputs) => { - match jcb - .complete_task_instance(ctx.task_instance_id, task_idx, outputs) - .await - { - Ok(state) => state.is_terminal(), - Err(e) => { - assert!( - is_rejection(&e), - "complete_task_instance for task {task_idx} instance \ - {instance_index} returned unexpected error: {e:?}" - ); - false - } - } - } - Err(error_message) => { - match jcb - .fail_task_instance( - ctx.task_instance_id, - TaskId::TaskIndex(task_idx), - error_message, - ) - .await - { - Ok(state) => state.is_terminal(), - Err(e) => { - assert!( - is_rejection(&e), - "fail_task_instance for task {task_idx} instance \ - {instance_index} returned unexpected error: {e:?}" - ); - false - } - } - } - }; - - let lat = InstanceLatency { - task_index: task_idx, - duration: inst_start.elapsed(), - }; - (terminal, lat) - })); - } - - let mut terminal = false; - let mut lats = Vec::with_capacity(handles.len()); - for handle in handles { - let (t, lat) = handle.await.expect("instance task should not panic"); - if t { - terminal = true; - } - lats.push(lat); - } - (terminal, lats) + (terminal, Some(lat)) } // ============================================================================= @@ -690,7 +584,7 @@ fn bytes_type() -> DataTypeDescriptor { } fn make_1kb_payload() -> Vec { - vec![0xAB_u8; 1024] + vec![0xab_u8; 1024] } /// Builds a flat graph of `num_tasks` independent tasks, each with `num_inputs` graph-level @@ -745,12 +639,7 @@ fn build_neural_net_graph( width: usize, fan_in: usize, ) -> (CoreTaskGraph, Vec, Vec>) { - build_neural_net_graph_with_policy( - num_layers, - width, - fan_in, - ExecutionPolicy::default(), - ) + build_neural_net_graph_with_policy(num_layers, width, fan_in, ExecutionPolicy::default()) } fn build_neural_net_graph_with_policy( @@ -1213,35 +1102,6 @@ async fn test_scheduler_smoke_layered() { ); } -/// Smoke test: validates multi-instance execution with a small graph (5 tasks, 2 workers, -/// 3 instances per task where 2 succeed and 1 fails). -/// -/// # Purpose -/// -/// Ensures the multi-instance path handles concurrent instance registration, completion, and -/// rejection errors (e.g. `TaskAlreadyTerminated`) gracefully before scaling up. -/// -/// # Verifies -/// -/// - All 5 tasks complete despite 1 out of 3 instances failing per task. -/// - Final job state is `Succeeded`. -#[tokio::test] -async fn test_scheduler_smoke_multi_instance() { - let policy = ExecutionPolicy { - max_num_instances: 3, - max_num_retries: 2, - }; - let (graph, job_inputs) = build_flat_graph_with_policy(5, 1, 1, policy); - let handler: Arc = Arc::new(MultiInstancePartialFailHandler { - num_outputs_per_task: 1, - num_instances: 3, - }); - let (result, _) = run_scheduled_test(&graph, job_inputs, 2, handler).await; - - assert_eq!(result.tasks_dispatched, 5, "all 5 tasks should complete"); - assert_eq!(result.final_state, JobState::Succeeded); -} - // ============================================================================= // Stage 2 tests: large-scale performance // ============================================================================= @@ -1263,7 +1123,8 @@ async fn test_scheduler_smoke_multi_instance() { /// /// # Metrics captured /// -/// Graph construction, `build_job`, total execution (128 workers), per-task latency (avg/p50/p95/p99). +/// Graph construction, `build_job`, total execution (128 workers), per-task latency +/// (avg/p50/p95/p99). /// /// # How to interpret results /// @@ -1287,8 +1148,7 @@ async fn test_scale_10k_independent() { run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; assert_eq!( - result.tasks_dispatched, - NUM_TASKS, + result.tasks_dispatched, NUM_TASKS, "all tasks should have completed" ); assert_eq!(result.final_state, JobState::Succeeded); @@ -1297,7 +1157,12 @@ async fn test_scale_10k_independent() { "no ready-queue propagation for independent tasks" ); - result.report("10k Independent Tasks", NUM_WORKERS, graph_time, build_job_time); + result.report( + "10k Independent Tasks", + NUM_WORKERS, + graph_time, + build_job_time, + ); } /// Large-scale dependency test: 10,000 tasks in a 10-layer × 1000-wide neural-network topology. @@ -1372,107 +1237,6 @@ async fn test_scale_10k_neural_net() { /// Compare against the single-instance 10k independent test. The difference shows the cost of /// instance contention, registration overhead, and rejection handling. /// -/// Run with `cargo test test_scale_10k_multi_instance -- --nocapture` to see timing output. -#[tokio::test] -async fn test_scale_10k_multi_instance() { - const NUM_TASKS: usize = 10_000; - const NUM_WORKERS: usize = 128; - - let policy = ExecutionPolicy { - max_num_instances: 3, - max_num_retries: 2, - }; - let graph_start = Instant::now(); - let (graph, job_inputs) = build_flat_graph_with_policy(NUM_TASKS, 2, 1, policy); - let graph_time = graph_start.elapsed(); - - let handler: Arc = Arc::new(MultiInstancePartialFailHandler { - num_outputs_per_task: 1, - num_instances: 3, - }); - let (result, build_job_time) = - run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; - - assert_eq!( - result.tasks_dispatched, - NUM_TASKS, - "all tasks should have completed" - ); - assert_eq!(result.final_state, JobState::Succeeded); - - result.report( - "10k Independent (3 instances, 2 success + 1 fail)", - NUM_WORKERS, - graph_time, - build_job_time, - ); -} - -/// Large-scale multi-instance dependency test: 10,000 tasks in a 10-layer × 1000-wide -/// neural-network topology, 3 instances per task (2 succeed, 1 fails), 128 workers. -/// -/// # Purpose -/// -/// Combines dependency-tracking overhead with multi-instance contention. Each task completion -/// acquires locks on 10 child TCBs while also racing against sibling instances that may -/// complete or fail concurrently. -/// -/// # Graph topology -/// -/// Same as [`test_scale_10k_neural_net`] (10 layers × 1000, fan-in=10, circular connectivity), -/// but with `max_num_instances=3` and `max_num_retries=2`. -/// -/// # How to interpret results -/// -/// Compare against the single-instance neural-net test to isolate multi-instance overhead in a -/// dependency-heavy graph. -/// -/// Run with `cargo test test_scale_10k_neural_net_multi_instance -- --nocapture` to see timing. -#[tokio::test] -async fn test_scale_10k_neural_net_multi_instance() { - const NUM_LAYERS: usize = 10; - const WIDTH: usize = 1000; - const FAN_IN: usize = 10; - const NUM_WORKERS: usize = 128; - - let policy = ExecutionPolicy { - max_num_instances: 3, - max_num_retries: 2, - }; - let graph_start = Instant::now(); - let (graph, job_inputs, _layers) = - build_neural_net_graph_with_policy(NUM_LAYERS, WIDTH, FAN_IN, policy); - let graph_time = graph_start.elapsed(); - - assert_eq!(graph.get_num_tasks(), NUM_LAYERS * WIDTH); - - let handler: Arc = Arc::new(MultiInstancePartialFailHandler { - num_outputs_per_task: 1, - num_instances: 3, - }); - let (result, build_job_time) = - run_scheduled_test(&graph, job_inputs, NUM_WORKERS, handler).await; - - assert_eq!( - result.tasks_dispatched, - NUM_LAYERS * WIDTH, - "all tasks should have completed" - ); - assert_eq!(result.final_state, JobState::Succeeded); - assert_eq!( - result.total_tasks_reported_ready, - (NUM_LAYERS - 1) * WIDTH, - "9 layers x 1000 tasks should be reported ready" - ); - - result.report( - "10k Neural Net (10x1000, fan=10, 3 instances, 2 success + 1 fail)", - NUM_WORKERS, - graph_time, - build_job_time, - ); -} - /// Tests that a job fails when all task instances always fail and retries are exhausted. /// /// # Purpose @@ -1553,7 +1317,10 @@ async fn test_neural_net_ready_queue_correctness() { let initial_set: HashSet = initial_ready.into_iter().collect(); let expected_layer_0: HashSet = layers[0].iter().copied().collect(); - assert_eq!(initial_set, expected_layer_0, "only layer 0 should be initially ready"); + assert_eq!( + initial_set, expected_layer_0, + "only layer 0 should be initially ready" + ); for (layer_idx, layer) in layers.iter().enumerate() { ready_tasks_ref.lock().await.clear(); @@ -1587,7 +1354,8 @@ async fn test_neural_net_ready_queue_correctness() { let mut expected: Vec = layers[layer_idx + 1].clone(); expected.sort_unstable(); assert_eq!( - reported_ready, expected, + reported_ready, + expected, "layer {layer_idx}: reported ready tasks should match layer {}", layer_idx + 1 );