diff --git a/crates/lib/src/usage_limit/usage_store.rs b/crates/lib/src/usage_limit/usage_store.rs index 4b18d3bc9..ccb6b199d 100644 --- a/crates/lib/src/usage_limit/usage_store.rs +++ b/crates/lib/src/usage_limit/usage_store.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Mutex}; use async_trait::async_trait; use deadpool_redis::{Connection, Pool}; +use once_cell::sync::Lazy; use redis::AsyncCommands; use crate::{error::KoraError, sanitize_error}; @@ -18,10 +19,37 @@ pub trait UsageStore: Send + Sync { /// Get current usage count for a key (returns 0 if not found) async fn get(&self, key: &str) -> Result; + /// Atomic check and increment: check if (current + delta) <= max, and increment if so. + /// Returns true if allowed and incremented, false if denied. + async fn check_and_increment( + &self, + key: &str, + delta: u64, + max: u64, + expiry: Option, + ) -> Result; + /// Clear all usage data (mainly for testing) async fn clear(&self) -> Result<(), KoraError>; } +/// Atomically checks the limit and increments — sets TTL only on first +/// increment. ARGV[3] = 0 is the sentinel for "no expiry". +static CHECK_AND_INCREMENT_SCRIPT: Lazy = Lazy::new(|| { + redis::Script::new( + r" + local current = redis.call('GET', KEYS[1]) + local count = current and tonumber(current) or 0 + if count + tonumber(ARGV[1]) > tonumber(ARGV[2]) then return 0 end + redis.call('INCRBY', KEYS[1], ARGV[1]) + if ARGV[3] ~= '0' and redis.call('TTL', KEYS[1]) < 0 then + redis.call('EXPIREAT', KEYS[1], ARGV[3]) + end + return 1 + ", + ) +}); + /// Redis-based implementation for production pub struct RedisUsageStore { pool: Pool, @@ -90,6 +118,32 @@ impl UsageStore for RedisUsageStore { Ok(count.unwrap_or(0)) } + async fn check_and_increment( + &self, + key: &str, + delta: u64, + max: u64, + expiry: Option, + ) -> Result { + let mut conn = self.get_connection().await?; + + let allowed: i32 = CHECK_AND_INCREMENT_SCRIPT + .key(key) + .arg(delta) + .arg(max) + .arg(expiry.unwrap_or(0)) + .invoke_async(&mut conn) + .await + .map_err(|e| { + KoraError::InternalServerError(sanitize_error!(format!( + "Failed to execute check_and_increment script: {}", + e + ))) + })?; + + Ok(allowed == 1) + } + async fn clear(&self) -> Result<(), KoraError> { let mut conn = self.get_connection().await?; let _: () = conn.flushdb().await.map_err(|e| { @@ -189,6 +243,44 @@ impl UsageStore for InMemoryUsageStore { } } + async fn check_and_increment( + &self, + key: &str, + delta: u64, + max: u64, + expiry: Option, + ) -> Result { + let mut data = self.data.lock().map_err(|e| { + KoraError::InternalServerError(sanitize_error!(format!( + "Failed to lock usage store: {}", + e + ))) + })?; + + let now = Self::current_timestamp(); + let entry = data.entry(key.to_string()).or_insert(UsageEntry { count: 0, expiry: None }); + + if let Some(e) = entry.expiry { + if now >= e { + entry.count = 0; + entry.expiry = None; + } + } + + let new_count = entry.count as u64 + delta; + if new_count > max || new_count > u32::MAX as u64 { + return Ok(false); + } + entry.count = new_count as u32; + if let Some(e) = expiry { + if entry.expiry.is_none() { + entry.expiry = Some(e); + } + } + + Ok(true) + } + async fn clear(&self) -> Result<(), KoraError> { let mut data = self.data.lock().map_err(|e| { KoraError::InternalServerError(sanitize_error!(format!( @@ -242,6 +334,20 @@ impl UsageStore for ErrorUsageStore { } } + async fn check_and_increment( + &self, + _key: &str, + _delta: u64, + _max: u64, + _expiry: Option, + ) -> Result { + if self.should_error_increment { + Err(KoraError::InternalServerError("Redis connection failed".to_string())) + } else { + Ok(true) + } + } + async fn clear(&self) -> Result<(), KoraError> { Ok(()) } diff --git a/crates/lib/src/usage_limit/usage_tracker.rs b/crates/lib/src/usage_limit/usage_tracker.rs index 3ac74babe..108f79b7d 100644 --- a/crates/lib/src/usage_limit/usage_tracker.rs +++ b/crates/lib/src/usage_limit/usage_tracker.rs @@ -241,8 +241,9 @@ impl UsageTracker { SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).map(|d| d.as_secs()).unwrap_or(0) } - /// Check and record usage for a transaction - /// Uses two-phase commit: check all rules first, then increment only if all pass + /// Check and record usage for a transaction. + /// Pre-checks all rules before incrementing; denied requests may still + /// consume quota from earlier rules under concurrent load. async fn check_and_record( &self, ctx: &mut LimiterContext<'_>, @@ -265,74 +266,59 @@ impl UsageTracker { Vec::new() }; - // Build HashSet for O(1) lookup instead of Vec::contains O(n) let ix_idx_set: HashSet = self.instruction_rule_indices.iter().copied().collect(); - // Phase 1: Check all rules first (no incrementing yet) - // Collect rule checks: (key, increment_count, window_seconds) - let mut pending_increments: Vec<(String, u64, Option)> = Vec::new(); + let mut rule_increments = Vec::with_capacity(self.rules.len()); let mut instruction_count_idx = 0; - for (idx, rule) in self.rules.iter().enumerate() { + for (idx, _rule) in self.rules.iter().enumerate() { let increment_count = if ix_idx_set.contains(&idx) { - // Use pre-computed count for instruction rule let count = instruction_counts[instruction_count_idx]; instruction_count_idx += 1; count } else { - // Transaction rules always increment by 1 1 }; + rule_increments.push(increment_count); + } + let mut pending_increments = Vec::with_capacity(self.rules.len()); + + for (idx, rule) in self.rules.iter().enumerate() { + let increment_count = rule_increments[idx]; if increment_count == 0 { continue; } let key = rule.storage_key(&ctx.user_id, ctx.timestamp); + let expiry = + rule.window_seconds().filter(|&w| w > 0).map(|w| (ctx.timestamp / w + 1) * w); + let max = rule.max(); + let description = rule.description(); + // fast path: skip check_and_increment if already over limit let current = self.store.get(&key).await?; - let new_count = current as u64 + increment_count; - if new_count > rule.max() { + if current as u64 + increment_count > max { return Ok(LimiterResult::Denied { reason: format!( "User {} exceeded {} limit: {}/{}", ctx.user_id, - rule.description(), - new_count, - rule.max() + description, + current as u64 + increment_count, + max ), }); } - // Queue for increment (don't increment yet) - pending_increments.push((key, increment_count, rule.window_seconds())); - - log::debug!( - "[rule] User {} {}: {}/{} ({})", - ctx.user_id, - rule.description(), - new_count, - rule.max(), - rule.window_seconds().map_or("lifetime".to_string(), |w| format!("{}s window", w)) - ); + pending_increments.push((key, increment_count, max, expiry, description)); } - for (key, increment_count, window_seconds) in pending_increments { - if let Some(window) = window_seconds.filter(|&w| w > 0) { - // Calculate bucket boundary: key expires at end of current bucket - // bucket = timestamp / window, so bucket_end = (bucket + 1) * window - let expires_at = (ctx.timestamp / window + 1) * window; - // First increment with expiry - self.store.increment_with_expiry(&key, expires_at).await?; - // Subsequent increments without resetting expiry - for _ in 1..increment_count { - self.store.increment(&key).await?; - } - } else { - for _ in 0..increment_count { - self.store.increment(&key).await?; - } + for (key, increment_count, max, expiry, description) in pending_increments { + if !self.store.check_and_increment(&key, increment_count, max, expiry).await? { + return Ok(LimiterResult::Denied { + reason: format!("User {} exceeded {} limit", ctx.user_id, description), + }); } } @@ -497,10 +483,58 @@ mod tests { transaction::TransactionUtil, usage_limit::{InMemoryUsageStore, UsageLimitConfig, UsageLimitRuleConfig}, }; + use async_trait::async_trait; use solana_message::{Message, VersionedMessage}; use solana_sdk::{account::Account, signature::Keypair, signer::Signer}; use std::sync::Arc; + struct ConcurrentMockStore { + inner: InMemoryUsageStore, + } + + impl ConcurrentMockStore { + fn new() -> Self { + Self { inner: InMemoryUsageStore::new() } + } + } + + #[async_trait] + impl UsageStore for ConcurrentMockStore { + async fn increment(&self, key: &str) -> Result { + tokio::task::yield_now().await; + self.inner.increment(key).await + } + + async fn increment_with_expiry( + &self, + key: &str, + expires_at: u64, + ) -> Result { + tokio::task::yield_now().await; + self.inner.increment_with_expiry(key, expires_at).await + } + + async fn get(&self, key: &str) -> Result { + tokio::task::yield_now().await; + self.inner.get(key).await + } + + async fn clear(&self) -> Result<(), KoraError> { + self.inner.clear().await + } + + async fn check_and_increment( + &self, + key: &str, + delta: u64, + max: u64, + expiry: Option, + ) -> Result { + tokio::task::yield_now().await; + self.inner.check_and_increment(key, delta, max, expiry).await + } + } + fn create_test_tracker(max_transactions: u64) -> UsageTracker { let store = Arc::new(InMemoryUsageStore::new()); let config = UsageLimitConfig { @@ -1028,4 +1062,93 @@ mod tests { assert_eq!(result, None); } + + #[tokio::test] + async fn test_concurrent_requests_enforce_limit() { + let max = 5; + let store = Arc::new(ConcurrentMockStore::new()); + let config = UsageLimitConfig { + enabled: true, + cache_url: None, + fallback_if_unavailable: false, + rules: vec![UsageLimitRuleConfig::Transaction { max, window_seconds: None }], + }; + let rules = config.build_rules().unwrap(); + let tracker = Arc::new(UsageTracker::new(true, store, rules, HashSet::new(), false)); + let user_id = "concurrent-user".to_string(); + let mut handles = Vec::new(); + + for _ in 0..10 { + let tracker = tracker.clone(); + let user_id = user_id.clone(); + handles.push(tokio::spawn(async move { + let mut tx = create_mock_resolved_transaction(); + let mut ctx = LimiterContext { + transaction: &mut tx, + user_id, + kora_signer: None, + timestamp: 1000000, + }; + tracker.check_and_record(&mut ctx).await.unwrap() + })); + } + + let mut allowed_count = 0; + for handle in handles { + if matches!(handle.await.unwrap(), LimiterResult::Allowed) { + allowed_count += 1; + } + } + + assert_eq!(allowed_count, max as usize); + } + + #[tokio::test] + async fn test_multi_rule_concurrent_enforces_bottleneck_rule() { + let store = Arc::new(ConcurrentMockStore::new()); + let config = UsageLimitConfig { + enabled: true, + cache_url: None, + fallback_if_unavailable: false, + rules: vec![ + UsageLimitRuleConfig::Transaction { max: 10, window_seconds: None }, + UsageLimitRuleConfig::Transaction { max: 2, window_seconds: Some(60) }, + ], + }; + let rules = config.build_rules().unwrap(); + let tracker = + Arc::new(UsageTracker::new(true, store.clone(), rules, HashSet::new(), false)); + let user_id = "multi-rule-concurrent-user".to_string(); + let mut handles = Vec::new(); + + for _ in 0..10 { + let tracker = tracker.clone(); + let user_id = user_id.clone(); + handles.push(tokio::spawn(async move { + let mut tx = create_mock_resolved_transaction(); + let mut ctx = LimiterContext { + transaction: &mut tx, + user_id, + kora_signer: None, + timestamp: UsageTracker::current_timestamp(), + }; + tracker.check_and_record(&mut ctx).await.unwrap() + })); + } + + let mut allowed_count = 0; + for handle in handles { + if matches!(handle.await.unwrap(), LimiterResult::Allowed) { + allowed_count += 1; + } + } + + assert_eq!(allowed_count, 2); + + let lifetime_key = "kora:tx:multi-rule-concurrent-user"; + let lifetime_count = store.get(lifetime_key).await.unwrap(); + // denied requests may have consumed quota from earlier rules + assert!(lifetime_count <= 10); + assert!(lifetime_count >= allowed_count as u32); + } }