From a9933f095bed2ec0a200d6183b08832a9f9cec46 Mon Sep 17 00:00:00 2001 From: humayrakhanom Date: Mon, 16 Feb 2026 02:11:32 -0600 Subject: [PATCH 1/2] [Feature] Add core memory types, errors, and vector index Introduce the memory layer foundation with MemoryId, Memory, MemoryScope, MemoryConfig types, an in-memory vector index for semantic search, the Memory error variant in GraphBitError, and rusqlite workspace dependency. Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 58 +++++++- Cargo.toml | 1 + core/Cargo.toml | 1 + core/src/errors.rs | 22 +++ core/src/lib.rs | 1 + core/src/memory/mod.rs | 12 ++ core/src/memory/types.rs | 282 ++++++++++++++++++++++++++++++++++++++ core/src/memory/vector.rs | 193 ++++++++++++++++++++++++++ 8 files changed, 569 insertions(+), 1 deletion(-) create mode 100644 core/src/memory/mod.rs create mode 100644 core/src/memory/types.rs create mode 100644 core/src/memory/vector.rs diff --git a/Cargo.lock b/Cargo.lock index e574f6d4..6d37e03e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -528,6 +528,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -801,6 +813,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest", + "rusqlite", "scraper", "serde", "serde_json", @@ -850,12 +863,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -1107,7 +1138,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.5", ] [[package]] @@ -1207,6 +1238,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2066,6 +2108,20 @@ dependencies = [ "winreg", ] +[[package]] +name = "rusqlite" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +dependencies = [ + "bitflags 2.9.1", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc-demangle" version = "0.1.26" diff --git a/Cargo.toml b/Cargo.toml index b6714ed9..9c0dabb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,6 +131,7 @@ quick-xml = {version = "0.36", features = ["serialize"]} rand = "0.8" regex = "1.10" reqwest = {version = "0.11", features = ["json", "stream"]} +rusqlite = { version = "0.31", features = ["bundled"] } # HTML parsing scraper = "0.20" serde = {version = "1.0", features = ["derive"]} diff --git a/core/Cargo.toml b/core/Cargo.toml index b516aa3b..c37cb915 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -15,6 +15,7 @@ quick-xml.workspace = true rand.workspace = true regex.workspace = true reqwest.workspace = true +rusqlite.workspace = true scraper.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/core/src/errors.rs b/core/src/errors.rs index ed250557..08878c98 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -126,6 +126,13 @@ pub enum GraphBitError { /// Error message message: String, }, + + /// Memory layer errors + #[error("Memory error: {message}")] + Memory { + /// Error message + message: String, + }, } impl GraphBitError { @@ -211,6 +218,13 @@ impl GraphBitError { } } + /// Create a new memory layer error + pub fn memory(message: impl Into) -> Self { + Self::Memory { + message: message.into(), + } + } + /// Check if the error is retryable pub fn is_retryable(&self) -> bool { matches!( @@ -269,3 +283,11 @@ impl From for GraphBitError { } } } + +impl From for GraphBitError { + fn from(error: rusqlite::Error) -> Self { + Self::Internal { + message: format!("SQLite error: {error}"), + } + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index 4c20d0b0..844fe209 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -36,6 +36,7 @@ pub mod embeddings; pub mod errors; pub mod graph; pub mod llm; +pub mod memory; pub mod text_splitter; pub mod types; pub mod validation; diff --git a/core/src/memory/mod.rs b/core/src/memory/mod.rs new file mode 100644 index 00000000..4b7142a2 --- /dev/null +++ b/core/src/memory/mod.rs @@ -0,0 +1,12 @@ +//! Memory layer for `GraphBit` +//! +//! Provides LLM-driven fact extraction from conversations, vector-based semantic +//! search, SQLite-backed persistent storage, and scoped memory isolation. + +pub mod types; +pub mod vector; + +pub use types::{ + Memory, MemoryAction, MemoryConfig, MemoryDecision, MemoryHistory, MemoryId, MemoryScope, + ScoredMemory, +}; diff --git a/core/src/memory/types.rs b/core/src/memory/types.rs new file mode 100644 index 00000000..07382303 --- /dev/null +++ b/core/src/memory/types.rs @@ -0,0 +1,282 @@ +//! Core data types for the memory layer. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; +use uuid::Uuid; + +use crate::embeddings::EmbeddingConfig; +use crate::llm::LlmConfig; + +/// Unique identifier for memories, following the `AgentId` pattern. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct MemoryId(pub Uuid); + +impl MemoryId { + /// Create a new random memory ID. + #[inline] + pub fn new() -> Self { + Self(Uuid::new_v4()) + } + + /// Parse a memory ID from a string. + pub fn from_string(s: &str) -> Result { + Ok(Self(Uuid::parse_str(s)?)) + } + + /// Get the underlying UUID. + #[inline] + pub fn as_uuid(&self) -> &Uuid { + &self.0 + } +} + +impl Default for MemoryId { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl fmt::Display for MemoryId { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Scoping information for memory isolation. +/// +/// Memories can be scoped to a specific user, agent, or run. All fields are +/// optional; omitted fields act as wildcards when filtering. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct MemoryScope { + /// User-level scope. + pub user_id: Option, + /// Agent-level scope. + pub agent_id: Option, + /// Run-level scope. + pub run_id: Option, +} + +/// A single stored memory fact. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Memory { + /// Unique identifier. + pub id: MemoryId, + /// The fact / content of the memory. + pub content: String, + /// Scoping information. + pub scope: MemoryScope, + /// Arbitrary key-value metadata. + pub metadata: HashMap, + /// Creation timestamp. + pub created_at: DateTime, + /// Last update timestamp. + pub updated_at: DateTime, + /// Content hash for deduplication. + pub hash: String, +} + +/// A memory paired with its similarity score from a search. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScoredMemory { + /// The memory. + pub memory: Memory, + /// Cosine similarity score (0.0 .. 1.0). + pub score: f64, +} + +/// A historical record of a memory mutation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryHistory { + /// Which memory was affected. + pub memory_id: MemoryId, + /// Content before the change (empty for `Add`). + pub old_content: String, + /// Content after the change (empty for `Delete`). + pub new_content: String, + /// What kind of mutation. + pub action: MemoryAction, + /// When the mutation occurred. + pub timestamp: DateTime, +} + +/// Describes the kind of memory mutation. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum MemoryAction { + /// A brand-new fact was added. + Add, + /// An existing fact was updated / refined. + Update, + /// A fact was removed. + Delete, + /// No change required (duplicate / irrelevant). + Noop, +} + +impl fmt::Display for MemoryAction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Add => write!(f, "ADD"), + Self::Update => write!(f, "UPDATE"), + Self::Delete => write!(f, "DELETE"), + Self::Noop => write!(f, "NOOP"), + } + } +} + +impl MemoryAction { + /// Parse an action from its string representation. + pub fn from_str_lossy(s: &str) -> Self { + match s.to_uppercase().as_str() { + "ADD" => Self::Add, + "UPDATE" => Self::Update, + "DELETE" => Self::Delete, + _ => Self::Noop, + } + } +} + +/// An LLM decision about what to do with an extracted fact. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryDecision { + /// The extracted fact text. + pub fact: String, + /// The decided action. + pub action: MemoryAction, + /// If updating or deleting, the target memory's ID (as a string). + pub target_memory_id: Option, +} + +/// Configuration for the memory subsystem. +#[derive(Debug, Clone)] +pub struct MemoryConfig { + /// LLM provider configuration for fact extraction. + pub llm_config: LlmConfig, + /// Embedding provider configuration for vector search. + pub embedding_config: EmbeddingConfig, + /// Path to the SQLite database file (`:memory:` for in-memory). + pub db_path: String, + /// Minimum cosine-similarity threshold for search results (0.0 .. 1.0). + pub similarity_threshold: f64, + /// Maximum tokens for the extraction LLM call. + pub max_extraction_tokens: u32, + /// Temperature for the extraction LLM call. + pub extraction_temperature: f32, +} + +impl MemoryConfig { + /// Create a new config with sensible defaults. + pub fn new(llm_config: LlmConfig, embedding_config: EmbeddingConfig) -> Self { + Self { + llm_config, + embedding_config, + db_path: "graphbit_memory.db".to_string(), + similarity_threshold: 0.7, + max_extraction_tokens: 1500, + extraction_temperature: 0.1, + } + } + + /// Override the database path. + pub fn with_db_path(mut self, db_path: impl Into) -> Self { + self.db_path = db_path.into(); + self + } + + /// Override the similarity threshold. + pub fn with_similarity_threshold(mut self, threshold: f64) -> Self { + self.similarity_threshold = threshold.clamp(0.0, 1.0); + self + } + + /// Override the max extraction tokens. + pub fn with_max_extraction_tokens(mut self, tokens: u32) -> Self { + self.max_extraction_tokens = tokens; + self + } + + /// Override the extraction temperature. + pub fn with_extraction_temperature(mut self, temperature: f32) -> Self { + self.extraction_temperature = temperature.clamp(0.0, 1.0); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::embeddings::EmbeddingConfig; + use crate::llm::LlmConfig; + + #[test] + fn test_memory_id_creation() { + let id1 = MemoryId::new(); + let id2 = MemoryId::new(); + assert_ne!(id1, id2, "Two fresh IDs should be unique"); + + let id_str = id1.to_string(); + let parsed = MemoryId::from_string(&id_str).expect("should parse valid UUID"); + assert_eq!(id1, parsed); + } + + #[test] + fn test_memory_id_from_invalid_string() { + let result = MemoryId::from_string("not-a-uuid"); + assert!(result.is_err()); + } + + #[test] + fn test_memory_scope_default() { + let scope = MemoryScope::default(); + assert!(scope.user_id.is_none()); + assert!(scope.agent_id.is_none()); + assert!(scope.run_id.is_none()); + } + + #[test] + fn test_memory_action_display() { + assert_eq!(MemoryAction::Add.to_string(), "ADD"); + assert_eq!(MemoryAction::Update.to_string(), "UPDATE"); + assert_eq!(MemoryAction::Delete.to_string(), "DELETE"); + assert_eq!(MemoryAction::Noop.to_string(), "NOOP"); + } + + #[test] + fn test_memory_action_from_str_lossy() { + assert_eq!(MemoryAction::from_str_lossy("ADD"), MemoryAction::Add); + assert_eq!(MemoryAction::from_str_lossy("add"), MemoryAction::Add); + assert_eq!(MemoryAction::from_str_lossy("UPDATE"), MemoryAction::Update); + assert_eq!(MemoryAction::from_str_lossy("DELETE"), MemoryAction::Delete); + assert_eq!(MemoryAction::from_str_lossy("unknown"), MemoryAction::Noop); + } + + #[test] + fn test_memory_config_builder() { + let llm_config = LlmConfig::openai("test-key", "gpt-4o-mini"); + let embedding_config = EmbeddingConfig { + provider: crate::embeddings::EmbeddingProvider::OpenAI, + api_key: "test-key".to_string(), + model: "text-embedding-3-small".to_string(), + base_url: None, + timeout_seconds: None, + max_batch_size: None, + extra_params: HashMap::new(), + #[cfg(feature = "python")] + python_instance: None, + }; + + let config = MemoryConfig::new(llm_config, embedding_config) + .with_db_path(":memory:") + .with_similarity_threshold(0.8) + .with_max_extraction_tokens(2000) + .with_extraction_temperature(0.2); + + assert_eq!(config.db_path, ":memory:"); + assert!((config.similarity_threshold - 0.8).abs() < f64::EPSILON); + assert_eq!(config.max_extraction_tokens, 2000); + assert!((config.extraction_temperature - 0.2).abs() < f32::EPSILON); + } +} diff --git a/core/src/memory/vector.rs b/core/src/memory/vector.rs new file mode 100644 index 00000000..248af6ee --- /dev/null +++ b/core/src/memory/vector.rs @@ -0,0 +1,193 @@ +//! In-memory vector index for semantic search over memories. + +use tokio::sync::RwLock; + +use crate::embeddings::EmbeddingService; +use crate::errors::GraphBitResult; + +use super::types::MemoryId; + +/// A single entry in the vector index. +#[derive(Debug, Clone)] +struct VectorEntry { + memory_id: MemoryId, + embedding: Vec, +} + +/// In-memory vector index backed by brute-force cosine similarity. +/// +/// Suitable for moderate memory counts (thousands). For larger datasets a +/// purpose-built ANN index should replace this implementation. +pub struct VectorIndex { + entries: RwLock>, +} + +impl VectorIndex { + /// Create a new, empty vector index. + pub fn new() -> Self { + Self { + entries: RwLock::new(Vec::new()), + } + } + + /// Insert an embedding for the given memory. + pub async fn insert(&self, memory_id: MemoryId, embedding: Vec) { + let mut entries = self.entries.write().await; + entries.push(VectorEntry { + memory_id, + embedding, + }); + } + + /// Search for the `top_k` most similar entries to `query_embedding`, + /// returning `(MemoryId, similarity_score)` pairs above `threshold`. + pub async fn search( + &self, + query_embedding: &[f32], + top_k: usize, + threshold: f64, + ) -> GraphBitResult> { + let entries = self.entries.read().await; + + let mut scored: Vec<(MemoryId, f64)> = entries + .iter() + .filter_map(|entry| { + let sim = EmbeddingService::cosine_similarity(query_embedding, &entry.embedding) + .ok()?; + let sim_f64 = f64::from(sim); + if sim_f64 >= threshold { + Some((entry.memory_id.clone(), sim_f64)) + } else { + None + } + }) + .collect(); + + // Sort descending by score. + scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.truncate(top_k); + + Ok(scored) + } + + /// Remove entries for a specific memory. + pub async fn remove(&self, memory_id: &MemoryId) { + let mut entries = self.entries.write().await; + entries.retain(|e| &e.memory_id != memory_id); + } + + /// Replace the embedding for an existing memory. + pub async fn update(&self, memory_id: &MemoryId, embedding: Vec) { + let mut entries = self.entries.write().await; + if let Some(entry) = entries.iter_mut().find(|e| &e.memory_id == memory_id) { + entry.embedding = embedding; + } else { + entries.push(VectorEntry { + memory_id: memory_id.clone(), + embedding, + }); + } + } + + /// Remove all entries from the index. + pub async fn clear(&self) { + let mut entries = self.entries.write().await; + entries.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_vector_index_operations() { + let index = VectorIndex::new(); + + let id1 = MemoryId::new(); + let id2 = MemoryId::new(); + + // Insert two vectors + index.insert(id1.clone(), vec![1.0, 0.0, 0.0]).await; + index.insert(id2.clone(), vec![0.0, 1.0, 0.0]).await; + + // Search with a vector close to id1 + let results = index + .search(&[0.9, 0.1, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert_eq!(results.len(), 2); + // The first result should be closer to id1 + assert_eq!(results[0].0, id1); + assert!(results[0].1 > results[1].1); + + // Remove id1 + index.remove(&id1).await; + let results = index + .search(&[1.0, 0.0, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, id2); + } + + #[tokio::test] + async fn test_vector_index_threshold() { + let index = VectorIndex::new(); + + let id1 = MemoryId::new(); + index.insert(id1.clone(), vec![1.0, 0.0, 0.0]).await; + + // Orthogonal vector should have ~0 similarity + let results = index + .search(&[0.0, 1.0, 0.0], 10, 0.5) + .await + .expect("search ok"); + assert!( + results.is_empty(), + "Orthogonal vector should be below threshold 0.5" + ); + + // Identical vector should have similarity 1.0 + let results = index + .search(&[1.0, 0.0, 0.0], 10, 0.99) + .await + .expect("search ok"); + assert_eq!(results.len(), 1); + assert!((results[0].1 - 1.0).abs() < 0.01); + } + + #[tokio::test] + async fn test_vector_index_update() { + let index = VectorIndex::new(); + let id = MemoryId::new(); + + index.insert(id.clone(), vec![1.0, 0.0, 0.0]).await; + + // Update embedding + index.update(&id, vec![0.0, 1.0, 0.0]).await; + + // Now id should be similar to [0, 1, 0] rather than [1, 0, 0] + let results = index + .search(&[0.0, 1.0, 0.0], 10, 0.5) + .await + .expect("search ok"); + assert_eq!(results.len(), 1); + assert!((results[0].1 - 1.0).abs() < 0.01); + } + + #[tokio::test] + async fn test_vector_index_clear() { + let index = VectorIndex::new(); + index.insert(MemoryId::new(), vec![1.0, 0.0]).await; + index.insert(MemoryId::new(), vec![0.0, 1.0]).await; + + index.clear().await; + + let results = index + .search(&[1.0, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert!(results.is_empty()); + } +} From 99f372fbb4631fa04ce9d16ad6b728fa59b76f70 Mon Sep 17 00:00:00 2001 From: humayrakhanom Date: Mon, 16 Feb 2026 02:36:15 -0600 Subject: [PATCH 2/2] [Feature] Add SQLite metadata store for memory persistence Implement MetadataStore with CRUD operations, scope-based filtering, and memory mutation history tracking backed by SQLite. Co-Authored-By: Claude Opus 4.6 --- core/src/memory/mod.rs | 1 + core/src/memory/store.rs | 524 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 525 insertions(+) create mode 100644 core/src/memory/store.rs diff --git a/core/src/memory/mod.rs b/core/src/memory/mod.rs index 4b7142a2..bfc3de14 100644 --- a/core/src/memory/mod.rs +++ b/core/src/memory/mod.rs @@ -3,6 +3,7 @@ //! Provides LLM-driven fact extraction from conversations, vector-based semantic //! search, SQLite-backed persistent storage, and scoped memory isolation. +pub mod store; pub mod types; pub mod vector; diff --git a/core/src/memory/store.rs b/core/src/memory/store.rs new file mode 100644 index 00000000..daaed8d0 --- /dev/null +++ b/core/src/memory/store.rs @@ -0,0 +1,524 @@ +//! SQLite-backed metadata store for persistent memory storage. + +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::errors::{GraphBitError, GraphBitResult}; + +use super::types::{Memory, MemoryAction, MemoryHistory, MemoryId, MemoryScope}; + +/// Persistent metadata store backed by SQLite. +/// +/// The connection is wrapped in `Arc` so that it can be shared across +/// async tasks and moved into `spawn_blocking` closures. +pub struct MetadataStore { + conn: Arc>, +} + +impl MetadataStore { + /// Open (or create) the database at `db_path`. + /// Pass `":memory:"` for an in-memory database. + pub fn new(db_path: &str) -> GraphBitResult { + let conn = rusqlite::Connection::open(db_path)?; + // Enable foreign key enforcement (must be set per-connection). + conn.execute("PRAGMA foreign_keys = ON", [])?; + let store = Self { + conn: Arc::new(Mutex::new(conn)), + }; + store.init_schema_sync()?; + Ok(store) + } + + /// Create the required tables if they do not already exist. + fn init_schema_sync(&self) -> GraphBitResult<()> { + let conn = self + .conn + .try_lock() + .map_err(|_| GraphBitError::memory("Failed to acquire database lock during init"))?; + + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + user_id TEXT, + agent_id TEXT, + run_id TEXT, + hash TEXT NOT NULL, + metadata TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS memory_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + old_content TEXT NOT NULL DEFAULT '', + new_content TEXT NOT NULL DEFAULT '', + action TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY (memory_id) REFERENCES memories(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id); + CREATE INDEX IF NOT EXISTS idx_memories_agent_id ON memories(agent_id); + CREATE INDEX IF NOT EXISTS idx_memories_run_id ON memories(run_id); + CREATE INDEX IF NOT EXISTS idx_memories_hash ON memories(hash); + CREATE INDEX IF NOT EXISTS idx_history_memory_id ON memory_history(memory_id);", + )?; + + Ok(()) + } + + /// Insert a new memory. + pub async fn insert_memory(&self, memory: &Memory) -> GraphBitResult<()> { + let id = memory.id.to_string(); + let content = memory.content.clone(); + let user_id = memory.scope.user_id.clone(); + let agent_id = memory.scope.agent_id.clone(); + let run_id = memory.scope.run_id.clone(); + let hash = memory.hash.clone(); + let metadata = serde_json::to_string(&memory.metadata)?; + let created_at = memory.created_at.to_rfc3339(); + let updated_at = memory.updated_at.to_rfc3339(); + + let conn_arc = Arc::clone(&self.conn); + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + conn.execute( + "INSERT INTO memories (id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + rusqlite::params![id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at], + )?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Get a single memory by ID. + pub async fn get_memory(&self, memory_id: &MemoryId) -> GraphBitResult> { + let id = memory_id.to_string(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult> { + let conn = conn_arc.blocking_lock(); + let mut stmt = + conn.prepare("SELECT id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at FROM memories WHERE id = ?1")?; + let mut rows = stmt.query(rusqlite::params![id])?; + if let Some(row) = rows.next()? { + Ok(Some(row_to_memory(row)?)) + } else { + Ok(None) + } + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Get all memories matching the given scope. + pub async fn get_all_memories(&self, scope: &MemoryScope) -> GraphBitResult> { + let user_id = scope.user_id.clone(); + let agent_id = scope.agent_id.clone(); + let run_id = scope.run_id.clone(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult> { + let conn = conn_arc.blocking_lock(); + let (where_clause, params) = build_scope_filter(&user_id, &agent_id, &run_id); + let sql = format!( + "SELECT id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at FROM memories{}", + where_clause + ); + let mut stmt = conn.prepare(&sql)?; + let param_refs: Vec<&dyn rusqlite::types::ToSql> = + params.iter().map(|p| p as &dyn rusqlite::types::ToSql).collect(); + let mut rows = stmt.query(param_refs.as_slice())?; + let mut memories = Vec::new(); + while let Some(row) = rows.next()? { + memories.push(row_to_memory(row)?); + } + Ok(memories) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Update content and metadata of an existing memory. + pub async fn update_memory( + &self, + memory_id: &MemoryId, + content: &str, + hash: &str, + ) -> GraphBitResult<()> { + let id = memory_id.to_string(); + let content = content.to_string(); + let hash = hash.to_string(); + let updated_at = Utc::now().to_rfc3339(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + let changed = conn.execute( + "UPDATE memories SET content = ?1, hash = ?2, updated_at = ?3 WHERE id = ?4", + rusqlite::params![content, hash, updated_at, id], + )?; + if changed == 0 { + return Err(GraphBitError::memory(format!( + "Memory not found: {id}" + ))); + } + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Delete a single memory by ID. + pub async fn delete_memory(&self, memory_id: &MemoryId) -> GraphBitResult<()> { + let id = memory_id.to_string(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + conn.execute("DELETE FROM memories WHERE id = ?1", rusqlite::params![id])?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Delete all memories matching the given scope. + pub async fn delete_all_memories(&self, scope: &MemoryScope) -> GraphBitResult<()> { + let user_id = scope.user_id.clone(); + let agent_id = scope.agent_id.clone(); + let run_id = scope.run_id.clone(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + let (where_clause, params) = build_scope_filter(&user_id, &agent_id, &run_id); + let sql = format!("DELETE FROM memories{where_clause}"); + let param_refs: Vec<&dyn rusqlite::types::ToSql> = + params.iter().map(|p| p as &dyn rusqlite::types::ToSql).collect(); + conn.execute(&sql, param_refs.as_slice())?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Record a history entry for a memory mutation. + pub async fn insert_history(&self, history: &MemoryHistory) -> GraphBitResult<()> { + let memory_id = history.memory_id.to_string(); + let old_content = history.old_content.clone(); + let new_content = history.new_content.clone(); + let action = history.action.to_string(); + let timestamp = history.timestamp.to_rfc3339(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + conn.execute( + "INSERT INTO memory_history (memory_id, old_content, new_content, action, timestamp) + VALUES (?1, ?2, ?3, ?4, ?5)", + rusqlite::params![memory_id, old_content, new_content, action, timestamp], + )?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Get the full history for a specific memory. + pub async fn get_history(&self, memory_id: &MemoryId) -> GraphBitResult> { + let id = memory_id.to_string(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult> { + let conn = conn_arc.blocking_lock(); + let mut stmt = conn.prepare( + "SELECT memory_id, old_content, new_content, action, timestamp + FROM memory_history WHERE memory_id = ?1 ORDER BY timestamp ASC", + )?; + let mut rows = stmt.query(rusqlite::params![id])?; + let mut entries = Vec::new(); + while let Some(row) = rows.next()? { + entries.push(row_to_history(row)?); + } + Ok(entries) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn row_to_memory(row: &rusqlite::Row<'_>) -> GraphBitResult { + let id_str: String = row.get(0)?; + let content: String = row.get(1)?; + let user_id: Option = row.get(2)?; + let agent_id: Option = row.get(3)?; + let run_id: Option = row.get(4)?; + let hash: String = row.get(5)?; + let metadata_json: String = row.get(6)?; + let created_at_str: String = row.get(7)?; + let updated_at_str: String = row.get(8)?; + + let id = MemoryId(Uuid::parse_str(&id_str).map_err(|e| { + GraphBitError::memory(format!("Invalid UUID in database: {e}")) + })?); + + let metadata: HashMap = + serde_json::from_str(&metadata_json).unwrap_or_default(); + + let created_at = DateTime::parse_from_rfc3339(&created_at_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + + let updated_at = DateTime::parse_from_rfc3339(&updated_at_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + + Ok(Memory { + id, + content, + scope: MemoryScope { + user_id, + agent_id, + run_id, + }, + metadata, + created_at, + updated_at, + hash, + }) +} + +fn row_to_history(row: &rusqlite::Row<'_>) -> GraphBitResult { + let memory_id_str: String = row.get(0)?; + let old_content: String = row.get(1)?; + let new_content: String = row.get(2)?; + let action_str: String = row.get(3)?; + let timestamp_str: String = row.get(4)?; + + let memory_id = MemoryId(Uuid::parse_str(&memory_id_str).map_err(|e| { + GraphBitError::memory(format!("Invalid UUID in history: {e}")) + })?); + + let timestamp = DateTime::parse_from_rfc3339(×tamp_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + + Ok(MemoryHistory { + memory_id, + old_content, + new_content, + action: MemoryAction::from_str_lossy(&action_str), + timestamp, + }) +} + +/// Build a SQL WHERE clause + params from optional scope fields. +fn build_scope_filter( + user_id: &Option, + agent_id: &Option, + run_id: &Option, +) -> (String, Vec) { + let mut conditions = Vec::new(); + let mut params = Vec::new(); + + if let Some(uid) = user_id { + params.push(uid.clone()); + conditions.push(format!("user_id = ?{}", params.len())); + } + if let Some(aid) = agent_id { + params.push(aid.clone()); + conditions.push(format!("agent_id = ?{}", params.len())); + } + if let Some(rid) = run_id { + params.push(rid.clone()); + conditions.push(format!("run_id = ?{}", params.len())); + } + + if conditions.is_empty() { + (String::new(), params) + } else { + (format!(" WHERE {}", conditions.join(" AND ")), params) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_metadata_store_crud() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + // Insert + let id = MemoryId::new(); + let memory = Memory { + id: id.clone(), + content: "User lives in Munich".to_string(), + scope: MemoryScope { + user_id: Some("user1".to_string()), + agent_id: None, + run_id: None, + }, + metadata: HashMap::new(), + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + hash: "abc123".to_string(), + }; + store + .insert_memory(&memory) + .await + .expect("insert should succeed"); + + // Get + let fetched = store + .get_memory(&id) + .await + .expect("get should succeed") + .expect("memory should exist"); + assert_eq!(fetched.content, "User lives in Munich"); + assert_eq!(fetched.scope.user_id.as_deref(), Some("user1")); + + // Update + store + .update_memory(&id, "User lives in Berlin", "def456") + .await + .expect("update should succeed"); + let updated = store + .get_memory(&id) + .await + .expect("get should succeed") + .expect("memory should exist"); + assert_eq!(updated.content, "User lives in Berlin"); + + // Delete + store + .delete_memory(&id) + .await + .expect("delete should succeed"); + let gone = store + .get_memory(&id) + .await + .expect("get should succeed"); + assert!(gone.is_none()); + } + + #[tokio::test] + async fn test_metadata_store_scope_filtering() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + // Insert memories for different users + for (user, content) in &[("alice", "Fact A"), ("bob", "Fact B"), ("alice", "Fact C")] { + let memory = Memory { + id: MemoryId::new(), + content: content.to_string(), + scope: MemoryScope { + user_id: Some(user.to_string()), + agent_id: None, + run_id: None, + }, + metadata: HashMap::new(), + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + hash: format!("hash_{content}"), + }; + store.insert_memory(&memory).await.expect("insert ok"); + } + + // Filter by Alice + let alice_scope = MemoryScope { + user_id: Some("alice".to_string()), + ..Default::default() + }; + let alice_memories = store + .get_all_memories(&alice_scope) + .await + .expect("get_all ok"); + assert_eq!(alice_memories.len(), 2); + + // Filter by Bob + let bob_scope = MemoryScope { + user_id: Some("bob".to_string()), + ..Default::default() + }; + let bob_memories = store + .get_all_memories(&bob_scope) + .await + .expect("get_all ok"); + assert_eq!(bob_memories.len(), 1); + + // No filter (all) + let all = store + .get_all_memories(&MemoryScope::default()) + .await + .expect("get_all ok"); + assert_eq!(all.len(), 3); + + // Delete all for Alice + store + .delete_all_memories(&alice_scope) + .await + .expect("delete_all ok"); + let remaining = store + .get_all_memories(&MemoryScope::default()) + .await + .expect("get_all ok"); + assert_eq!(remaining.len(), 1); + assert_eq!(remaining[0].scope.user_id.as_deref(), Some("bob")); + } + + #[tokio::test] + async fn test_metadata_store_history() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + let id = MemoryId::new(); + + // Insert a parent memory so the FK constraint is satisfied. + let memory = Memory { + id: id.clone(), + content: "Initial".to_string(), + scope: MemoryScope::default(), + metadata: HashMap::new(), + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + hash: "h".to_string(), + }; + store.insert_memory(&memory).await.expect("insert ok"); + + store + .insert_history(&MemoryHistory { + memory_id: id.clone(), + old_content: String::new(), + new_content: "First version".to_string(), + action: MemoryAction::Add, + timestamp: chrono::Utc::now(), + }) + .await + .expect("insert_history ok"); + + store + .insert_history(&MemoryHistory { + memory_id: id.clone(), + old_content: "First version".to_string(), + new_content: "Second version".to_string(), + action: MemoryAction::Update, + timestamp: chrono::Utc::now(), + }) + .await + .expect("insert_history ok"); + + let history = store.get_history(&id).await.expect("get_history ok"); + assert_eq!(history.len(), 2); + assert_eq!(history[0].action, MemoryAction::Add); + assert_eq!(history[1].action, MemoryAction::Update); + } +}