diff --git a/Cargo.lock b/Cargo.lock index 66610c7..1b41e0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,6 +41,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + [[package]] name = "anyhow" version = "1.0.95" @@ -214,6 +220,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "either" version = "1.10.0" @@ -240,6 +252,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8878864ba14bb86e818a412bfd6f18f9eabd4ec0f008a28e8f7eb61db532fcf9" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + [[package]] name = "gimli" version = "0.29.0" @@ -411,6 +438,32 @@ dependencies = [ "adler", ] +[[package]] +name = "mockall" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f58d964098a5f9c6b63d0798e5372fd04708193510a7af313c22e9f29b7b620b" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca41ce716dda6a9be188b385aa78ee5260fc25cd3802cb2a8afdc6afbe6b6dbf" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote 1.0.36", + "syn 2.0.58", +] + [[package]] name = "nix" version = "0.26.4" @@ -517,6 +570,32 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "predicates" +version = "3.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ada8f2932f28a27ee7b70dd6c1c39ea0675c55a36879ab92f3a715eaa1e63cfe" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cad38746f3166b4031b1a0d39ad9f954dd291e7854fcc0eed52ee41a0b50d144" + +[[package]] +name = "predicates-tree" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0de1b847b39c8131db0467e9df1ff60e6d0562ab8e9a16e568ad0fdb372e2f2" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "prettyplease" version = "0.2.17" @@ -751,6 +830,12 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "tinyvec" version = "1.6.0" @@ -822,6 +907,7 @@ dependencies = [ "libc", "linkme", "log", + "mockall", "nix", "num-traits 0.2.19", "paste", diff --git a/Cargo.toml b/Cargo.toml index ada6534..d7529fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -173,16 +173,18 @@ cfg-if = "1" valkey-module-macros-internals = { path = "valkeymodule-rs-macros-internals", version = "0.1.4"} log = "0.4" paste = "1.0.15" +mockall = { version = "0.14.0", optional = true } [dev-dependencies] anyhow = "1" redis = "0.28" lazy_static = "1" valkey-module-macros = { path = "valkeymodule-rs-macros", version = "0.1.4" } -valkey-module = { path = "./", default-features = false, features = ["min-valkey-compatibility-version-8-0", "min-redis-compatibility-version-7-2"] } +valkey-module = { path = "./", default-features = false, features = ["min-valkey-compatibility-version-8-0", "min-redis-compatibility-version-7-2", "test-mocks", "enable-system-alloc"] } cron = "0.15.0" chrono = "0.4.41" dashmap = "6.1.0" +mockall = { version = "0.14.0"} [build-dependencies] bindgen = "0.70" @@ -199,3 +201,5 @@ min-redis-compatibility-version-6-0 = [] enable-system-alloc = [] # this is to indicate the Module wants to use RedisModule APIs for calls use-redismodule-api = [] +# exposes MockContext (generated via mockall) for downstream unit tests +test-mocks = ["dep:mockall"] diff --git a/README.md b/README.md index 84c61d8..a5abd2b 100644 --- a/README.md +++ b/README.md @@ -61,3 +61,42 @@ default = [] ``` cargo build --release --features use-redismodule-api ``` + +3. Mock contexts for unit tests + +`Context`, `CommandFilterCtx`, and `InfoContext` are thin wrappers around raw pointers that Valkey hands to the module at runtime. To unit-test module logic without a live Valkey server, `valkey-module` exposes three trait abstractions — `ContextTrait`, `CommandFilterCtxTrait`, `InfoContextTrait` — each implemented for the concrete wrapper, plus `mockall`-generated mocks behind the `test-mocks` feature. + +The traits are always available; only the `Mock*` types require the feature. Add `valkey-module` as a `dev-dependency` with `test-mocks` enabled: + +```toml +[dev-dependencies] +valkey-module = { version = "...", features = ["test-mocks"] } +mockall = "0.14" +``` + +Write your command / filter / info handler against the trait (`&impl ContextTrait`) instead of the concrete `&Context`. Monomorphization still produces a `fn(&Context, ...)` for the `valkey_module!` macro to register. + +```rust +use valkey_module::{ContextTrait, ValkeyResult, ValkeyString, ValkeyValue}; + +fn get_client_id(ctx: &impl ContextTrait, _args: Vec) -> ValkeyResult { + Ok((ctx.get_client_id() as i64).into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use valkey_module::MockContext; + + #[test] + fn returns_client_id_from_context() { + let mut ctx = MockContext::new(); + ctx.expect_get_client_id().times(1).returning(|| 42); + + let reply = get_client_id(&ctx, vec![ValkeyString::create_for_test("")]).unwrap(); + assert_eq!(reply, ValkeyValue::Integer(42)); + } +} +``` + +`MockCommandFilterCtx` and `MockInfoContext` follow the same pattern. See `examples/client.rs`, `examples/preload.rs`, `examples/server_events.rs`, and `examples/info_handler_struct.rs` for full working tests. diff --git a/examples/client.rs b/examples/client.rs index 0106341..3660472 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,14 +1,15 @@ use valkey_module::alloc::ValkeyAlloc; use valkey_module::{ - valkey_module, Context, NextArg, Status, ValkeyError, ValkeyResult, ValkeyString, ValkeyValue, + valkey_module, ContextTrait, NextArg, Status, ValkeyError, ValkeyResult, ValkeyString, + ValkeyValue, }; -fn get_client_id(ctx: &Context, _args: Vec) -> ValkeyResult { +fn get_client_id(ctx: &impl ContextTrait, _args: Vec) -> ValkeyResult { let client_id = ctx.get_client_id(); Ok((client_id as i64).into()) } -fn get_client_name(ctx: &Context, _args: Vec) -> ValkeyResult { +fn get_client_name(ctx: &impl ContextTrait, _args: Vec) -> ValkeyResult { // test for invalid client_id match ctx.get_client_name_by_id(0) { Ok(tmp) => ctx.log_notice(&format!( @@ -21,7 +22,7 @@ fn get_client_name(ctx: &Context, _args: Vec) -> ValkeyResult { Ok(ValkeyValue::from(client_name.to_string())) } -fn get_client_username(ctx: &Context, _args: Vec) -> ValkeyResult { +fn get_client_username(ctx: &impl ContextTrait, _args: Vec) -> ValkeyResult { // test for invalid client_id match ctx.get_client_username_by_id(0) { Ok(tmp) => ctx.log_notice(&format!( @@ -34,7 +35,7 @@ fn get_client_username(ctx: &Context, _args: Vec) -> ValkeyResult Ok(ValkeyValue::from(client_username.to_string())) } -fn set_client_name(ctx: &Context, args: Vec) -> ValkeyResult { +fn set_client_name(ctx: &impl ContextTrait, args: Vec) -> ValkeyResult { if args.len() != 2 { return Err(ValkeyError::WrongArity); } @@ -47,7 +48,7 @@ fn set_client_name(ctx: &Context, args: Vec) -> ValkeyResult { Ok(ValkeyValue::Integer(resp2 as i64)) } -fn get_client_cert(ctx: &Context, _args: Vec) -> ValkeyResult { +fn get_client_cert(ctx: &impl ContextTrait, _args: Vec) -> ValkeyResult { // unless connection is made with cert, this will return Err, so just log it and return nothing match ctx.get_client_cert() { Ok(tmp) => ctx.log_notice(&format!("client_cert: {:?}", tmp.to_string())), @@ -56,7 +57,7 @@ fn get_client_cert(ctx: &Context, _args: Vec) -> ValkeyResult { Ok("".into()) } -fn get_client_info(ctx: &Context, _args: Vec) -> ValkeyResult { +fn get_client_info(ctx: &impl ContextTrait, _args: Vec) -> ValkeyResult { // test for invalid client_id let client_info_by_id = ctx.get_client_info_by_id(0); ctx.log_notice(&format!( @@ -69,7 +70,7 @@ fn get_client_info(ctx: &Context, _args: Vec) -> ValkeyResult { Ok(ValkeyValue::from(client_info.version.to_string())) } -fn get_client_ip(ctx: &Context, _args: Vec) -> ValkeyResult { +fn get_client_ip(ctx: &impl ContextTrait, _args: Vec) -> ValkeyResult { // test for invalid client_id let client_ip_by_id = ctx.get_client_ip_by_id(0); ctx.log_notice(&format!( @@ -79,7 +80,7 @@ fn get_client_ip(ctx: &Context, _args: Vec) -> ValkeyResult { Ok(ctx.get_client_ip()?.into()) } -fn deauth_client_by_id(ctx: &Context, args: Vec) -> ValkeyResult { +fn deauth_client_by_id(ctx: &impl ContextTrait, args: Vec) -> ValkeyResult { if args.len() != 2 { return Err(ValkeyError::WrongArity); } @@ -95,7 +96,7 @@ fn deauth_client_by_id(ctx: &Context, args: Vec) -> ValkeyResult { } } -fn config_get(ctx: &Context, args: Vec) -> ValkeyResult { +fn config_get(ctx: &impl ContextTrait, args: Vec) -> ValkeyResult { if args.len() != 2 { return Err(ValkeyError::WrongArity); } @@ -125,3 +126,160 @@ valkey_module! { ["client.config_get", config_get, "", 0, 0, 0] ] } + +#[cfg(test)] +mod tests { + use super::*; + use valkey_module::{MockContext, RedisModuleClientInfo}; + + fn module_client_info(id: u64) -> RedisModuleClientInfo { + RedisModuleClientInfo { + version: 1, + flags: 0, + id, + addr: [0; 46], + port: 6379, + db: 0, + } + } + + #[test] + fn test_get_client_id() { + let mut ctx = MockContext::new(); + ctx.expect_get_client_id().times(1).returning(|| 42); + let reply = get_client_id(&ctx, vec![]).unwrap(); + assert_eq!(reply, ValkeyValue::Integer(42)); + } + + #[test] + fn test_get_client_name() { + let mut ctx = MockContext::new(); + ctx.expect_log_notice().times(1).return_const(()); + ctx.expect_get_client_name_by_id() + .times(1) + .returning(|_| Err(ValkeyError::Str("no such client"))); + ctx.expect_get_client_name() + .times(1) + .returning(|| Err(ValkeyError::Str("no name"))); + + let err = get_client_name(&ctx, vec![]).unwrap_err(); + assert!(matches!(err, ValkeyError::Str("no name"))); + } + + #[test] + fn test_get_client_username() { + let mut ctx = MockContext::new(); + ctx.expect_log_notice().times(1).return_const(()); + ctx.expect_get_client_username_by_id() + .times(1) + .returning(|_| Err(ValkeyError::Str("no such client"))); + ctx.expect_get_client_username() + .times(1) + .returning(|| Err(ValkeyError::Str("no user"))); + + let err = get_client_username(&ctx, vec![]).unwrap_err(); + assert!(matches!(err, ValkeyError::Str("no user"))); + } + + #[test] + fn test_set_client_name_wrong_arity() { + let ctx = MockContext::new(); + let err = set_client_name(&ctx, vec![ValkeyString::create_for_test("client.set_name")]) + .unwrap_err(); + assert!(matches!(err, ValkeyError::WrongArity)); + } + + #[test] + fn test_set_client_name() { + let mut ctx = MockContext::new(); + ctx.expect_log_notice().times(1).return_const(()); + ctx.expect_set_client_name_by_id() + .times(1) + .returning(|_, client_name| { + assert_eq!(client_name.try_as_str().unwrap(), "my-client"); + Status::Ok + }); + ctx.expect_set_client_name() + .times(1) + .returning(|client_name| { + assert_eq!(client_name.try_as_str().unwrap(), "my-client"); + Status::Ok + }); + + let reply = set_client_name( + &ctx, + vec![ + ValkeyString::create_for_test("client.set_name"), + ValkeyString::create_for_test("my-client"), + ], + ) + .unwrap(); + assert_eq!(reply, ValkeyValue::Integer(Status::Ok as i64)); + } + + #[test] + fn test_get_client_cert() { + let mut ctx = MockContext::new(); + ctx.expect_log_notice().times(1).return_const(()); + ctx.expect_get_client_cert() + .times(1) + .returning(|| Err(ValkeyError::Str("no cert"))); + let reply = get_client_cert(&ctx, vec![]).unwrap(); + assert_eq!(reply, ValkeyValue::BulkString(String::new())); + } + + #[test] + fn test_get_client_info() { + let mut ctx = MockContext::new(); + ctx.expect_log_notice().times(2).return_const(()); + ctx.expect_get_client_info_by_id() + .times(1) + .returning(|_| Ok(module_client_info(0))); + ctx.expect_get_client_info() + .times(1) + .returning(|| Ok(module_client_info(42))); + + let reply = get_client_info(&ctx, vec![]).unwrap(); + assert_eq!(reply, ValkeyValue::BulkString("1".to_string())); + } + + #[test] + fn test_get_client_ip() { + let mut ctx = MockContext::new(); + ctx.expect_get_client_ip_by_id() + .times(1) + .returning(|_| Ok("0.0.0.0".to_string())); + ctx.expect_log_notice().times(1).return_const(()); + ctx.expect_get_client_ip() + .times(1) + .returning(|| Ok("127.0.0.1".to_string())); + let reply = get_client_ip(&ctx, vec![]).unwrap(); + assert_eq!(reply, ValkeyValue::BulkString("127.0.0.1".to_string())); + } + + #[test] + fn test_deauth_client_by_id_wrong_arity() { + let ctx = MockContext::new(); + let err = deauth_client_by_id(&ctx, vec![ValkeyString::create_for_test("")]).unwrap_err(); + assert!(matches!(err, ValkeyError::WrongArity)); + } + + #[test] + #[ignore] + fn test_death_client_by_id() { + // TODO - test when ValkeyString is unit testable + } + + #[test] + fn test_config_get_wrong_arity() { + let ctx = MockContext::new(); + let err = config_get(&ctx, vec![ValkeyString::create_for_test("")]).unwrap_err(); + assert!(matches!(err, ValkeyError::WrongArity)); + } + + #[test] + #[ignore] + fn test_config_get() { + // TODO - test when ValkeyString is unit testable + } +} diff --git a/examples/crontab.rs b/examples/crontab.rs index 30dee40..1fa1ef5 100644 --- a/examples/crontab.rs +++ b/examples/crontab.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::{LazyLock, RwLock}; use valkey_module::alloc::ValkeyAlloc; -use valkey_module::{valkey_module, Context, Status, ValkeyString}; +use valkey_module::{valkey_module, Context, ContextTrait, Status, ValkeyString}; use valkey_module_macros::cron_event_handler; // struct to hold environment-specific configs, based on the environment name passed in via MODULE LOAD @@ -86,7 +86,7 @@ fn cron_event_handler(ctx: &Context, _hz: u64) { } } -fn initialize(ctx: &Context, args: &[ValkeyString]) -> Status { +fn initialize(ctx: &impl ContextTrait, args: &[ValkeyString]) -> Status { // if arg passed in MODULE LOAD use it to set env_name let env_name = match args.get(0) { Some(tmp) => tmp.to_string(), @@ -113,3 +113,41 @@ valkey_module! { commands: [ ], } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + use valkey_module::MockContext; + + static TEST_LOCK: Mutex<()> = Mutex::new(()); + + #[test] + fn initialize_uses_dev_environment_config() { + let _guard = TEST_LOCK.lock().unwrap(); + let mut ctx = MockContext::new(); + ctx.expect_log_notice().times(1).return_const(()); + let args = vec![ValkeyString::create_for_test("dev")]; + + let status = initialize(&ctx, &args); + let config = ENV_CONFIG.read().unwrap(); + + assert_eq!(status, Status::Ok); + assert_eq!(config.cron_fn1_fn2, "*/5 * * * * * *"); + assert_eq!(config.cron_fn3, "*/10 * * * * * *"); + } + + #[test] + fn initialize_uses_default_config_without_args() { + let _guard = TEST_LOCK.lock().unwrap(); + let mut ctx = MockContext::new(); + ctx.expect_log_notice().times(1).return_const(()); + + let test = initialize(&ctx, &[]); + let config = ENV_CONFIG.read().unwrap(); + + assert_eq!(test, Status::Ok); + assert_eq!(config.cron_fn1_fn2, "*/15 * * * * * *"); + assert_eq!(config.cron_fn3, "*/30 * * * * * *"); + } +} diff --git a/examples/filter2.rs b/examples/filter2.rs index 7a371c6..649f439 100644 --- a/examples/filter2.rs +++ b/examples/filter2.rs @@ -4,8 +4,9 @@ use std::sync::LazyLock; use valkey_module::alloc::ValkeyAlloc; use valkey_module::server_events::ClientChangeSubevent; use valkey_module::{ - valkey_module, CommandFilterCtx, Context, RedisModuleCommandFilterCtx, Status, ValkeyError, - ValkeyString, AUTH_HANDLED, AUTH_NOT_HANDLED, VALKEYMODULE_CMDFILTER_NOSELF, + valkey_module, CommandFilterCtx, CommandFilterCtxTrait, Context, RedisModuleCommandFilterCtx, + Status, ValkeyError, ValkeyString, AUTH_HANDLED, AUTH_NOT_HANDLED, + VALKEYMODULE_CMDFILTER_NOSELF, }; use valkey_module_macros::client_changed_event_handler; @@ -54,16 +55,18 @@ fn auth_callback( Ok(AUTH_NOT_HANDLED) } +fn filter1_logic(ctx: &impl CommandFilterCtxTrait) -> String { + let client_id = ctx.get_client_id(); + match CLIENT_ID_USERNAME_MAP.get(&client_id) { + Some(tmp) => tmp.clone(), + None => "default".to_string(), + } +} + fn filter1_fn(ctx: *mut RedisModuleCommandFilterCtx) { // registered via valkey_module! macro // making sure that two modules can have the same filter fn name - let cf_ctx = CommandFilterCtx::new(ctx); - let client_id = cf_ctx.get_client_id(); - // lookup username by client_id - let _username = match CLIENT_ID_USERNAME_MAP.get(&client_id) { - Some(tmp) => tmp.clone(), - None => "default".to_string(), - }; + let _username = filter1_logic(&CommandFilterCtx::new(ctx)); // do something with the username } @@ -85,3 +88,33 @@ valkey_module! { [filter2_fn, VALKEYMODULE_CMDFILTER_NOSELF] ] } + +#[cfg(test)] +mod tests { + use super::*; + use valkey_module::MockCommandFilterCtx; + + #[test] + fn filter1_uses_default_username_when_client_is_unknown() { + let mut ctx = MockCommandFilterCtx::new(); + let client_id = 100_u64; + CLIENT_ID_USERNAME_MAP.remove(&client_id); + + ctx.expect_get_client_id().times(1).return_const(client_id); + + assert_eq!(filter1_logic(&ctx), "default"); + } + + #[test] + fn filter1_uses_mapped_username_when_client_is_known() { + let mut ctx = MockCommandFilterCtx::new(); + let client_id = 200_u64; + CLIENT_ID_USERNAME_MAP.insert(client_id, "alice".to_string()); + + ctx.expect_get_client_id().times(1).return_const(client_id); + + assert_eq!(filter1_logic(&ctx), "alice"); + + CLIENT_ID_USERNAME_MAP.remove(&client_id); + } +} diff --git a/examples/hello.rs b/examples/hello.rs index 00245e2..d2c5650 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -30,3 +30,45 @@ valkey_module! { ["hello.mul", hello_mul, "", 0, 0, 0], ], } + +#[cfg(test)] +mod tests { + use super::*; + use valkey_module::{Context, ValkeyValue}; + + fn test_args(values: &[&str]) -> Vec { + values + .iter() + .map(|value| ValkeyString::create_for_test(*value)) + .collect() + } + + #[test] + fn hello_mul_returns_inputs_and_product() { + let reply = hello_mul(&Context::dummy(), test_args(&["hello.mul", "2", "3", "4"])).unwrap(); + + assert_eq!( + reply, + ValkeyValue::Array(vec![ + ValkeyValue::Integer(2), + ValkeyValue::Integer(3), + ValkeyValue::Integer(4), + ValkeyValue::Integer(24), + ]) + ); + } + + #[test] + fn hello_mul_returns_wrong_arity_without_numbers() { + let err = hello_mul(&Context::dummy(), test_args(&["hello.mul"])).unwrap_err(); + + assert!(matches!(err, ValkeyError::WrongArity)); + } + + #[test] + fn hello_mul_rejects_invalid_integer() { + let err = hello_mul(&Context::dummy(), test_args(&["hello.mul", "2", "xx"])).unwrap_err(); + + assert!(matches!(err, ValkeyError::Str("Couldn't parse as integer"))); + } +} diff --git a/examples/info_handler_struct.rs b/examples/info_handler_struct.rs index 68ab770..742653b 100644 --- a/examples/info_handler_struct.rs +++ b/examples/info_handler_struct.rs @@ -1,8 +1,7 @@ use std::collections::HashMap; use valkey_module::alloc::ValkeyAlloc; -use valkey_module::InfoContext; -use valkey_module::{valkey_module, ValkeyResult}; +use valkey_module::{valkey_module, InfoContext, InfoContextTrait, ValkeyResult}; use valkey_module_macros::{info_command_handler, InfoSection}; #[derive(Debug, Clone, InfoSection)] @@ -11,15 +10,20 @@ struct Info { dictionary: HashMap, } -#[info_command_handler] -fn add_info(ctx: &InfoContext, _for_crash_report: bool) -> ValkeyResult<()> { +// handler logic written against InfoContextTrait so it is mockable in unit tests +fn add_info_logic(ctx: &impl InfoContextTrait) -> ValkeyResult<()> { let mut dictionary = HashMap::new(); dictionary.insert("key".to_owned(), "value".into()); let data = Info { field: "value".to_owned(), dictionary, }; - ctx.build_one_section(data) + ctx.build_one_section(data.into()) +} + +#[info_command_handler] +fn add_info(ctx: &InfoContext, _for_crash_report: bool) -> ValkeyResult<()> { + add_info_logic(ctx) } ////////////////////////////////////////////////////// @@ -31,3 +35,27 @@ valkey_module! { data_types: [], commands: [], } + +#[cfg(test)] +mod tests { + use super::*; + use valkey_module::MockInfoContext; + + #[test] + fn test_add_info_logic_builds_expected_section() { + let mut ctx = MockInfoContext::new(); + ctx.expect_build_one_section() + .withf(|(name, fields)| { + // section name is derived from the struct name by the InfoSection derive + name == "Info" + && fields.iter().any(|(field_name, _)| field_name == "field") + && fields + .iter() + .any(|(field_name, _)| field_name == "dictionary") + }) + .times(1) + .returning(|_| Ok(())); + + add_info_logic(&ctx).unwrap(); + } +} diff --git a/examples/preload.rs b/examples/preload.rs index 2c39b87..3fe99e6 100644 --- a/examples/preload.rs +++ b/examples/preload.rs @@ -1,7 +1,7 @@ use valkey_module::alloc::ValkeyAlloc; -use valkey_module::{valkey_module, Context, Status, ValkeyString}; +use valkey_module::{valkey_module, ContextTrait, Status, ValkeyString}; -fn preload(ctx: &Context, args: &[ValkeyString]) -> Status { +fn preload(ctx: &impl ContextTrait, args: &[ValkeyString]) -> Status { // perform preload validations here, useful for MODULE LOAD // unlike init which is called at the end of the valkey_module! macro this is called at the beginning let version = ctx.get_server_version().unwrap(); @@ -21,3 +21,26 @@ valkey_module! { preload: preload, commands: [], } + +#[cfg(test)] +mod tests { + use super::*; + use valkey_module::raw::Version; + use valkey_module::MockContext; + + #[test] + fn test_preload_calls_get_server_version() { + let mut ctx = MockContext::new(); + ctx.expect_get_server_version().times(1).returning(|| { + Ok(Version { + major: 8, + minor: 0, + patch: 1, + }) + }); + ctx.expect_log_notice().times(1).return_const(()); + + let status = preload(&ctx, &[]); + assert!(matches!(status, Status::Ok)); + } +} diff --git a/examples/server_events.rs b/examples/server_events.rs index 8547b26..77b0335 100644 --- a/examples/server_events.rs +++ b/examples/server_events.rs @@ -6,8 +6,8 @@ use valkey_module::server_events::{ MasterLinkChangeSubevent, PersistenceSubevent, ReplAsyncLoadSubevent, ReplicaChangeSubevent, }; use valkey_module::{ - server_events::FlushSubevent, valkey_module, Context, ModuleOptions, Status, ValkeyResult, - ValkeyString, ValkeyValue, + server_events::FlushSubevent, valkey_module, Context, ContextTrait, ModuleOptions, Status, + ValkeyResult, ValkeyString, ValkeyValue, }; use valkey_module_macros::{ client_changed_event_handler, config_changed_event_handler, cron_event_handler, @@ -306,7 +306,7 @@ fn num_event_loop_after_sleep(_ctx: &Context, _args: Vec) -> Valke )) } -fn init(ctx: &Context, _args: &[ValkeyString]) -> Status { +fn init(ctx: &impl ContextTrait, _args: &[ValkeyString]) -> Status { // https://valkey.io/topics/modules-api-ref/#ValkeyModule_SetModuleOptions // otherwise you get: Skipping diskless-load because there are modules that are not aware of async replication. // needed for repl_async_load_event_handler @@ -341,3 +341,21 @@ valkey_module! { ["num_event_loop_after_sleep", num_event_loop_after_sleep, "readonly", 0, 0, 0], ] } + +#[cfg(test)] +mod tests { + use super::*; + use valkey_module::MockContext; + + #[test] + fn test_init_sets_module_options() { + let mut ctx = MockContext::new(); + ctx.expect_set_module_options() + .withf(|opts| opts.bits() == ModuleOptions::HANDLE_REPL_ASYNC_LOAD.bits()) + .times(1) + .return_const(()); + + let status = init(&ctx, &[]); + assert!(matches!(status, Status::Ok)); + } +} diff --git a/src/context/filter.rs b/src/context/filter.rs index 145e687..b13c6fe 100644 --- a/src/context/filter.rs +++ b/src/context/filter.rs @@ -65,7 +65,7 @@ impl CommandFilterCtx { } /// wrapper to get Vector of all args minus the command (0th arg) - pub fn get_all_args_wo_cmd(&self) -> Vec<&str> { + pub fn get_all_args_wo_cmd<'a>(&self) -> Vec<&'a str> { let mut output = Vec::new(); for pos in 1..self.args_count() { match self.arg_get_try_as_str(pos) { diff --git a/src/context/mock/cmd_filter_ctx_impl.rs b/src/context/mock/cmd_filter_ctx_impl.rs new file mode 100644 index 0000000..7b60f80 --- /dev/null +++ b/src/context/mock/cmd_filter_ctx_impl.rs @@ -0,0 +1,46 @@ +use super::CommandFilterCtxTrait; +use crate::{CommandFilterCtx, RedisModuleString}; +use std::ffi::c_int; +use std::str::Utf8Error; + +impl CommandFilterCtxTrait for CommandFilterCtx { + fn args_count(&self) -> c_int { + CommandFilterCtx::args_count(self) + } + + fn arg_get(&self, pos: c_int) -> *mut RedisModuleString { + CommandFilterCtx::arg_get(&self, pos) + } + + fn arg_get_try_as_str<'a>(&self, pos: c_int) -> Result<&'a str, Utf8Error> { + CommandFilterCtx::arg_get_try_as_str(self, pos) + } + + fn cmd_get_try_as_str<'a>(&self) -> Result<&'a str, Utf8Error> { + CommandFilterCtx::cmd_get_try_as_str(self) + } + + fn get_all_args_wo_cmd<'a>(&self) -> Vec<&'a str> { + CommandFilterCtx::get_all_args_wo_cmd(self) + } + + fn arg_replace(&self, pos: c_int, arg: &str) { + CommandFilterCtx::arg_replace(self, pos, arg); + } + + fn arg_insert(&self, pos: c_int, arg: &str) { + CommandFilterCtx::arg_insert(self, pos, arg); + } + + fn arg_delete(&self, pos: c_int) { + CommandFilterCtx::arg_delete(self, pos); + } + + #[cfg(all(any( + feature = "min-redis-compatibility-version-7-2", + feature = "min-valkey-compatibility-version-8-0" + ),))] + fn get_client_id(&self) -> u64 { + CommandFilterCtx::get_client_id(self) + } +} diff --git a/src/context/mock/cmd_filter_ctx_trait.rs b/src/context/mock/cmd_filter_ctx_trait.rs new file mode 100644 index 0000000..5350ecc --- /dev/null +++ b/src/context/mock/cmd_filter_ctx_trait.rs @@ -0,0 +1,48 @@ +use crate::RedisModuleString; +use std::ffi::c_int; +use std::str::Utf8Error; + +#[cfg_attr(any(test, feature = "test-mocks"), mockall::automock)] +pub trait CommandFilterCtxTrait { + fn args_count(&self) -> c_int; + fn arg_get(&self, pos: c_int) -> *mut RedisModuleString; + fn arg_get_try_as_str<'a>(&self, pos: c_int) -> Result<&'a str, Utf8Error>; + fn cmd_get_try_as_str<'a>(&self) -> Result<&'a str, Utf8Error>; + fn get_all_args_wo_cmd<'a>(&self) -> Vec<&'a str>; + + fn arg_replace(&self, pos: c_int, arg: &str); + fn arg_insert(&self, pos: c_int, arg: &str); + fn arg_delete(&self, pos: c_int); + + #[cfg(all(any( + feature = "min-redis-compatibility-version-7-2", + feature = "min-valkey-compatibility-version-8-0" + ),))] + fn get_client_id(&self) -> u64; +} + +#[cfg(test)] +mod tests { + use super::*; + use mockall::predicate::eq; + + #[test] + fn test_dispatches_through_impl_and_dyn() { + fn static_dispatch(ctx: &impl CommandFilterCtxTrait) { + ctx.arg_replace(0, "info2"); + } + + fn dynamic_dispatch(ctx: &dyn CommandFilterCtxTrait) { + ctx.arg_replace(0, "info2"); + } + + let mut ctx = MockCommandFilterCtxTrait::new(); + ctx.expect_arg_replace() + .with(eq(0), eq("info2")) + .times(2) + .return_const(()); + + static_dispatch(&ctx); + dynamic_dispatch(&ctx); + } +} diff --git a/src/context/mock/context_impl.rs b/src/context/mock/context_impl.rs new file mode 100644 index 0000000..9d341ae --- /dev/null +++ b/src/context/mock/context_impl.rs @@ -0,0 +1,80 @@ +//! `ContextTrait` delegations for the real [`Context`], split out so +//! `mod.rs` only shows the mockable surface. + +use super::ContextTrait; +use crate::logging::ValkeyLogLevel; +use crate::{Context, RedisModuleClientInfo, Status, ValkeyResult, ValkeyString}; + +impl ContextTrait for Context { + fn log(&self, level: ValkeyLogLevel, message: &str) { + Context::log(self, level, message); + } + fn create_string(&self, s: &str) -> ValkeyString { + Context::create_string(self, s) + } + fn get_current_user(&self) -> ValkeyString { + Context::get_current_user(self) + } + // call methods + fn call<'a>(&self, command: &str, args: &'a [&'a str]) -> ValkeyResult { + Context::call(self, command, args) + } + fn set_module_options(&self, options: crate::raw::ModuleOptions) { + Context::set_module_options(self, options); + } + fn get_server_version(&self) -> ValkeyResult { + Context::get_server_version(self) + } + + // auth methods + fn authenticate_client_with_acl_user(&self, username: &ValkeyString) -> Status { + Context::authenticate_client_with_acl_user(self, username) + } + + // client methods + fn get_client_id(&self) -> u64 { + Context::get_client_id(self) + } + fn get_client_name_by_id(&self, client_id: u64) -> ValkeyResult { + Context::get_client_name_by_id(self, client_id) + } + fn get_client_name(&self) -> ValkeyResult { + Context::get_client_name(self) + } + fn set_client_name_by_id(&self, client_id: u64, client_name: &ValkeyString) -> Status { + Context::set_client_name_by_id(self, client_id, client_name) + } + fn set_client_name(&self, client_name: &ValkeyString) -> Status { + Context::set_client_name(self, client_name) + } + fn get_client_username_by_id(&self, client_id: u64) -> ValkeyResult { + Context::get_client_username_by_id(self, client_id) + } + fn get_client_username(&self) -> ValkeyResult { + Context::get_client_username(self) + } + fn get_client_cert(&self) -> ValkeyResult { + Context::get_client_cert(self) + } + fn get_client_info_by_id(&self, client_id: u64) -> ValkeyResult { + Context::get_client_info_by_id(self, client_id) + } + fn get_client_info(&self) -> ValkeyResult { + Context::get_client_info(self) + } + fn get_client_ip_by_id(&self, client_id: u64) -> ValkeyResult { + Context::get_client_ip_by_id(self, client_id) + } + fn get_client_ip(&self) -> ValkeyResult { + Context::get_client_ip(self) + } + fn deauthenticate_and_close_client_by_id(&self, client_id: u64) -> Status { + Context::deauthenticate_and_close_client_by_id(self, client_id) + } + fn deauthenticate_and_close_client(&self) -> Status { + Context::deauthenticate_and_close_client(self) + } + fn config_get(&self, config: String) -> ValkeyResult { + Context::config_get(self, config) + } +} diff --git a/src/context/mock/context_trait.rs b/src/context/mock/context_trait.rs new file mode 100644 index 0000000..9b33a07 --- /dev/null +++ b/src/context/mock/context_trait.rs @@ -0,0 +1,81 @@ +use crate::logging::ValkeyLogLevel; +use crate::{RedisModuleClientInfo, Status, ValkeyResult, ValkeyString}; + +#[cfg_attr(any(test, feature = "test-mocks"), mockall::automock)] +pub trait ContextTrait { + fn log(&self, level: ValkeyLogLevel, message: &str); + fn log_debug(&self, message: &str) { + self.log(ValkeyLogLevel::Debug, message); + } + fn log_notice(&self, message: &str) { + self.log(ValkeyLogLevel::Notice, message); + } + fn log_verbose(&self, message: &str) { + self.log(ValkeyLogLevel::Verbose, message); + } + fn log_warning(&self, message: &str) { + self.log(ValkeyLogLevel::Warning, message); + } + fn create_string(&self, s: &str) -> ValkeyString; + fn get_current_user(&self) -> ValkeyString; + fn call<'a>(&self, command: &str, args: &'a [&'a str]) -> ValkeyResult; + fn set_module_options(&self, options: crate::raw::ModuleOptions); + fn get_server_version(&self) -> ValkeyResult; + + // auth methods + fn authenticate_client_with_acl_user(&self, username: &ValkeyString) -> Status; + + // client methods + fn get_client_id(&self) -> u64; + fn get_client_name_by_id(&self, client_id: u64) -> ValkeyResult; + fn get_client_name(&self) -> ValkeyResult; + fn set_client_name_by_id(&self, client_id: u64, client_name: &ValkeyString) -> Status; + fn set_client_name(&self, client_name: &ValkeyString) -> Status; + fn get_client_username_by_id(&self, client_id: u64) -> ValkeyResult; + fn get_client_username(&self) -> ValkeyResult; + fn get_client_cert(&self) -> ValkeyResult; + fn get_client_info_by_id(&self, client_id: u64) -> ValkeyResult; + fn get_client_info(&self) -> ValkeyResult; + fn get_client_ip_by_id(&self, client_id: u64) -> ValkeyResult; + fn get_client_ip(&self) -> ValkeyResult; + fn deauthenticate_and_close_client_by_id(&self, client_id: u64) -> Status; + fn deauthenticate_and_close_client(&self) -> Status; + fn config_get(&self, config: String) -> ValkeyResult; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ValkeyValue; + use mockall::predicate::eq; + + #[test] + fn test_dispatches_through_impl_and_dyn() { + fn static_dispatch(ctx: &impl ContextTrait) { + ctx.log_notice("hi"); + } + fn dynamic_dispatch(ctx: &dyn ContextTrait) { + ctx.log_notice("hi"); + } + + let mut ctx = MockContextTrait::new(); + ctx.expect_log_notice() + .with(eq("hi")) + .times(2) + .return_const(()); + static_dispatch(&ctx); + dynamic_dispatch(&ctx); + } + + #[test] + fn test_call() { + let mut ctx = MockContextTrait::new(); + ctx.expect_call() + .withf(|cmd, args| cmd == "SET" && args == ["key", "val"]) + .times(1) + .returning(|_, _| Ok(ValkeyValue::SimpleStringStatic("OK"))); + + let res = ctx.call("SET", &["key", "val"]).unwrap(); + assert!(matches!(res, ValkeyValue::SimpleStringStatic("OK"))); + } +} diff --git a/src/context/mock/info_context_impl.rs b/src/context/mock/info_context_impl.rs new file mode 100644 index 0000000..642f701 --- /dev/null +++ b/src/context/mock/info_context_impl.rs @@ -0,0 +1,8 @@ +use super::InfoContextTrait; +use crate::{InfoContext, OneInfoSectionData, ValkeyResult}; + +impl InfoContextTrait for InfoContext { + fn build_one_section(&self, data: OneInfoSectionData) -> ValkeyResult<()> { + InfoContext::build_one_section(self, data) + } +} diff --git a/src/context/mock/info_context_trait.rs b/src/context/mock/info_context_trait.rs new file mode 100644 index 0000000..9054fa1 --- /dev/null +++ b/src/context/mock/info_context_trait.rs @@ -0,0 +1,35 @@ +use crate::{OneInfoSectionData, ValkeyResult}; + +#[cfg_attr(any(test, feature = "test-mocks"), mockall::automock)] +pub trait InfoContextTrait { + fn build_one_section(&self, data: OneInfoSectionData) -> ValkeyResult<()>; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::InfoContextFieldTopLevelData; + + fn sample_section() -> OneInfoSectionData { + ("mysec".to_string(), InfoContextFieldTopLevelData::new()) + } + + #[test] + fn test_dispatches_through_impl_and_dyn() { + fn static_dispatch(ctx: &impl InfoContextTrait) { + let _ = ctx.build_one_section(sample_section()); + } + fn dynamic_dispatch(ctx: &dyn InfoContextTrait) { + let _ = ctx.build_one_section(sample_section()); + } + + let mut ctx = MockInfoContextTrait::new(); + ctx.expect_build_one_section() + .withf(|(name, fields)| name == "mysec" && fields.is_empty()) + .times(2) + .returning(|_| Ok(())); + + static_dispatch(&ctx); + dynamic_dispatch(&ctx); + } +} diff --git a/src/context/mock/mod.rs b/src/context/mock/mod.rs new file mode 100644 index 0000000..e47064b --- /dev/null +++ b/src/context/mock/mod.rs @@ -0,0 +1,22 @@ +//! Mockable trait abstractions over context wrappers used by the crate. +//! +//! These traits mirror subsets of the concrete context APIs so module logic can +//! be written generically and unit-tested with `mockall`-generated mocks instead +//! of requiring a running Valkey server. + +mod cmd_filter_ctx_impl; +mod cmd_filter_ctx_trait; +mod context_impl; +mod context_trait; +mod info_context_impl; +mod info_context_trait; + +pub use self::cmd_filter_ctx_trait::CommandFilterCtxTrait; +#[cfg(any(test, feature = "test-mocks"))] +pub use self::cmd_filter_ctx_trait::MockCommandFilterCtxTrait as MockCommandFilterCtx; +pub use self::context_trait::ContextTrait; +#[cfg(any(test, feature = "test-mocks"))] +pub use self::context_trait::MockContextTrait as MockContext; +pub use self::info_context_trait::InfoContextTrait; +#[cfg(any(test, feature = "test-mocks"))] +pub use self::info_context_trait::MockInfoContextTrait as MockInfoContext; diff --git a/src/context/mod.rs b/src/context/mod.rs index c3ba27b..4545a8a 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -35,6 +35,7 @@ pub mod commands; pub mod filter; pub mod info; pub mod keys_cursor; +pub mod mock; pub mod server_events; pub mod thread_safe; diff --git a/src/lib.rs b/src/lib.rs index 299f3f3..663c0dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,8 @@ pub mod native_types; pub mod raw; pub mod rediserror; mod redismodule; +#[cfg(any(test, feature = "test-mocks"))] +mod redismodule_test; pub mod redisraw; pub mod redisvalue; pub mod stream; @@ -35,6 +37,26 @@ pub use crate::context::call_reply::{CallReply, CallResult, ErrorReply, PromiseC pub use crate::context::commands; pub use crate::context::info::ServerInfo; pub use crate::context::keys_cursor::KeysCursor; +/// Trait abstraction over [`CommandFilterCtx`] used to make filter logic +/// mockable in tests. +pub use crate::context::mock::CommandFilterCtxTrait; +/// Trait abstraction over [`Context`] used to make module logic mockable in tests. +pub use crate::context::mock::ContextTrait; +/// Trait abstraction over [`InfoContext`] used to make info-handler logic +/// mockable in tests. +pub use crate::context::mock::InfoContextTrait; +#[cfg(any(test, feature = "test-mocks"))] +/// Mock generated from [`CommandFilterCtxTrait`]. Available in crate tests and +/// with the `test-mocks` feature for downstream users. +pub use crate::context::mock::MockCommandFilterCtx; +#[cfg(any(test, feature = "test-mocks"))] +/// Mock generated from [`ContextTrait`]. Available in crate tests and with +/// the `test-mocks` feature for downstream users. +pub use crate::context::mock::MockContext; +#[cfg(any(test, feature = "test-mocks"))] +/// Mock generated from [`InfoContextTrait`]. Available in crate tests and +/// with the `test-mocks` feature for downstream users. +pub use crate::context::mock::MockInfoContext; pub use crate::context::server_events; pub use crate::context::AclPermissions; #[cfg(all(any( diff --git a/src/redismodule_test.rs b/src/redismodule_test.rs new file mode 100644 index 0000000..9de4237 --- /dev/null +++ b/src/redismodule_test.rs @@ -0,0 +1,418 @@ +use std::cmp::Ordering; +use std::os::raw::{c_char, c_double, c_int, c_longlong}; +use std::slice; +use std::str::FromStr; +use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; +use std::sync::Once; + +use crate::raw; +use crate::ValkeyString; + +/// Heap-allocated backing store for a test `ValkeyString`. +/// The pointer is cast to `*mut RedisModuleString` so the existing +/// API surface (`try_as_str`, `as_slice`, `len`, etc.) works without +/// a running Valkey server. +#[repr(C)] +struct TestStringInner { + refcount: AtomicUsize, + data: Vec, +} + +impl TestStringInner { + fn new(data: Vec) -> Self { + // Start at one reference to match a freshly created module string. + Self { + refcount: AtomicUsize::new(1), + data, + } + } +} + +fn test_string_inner<'a>(str_: *const raw::RedisModuleString) -> &'a TestStringInner { + // Test shims only pass pointers created from `TestStringInner` allocations. + unsafe { &*str_.cast::() } +} + +fn test_string_inner_mut<'a>(str_: *mut raw::RedisModuleString) -> &'a mut TestStringInner { + // Mutable access is only used for those same test-owned string allocations. + unsafe { &mut *str_.cast::() } +} + +fn allocate_test_string(data: Vec) -> *mut raw::RedisModuleString { + // Wrap test-owned bytes in the same opaque pointer shape the raw API uses. + let test_string = TestStringInner::new(data); + let boxed_test_string = Box::new(test_string); + let test_string_ptr = Box::into_raw(boxed_test_string); + + test_string_ptr.cast() +} + +fn parse_test_string(str_: *const raw::RedisModuleString) -> Option { + // Shared parser for shim APIs that first require valid UTF-8 input bytes. + let test_string = test_string_inner(str_); + let string_bytes = &test_string.data; + let string_utf8 = str::from_utf8(string_bytes); + + string_utf8.ok().and_then(|s| s.parse::().ok()) +} + +fn write_optional_output(out: *mut T, value: T) { + // Match Valkey's out-parameter convention: write only when the caller supplies one. + if !out.is_null() { + unsafe { + *out = value; + } + } +} + +extern "C" fn test_string_ptr_len( + str_: *const raw::RedisModuleString, + len: *mut usize, +) -> *const c_char { + // The real Valkey API returns a raw pointer to the string bytes and writes + // the byte length into the provided out-parameter when one is supplied. + let test_string = test_string_inner(str_); + let string_bytes = &test_string.data; + let string_len = string_bytes.len(); + + if !len.is_null() { + unsafe { + *len = string_len; + } + } + + let string_bytes_ptr = string_bytes.as_ptr(); + string_bytes_ptr.cast::() +} + +extern "C" fn test_create_string( + _ctx: *mut raw::RedisModuleCtx, + ptr: *const c_char, + len: usize, +) -> *mut raw::RedisModuleString { + // The real Valkey API accepts a raw C pointer/length pair, copies those + // bytes into a new module-owned string allocation, and returns an opaque + // `RedisModuleString` pointer to that owned storage. + let input_bytes_ptr = ptr.cast::(); + let input_bytes = unsafe { slice::from_raw_parts(input_bytes_ptr, len) }; + + allocate_test_string(input_bytes.to_vec()) +} + +extern "C" fn test_create_string_from_string( + _ctx: *mut raw::RedisModuleCtx, + str_: *const raw::RedisModuleString, +) -> *mut raw::RedisModuleString { + // The real Valkey API returns a new module string allocation containing a + // copy of the source string bytes, so the cloned string can outlive and + // diverge from the original. + allocate_test_string(test_string_inner(str_).data.clone()) +} + +extern "C" fn test_free_string(_ctx: *mut raw::RedisModuleCtx, str_: *mut raw::RedisModuleString) { + if !str_.is_null() { + // `ValkeyString::safe_clone` can retain the original allocation, so the + // test shim needs refcounted destruction instead of freeing eagerly on + // every drop. + let test_string = test_string_inner_mut(str_); + let previous_refcount = test_string.refcount.fetch_sub(1, AtomicOrdering::AcqRel); + let should_drop_allocation = previous_refcount == 1; + + if should_drop_allocation { + let test_string_ptr = str_.cast::(); + let _owned_test_string = unsafe { Box::from_raw(test_string_ptr) }; + } + } +} + +extern "C" fn test_retain_string( + _ctx: *mut raw::RedisModuleCtx, + str_: *mut raw::RedisModuleString, +) { + if !str_.is_null() { + // The real Valkey API retains the existing module string allocation by + // incrementing its reference count, so this test shim mirrors that + // behavior on the Rust-owned backing allocation. + let test_string = test_string_inner_mut(str_); + let refcount = &test_string.refcount; + + refcount.fetch_add(1, AtomicOrdering::AcqRel); + } +} + +extern "C" fn test_string_compare( + a: *const raw::RedisModuleString, + b: *const raw::RedisModuleString, +) -> c_int { + // The real Valkey API compares the underlying string bytes and returns a + // negative value, zero, or a positive value depending on the ordering. + let left_test_string = test_string_inner(a); + let right_test_string = test_string_inner(b); + let left_bytes = &left_test_string.data; + let right_bytes = &right_test_string.data; + let ordering = left_bytes.cmp(right_bytes); + + match ordering { + Ordering::Less => -1, + Ordering::Equal => 0, + Ordering::Greater => 1, + } +} + +extern "C" fn test_string_append_buffer( + _ctx: *mut raw::RedisModuleCtx, + str_: *mut raw::RedisModuleString, + buf: *const c_char, + len: usize, +) -> c_int { + // The real Valkey API appends raw bytes into the existing string buffer. + // Reconstruct that byte slice from the incoming C pointer/length pair, then + // mutate the test-owned backing store in place. + let appended_bytes_ptr = buf.cast::(); + let appended_bytes = unsafe { slice::from_raw_parts(appended_bytes_ptr, len) }; + let test_string = test_string_inner_mut(str_); + let existing_bytes = &mut test_string.data; + + existing_bytes.extend_from_slice(appended_bytes); + raw::REDISMODULE_OK as c_int +} + +extern "C" fn test_string_to_longlong( + str_: *const raw::RedisModuleString, + ll: *mut c_longlong, +) -> c_int { + // The real Valkey API parses the string bytes as a signed integer and, on + // success, writes the parsed value into the provided out-parameter. + match parse_test_string::(str_) { + Some(value) => { + write_optional_output(ll, value as c_longlong); + raw::REDISMODULE_OK as c_int + } + None => raw::REDISMODULE_ERR as c_int, + } +} + +extern "C" fn test_string_to_double( + str_: *const raw::RedisModuleString, + d: *mut c_double, +) -> c_int { + // The real Valkey API parses the string bytes as a floating-point number + // and, on success, writes the parsed value into the provided + // out-parameter. + match parse_test_string::(str_) { + Some(value) => { + write_optional_output(d, value as c_double); + raw::REDISMODULE_OK as c_int + } + None => raw::REDISMODULE_ERR as c_int, + } +} + +static INIT: Once = Once::new(); + +/// Install test shim function pointers so `ValkeyString` methods work +/// without the Valkey C runtime. +fn ensure_test_shims() { + INIT.call_once(|| unsafe { + let create_string = raw::RedisModule_CreateString; + if create_string.is_none() { + // Only install the shim when the raw API has not already been + // initialized by a real Valkey context. + raw::RedisModule_StringPtrLen = Some(test_string_ptr_len); + raw::RedisModule_CreateString = Some(test_create_string); + raw::RedisModule_CreateStringFromString = Some(test_create_string_from_string); + raw::RedisModule_FreeString = Some(test_free_string); + raw::RedisModule_RetainString = Some(test_retain_string); + raw::RedisModule_StringCompare = Some(test_string_compare); + raw::RedisModule_StringAppendBuffer = Some(test_string_append_buffer); + raw::RedisModule_StringToLongLong = Some(test_string_to_longlong); + raw::RedisModule_StringToDouble = Some(test_string_to_double); + } + }); +} + +#[cfg(any(test, feature = "test-mocks"))] +impl ValkeyString { + /// Create a `ValkeyString` that owns its bytes without a running Valkey + /// server. Only available in test / `test-mocks` builds. + /// + /// The returned value supports `try_as_str`, `as_slice`, `len`, and + /// `is_empty` — everything needed to pass it as a command argument in + /// unit tests. + #[cfg(any(test, feature = "test-mocks"))] + pub fn create_for_test>>(s: T) -> Self { + // Install the shim lazily so tests can opt in only when they need it. + ensure_test_shims(); + ValkeyString::create(None, s) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::borrow::Borrow; + + #[test] + fn create_for_test_installs_all_required_string_shims() { + ensure_test_shims(); + + let create = unsafe { raw::RedisModule_CreateString }; + let create_from_string = unsafe { raw::RedisModule_CreateStringFromString }; + let free = unsafe { raw::RedisModule_FreeString }; + let retain = unsafe { raw::RedisModule_RetainString }; + let ptr_len = unsafe { raw::RedisModule_StringPtrLen }; + let compare = unsafe { raw::RedisModule_StringCompare }; + let append = unsafe { raw::RedisModule_StringAppendBuffer }; + let to_longlong = unsafe { raw::RedisModule_StringToLongLong }; + let to_double = unsafe { raw::RedisModule_StringToDouble }; + + assert!(create.is_some()); + assert!(create_from_string.is_some()); + assert!(free.is_some()); + assert!(retain.is_some()); + assert!(ptr_len.is_some()); + assert!(compare.is_some()); + assert!(append.is_some()); + assert!(to_longlong.is_some()); + assert!(to_double.is_some()); + } + + #[test] + fn test_string_inner_returns_original_allocation() { + let s = ValkeyString::create_for_test("hello"); + let inner = test_string_inner(s.inner); + + assert_eq!( + inner as *const TestStringInner, + s.inner.cast::() + ); + assert_eq!(inner.data, b"hello"); + } + + #[test] + fn test_string_inner_mut_updates_backing_bytes() { + let s = ValkeyString::create_for_test("hello"); + + let inner = test_string_inner_mut(s.inner); + inner.data.extend_from_slice(b"!"); + + assert_eq!(s.as_slice(), b"hello!"); + assert_eq!(s.try_as_str().unwrap(), "hello!"); + } + + #[test] + fn accessors_return_input_bytes() { + let s = ValkeyString::create_for_test("hello"); + assert_eq!(s.try_as_str().unwrap(), "hello"); + assert_eq!(s.as_slice(), b"hello"); + assert_eq!(s.len(), 5); + assert!(!s.is_empty()); + } + + #[test] + fn create_for_test_accepts_byte_buffers() { + let s = ValkeyString::create_for_test(vec![0x66, 0x6f, 0x6f]); + + assert_eq!(s.as_slice(), b"foo"); + assert_eq!(s.try_as_str().unwrap(), "foo"); + } + + #[test] + fn invalid_utf8_is_preserved_in_slice_accessors() { + let s = ValkeyString::create_for_test(vec![0xff, b'o', b'o']); + + assert_eq!(s.as_slice(), &[0xff, b'o', b'o']); + assert!(matches!( + s.try_as_str(), + Err(crate::ValkeyError::Str("Couldn't parse as UTF-8 string")) + )); + } + + #[test] + fn empty_string_is_empty() { + let s = ValkeyString::create_for_test(""); + assert_eq!(s.len(), 0); + assert!(s.is_empty()); + } + + #[test] + fn ordering_and_equality() { + let a = ValkeyString::create_for_test("aaa"); + let b = ValkeyString::create_for_test("bbb"); + let a2 = ValkeyString::create_for_test("aaa"); + assert!(a < b); + assert!(b > a); + assert_eq!(a, a2); + } + + #[test] + fn clone_creates_independent_string() { + let mut s = ValkeyString::create_for_test("41"); + let cloned = s.clone(); + + s.append(".5"); + + assert_eq!(cloned.try_as_str().unwrap(), "41"); + assert_eq!(s.try_as_str().unwrap(), "41.5"); + } + + #[test] + fn append_returns_ok_and_updates_bytes() { + let mut s = ValkeyString::create_for_test("foo"); + + assert_eq!(s.append("bar"), raw::Status::Ok); + assert_eq!(s.as_slice(), b"foobar"); + } + + #[test] + fn integer_and_float_parsing_work() { + let integer = ValkeyString::create_for_test("41"); + let float = ValkeyString::create_for_test("41.5"); + + assert_eq!(integer.parse_integer().unwrap(), 41); + assert_eq!(integer.parse_unsigned_integer().unwrap(), 41); + assert_eq!(float.parse_float().unwrap(), 41.5); + } + + #[test] + fn parsing_errors_match_expected_messages() { + let invalid_int = ValkeyString::create_for_test("abc"); + let negative = ValkeyString::create_for_test("-1"); + let invalid_float = ValkeyString::create_for_test("abc"); + + assert!(matches!( + invalid_int.parse_integer(), + Err(crate::ValkeyError::Str("Couldn't parse as integer")) + )); + assert!(matches!( + negative.parse_unsigned_integer(), + Err(crate::ValkeyError::Str( + "Couldn't parse negative number as unsigned integer" + )) + )); + assert!(matches!( + invalid_float.parse_float(), + Err(crate::ValkeyError::Str("Couldn't parse as float")) + )); + } + + #[test] + fn safe_clone_keeps_original_alive() { + let s = ValkeyString::create_for_test("hello"); + let cloned = s.safe_clone(&crate::Context::dummy()); + + drop(s); + + assert_eq!(cloned.try_as_str().unwrap(), "hello"); + } + + #[test] + fn borrow_returns_placeholder_for_invalid_utf8() { + let s = ValkeyString::create_for_test(vec![0xff, b'o']); + + assert_eq!( + >::borrow(&s), + "" + ); + } +} diff --git a/test.sh b/test.sh index f2f5986..693192a 100755 --- a/test.sh +++ b/test.sh @@ -1,4 +1,3 @@ #!/usr/bin/env sh -# TODO cargo test --all --all-targets --no-default-features rm dump.rdb -cargo test --all --no-default-features +cargo test --all --all-targets --no-default-features --features enable-system-alloc diff --git a/tests/integration.rs b/tests/integration.rs index f200fe3..0178910 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -570,13 +570,25 @@ fn test_client_change_event() -> Result<()> { let con2: redis::Connection = get_valkey_connection(port).with_context(|| FAILED_TO_CONNECT_TO_SERVER)?; - let conn_res: i64 = redis::cmd("num_connects").query(&mut con)?; + let wait_for_num_connects = |con: &mut redis::Connection, expected: i64| -> Result { + let mut last = -1; + for _ in 0..20 { + last = redis::cmd("num_connects").query(con)?; + if last == expected { + return Ok(last); + } + thread::sleep(Duration::from_millis(50)); + } + Ok(last) + }; + + let conn_res = wait_for_num_connects(&mut con, 2)?; println!("Connection result: {}", conn_res); assert_eq!(conn_res, 2); drop(con2); - let disconn_res: i64 = redis::cmd("num_connects").query(&mut con)?; + let disconn_res = wait_for_num_connects(&mut con, 1)?; println!("Disconnection result: {}", disconn_res); assert_eq!(disconn_res, 1);