diff --git a/crates/cache/src/lib.rs b/crates/cache/src/lib.rs index 5dd4c36ce..2175ada8c 100644 --- a/crates/cache/src/lib.rs +++ b/crates/cache/src/lib.rs @@ -59,6 +59,15 @@ pub trait Cache: ReadOnlyCache + Send + Sync + std::fmt::Debug { /// Mark a token as registered. async fn mark_token_registered(&self, token_id: TokenId); + /// Rebuild the token-registration registry from committed storage. + /// + /// Used after a rollback: `mark_token_registered` mutates the registry before SQL + /// commits, so a rolled-back chunk can leave the cache claiming a token is + /// registered when the storage row is gone. Resetting from `storage.token_ids()` + /// restores the cache to the last committed state without losing tokens that + /// previous chunks did register successfully. + async fn reset_token_registry(&self) -> Result<(), CacheError>; + /// Clear the balances diff. async fn clear_balances_diff(&self); @@ -134,6 +143,10 @@ impl Cache for InMemoryCache { self.erc_cache.mark_token_registered(token_id).await } + async fn reset_token_registry(&self) -> Result<(), CacheError> { + self.erc_cache.reset_token_registry().await + } + async fn clear_balances_diff(&self) { self.erc_cache.balances_diff.clear(); self.erc_cache.balances_diff.shrink_to_fit(); @@ -252,23 +265,40 @@ pub struct ErcCache { // the registry is a map of token_id to a mutex that is used to track if the token is registered // we need a mutex for the token state to prevent race conditions in case of multiple token regs pub token_id_registry: DashMap, + storage: Arc, } impl ErcCache { pub async fn new(storage: Arc) -> Result { - // read existing token_id's from balances table and cache them - let token_id_registry: HashSet = storage.token_ids().await?; + let token_id_registry = Self::load_token_registry(&*storage).await?; Ok(Self { balances_diff: DashMap::new(), total_supply_diff: DashMap::new(), - token_id_registry: token_id_registry - .into_iter() - .map(|token_id| (token_id, TokenState::Registered)) - .collect(), + token_id_registry, + storage, }) } + async fn load_token_registry( + storage: &dyn ReadOnlyStorage, + ) -> Result, Error> { + let token_ids: HashSet = storage.token_ids().await?; + Ok(token_ids + .into_iter() + .map(|token_id| (token_id, TokenState::Registered)) + .collect()) + } + + pub async fn reset_token_registry(&self) -> Result<(), Error> { + let rebuilt = Self::load_token_registry(&*self.storage).await?; + self.token_id_registry.clear(); + for (token_id, state) in rebuilt { + self.token_id_registry.insert(token_id, state); + } + Ok(()) + } + pub async fn get_token_registration_lock(&self, token_id: TokenId) -> Option>> { let entry = self.token_id_registry.entry(token_id); match entry { @@ -486,3 +516,209 @@ pub fn get_entrypoint_name_from_class(class: &ClassAbi, selector: Felt) -> Optio }), } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + use torii_proto::{ + Achievement, AchievementQuery, Activity, ActivityQuery, AggregationEntry, AggregationQuery, + Contract, ContractQuery, Controller, ControllerQuery, Event, EventQuery, Page, + PlayerAchievementEntry, PlayerAchievementQuery, Query, SearchQuery, SearchResponse, Token, + TokenBalance, TokenBalanceQuery, TokenContract, TokenContractQuery, TokenQuery, + TokenTransfer, TokenTransferQuery, Transaction, TransactionQuery, + }; + use torii_storage::StorageError; + + /// Test stub that only services `token_ids` and `models`. The rest of `ReadOnlyStorage` + /// panics — these tests do not exercise any of those methods. + #[derive(Debug)] + struct StubStorage { + token_ids: Mutex>, + models: Mutex>, + } + + impl StubStorage { + fn new() -> Arc { + Arc::new(Self { + token_ids: Mutex::new(HashSet::new()), + models: Mutex::new(Vec::new()), + }) + } + + fn set_committed_tokens(&self, ids: Vec) { + *self.token_ids.lock().unwrap() = ids.into_iter().collect(); + } + } + + #[async_trait] + impl ReadOnlyStorage for StubStorage { + fn as_read_only(&self) -> &dyn ReadOnlyStorage { + self + } + + async fn model(&self, _world: Felt, _selector: Felt) -> Result { + unimplemented!() + } + + async fn model_optional( + &self, + _world: Felt, + _selector: Felt, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn models( + &self, + _world_addresses: &[Felt], + _selectors: &[Felt], + ) -> Result, StorageError> { + Ok(self.models.lock().unwrap().clone()) + } + + async fn token_ids(&self) -> Result, StorageError> { + Ok(self.token_ids.lock().unwrap().clone()) + } + + async fn controllers( + &self, + _query: &ControllerQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn contracts(&self, _query: &ContractQuery) -> Result, StorageError> { + unimplemented!() + } + async fn tokens(&self, _query: &TokenQuery) -> Result, StorageError> { + unimplemented!() + } + async fn token_balances( + &self, + _query: &TokenBalanceQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn token_contracts( + &self, + _query: &TokenContractQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn token_transfers( + &self, + _query: &TokenTransferQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn transactions( + &self, + _query: &TransactionQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn events(&self, _query: EventQuery) -> Result, StorageError> { + unimplemented!() + } + async fn entities(&self, _query: &Query) -> Result, StorageError> { + unimplemented!() + } + async fn event_messages(&self, _query: &Query) -> Result, StorageError> { + unimplemented!() + } + async fn entity_model( + &self, + _world: Felt, + _entity_id: Felt, + _model_selector: Felt, + ) -> Result, StorageError> { + unimplemented!() + } + async fn aggregations( + &self, + _query: &AggregationQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn activities(&self, _query: &ActivityQuery) -> Result, StorageError> { + unimplemented!() + } + async fn achievements( + &self, + _query: &AchievementQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn player_achievements( + &self, + _query: &PlayerAchievementQuery, + ) -> Result, StorageError> { + unimplemented!() + } + async fn search(&self, _query: &SearchQuery) -> Result { + unimplemented!() + } + } + + use torii_proto::schema::Entity; + + fn token_id(byte: u8) -> TokenId { + TokenId::Contract(Felt::from(byte)) + } + + /// Mirrors the production hazard: `mark_token_registered` runs before SQL commit; + /// rollback drops the SQL but leaves the cache claiming the token is registered. + /// `reset_token_registry` must restore the cache to "what storage actually has". + #[tokio::test] + async fn reset_token_registry_drops_uncommitted_marks() { + let storage = StubStorage::new(); + // T1 was registered in a previous (committed) chunk. + storage.set_committed_tokens(vec![token_id(1)]); + + let cache = InMemoryCache::new(storage.clone()).await.unwrap(); + + // T2 gets marked in this chunk but the SQL row never lands (rollback). + cache.mark_token_registered(token_id(2)).await; + assert!(cache.is_token_registered(&token_id(1)).await); + assert!(cache.is_token_registered(&token_id(2)).await); + + cache.reset_token_registry().await.unwrap(); + + assert!(cache.is_token_registered(&token_id(1)).await); + assert!(!cache.is_token_registered(&token_id(2).clone()).await); + } + + /// `clear_models` is the model-cache half of the rollback recovery. After clearing, + /// `cache.model()` must report missing — callers fall through to storage which + /// reflects the committed (rolled-back) state. + #[tokio::test] + async fn clear_models_empties_the_cache() { + let storage = StubStorage::new(); + let cache = InMemoryCache::new(storage).await.unwrap(); + + let world = Felt::from(0xa); + let selector = Felt::from(0xb); + let model = Model { + world_address: world, + namespace: "ns".into(), + name: "M".into(), + selector, + class_hash: Felt::ZERO, + contract_address: Felt::ZERO, + packed_size: 0, + unpacked_size: 0, + layout: dojo_world::contracts::abigen::model::Layout::Fixed(vec![]), + schema: dojo_types::schema::Ty::Tuple(vec![]), + use_legacy_store: true, + }; + cache.register_model(world, selector, model).await; + assert!(cache.model(world, selector).await.is_ok()); + + cache.clear_models().await; + + assert!(matches!( + cache.model(world, selector).await, + Err(CacheError::ModelNotFound(s)) if s == selector + )); + } +} diff --git a/crates/indexer/engine/src/engine.rs b/crates/indexer/engine/src/engine.rs index ee2aaca11..9e6185c63 100644 --- a/crates/indexer/engine/src/engine.rs +++ b/crates/indexer/engine/src/engine.rs @@ -233,6 +233,9 @@ impl Engine

{ error!(target: LOG_TARGET, error = ?e, "Processing fetched data."); processing_erroring_out = true; self.storage.rollback().await?; + self.cache.clear_balances_diff().await; + self.cache.clear_models().await; + self.cache.reset_token_registry().await?; self.task_manager.clear_tasks(); gauge!("torii_indexer_backoff_delay_seconds", "operation" => "process").set(processing_backoff_delay.as_secs_f64()); sleep(processing_backoff_delay).await; diff --git a/crates/indexer/engine/src/test.rs b/crates/indexer/engine/src/test.rs index ed7f16462..d64d72050 100644 --- a/crates/indexer/engine/src/test.rs +++ b/crates/indexer/engine/src/test.rs @@ -1,10 +1,16 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; use std::str::FromStr; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use async_trait::async_trait; use cainome::cairo_serde::{ByteArray, CairoSerde, ContractAddress}; use dojo_test_utils::migration::copy_spawn_and_move_db; use dojo_test_utils::setup::TestSetup; +use dojo_types::primitive::Primitive; +use dojo_types::schema::{Member, Struct, Ty}; use dojo_utils::{TransactionExt, TransactionWaiter, TxnConfig}; +use dojo_world::contracts::abigen::model::Layout; use dojo_world::contracts::naming::{compute_bytearray_hash, compute_selector_from_names}; use dojo_world::contracts::world::WorldContract; use katana_runner::RunnerCtx; @@ -12,15 +18,17 @@ use num_traits::ToPrimitive; use scarb_interop::Profile; use scarb_metadata_ext::MetadataDojoExt; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; +use sqlx::Row; use starknet::accounts::Account; -use starknet::core::types::{BlockId, BlockTag, Call, Felt, FunctionCall, U256}; +use starknet::core::types::{BlockId, BlockTag, Call, Event, Felt, FunctionCall, U256}; use starknet::core::utils::get_selector_from_name; +use starknet::macros::selector; use starknet::providers::jsonrpc::HttpTransport; -use starknet::providers::{JsonRpcClient, Provider}; +use starknet::providers::{JsonRpcClient, Provider, Url}; use starknet_crypto::poseidon_hash_many; use tempfile::NamedTempFile; use tokio::sync::broadcast; -use torii_cache::{Cache, InMemoryCache}; +use torii_cache::{Cache, InMemoryCache, ReadOnlyCache}; use torii_sqlite::executor::Executor; use torii_sqlite::types::Token; use torii_sqlite::utils::{felt_and_u256_to_sql_string, felt_to_sql_string, u256_to_sql_string}; @@ -30,8 +38,13 @@ use torii_storage::utils::format_world_scoped_id; use torii_storage::Storage; use crate::engine::{Engine, EngineConfig}; -use torii_indexer_fetcher::{Fetcher, FetcherConfig}; -use torii_processors::processors::Processors; +use torii_indexer_fetcher::{ + FetchRangeBlock, FetchRangeResult, FetchResult, FetchTransaction, Fetcher, FetcherConfig, +}; +use torii_processors::error::Error as ProcessorError; +use torii_processors::{ + EventProcessor, EventProcessorContext, Processors, Result as ProcessorResult, +}; pub async fn bootstrap_engine

( db: Sql, @@ -1394,6 +1407,574 @@ async fn count_table(table_name: &str, pool: &sqlx::Pool) -> i64 { count.0 } +fn hashed_task_identifier(from_address: Felt, discriminator: Felt) -> u64 { + let mut hasher = DefaultHasher::new(); + from_address.hash(&mut hasher); + discriminator.hash(&mut hasher); + hasher.finish() +} + +fn schema_has_member(schema: &Ty, member_name: &str) -> bool { + match schema { + Ty::Struct(struct_ty) => struct_ty + .children + .iter() + .any(|member| member.name == member_name), + _ => false, + } +} + +async fn table_has_column( + pool: &sqlx::Pool, + table_name: &str, + column_name: &str, +) -> bool { + let query = format!("PRAGMA table_info([{table_name}])"); + let rows = sqlx::query(&query).fetch_all(pool).await.unwrap(); + + rows.into_iter() + .any(|row| row.try_get::("name").unwrap() == column_name) +} + +#[derive(Debug)] +struct SyntheticModelUpgradeProcessor { + selector: Felt, + added_member: &'static str, +} + +#[async_trait] +impl

EventProcessor

for SyntheticModelUpgradeProcessor +where + P: Provider + Send + Sync + Clone + std::fmt::Debug + 'static, +{ + fn event_key(&self) -> String { + "SyntheticModelUpgrade".to_string() + } + + fn validate(&self, event: &Event) -> bool { + event.keys.len() == 2 && event.keys[1] == self.selector + } + + fn task_identifier(&self, event: &Event) -> u64 { + hashed_task_identifier(event.from_address, event.keys[1]) + } + + async fn process(&self, ctx: &EventProcessorContext

) -> ProcessorResult<()> { + let current = ctx + .storage + .model_optional(ctx.contract_address, self.selector) + .await? + .expect("seeded model must exist"); + + let mut upgraded_schema = current.schema.clone(); + let struct_ty = match &mut upgraded_schema { + Ty::Struct(struct_ty) => struct_ty, + _ => panic!("synthetic test expects a struct model"), + }; + + if struct_ty + .children + .iter() + .all(|member| member.name != self.added_member) + { + struct_ty.children.push(Member { + name: self.added_member.to_string(), + ty: Ty::Primitive(Primitive::U32(None)), + key: false, + }); + } + + let Some(schema_diff) = upgraded_schema.diff(¤t.schema) else { + return Ok(()); + }; + let upgrade_diff = current.schema.diff(&upgraded_schema); + let packed_size = current.packed_size.saturating_add(1); + let unpacked_size = current.unpacked_size.saturating_add(1); + + ctx.storage + .register_model( + ctx.contract_address, + self.selector, + &upgraded_schema, + ¤t.layout, + current.class_hash, + current.contract_address, + packed_size, + unpacked_size, + ctx.block_timestamp, + Some(&schema_diff), + upgrade_diff.as_ref(), + current.use_legacy_store, + ) + .await?; + + ctx.cache + .register_model( + ctx.contract_address, + self.selector, + torii_storage::proto::Model { + world_address: ctx.contract_address, + namespace: current.namespace, + name: current.name, + selector: self.selector, + class_hash: current.class_hash, + contract_address: current.contract_address, + packed_size, + unpacked_size, + layout: current.layout, + schema: upgraded_schema, + use_legacy_store: current.use_legacy_store, + }, + ) + .await; + + Ok(()) + } +} + +#[derive(Debug)] +struct OneShotFailProcessor { + invocations: Arc, +} + +#[async_trait] +impl

EventProcessor

for OneShotFailProcessor +where + P: Provider + Send + Sync + Clone + std::fmt::Debug + 'static, +{ + fn event_key(&self) -> String { + "Fail".to_string() + } + + fn validate(&self, event: &Event) -> bool { + event.keys.len() == 3 + } + + fn task_identifier(&self, event: &Event) -> u64 { + let mut hasher = DefaultHasher::new(); + event.from_address.hash(&mut hasher); + let canonical_pair = std::cmp::max(event.keys[1], event.keys[2]); + canonical_pair.hash(&mut hasher); + hasher.finish() + } + + async fn process(&self, _ctx: &EventProcessorContext

) -> ProcessorResult<()> { + let attempt = self.invocations.fetch_add(1, Ordering::SeqCst); + if attempt == 0 { + Err(ProcessorError::UriMalformed) + } else { + Ok(()) + } + } +} + +#[derive(Debug)] +struct OneShotModelFailProcessor { + invocations: Arc, + selector: Felt, +} + +#[async_trait] +impl

EventProcessor

for OneShotModelFailProcessor +where + P: Provider + Send + Sync + Clone + std::fmt::Debug + 'static, +{ + fn event_key(&self) -> String { + "SyntheticModelFail".to_string() + } + + fn validate(&self, event: &Event) -> bool { + event.keys.len() == 2 && event.keys[1] == self.selector + } + + fn task_identifier(&self, event: &Event) -> u64 { + hashed_task_identifier(event.from_address, event.keys[1]) + } + + async fn process(&self, _ctx: &EventProcessorContext

) -> ProcessorResult<()> { + let attempt = self.invocations.fetch_add(1, Ordering::SeqCst); + if attempt == 0 { + Err(ProcessorError::UriMalformed) + } else { + Ok(()) + } + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_rollback_replays_model_upgrade_after_cache_reset() { + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new( + Url::parse("http://127.0.0.1:0").unwrap(), + ))); + + let tempfile = NamedTempFile::new().unwrap(); + let path = tempfile.path().to_string_lossy(); + let options = SqliteConnectOptions::from_str(&path) + .unwrap() + .create_if_missing(true); + let pool = SqlitePoolOptions::new() + .connect_with(options) + .await + .unwrap(); + sqlx::migrate!("../../migrations").run(&pool).await.unwrap(); + + let (shutdown_tx, _) = broadcast::channel(1); + let (mut executor, sender) = + Executor::new(pool.clone(), shutdown_tx.clone(), Arc::clone(&provider)) + .await + .unwrap(); + tokio::spawn(async move { + executor.run().await.unwrap(); + }); + + let world_address = Felt::from(0x111_u64); + let model_selector = Felt::from(0x222_u64); + let initial_schema = Ty::Struct(Struct { + name: "ns-RollbackProbe".to_string(), + children: vec![ + Member { + name: "player".to_string(), + ty: Ty::Primitive(Primitive::ContractAddress(None)), + key: true, + }, + Member { + name: "score".to_string(), + ty: Ty::Primitive(Primitive::U32(None)), + key: false, + }, + ], + }); + let layout = Layout::Fixed(vec![]); + + let contracts = vec![ContractDefinition { + address: world_address, + r#type: ContractType::WORLD, + starting_block: None, + }]; + + let db = Sql::new(pool.clone(), sender, &contracts).await.unwrap(); + let cache = Arc::new(InMemoryCache::new(Arc::new(db.clone())).await.unwrap()); + let db = db.with_cache(cache.clone()); + + db.register_model( + world_address, + model_selector, + &initial_schema, + &layout, + Felt::from(0x333_u64), + Felt::from(0x444_u64), + 1, + 2, + 1_715_000_000, + None, + None, + true, + ) + .await + .unwrap(); + db.execute().await.unwrap(); + + let table_name = initial_schema.name(); + let added_member = "rollback_probe"; + assert!(!table_has_column(&pool, &table_name, added_member).await); + assert!(matches!( + cache.model(world_address, model_selector).await, + Err(torii_cache::error::Error::ModelNotFound(selector)) if selector == model_selector + )); + + let fail_invocations = Arc::new(AtomicUsize::new(0)); + let mut processors = Processors::>>::default(); + processors + .event_processors + .get_mut(&ContractType::WORLD) + .unwrap() + .entry(selector!("SyntheticModelUpgrade")) + .or_default() + .push(Box::new(SyntheticModelUpgradeProcessor { + selector: model_selector, + added_member, + })); + processors + .event_processors + .get_mut(&ContractType::WORLD) + .unwrap() + .entry(selector!("SyntheticModelFail")) + .or_default() + .push(Box::new(OneShotModelFailProcessor { + invocations: fail_invocations.clone(), + selector: model_selector, + })); + let processors = Arc::new(processors); + + let upgrade_event = Event { + from_address: world_address, + keys: vec![selector!("SyntheticModelUpgrade"), model_selector], + data: vec![], + }; + let fail_event = Event { + from_address: world_address, + keys: vec![selector!("SyntheticModelFail"), model_selector], + data: vec![], + }; + + let tx_hash = Felt::from(0x999_u64); + let block_number = 1_u64; + let block_timestamp = 1_715_000_123_u64; + let fetch_result = FetchResult { + range: FetchRangeResult { + blocks: std::collections::BTreeMap::from([( + block_number, + FetchRangeBlock { + block_hash: Some(Felt::from(0x1234_u64)), + timestamp: block_timestamp, + transactions: vec![( + tx_hash, + FetchTransaction { + transaction: None, + events: vec![upgrade_event.clone(), fail_event.clone()], + receipt: None, + }, + )] + .into_iter() + .collect(), + }, + )]), + }, + preconfirmed_block: None, + cursors: torii_indexer_fetcher::Cursors { + cursor_transactions: std::collections::HashMap::new(), + cursors: std::collections::HashMap::from([( + world_address, + torii_storage::proto::ContractCursor { + contract_address: world_address, + head: Some(0), + last_block_timestamp: None, + last_pending_block_tx: None, + }, + )]), + }, + }; + let contract_types = std::collections::HashMap::from([(world_address, ContractType::WORLD)]); + + let mut engine = Engine::new( + Arc::new(db.clone()), + cache.clone(), + provider.clone(), + processors.clone(), + EngineConfig::default(), + shutdown_tx.clone(), + ); + + assert!(engine + .process(&fetch_result, &contract_types) + .await + .is_err()); + + let poisoned_model = cache.model(world_address, model_selector).await.unwrap(); + assert!(schema_has_member(&poisoned_model.schema, added_member)); + assert!(!table_has_column(&pool, &table_name, added_member).await); + + db.rollback().await.unwrap(); + cache.clear_balances_diff().await; + cache.clear_models().await; + cache.reset_token_registry().await.unwrap(); + + assert!(matches!( + cache.model(world_address, model_selector).await, + Err(torii_cache::error::Error::ModelNotFound(selector)) if selector == model_selector + )); + + let mut retry_engine = Engine::new( + Arc::new(db.clone()), + cache.clone(), + provider, + processors, + EngineConfig::default(), + shutdown_tx, + ); + retry_engine + .process(&fetch_result, &contract_types) + .await + .unwrap(); + db.execute().await.unwrap(); + + assert!(table_has_column(&pool, &table_name, added_member).await); + let upgraded_model = cache.model(world_address, model_selector).await.unwrap(); + assert!(schema_has_member(&upgraded_model.schema, added_member)); + assert_eq!(fail_invocations.load(Ordering::SeqCst), 2); +} + +#[tokio::test(flavor = "multi_thread")] +#[katana_runner::test(accounts = 10, db_dir = copy_spawn_and_move_db().as_str())] +async fn test_rollback_resets_token_registry_for_retry(sequencer: &RunnerCtx) { + let setup = TestSetup::from_examples("/tmp", "../../../examples/"); + let metadata = setup.load_metadata("spawn-and-move", Profile::DEV); + + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(sequencer.url()))); + let manifest = metadata.read_dojo_manifest_profile().unwrap().unwrap(); + let token_address = manifest + .external_contracts + .iter() + .find(|c| c.tag == "ns-WoodToken") + .unwrap() + .address; + + let tempfile = NamedTempFile::new().unwrap(); + let path = tempfile.path().to_string_lossy(); + let options = SqliteConnectOptions::from_str(&path) + .unwrap() + .create_if_missing(true); + let pool = SqlitePoolOptions::new() + .connect_with(options) + .await + .unwrap(); + sqlx::migrate!("../../migrations").run(&pool).await.unwrap(); + + let (shutdown_tx, _) = broadcast::channel(1); + let (mut executor, sender) = + Executor::new(pool.clone(), shutdown_tx.clone(), Arc::clone(&provider)) + .await + .unwrap(); + tokio::spawn(async move { + executor.run().await.unwrap(); + }); + + let contracts = vec![ContractDefinition { + address: token_address, + r#type: ContractType::ERC20, + starting_block: None, + }]; + + let db = Sql::new(pool.clone(), sender, &contracts).await.unwrap(); + let cache = Arc::new(InMemoryCache::new(Arc::new(db.clone())).await.unwrap()); + let db = db.with_cache(cache.clone()); + + let fail_invocations = Arc::new(AtomicUsize::new(0)); + let mut processors = Processors::>>::default(); + processors + .event_processors + .get_mut(&ContractType::ERC20) + .unwrap() + .entry(selector!("Fail")) + .or_default() + .push(Box::new(OneShotFailProcessor { + invocations: fail_invocations.clone(), + })); + let processors = Arc::new(processors); + + let transfer_event = Event { + from_address: token_address, + keys: vec![ + selector!("Transfer"), + Felt::from(0xabc_u64), + Felt::from(0xdef_u64), + ], + data: vec![Felt::from(12345_u64), Felt::ZERO], + }; + let fail_event = Event { + from_address: token_address, + keys: vec![ + selector!("Fail"), + Felt::from(0xabc_u64), + Felt::from(0xdef_u64), + ], + data: vec![], + }; + + let tx_hash = Felt::from(0x999_u64); + let block_number = 1_u64; + let block_timestamp = 1_715_000_000_u64; + let fetch_result = FetchResult { + range: FetchRangeResult { + blocks: std::collections::BTreeMap::from([( + block_number, + FetchRangeBlock { + block_hash: Some(Felt::from(0x1234_u64)), + timestamp: block_timestamp, + transactions: vec![( + tx_hash, + FetchTransaction { + transaction: None, + events: vec![transfer_event.clone(), fail_event.clone()], + receipt: None, + }, + )] + .into_iter() + .collect(), + }, + )]), + }, + preconfirmed_block: None, + cursors: torii_indexer_fetcher::Cursors { + cursor_transactions: std::collections::HashMap::new(), + cursors: std::collections::HashMap::from([( + token_address, + torii_storage::proto::ContractCursor { + contract_address: token_address, + head: Some(0), + last_block_timestamp: None, + last_pending_block_tx: None, + }, + )]), + }, + }; + let contract_types = std::collections::HashMap::from([(token_address, ContractType::ERC20)]); + + let mut engine = Engine::new( + Arc::new(db.clone()), + cache.clone(), + provider.clone(), + processors.clone(), + EngineConfig::default(), + shutdown_tx.clone(), + ); + + assert!(engine + .process(&fetch_result, &contract_types) + .await + .is_err()); + assert_eq!(cache.erc_cache.token_id_registry.len(), 1); + assert_eq!( + sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM tokens WHERE contract_address = ?") + .bind(felt_to_sql_string(&token_address)) + .fetch_one(&pool) + .await + .unwrap(), + 0 + ); + + db.rollback().await.unwrap(); + cache.clear_balances_diff().await; + cache.clear_models().await; + cache.reset_token_registry().await.unwrap(); + + assert_eq!(cache.erc_cache.token_id_registry.len(), 0); + + let mut retry_engine = Engine::new( + Arc::new(db.clone()), + cache.clone(), + provider, + processors, + EngineConfig::default(), + shutdown_tx, + ); + retry_engine + .process(&fetch_result, &contract_types) + .await + .unwrap(); + db.execute().await.unwrap(); + + assert_eq!( + sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM tokens WHERE contract_address = ?") + .bind(felt_to_sql_string(&token_address)) + .fetch_one(&pool) + .await + .unwrap(), + 1 + ); + assert_eq!(fail_invocations.load(Ordering::SeqCst), 2); +} + #[tokio::test(flavor = "multi_thread")] #[katana_runner::test(accounts = 10, db_dir = copy_spawn_and_move_db().as_str())] async fn test_erc20_total_supply_tracking(sequencer: &RunnerCtx) { diff --git a/crates/processors/src/error.rs b/crates/processors/src/error.rs index 308a733ab..f0c1bce08 100644 --- a/crates/processors/src/error.rs +++ b/crates/processors/src/error.rs @@ -1,3 +1,4 @@ +use starknet::core::types::Felt; use starknet::core::utils::ParseCairoShortStringError; use thiserror::Error; @@ -13,6 +14,8 @@ pub enum Error { ModelError(#[from] dojo_world::contracts::model::ModelError), #[error(transparent)] PrimitiveError(#[from] dojo_types::primitive::PrimitiveError), + #[error("Model not found: {0:#x}")] + ModelNotFound(Felt), #[error("Model member not found: {0}")] ModelMemberNotFound(String), #[error("Uri is malformed")] diff --git a/crates/processors/src/lib.rs b/crates/processors/src/lib.rs index 1dd04035e..44ccbd286 100644 --- a/crates/processors/src/lib.rs +++ b/crates/processors/src/lib.rs @@ -6,7 +6,9 @@ use starknet::core::types::{Event, Felt, TransactionContent}; use starknet::providers::Provider; use tokio::sync::Semaphore; use torii_cache::{Cache, ContractClassCache}; +use torii_proto::Model; use torii_storage::Storage; +use tracing::debug; mod constants; pub mod erc; @@ -37,6 +39,33 @@ pub struct EventProcessorContext { pub is_at_head: bool, } +impl EventProcessorContext

{ + /// Looks up a model registered for the current contract. + /// + /// `Ok(Some(model))` — model exists; `Ok(None)` — model is unknown but namespace + /// filtering is configured, so the caller should skip silently; + /// `Err(Error::ModelNotFound)` — model is unknown and namespace filtering is off, + /// so the caller should fail loud. + pub async fn resolve_model_or_skip(&self, selector: Felt) -> Result> { + match self + .storage + .model_optional(self.contract_address, selector) + .await? + { + Some(model) => Ok(Some(model)), + None if !self.config.namespaces.is_empty() => { + debug!( + target: "torii::processors::resolve_model", + selector = %selector, + "Model not found in storage, skipping. This can happen if only specific namespaces are indexed." + ); + Ok(None) + } + None => Err(Error::ModelNotFound(selector)), + } + } +} + #[derive(Clone, Debug)] pub struct EventProcessorConfig { pub namespaces: HashSet, diff --git a/crates/processors/src/processors/event_message.rs b/crates/processors/src/processors/event_message.rs index a1786c06c..2169ce841 100644 --- a/crates/processors/src/processors/event_message.rs +++ b/crates/processors/src/processors/event_message.rs @@ -6,8 +6,7 @@ use dojo_world::contracts::abigen::world::Event as WorldEvent; use starknet::core::types::{Event, Felt}; use starknet::providers::Provider; use starknet_crypto::poseidon_hash_many; -use torii_cache::CacheError; -use tracing::{debug, info}; +use tracing::info; use crate::error::Error; use crate::task_manager::TaskId; @@ -81,18 +80,8 @@ where } }; - // silently ignore if the model is not found - let model = match ctx.cache.model(ctx.contract_address, event.selector).await { - Ok(model) => model, - Err(CacheError::ModelNotFound(_)) if !ctx.config.namespaces.is_empty() => { - debug!( - target: LOG_TARGET, - selector = %event.selector, - "Model not found in cache, skipping. This can happen if only specific namespaces are indexed." - ); - return Ok(()); - } - Err(e) => return Err(e.into()), + let Some(model) = ctx.resolve_model_or_skip(event.selector).await? else { + return Ok(()); }; info!( @@ -103,7 +92,7 @@ where "Store event message." ); - let mut keys_and_unpacked = [event.keys.clone(), event.values].concat(); + let mut keys_and_unpacked = [event.keys.as_slice(), event.values.as_slice()].concat(); let mut entity = model.schema.clone(); entity.deserialize(&mut keys_and_unpacked, model.use_legacy_store)?; diff --git a/crates/processors/src/processors/store_del_record.rs b/crates/processors/src/processors/store_del_record.rs index 50c4abed5..0070de033 100644 --- a/crates/processors/src/processors/store_del_record.rs +++ b/crates/processors/src/processors/store_del_record.rs @@ -3,8 +3,7 @@ use dojo_world::contracts::abigen::world::Event as WorldEvent; use starknet::core::types::Event; use starknet::providers::Provider; use std::hash::{DefaultHasher, Hash, Hasher}; -use torii_cache::CacheError; -use tracing::{debug, info}; +use tracing::info; use crate::error::Error; use crate::task_manager::TaskId; @@ -70,19 +69,8 @@ where } }; - // If the model does not exist, silently ignore it. - // This can happen if only specific namespaces are indexed. - let model = match ctx.cache.model(ctx.contract_address, event.selector).await { - Ok(m) => m, - Err(CacheError::ModelNotFound(_)) if !ctx.config.namespaces.is_empty() => { - debug!( - target: LOG_TARGET, - selector = %event.selector, - "Model not found in cache, skipping. This can happen if only specific namespaces are indexed." - ); - return Ok(()); - } - Err(e) => return Err(e.into()), + let Some(model) = ctx.resolve_model_or_skip(event.selector).await? else { + return Ok(()); }; info!( diff --git a/crates/processors/src/processors/store_set_record.rs b/crates/processors/src/processors/store_set_record.rs index ce978f885..151f1484e 100644 --- a/crates/processors/src/processors/store_set_record.rs +++ b/crates/processors/src/processors/store_set_record.rs @@ -4,8 +4,7 @@ use async_trait::async_trait; use dojo_world::contracts::abigen::world::Event as WorldEvent; use starknet::core::types::Event; use starknet::providers::Provider; -use torii_cache::CacheError; -use tracing::{debug, info}; +use tracing::info; use crate::error::Error; use crate::task_manager::TaskId; @@ -72,21 +71,8 @@ where } }; - // If the model does not exist, silently ignore it. - // This can happen if only specific namespaces are indexed. - let model = match ctx.cache.model(ctx.contract_address, event.selector).await { - Ok(m) => m, - Err(CacheError::ModelNotFound(_)) if !ctx.config.namespaces.is_empty() => { - debug!( - target: LOG_TARGET, - selector = %event.selector, - "Model not found in cache, skipping. This can happen if only specific namespaces are indexed." - ); - return Ok(()); - } - Err(e) => { - return Err(e.into()); - } + let Some(model) = ctx.resolve_model_or_skip(event.selector).await? else { + return Ok(()); }; info!( diff --git a/crates/processors/src/processors/store_update_member.rs b/crates/processors/src/processors/store_update_member.rs index e9b679849..623a9a320 100644 --- a/crates/processors/src/processors/store_update_member.rs +++ b/crates/processors/src/processors/store_update_member.rs @@ -6,8 +6,7 @@ use dojo_world::contracts::abigen::world::Event as WorldEvent; use starknet::core::types::Event; use starknet::core::utils::get_selector_from_name; use starknet::providers::Provider; -use torii_cache::CacheError; -use tracing::{debug, info}; +use tracing::info; use crate::error::Error; use crate::task_manager::TaskId; @@ -83,21 +82,8 @@ where let entity_id = event.entity_id; let member_selector = event.member_selector; - // If the model does not exist, silently ignore it. - // This can happen if only specific namespaces are indexed. - let model = match ctx.cache.model(ctx.contract_address, model_selector).await { - Ok(m) => m, - Err(CacheError::ModelNotFound(_)) if !ctx.config.namespaces.is_empty() => { - debug!( - target: LOG_TARGET, - selector = %model_selector, - "Model not found in cache, skipping. This can happen if only specific namespaces are indexed." - ); - return Ok(()); - } - Err(e) => { - return Err(e.into()); - } + let Some(model) = ctx.resolve_model_or_skip(model_selector).await? else { + return Ok(()); }; let schema = model.schema; diff --git a/crates/processors/src/processors/store_update_record.rs b/crates/processors/src/processors/store_update_record.rs index 573504e77..591206ee9 100644 --- a/crates/processors/src/processors/store_update_record.rs +++ b/crates/processors/src/processors/store_update_record.rs @@ -5,8 +5,7 @@ use dojo_types::schema::Ty; use dojo_world::contracts::abigen::world::Event as WorldEvent; use starknet::core::types::Event; use starknet::providers::Provider; -use torii_cache::CacheError; -use tracing::{debug, info}; +use tracing::info; use crate::task_manager::TaskId; use crate::{EventProcessor, EventProcessorConfig, EventProcessorContext}; @@ -78,21 +77,8 @@ where let model_selector = event.selector; let entity_id = event.entity_id; - // If the model does not exist, silently ignore it. - // This can happen if only specific namespaces are indexed. - let model = match ctx.cache.model(ctx.contract_address, event.selector).await { - Ok(m) => m, - Err(CacheError::ModelNotFound(_)) if !ctx.config.namespaces.is_empty() => { - debug!( - target: LOG_TARGET, - selector = %event.selector, - "Model not found in cache, skipping. This can happen if only specific namespaces are indexed." - ); - return Ok(()); - } - Err(e) => { - return Err(e.into()); - } + let Some(model) = ctx.resolve_model_or_skip(event.selector).await? else { + return Ok(()); }; info!( diff --git a/crates/processors/src/processors/upgrade_event.rs b/crates/processors/src/processors/upgrade_event.rs index f56fa35ed..c085ac61f 100644 --- a/crates/processors/src/processors/upgrade_event.rs +++ b/crates/processors/src/processors/upgrade_event.rs @@ -7,7 +7,6 @@ use dojo_world::contracts::model::{ModelRPCReader, ModelReader}; use dojo_world::contracts::WorldContractReader; use starknet::core::types::{BlockId, Event}; use starknet::providers::Provider; -use torii_cache::CacheError; use torii_proto::Model; use tracing::{debug, info}; @@ -60,19 +59,8 @@ where // Called model here by language, but it's an event. Torii rework will make clear // distinction. - // If the model does not exist, silently ignore it. - // This can happen if only specific namespaces are indexed. - let model = match ctx.cache.model(ctx.contract_address, event.selector).await { - Ok(m) => m, - Err(CacheError::ModelNotFound(_)) if !ctx.config.namespaces.is_empty() => { - debug!( - target: LOG_TARGET, - selector = %event.selector, - "Model not found in cache, skipping. This can happen if only specific namespaces are indexed." - ); - return Ok(()); - } - Err(e) => return Err(e.into()), + let Some(model) = ctx.resolve_model_or_skip(event.selector).await? else { + return Ok(()); }; let name = model.name; let namespace = model.namespace; diff --git a/crates/processors/src/processors/upgrade_model.rs b/crates/processors/src/processors/upgrade_model.rs index 1b87b1551..bcc42bf51 100644 --- a/crates/processors/src/processors/upgrade_model.rs +++ b/crates/processors/src/processors/upgrade_model.rs @@ -8,7 +8,6 @@ use dojo_world::contracts::model::{ModelError, ModelRPCReader, ModelReader}; use dojo_world::contracts::WorldContractReader; use starknet::core::types::{BlockId, Event, StarknetError}; use starknet::providers::{Provider, ProviderError}; -use torii_cache::CacheError; use torii_proto::Model; use tracing::{debug, info}; @@ -58,19 +57,8 @@ where } }; - // If the model does not exist, silently ignore it. - // This can happen if only specific namespaces are indexed. - let model = match ctx.cache.model(ctx.contract_address, event.selector).await { - Ok(m) => m, - Err(CacheError::ModelNotFound(_)) if !ctx.config.namespaces.is_empty() => { - debug!( - target: LOG_TARGET, - selector = %event.selector, - "Model not found in cache, skipping. This can happen if only specific namespaces are indexed." - ); - return Ok(()); - } - Err(e) => return Err(e.into()), + let Some(model) = ctx.resolve_model_or_skip(event.selector).await? else { + return Ok(()); }; let name = model.name; diff --git a/crates/processors/src/task_manager.rs b/crates/processors/src/task_manager.rs index 9d67451df..b317b7f22 100644 --- a/crates/processors/src/task_manager.rs +++ b/crates/processors/src/task_manager.rs @@ -101,6 +101,19 @@ impl TaskManager< task_data.events.push(parallelized_event); } } + + if let Err(e) = self + .task_network + .add_dependencies(task_identifier, dependencies.clone()) + { + error!( + target: LOG_TARGET, + error = ?e, + task_id = %task_identifier, + dependencies = ?dependencies, + "Failed to add dependencies to existing task." + ); + } } else { let task_data = match parallelized_event.indexing_mode { IndexingMode::Latest(event_key) => TaskData { diff --git a/crates/sqlite/sqlite/src/storage.rs b/crates/sqlite/sqlite/src/storage.rs index ecf4bf84e..e01df0489 100644 --- a/crates/sqlite/sqlite/src/storage.rs +++ b/crates/sqlite/sqlite/src/storage.rs @@ -54,9 +54,20 @@ impl ReadOnlyStorage for Sql { /// Returns the model metadata for the storage. async fn model(&self, world_address: Felt, selector: Felt) -> Result { + match self.model_optional(world_address, selector).await? { + Some(model) => Ok(model), + None => Err(Box::new(sqlx::Error::RowNotFound)), + } + } + + async fn model_optional( + &self, + world_address: Felt, + selector: Felt, + ) -> Result, StorageError> { if let Some(cache) = &self.cache { if let Ok(model) = cache.model(world_address, selector).await { - return Ok(model); + return Ok(Some(model)); } else { warn!( target: LOG_TARGET, @@ -66,10 +77,14 @@ impl ReadOnlyStorage for Sql { } } - let model = sqlx::query_as::<_, SQLModel>("SELECT * FROM models WHERE id = ?") + let row = sqlx::query_as::<_, SQLModel>("SELECT * FROM models WHERE id = ?") .bind(format_world_scoped_id(&world_address, &selector)) - .fetch_one(&self.pool) + .fetch_optional(&self.pool) .await?; + + let Some(model) = row else { + return Ok(None); + }; let model: torii_proto::Model = model.into(); // Update cache to prevent repeated cache misses @@ -79,7 +94,7 @@ impl ReadOnlyStorage for Sql { .await; } - Ok(model) + Ok(Some(model)) } /// Returns the models for the storage. @@ -2599,3 +2614,252 @@ impl Sql { Ok(unique_models) } } + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::str::FromStr; + use std::sync::Arc; + + use async_trait::async_trait; + use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; + use tempfile::TempDir; + use tokio::sync::mpsc::unbounded_channel; + use torii_cache::{CacheError, InMemoryCache, ReadOnlyCache}; + use torii_proto::{ + Achievement, AchievementQuery, PlayerAchievementEntry, PlayerAchievementQuery, + }; + use torii_storage::utils::format_world_scoped_id; + + use super::*; + use crate::SqlConfig; + + #[derive(Debug)] + struct EmptyStorage; + + #[async_trait] + impl ReadOnlyStorage for EmptyStorage { + fn as_read_only(&self) -> &dyn ReadOnlyStorage { + self + } + + async fn model(&self, _world_address: Felt, _model: Felt) -> Result { + Err(Box::new(sqlx::Error::RowNotFound)) + } + + async fn model_optional( + &self, + _world_address: Felt, + _model: Felt, + ) -> Result, StorageError> { + Ok(None) + } + + async fn models( + &self, + _world_addresses: &[Felt], + _selectors: &[Felt], + ) -> Result, StorageError> { + Ok(vec![]) + } + + async fn token_ids(&self) -> Result, StorageError> { + Ok(HashSet::new()) + } + + async fn controllers( + &self, + _query: &ControllerQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn contracts(&self, _query: &ContractQuery) -> Result, StorageError> { + unimplemented!() + } + + async fn tokens(&self, _query: &TokenQuery) -> Result, StorageError> { + unimplemented!() + } + + async fn token_balances( + &self, + _query: &TokenBalanceQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn token_contracts( + &self, + _query: &TokenContractQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn token_transfers( + &self, + _query: &TokenTransferQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn transactions( + &self, + _query: &TransactionQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn events(&self, _query: EventQuery) -> Result, StorageError> { + unimplemented!() + } + + async fn entities(&self, _query: &Query) -> Result, StorageError> { + unimplemented!() + } + + async fn event_messages(&self, _query: &Query) -> Result, StorageError> { + unimplemented!() + } + + async fn entity_model( + &self, + _world_address: Felt, + _entity_id: Felt, + _model_selector: Felt, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn aggregations( + &self, + _query: &AggregationQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn activities(&self, _query: &ActivityQuery) -> Result, StorageError> { + unimplemented!() + } + + async fn achievements( + &self, + _query: &AchievementQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn player_achievements( + &self, + _query: &PlayerAchievementQuery, + ) -> Result, StorageError> { + unimplemented!() + } + + async fn search(&self, _query: &SearchQuery) -> Result { + unimplemented!() + } + } + + async fn setup_sql() -> (TempDir, sqlx::Pool, Sql) { + let temp_dir = tempfile::tempdir().unwrap(); + let db_path = temp_dir.path().join("storage-tests.db"); + let options = SqliteConnectOptions::from_str(db_path.to_str().unwrap()) + .unwrap() + .create_if_missing(true); + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect_with(options) + .await + .unwrap(); + sqlx::migrate!("../../migrations").run(&pool).await.unwrap(); + + let (executor, _rx) = unbounded_channel(); + let sql = Sql { + pool: pool.clone(), + executor, + config: SqlConfig::default(), + cache: None, + }; + + (temp_dir, pool, sql) + } + + async fn insert_model_row( + pool: &sqlx::Pool, + world_address: Felt, + selector: Felt, + ) { + sqlx::query( + "INSERT INTO models (id, world_address, model_selector, namespace, name, class_hash, contract_address, layout, legacy_store, schema, packed_size, unpacked_size, executed_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ) + .bind(format_world_scoped_id(&world_address, &selector)) + .bind(felt_to_sql_string(&world_address)) + .bind(felt_to_sql_string(&selector)) + .bind("ns") + .bind("Model") + .bind(felt_to_sql_string(&Felt::from(0x123u64))) + .bind(felt_to_sql_string(&Felt::from(0x456u64))) + .bind(serde_json::to_string(&Layout::Fixed(vec![])).unwrap()) + .bind(true) + .bind(serde_json::to_string(&Ty::Tuple(vec![])).unwrap()) + .bind(0_i64) + .bind(0_i64) + .bind("2026-05-01T00:00:00Z") + .execute(pool) + .await + .unwrap(); + } + + #[tokio::test] + async fn model_optional_returns_none_when_model_is_missing() { + let (_temp_dir, _pool, sql) = setup_sql().await; + let world = Felt::from(0xa_u8); + let selector = Felt::from(0xb_u8); + + assert!(sql.model_optional(world, selector).await.unwrap().is_none()); + + let err = sql.model(world, selector).await.unwrap_err(); + assert!(matches!( + err.downcast_ref::(), + Some(sqlx::Error::RowNotFound) + )); + } + + #[tokio::test] + async fn model_optional_repopulates_cache_after_database_fallback() { + let (_temp_dir, pool, sql_no_cache) = setup_sql().await; + let cache = Arc::new(InMemoryCache::new(Arc::new(EmptyStorage)).await.unwrap()); + let world = Felt::from(0xc_u8); + let selector = Felt::from(0xd_u8); + + insert_model_row(&pool, world, selector).await; + + assert!(matches!( + cache.model(world, selector).await, + Err(CacheError::ModelNotFound(s)) if s == selector + )); + + let sql = sql_no_cache.with_cache(cache.clone()); + let fetched = sql.model_optional(world, selector).await.unwrap().unwrap(); + assert_eq!(fetched.world_address, world); + assert_eq!(fetched.selector, selector); + assert_eq!(fetched.namespace, "ns"); + assert_eq!(fetched.name, "Model"); + + let cached = cache.model(world, selector).await.unwrap(); + assert_eq!(cached.selector, selector); + assert_eq!(cached.name, "Model"); + + sqlx::query("DELETE FROM models WHERE id = ?") + .bind(format_world_scoped_id(&world, &selector)) + .execute(&pool) + .await + .unwrap(); + + let cached_after_delete = sql.model_optional(world, selector).await.unwrap().unwrap(); + assert_eq!(cached_after_delete.selector, selector); + assert_eq!(cached_after_delete.name, "Model"); + } +} diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index 3f892ee53..867920cdc 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -34,6 +34,18 @@ pub trait ReadOnlyStorage: Send + Sync + Debug { /// Returns the model metadata for the storage. async fn model(&self, world_address: Felt, model: Felt) -> Result; + /// Returns the model metadata if it is registered in the storage, or `Ok(None)` + /// when no row exists for `(world_address, selector)`. + /// + /// Only "no such row" returns `Ok(None)` — any other failure (DB error, decode error, + /// connection issue, etc.) propagates as `Err`. Callers that need to distinguish + /// "model is unknown" from "lookup failed" should prefer this method over [`model`]. + async fn model_optional( + &self, + world_address: Felt, + model: Felt, + ) -> Result, StorageError>; + /// Returns the models for the storage. /// If world_addresses is empty, returns models from all worlds. /// If selectors is empty, returns all models from the specified worlds. diff --git a/crates/task-network/src/lib.rs b/crates/task-network/src/lib.rs index dd5da232b..2158d8eff 100644 --- a/crates/task-network/src/lib.rs +++ b/crates/task-network/src/lib.rs @@ -4,6 +4,7 @@ pub use error::TaskNetworkError; pub type Result = std::result::Result; +use std::collections::{HashMap, HashSet}; use std::future::Future; use std::hash::Hash; use std::sync::Arc; @@ -22,6 +23,7 @@ where T: Clone + Send + Sync + 'static, { tasks: AcyclicDigraphMap, + pending_dependents: HashMap>, semaphore: Arc, } @@ -33,6 +35,7 @@ where pub fn new(max_concurrent_tasks: usize) -> Self { Self { tasks: AcyclicDigraphMap::new(), + pending_dependents: HashMap::new(), semaphore: Arc::new(Semaphore::new(max_concurrent_tasks)), } } @@ -52,17 +55,15 @@ where .add_node(task_id.clone(), task) .map_err(TaskNetworkError::GraphError)?; + self.resolve_pending_dependents(&task_id)?; + self.add_dependencies(task_id, dependencies)?; + + Ok(()) + } + + pub fn add_dependencies(&mut self, task_id: K, dependencies: Vec) -> Result<()> { for dep in dependencies { - if self.tasks.contains_key(&dep) { - self.add_dependency(dep, task_id.clone())?; - } else { - debug!( - target: LOG_TARGET, - task_id = ?task_id, - dependency = ?dep, - "Ignoring non-existent dependency." - ); - } + self.add_dependency_or_defer(dep, task_id.clone())?; } Ok(()) @@ -74,6 +75,35 @@ where .map_err(TaskNetworkError::GraphError) } + fn add_dependency_or_defer(&mut self, from: K, to: K) -> Result<()> { + if self.tasks.contains_key(&from) { + self.add_dependency(from, to) + } else { + debug!( + target: LOG_TARGET, + task_id = ?to, + dependency = ?from, + "Deferring dependency until prerequisite task exists." + ); + self.pending_dependents.entry(from).or_default().insert(to); + Ok(()) + } + } + + fn resolve_pending_dependents(&mut self, task_id: &K) -> Result<()> { + let Some(dependents) = self.pending_dependents.remove(task_id) else { + return Ok(()); + }; + + for dependent in dependents { + if self.tasks.contains_key(&dependent) { + self.add_dependency(task_id.clone(), dependent)?; + } + } + + Ok(()) + } + pub async fn process_tasks(&mut self, task_handler: F) -> Result<()> where F: Fn(K, T) -> Fut + Clone + Send + Sync + 'static, @@ -142,6 +172,7 @@ where } self.tasks.clear(); + self.pending_dependents.clear(); Ok(()) } @@ -160,6 +191,7 @@ where pub fn clear(&mut self) { self.tasks.clear(); + self.pending_dependents.clear(); } } @@ -232,6 +264,61 @@ mod tests { assert_eq!(result[2], 3); } + #[tokio::test] + async fn test_late_dependency_becomes_active() { + let mut manager = TaskNetwork::::new(4); + + manager + .add_task_with_dependencies(1, "Task 1".to_string(), vec![99]) + .unwrap(); + manager.add_task(99, "Task 99".to_string()).unwrap(); + + let executed = Arc::new(tokio::sync::Mutex::new(Vec::new())); + + let executed_clone = executed.clone(); + manager + .process_tasks(move |id, _task| { + let executed = executed_clone.clone(); + async move { + let mut locked = executed.lock().await; + locked.push(id); + Ok::<_, std::io::Error>(()) + } + }) + .await + .unwrap(); + + let result = executed.lock().await; + assert_eq!(&*result, &[99, 1]); + } + + #[tokio::test] + async fn test_add_dependencies_to_existing_task() { + let mut manager = TaskNetwork::::new(4); + + manager.add_task(1, "Task 1".to_string()).unwrap(); + manager.add_task(2, "Task 2".to_string()).unwrap(); + manager.add_dependencies(1, vec![2]).unwrap(); + + let executed = Arc::new(tokio::sync::Mutex::new(Vec::new())); + + let executed_clone = executed.clone(); + manager + .process_tasks(move |id, _task| { + let executed = executed_clone.clone(); + async move { + let mut locked = executed.lock().await; + locked.push(id); + Ok::<_, std::io::Error>(()) + } + }) + .await + .unwrap(); + + let result = executed.lock().await; + assert_eq!(&*result, &[2, 1]); + } + #[tokio::test] async fn test_dependency_ordering() { let mut manager = TaskNetwork::::new(4);