Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions crates/lib/src/usage_limit/usage_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Mutex};

use async_trait::async_trait;
use deadpool_redis::{Connection, Pool};
use once_cell::sync::Lazy;
use redis::AsyncCommands;

use crate::{error::KoraError, sanitize_error};
Expand All @@ -18,10 +19,37 @@ pub trait UsageStore: Send + Sync {
/// Get current usage count for a key (returns 0 if not found)
async fn get(&self, key: &str) -> Result<u32, KoraError>;

/// Atomic check and increment: check if (current + delta) <= max, and increment if so.
/// Returns true if allowed and incremented, false if denied.
async fn check_and_increment(
&self,
key: &str,
delta: u64,
max: u64,
expiry: Option<u64>,
) -> Result<bool, KoraError>;

/// Clear all usage data (mainly for testing)
async fn clear(&self) -> Result<(), KoraError>;
}

/// Atomically checks the limit and increments — sets TTL only on first
/// increment. ARGV[3] = 0 is the sentinel for "no expiry".
static CHECK_AND_INCREMENT_SCRIPT: Lazy<redis::Script> = Lazy::new(|| {
redis::Script::new(
r"
local current = redis.call('GET', KEYS[1])
local count = current and tonumber(current) or 0
if count + tonumber(ARGV[1]) > tonumber(ARGV[2]) then return 0 end
redis.call('INCRBY', KEYS[1], ARGV[1])
if ARGV[3] ~= '0' and redis.call('TTL', KEYS[1]) < 0 then
Comment thread
raushan728 marked this conversation as resolved.
redis.call('EXPIREAT', KEYS[1], ARGV[3])
end
return 1
",
)
});

/// Redis-based implementation for production
pub struct RedisUsageStore {
pool: Pool,
Expand Down Expand Up @@ -90,6 +118,32 @@ impl UsageStore for RedisUsageStore {
Ok(count.unwrap_or(0))
}

async fn check_and_increment(
&self,
key: &str,
delta: u64,
max: u64,
expiry: Option<u64>,
) -> Result<bool, KoraError> {
let mut conn = self.get_connection().await?;

let allowed: i32 = CHECK_AND_INCREMENT_SCRIPT
.key(key)
.arg(delta)
.arg(max)
.arg(expiry.unwrap_or(0))
.invoke_async(&mut conn)
.await
.map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to execute check_and_increment script: {}",
e
)))
})?;

Ok(allowed == 1)
}

async fn clear(&self) -> Result<(), KoraError> {
let mut conn = self.get_connection().await?;
let _: () = conn.flushdb().await.map_err(|e| {
Expand Down Expand Up @@ -189,6 +243,44 @@ impl UsageStore for InMemoryUsageStore {
}
}

async fn check_and_increment(
&self,
key: &str,
delta: u64,
max: u64,
expiry: Option<u64>,
) -> Result<bool, KoraError> {
let mut data = self.data.lock().map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to lock usage store: {}",
e
)))
})?;

let now = Self::current_timestamp();
let entry = data.entry(key.to_string()).or_insert(UsageEntry { count: 0, expiry: None });

if let Some(e) = entry.expiry {
if now >= e {
entry.count = 0;
entry.expiry = None;
}
}

let new_count = entry.count as u64 + delta;
if new_count > max || new_count > u32::MAX as u64 {
return Ok(false);
}
entry.count = new_count as u32;
if let Some(e) = expiry {
if entry.expiry.is_none() {
entry.expiry = Some(e);
}
}

Ok(true)
}

async fn clear(&self) -> Result<(), KoraError> {
let mut data = self.data.lock().map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
Expand Down Expand Up @@ -242,6 +334,20 @@ impl UsageStore for ErrorUsageStore {
}
}

async fn check_and_increment(
&self,
_key: &str,
_delta: u64,
_max: u64,
_expiry: Option<u64>,
) -> Result<bool, KoraError> {
if self.should_error_increment {
Err(KoraError::InternalServerError("Redis connection failed".to_string()))
} else {
Ok(true)
}
}

async fn clear(&self) -> Result<(), KoraError> {
Ok(())
}
Expand Down
Loading
Loading