diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7031c948..1986ea91 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,6 +49,25 @@ jobs: test: name: Test runs-on: ubuntu-latest + + services: + postgres: + image: postgres:16 + env: + POSTGRES_USER: torii + POSTGRES_PASSWORD: torii + POSTGRES_DB: torii + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U torii" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + env: + DATABASE_URL: postgres://torii:torii@localhost:5432/torii + steps: - uses: actions/checkout@v4 diff --git a/Cargo.lock b/Cargo.lock index 16dbbecd..e9e3f967 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5038,6 +5038,7 @@ dependencies = [ "torii-dojo", "torii-introspect", "torii-runtime-common", + "torii-sqlite", "tracing", ] @@ -5403,7 +5404,9 @@ version = "0.1.0" dependencies = [ "async-trait", "futures", + "libsqlite3-sys", "sqlx", + "tokio", "torii-common", ] diff --git a/crates/introspect-postgres-sink/migrations/004_hex2int_function.sql b/crates/introspect-postgres-sink/migrations/004_hex2int_function.sql new file mode 100644 index 00000000..a1eca0a7 --- /dev/null +++ b/crates/introspect-postgres-sink/migrations/004_hex2int_function.sql @@ -0,0 +1,25 @@ +CREATE OR REPLACE FUNCTION public.hex2int(hex_text TEXT) RETURNS BIGINT AS $$ +DECLARE + stripped TEXT; +BEGIN + IF hex_text IS NULL THEN + RETURN NULL; + END IF; + + -- Strip a single leading 0x / 0X prefix (not all occurrences). + IF left(hex_text, 2) IN ('0x', '0X') THEN + stripped := substr(hex_text, 3); + ELSE + stripped := hex_text; + END IF; + + -- Take the rightmost 16 hex chars (lower 64 bits). Hex strings may be up + -- to 256 bits; values that fit in u64 are preserved exactly, larger values + -- are truncated to their low 64 bits. + IF length(stripped) > 16 THEN + stripped := right(stripped, 16); + END IF; + + RETURN ('x' || lpad(stripped, 16, '0'))::bit(64)::bigint; +END; +$$ LANGUAGE plpgsql IMMUTABLE STRICT; diff --git a/crates/introspect-postgres-sink/tests/hex2int.rs b/crates/introspect-postgres-sink/tests/hex2int.rs new file mode 100644 index 00000000..fd1e3cc8 --- /dev/null +++ b/crates/introspect-postgres-sink/tests/hex2int.rs @@ -0,0 +1,70 @@ +//! Integration test for the `hex2int` Postgres function shipped via migration +//! `004_hex2int_function.sql`. Requires a running Postgres reachable via +//! `DATABASE_URL`; skipped otherwise so local `cargo test` still passes. + +use sqlx::postgres::PgPoolOptions; +use sqlx::Row; +use torii_introspect_postgres_sink::INTROSPECT_PG_SINK_MIGRATIONS; +use torii_postgres::migration::SchemaMigrator; + +async fn get_pool() -> Option { + let url = std::env::var("DATABASE_URL").ok()?; + let pool = PgPoolOptions::new() + .max_connections(1) + .connect(&url) + .await + .expect("failed to connect to DATABASE_URL"); + SchemaMigrator::new("introspect", INTROSPECT_PG_SINK_MIGRATIONS) + .run(&pool) + .await + .expect("failed to run migrations"); + Some(pool) +} + +async fn hex2int(pool: &sqlx::PgPool, input: Option<&str>) -> Option { + let row = sqlx::query("SELECT hex2int($1) AS v") + .bind(input) + .fetch_one(pool) + .await + .unwrap(); + row.try_get::, _>("v").unwrap() +} + +#[tokio::test] +async fn test_hex2int_postgres() { + let Some(pool) = get_pool().await else { + eprintln!("DATABASE_URL not set; skipping hex2int Postgres test"); + return; + }; + + assert_eq!(hex2int(&pool, Some("0xff")).await, Some(255)); + assert_eq!(hex2int(&pool, Some("ff")).await, Some(255)); + assert_eq!(hex2int(&pool, Some("0x0")).await, Some(0)); + assert_eq!(hex2int(&pool, Some("0XAB")).await, Some(171)); + + // u64::MAX → -1 as i64 + assert_eq!(hex2int(&pool, Some("0xffffffffffffffff")).await, Some(-1)); + + // NULL passthrough + assert_eq!(hex2int(&pool, None).await, None); + + // 256-bit value zero-padded: only the lower 64 bits are kept + assert_eq!( + hex2int( + &pool, + Some("0x00000000000000000000000000000000000000000000000000000000000000ff"), + ) + .await, + Some(255), + ); + + // 256-bit value with non-zero high bits: high bits ignored, lower 64 bits returned + assert_eq!( + hex2int( + &pool, + Some("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef1234567890abcdef"), + ) + .await, + Some(0x1234567890abcdef_i64), + ); +} diff --git a/crates/sqlite/Cargo.toml b/crates/sqlite/Cargo.toml index 08f62825..e0a1b34a 100644 --- a/crates/sqlite/Cargo.toml +++ b/crates/sqlite/Cargo.toml @@ -7,6 +7,7 @@ description = "SQLite connection helpers for Torii storage crates" [dependencies] async-trait.workspace = true futures.workspace = true +libsqlite3-sys = { version = "0.30", features = ["bundled"] } sqlx = { workspace = true, features = [ "sqlite", "runtime-tokio-rustls", @@ -15,5 +16,8 @@ sqlx = { workspace = true, features = [ torii-common.workspace = true +[dev-dependencies] +tokio = { version = "1", features = ["rt", "macros"] } + [lints] workspace = true diff --git a/crates/sqlite/src/lib.rs b/crates/sqlite/src/lib.rs index e0e1dc65..7bbd03c7 100644 --- a/crates/sqlite/src/lib.rs +++ b/crates/sqlite/src/lib.rs @@ -1,4 +1,6 @@ pub mod db; pub mod migration; +pub mod udf; pub use db::{is_sqlite_memory_path, sqlite_connect_options, SqliteConnection}; +pub use udf::install_udfs; diff --git a/crates/sqlite/src/udf.rs b/crates/sqlite/src/udf.rs new file mode 100644 index 00000000..243fd9e2 --- /dev/null +++ b/crates/sqlite/src/udf.rs @@ -0,0 +1,194 @@ +use std::ffi::CString; +use std::os::raw::{c_char, c_int}; +use std::ptr; +use std::sync::Once; + +use libsqlite3_sys::{ + sqlite3, sqlite3_api_routines, sqlite3_auto_extension, sqlite3_create_function_v2, + sqlite3_result_error, sqlite3_result_int64, sqlite3_result_null, sqlite3_value, + sqlite3_value_text, sqlite3_value_type, SQLITE_OK, SQLITE_TEXT, SQLITE_UTF8, +}; + +static INIT: Once = Once::new(); + +/// Register all custom UDFs as auto-extensions. +/// +/// After calling this, every new SQLite connection in the process will +/// automatically have the UDFs available. Safe to call multiple times; +/// only the first call has any effect. +#[allow(unsafe_code)] +pub fn install_udfs() { + INIT.call_once(|| { + // SAFETY: sqlite3_auto_extension expects an extension entry point cast to + // Option c_int>. + // `udfs_init` matches this signature. + unsafe { + let rc = sqlite3_auto_extension(Some(udfs_init)); + assert_eq!(rc, SQLITE_OK, "failed to install SQLite UDF auto-extension"); + } + }); +} + +/// Auto-extension entry point called by SQLite for each new connection. +#[allow(unsafe_code)] +unsafe extern "C" fn udfs_init( + db: *mut sqlite3, + _pz_err_msg: *mut *mut c_char, + _p_thunk: *const sqlite3_api_routines, +) -> c_int { + register_hex2int(db); + SQLITE_OK +} + +/// Register the `hex2int` scalar function on a raw sqlite3 handle. +/// +/// `hex2int(hex_string)` converts a hex-encoded integer (with or without `0x` prefix) +/// to an i64. Returns NULL for NULL input. +#[allow(unsafe_code)] +unsafe fn register_hex2int(db: *mut sqlite3) { + let name = CString::new("hex2int").unwrap(); + let rc = sqlite3_create_function_v2( + db, + name.as_ptr(), + 1, // nArg + SQLITE_UTF8, // eTextRep + ptr::null_mut(), // pApp + Some(hex2int_fn), // xFunc + None, // xStep + None, // xFinal + None, // xDestroy + ); + assert_eq!(rc, SQLITE_OK, "failed to register hex2int UDF"); +} + +/// The C callback implementing hex2int. +#[allow(unsafe_code)] +unsafe extern "C" fn hex2int_fn( + ctx: *mut libsqlite3_sys::sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) { + debug_assert_eq!(argc, 1); + let val = *argv; + + // NULL in → NULL out + if sqlite3_value_type(val) == libsqlite3_sys::SQLITE_NULL { + sqlite3_result_null(ctx); + return; + } + + // Must be text + if sqlite3_value_type(val) != SQLITE_TEXT { + let msg = CString::new("hex2int: expected text argument").unwrap(); + sqlite3_result_error(ctx, msg.as_ptr(), -1); + return; + } + + let text_ptr = sqlite3_value_text(val); + if text_ptr.is_null() { + sqlite3_result_null(ctx); + return; + } + + let text = std::ffi::CStr::from_ptr(text_ptr.cast::()) + .to_str() + .unwrap_or(""); + let stripped = text + .strip_prefix("0x") + .or_else(|| text.strip_prefix("0X")) + .unwrap_or(text); + // Take the rightmost 16 hex chars (lower 64 bits). Hex strings may be up + // to 256 bits; values that fit in u64 are preserved exactly, larger values + // are truncated to their low 64 bits. + let lower = stripped + .get(stripped.len().saturating_sub(16)..) + .unwrap_or(stripped); + + if let Ok(n) = u64::from_str_radix(lower, 16) { + sqlite3_result_int64(ctx, n as i64); + } else { + let msg = CString::new(format!("hex2int: invalid hex string '{text}'")).unwrap(); + sqlite3_result_error(ctx, msg.as_ptr(), -1); + } +} + +#[cfg(test)] +mod tests { + use sqlx::sqlite::SqlitePoolOptions; + use sqlx::Row; + + use super::*; + + #[tokio::test] + async fn test_hex2int() { + install_udfs(); + + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + + // UDFs are automatically available — no per-connection registration needed + let row = sqlx::query("SELECT hex2int('0xff') AS v") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.get::("v"), 255); + + // Without prefix + let row = sqlx::query("SELECT hex2int('ff') AS v") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.get::("v"), 255); + + // Large value + let row = sqlx::query("SELECT hex2int('0xffffffffffffffff') AS v") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.get::("v"), -1); // u64::MAX as i64 + + // NULL passthrough + let row = sqlx::query("SELECT hex2int(NULL) AS v") + .fetch_one(&pool) + .await + .unwrap(); + assert!(row.try_get::("v").is_err() || row.get::, _>("v").is_none()); + + // Zero + let row = sqlx::query("SELECT hex2int('0x0') AS v") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.get::("v"), 0); + + // Uppercase prefix + let row = sqlx::query("SELECT hex2int('0XAB') AS v") + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.get::("v"), 171); + + // 256-bit value: only the lower 64 bits are kept + // 0x0000...0000_00000000000000ff → 255 + let row = sqlx::query( + "SELECT hex2int('0x00000000000000000000000000000000000000000000000000000000000000ff') AS v", + ) + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.get::("v"), 255); + + // 256-bit value where high bits are non-zero but ignored + // high 192 bits: 0xdead...; low 64 bits: 0x1234567890abcdef + let row = sqlx::query( + "SELECT hex2int('0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef1234567890abcdef') AS v", + ) + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.get::("v"), 0x1234567890abcdef_i64); + } +} diff --git a/crates/torii-ecs-sink/Cargo.toml b/crates/torii-ecs-sink/Cargo.toml index 3598c8c9..1edc7cf8 100644 --- a/crates/torii-ecs-sink/Cargo.toml +++ b/crates/torii-ecs-sink/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" torii = { path = "../.." } torii-dojo = { path = "../dojo" } torii-introspect = { path = "../introspect" } +torii-sqlite.workspace = true torii-runtime-common.workspace = true dojo-introspect.workspace = true introspect-types.workspace = true diff --git a/crates/torii-ecs-sink/src/grpc_service.rs b/crates/torii-ecs-sink/src/grpc_service.rs index 565828cf..80b92e4c 100644 --- a/crates/torii-ecs-sink/src/grpc_service.rs +++ b/crates/torii-ecs-sink/src/grpc_service.rs @@ -645,6 +645,7 @@ impl EcsService { erc1155_url: Option<&str>, ) -> Result { sqlx::any::install_default_drivers(); + torii_sqlite::install_udfs(); let backend = DbBackend::detect(database_url); let database_url = match backend {