diff --git a/Cargo.lock b/Cargo.lock index e574f6d4..6d37e03e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -528,6 +528,18 @@ dependencies = [ "num-traits", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -801,6 +813,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest", + "rusqlite", "scraper", "serde", "serde_json", @@ -850,12 +863,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.5.0" @@ -1107,7 +1138,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.5", ] [[package]] @@ -1207,6 +1238,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2066,6 +2108,20 @@ dependencies = [ "winreg", ] +[[package]] +name = "rusqlite" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +dependencies = [ + "bitflags 2.9.1", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc-demangle" version = "0.1.26" diff --git a/Cargo.toml b/Cargo.toml index b6714ed9..9c0dabb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,6 +131,7 @@ quick-xml = {version = "0.36", features = ["serialize"]} rand = "0.8" regex = "1.10" reqwest = {version = "0.11", features = ["json", "stream"]} +rusqlite = { version = "0.31", features = ["bundled"] } # HTML parsing scraper = "0.20" serde = {version = "1.0", features = ["derive"]} diff --git a/core/Cargo.toml b/core/Cargo.toml index b516aa3b..c37cb915 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -15,6 +15,7 @@ quick-xml.workspace = true rand.workspace = true regex.workspace = true reqwest.workspace = true +rusqlite.workspace = true scraper.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/core/src/errors.rs b/core/src/errors.rs index ed250557..08878c98 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -126,6 +126,13 @@ pub enum GraphBitError { /// Error message message: String, }, + + /// Memory layer errors + #[error("Memory error: {message}")] + Memory { + /// Error message + message: String, + }, } impl GraphBitError { @@ -211,6 +218,13 @@ impl GraphBitError { } } + /// Create a new memory layer error + pub fn memory(message: impl Into) -> Self { + Self::Memory { + message: message.into(), + } + } + /// Check if the error is retryable pub fn is_retryable(&self) -> bool { matches!( @@ -269,3 +283,11 @@ impl From for GraphBitError { } } } + +impl From for GraphBitError { + fn from(error: rusqlite::Error) -> Self { + Self::Internal { + message: format!("SQLite error: {error}"), + } + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index 4c20d0b0..890bbfdc 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -36,6 +36,7 @@ pub mod embeddings; pub mod errors; pub mod graph; pub mod llm; +pub mod memory; pub mod text_splitter; pub mod types; pub mod validation; @@ -50,6 +51,9 @@ pub use embeddings::{ pub use errors::{GraphBitError, GraphBitResult}; pub use graph::{NodeType, WorkflowEdge, WorkflowGraph, WorkflowNode}; pub use llm::{LlmConfig, LlmProvider, LlmResponse}; +pub use memory::{ + Memory, MemoryConfig, MemoryHistory, MemoryId, MemoryScope, MemoryService, ScoredMemory, +}; pub use text_splitter::{ CharacterSplitter, RecursiveSplitter, SentenceSplitter, SplitterStrategy, TextChunk, TextSplitterConfig, TextSplitterFactory, TextSplitterTrait, TokenSplitter, diff --git a/core/src/memory/mod.rs b/core/src/memory/mod.rs new file mode 100644 index 00000000..d909183c --- /dev/null +++ b/core/src/memory/mod.rs @@ -0,0 +1,16 @@ +//! Memory layer for `GraphBit` +//! +//! Provides LLM-driven fact extraction from conversations, vector-based semantic +//! search, SQLite-backed persistent storage, and scoped memory isolation. + +pub mod processor; +pub mod service; +pub mod store; +pub mod types; +pub mod vector; + +pub use service::MemoryService; +pub use types::{ + Memory, MemoryAction, MemoryConfig, MemoryDecision, MemoryHistory, MemoryId, MemoryScope, + ScoredMemory, +}; diff --git a/core/src/memory/processor.rs b/core/src/memory/processor.rs new file mode 100644 index 00000000..adc9759f --- /dev/null +++ b/core/src/memory/processor.rs @@ -0,0 +1,202 @@ +//! LLM-driven fact extraction and consolidation logic. + +use crate::errors::{GraphBitError, GraphBitResult}; +use crate::llm::{LlmMessage, LlmProviderTrait, LlmRequest}; + +use super::types::{Memory, MemoryAction, MemoryDecision}; + +/// Handles sending conversations to an LLM for fact extraction and +/// deciding how new facts relate to existing memories. +pub struct MemoryProcessor { + llm_provider: Box, + max_tokens: u32, + temperature: f32, +} + +impl MemoryProcessor { + /// Create a new processor wrapping the given LLM provider. + pub fn new( + llm_provider: Box, + max_tokens: u32, + temperature: f32, + ) -> Self { + Self { + llm_provider, + max_tokens, + temperature, + } + } + + /// Extract discrete facts from a list of conversation messages. + /// + /// Returns a `Vec` of facts parsed from the LLM's JSON response. + pub async fn extract_facts(&self, messages: &[LlmMessage]) -> GraphBitResult> { + if messages.is_empty() { + return Ok(Vec::new()); + } + + let conversation = messages + .iter() + .map(|m| format!("{}: {}", role_label(&m.role), &m.content)) + .collect::>() + .join("\n"); + + let system_prompt = concat!( + "You are a memory extraction assistant. Your task is to extract important facts, ", + "preferences, and information from the conversation that would be useful to remember ", + "for future interactions.\n\n", + "Rules:\n", + "- Extract only factual, specific information (not greetings or filler).\n", + "- Each fact should be a single, self-contained sentence.\n", + "- Do not duplicate facts.\n", + "- If no meaningful facts exist, return an empty array.\n\n", + "Return a JSON array of strings. Example: [\"User lives in Munich\", \"User prefers dark mode\"]", + ); + + let request = LlmRequest::with_messages(vec![ + LlmMessage::system(system_prompt), + LlmMessage::user(format!("Extract facts from this conversation:\n\n{conversation}")), + ]) + .with_max_tokens(self.max_tokens) + .with_temperature(self.temperature); + + let response = self.llm_provider.complete(request).await.map_err(|e| { + GraphBitError::memory(format!("Fact extraction LLM call failed: {e}")) + })?; + + parse_json_string_array(&response.content) + } + + /// Given extracted facts and existing memories, ask the LLM to decide + /// whether each fact should be added, used to update an existing memory, + /// delete an existing memory, or be ignored. + pub async fn decide_actions( + &self, + facts: &[String], + existing_memories: &[Memory], + ) -> GraphBitResult> { + if facts.is_empty() { + return Ok(Vec::new()); + } + + let facts_list = facts + .iter() + .enumerate() + .map(|(i, f)| format!("{}. {f}", i + 1)) + .collect::>() + .join("\n"); + + let memories_list = if existing_memories.is_empty() { + "No existing memories.".to_string() + } else { + existing_memories + .iter() + .map(|m| format!("ID: {} | Content: {}", m.id, m.content)) + .collect::>() + .join("\n") + }; + + let system_prompt = concat!( + "You are a memory management assistant. Given new facts and existing memories, ", + "decide what action to take for each fact.\n\n", + "Actions:\n", + "- ADD: The fact is new information not captured by any existing memory.\n", + "- UPDATE: The fact refines or corrects an existing memory. Provide the target memory ID.\n", + "- DELETE: The fact contradicts or invalidates an existing memory. Provide the target memory ID.\n", + "- NOOP: The fact is already captured or is not worth storing.\n\n", + "Return a JSON array of objects with keys: \"fact\", \"action\", \"target_memory_id\" (null if ADD/NOOP).\n", + "Example: [{\"fact\":\"User lives in Berlin\",\"action\":\"UPDATE\",\"target_memory_id\":\"\"}]", + ); + + let request = LlmRequest::with_messages(vec![ + LlmMessage::system(system_prompt), + LlmMessage::user(format!( + "New facts:\n{facts_list}\n\nExisting memories:\n{memories_list}" + )), + ]) + .with_max_tokens(self.max_tokens) + .with_temperature(self.temperature); + + let response = self.llm_provider.complete(request).await.map_err(|e| { + GraphBitError::memory(format!("Decision LLM call failed: {e}")) + })?; + + parse_decisions(&response.content) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn role_label(role: &crate::llm::LlmRole) -> &'static str { + match role { + crate::llm::LlmRole::User => "User", + crate::llm::LlmRole::Assistant => "Assistant", + crate::llm::LlmRole::System => "System", + crate::llm::LlmRole::Tool => "Tool", + } +} + +/// Parse a JSON array of strings from potentially messy LLM output. +fn parse_json_string_array(text: &str) -> GraphBitResult> { + // Try to find the JSON array in the response. + let trimmed = text.trim(); + + // First try direct parse. + if let Ok(arr) = serde_json::from_str::>(trimmed) { + return Ok(arr); + } + + // Try to extract the first JSON array from the text. + if let Some(start) = trimmed.find('[') { + if let Some(end) = trimmed.rfind(']') { + let slice = &trimmed[start..=end]; + if let Ok(arr) = serde_json::from_str::>(slice) { + return Ok(arr); + } + } + } + + // Fallback: return empty array if we can't parse. + Ok(Vec::new()) +} + +/// Parse the decision JSON from the LLM response. +fn parse_decisions(text: &str) -> GraphBitResult> { + let trimmed = text.trim(); + + // Try to find and parse the JSON array. + let json_str = if let Some(start) = trimmed.find('[') { + if let Some(end) = trimmed.rfind(']') { + &trimmed[start..=end] + } else { + trimmed + } + } else { + trimmed + }; + + let raw: Vec = serde_json::from_str(json_str).unwrap_or_default(); + + let decisions = raw + .into_iter() + .filter_map(|v| { + let fact = v.get("fact")?.as_str()?.to_string(); + let action_str = v.get("action")?.as_str()?; + let action = MemoryAction::from_str_lossy(action_str); + let target_memory_id = v + .get("target_memory_id") + .and_then(|t| t.as_str()) + .map(String::from); + + Some(MemoryDecision { + fact, + action, + target_memory_id, + }) + }) + .collect(); + + Ok(decisions) +} diff --git a/core/src/memory/service.rs b/core/src/memory/service.rs new file mode 100644 index 00000000..8557c7c5 --- /dev/null +++ b/core/src/memory/service.rs @@ -0,0 +1,326 @@ +//! `MemoryService` -- the primary public API that orchestrates the full +//! memory pipeline: fact extraction, embedding, vector search, LLM-driven +//! deduplication, and persistent storage. + +use std::collections::HashMap; + +use chrono::Utc; + +use crate::embeddings::EmbeddingService; +use crate::errors::{GraphBitError, GraphBitResult}; +use crate::llm::{LlmMessage, LlmProviderFactory}; + +use super::processor::MemoryProcessor; +use super::store::MetadataStore; +use super::types::{ + Memory, MemoryAction, MemoryConfig, MemoryHistory, MemoryId, MemoryScope, ScoredMemory, +}; +use super::vector::VectorIndex; + +/// Orchestrates the full memory pipeline. +pub struct MemoryService { + store: MetadataStore, + vector_index: VectorIndex, + embedding_service: EmbeddingService, + processor: MemoryProcessor, + config: MemoryConfig, +} + +impl MemoryService { + /// Build a new `MemoryService` from the provided configuration. + /// + /// This creates the SQLite store, vector index, embedding service, and + /// LLM processor, then loads any existing memories into the vector index. + pub async fn new(config: MemoryConfig) -> GraphBitResult { + let store = MetadataStore::new(&config.db_path)?; + let vector_index = VectorIndex::new(); + let embedding_service = EmbeddingService::new(config.embedding_config.clone())?; + let llm_provider = LlmProviderFactory::create_provider(config.llm_config.clone())?; + let processor = MemoryProcessor::new( + llm_provider, + config.max_extraction_tokens, + config.extraction_temperature, + ); + + let service = Self { + store, + vector_index, + embedding_service, + processor, + config, + }; + + // Load existing memories into the vector index. + service.load_existing_memories().await?; + + Ok(service) + } + + /// Extract facts from `messages`, embed them, decide actions against + /// existing memories, and persist the results. Returns newly created + /// or updated memories. + pub async fn add( + &self, + messages: &[LlmMessage], + scope: &MemoryScope, + ) -> GraphBitResult> { + // Phase 1: extract facts. + let facts = self.processor.extract_facts(messages).await?; + if facts.is_empty() { + return Ok(Vec::new()); + } + + // Phase 2: get existing memories for this scope for deduplication. + let existing = self.store.get_all_memories(scope).await?; + let decisions = self.processor.decide_actions(&facts, &existing).await?; + + let mut result_memories = Vec::new(); + + for decision in &decisions { + match decision.action { + MemoryAction::Add => { + let memory = self + .create_memory(&decision.fact, scope.clone()) + .await?; + result_memories.push(memory); + } + MemoryAction::Update => { + if let Some(ref target_id_str) = decision.target_memory_id { + if let Ok(target_id) = MemoryId::from_string(target_id_str) { + if let Some(old_memory) = self.store.get_memory(&target_id).await? { + let updated = self + .update_memory_internal( + &target_id, + &decision.fact, + &old_memory.content, + ) + .await?; + result_memories.push(updated); + } + } + } + } + MemoryAction::Delete => { + if let Some(ref target_id_str) = decision.target_memory_id { + if let Ok(target_id) = MemoryId::from_string(target_id_str) { + if let Some(old_memory) = self.store.get_memory(&target_id).await? { + self.delete_memory_internal(&target_id, &old_memory.content) + .await?; + } + } + } + } + MemoryAction::Noop => {} + } + } + + Ok(result_memories) + } + + /// Embed a query and search for the most similar memories within a scope. + pub async fn search( + &self, + query: &str, + scope: &MemoryScope, + top_k: usize, + ) -> GraphBitResult> { + let query_embedding = self.embedding_service.embed_text(query).await?; + let results = self + .vector_index + .search(&query_embedding, top_k, self.config.similarity_threshold) + .await?; + + let mut scored = Vec::new(); + for (memory_id, score) in results { + if let Some(memory) = self.store.get_memory(&memory_id).await? { + if matches_scope(&memory.scope, scope) { + scored.push(ScoredMemory { memory, score }); + } + } + } + + Ok(scored) + } + + /// Get a single memory by its ID. + pub async fn get(&self, memory_id: &MemoryId) -> GraphBitResult> { + self.store.get_memory(memory_id).await + } + + /// Get all memories matching a scope. + pub async fn get_all(&self, scope: &MemoryScope) -> GraphBitResult> { + self.store.get_all_memories(scope).await + } + + /// Update a memory's content by ID. + pub async fn update(&self, memory_id: &MemoryId, content: &str) -> GraphBitResult { + let old = self + .store + .get_memory(memory_id) + .await? + .ok_or_else(|| GraphBitError::memory(format!("Memory not found: {memory_id}")))?; + + self.update_memory_internal(memory_id, content, &old.content) + .await + } + + /// Delete a single memory by its ID. + pub async fn delete(&self, memory_id: &MemoryId) -> GraphBitResult<()> { + let old = self.store.get_memory(memory_id).await?; + let old_content = old.map(|m| m.content).unwrap_or_default(); + self.delete_memory_internal(memory_id, &old_content).await + } + + /// Delete all memories matching a scope. + pub async fn delete_all(&self, scope: &MemoryScope) -> GraphBitResult<()> { + // Remove from vector index first. + let memories = self.store.get_all_memories(scope).await?; + for m in &memories { + self.vector_index.remove(&m.id).await; + } + self.store.delete_all_memories(scope).await + } + + /// Get the mutation history for a memory. + pub async fn history(&self, memory_id: &MemoryId) -> GraphBitResult> { + self.store.get_history(memory_id).await + } + + // ----------------------------------------------------------------------- + // Internal helpers + // ----------------------------------------------------------------------- + + /// Load all existing memories from the store into the vector index. + async fn load_existing_memories(&self) -> GraphBitResult<()> { + let all_memories = self.store.get_all_memories(&MemoryScope::default()).await?; + for memory in &all_memories { + let embedding = self.embedding_service.embed_text(&memory.content).await?; + self.vector_index + .insert(memory.id.clone(), embedding) + .await; + } + Ok(()) + } + + /// Create a brand-new memory: hash, embed, store, index, record history. + async fn create_memory( + &self, + content: &str, + scope: MemoryScope, + ) -> GraphBitResult { + let now = Utc::now(); + let hash = simple_hash(content); + let id = MemoryId::new(); + + let memory = Memory { + id: id.clone(), + content: content.to_string(), + scope, + metadata: HashMap::new(), + created_at: now, + updated_at: now, + hash, + }; + + self.store.insert_memory(&memory).await?; + + let embedding = self.embedding_service.embed_text(content).await?; + self.vector_index.insert(id.clone(), embedding).await; + + self.store + .insert_history(&MemoryHistory { + memory_id: id, + old_content: String::new(), + new_content: content.to_string(), + action: MemoryAction::Add, + timestamp: now, + }) + .await?; + + Ok(memory) + } + + /// Update an existing memory: re-hash, re-embed, persist, record history. + async fn update_memory_internal( + &self, + memory_id: &MemoryId, + new_content: &str, + old_content: &str, + ) -> GraphBitResult { + let hash = simple_hash(new_content); + self.store + .update_memory(memory_id, new_content, &hash) + .await?; + + let embedding = self.embedding_service.embed_text(new_content).await?; + self.vector_index.update(memory_id, embedding).await; + + self.store + .insert_history(&MemoryHistory { + memory_id: memory_id.clone(), + old_content: old_content.to_string(), + new_content: new_content.to_string(), + action: MemoryAction::Update, + timestamp: Utc::now(), + }) + .await?; + + self.store + .get_memory(memory_id) + .await? + .ok_or_else(|| GraphBitError::memory("Memory disappeared after update")) + } + + /// Delete a memory from store and index, record history. + async fn delete_memory_internal( + &self, + memory_id: &MemoryId, + old_content: &str, + ) -> GraphBitResult<()> { + // Record history BEFORE deleting (so FK constraint is satisfied). + // Note: with ON DELETE CASCADE, the history entry will be deleted + // along with the memory, so this serves as an audit log only if + // the cascade is disabled or history is preserved elsewhere. + // For now, we skip history on delete to avoid FK issues. + + // Delete from store (cascades to history) and index. + self.store.delete_memory(memory_id).await?; + self.vector_index.remove(memory_id).await; + + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Check if a memory's scope matches the filter scope. +/// `None` fields in the filter are treated as wildcards. +fn matches_scope(memory_scope: &MemoryScope, filter: &MemoryScope) -> bool { + if let Some(ref uid) = filter.user_id { + if memory_scope.user_id.as_ref() != Some(uid) { + return false; + } + } + if let Some(ref aid) = filter.agent_id { + if memory_scope.agent_id.as_ref() != Some(aid) { + return false; + } + } + if let Some(ref rid) = filter.run_id { + if memory_scope.run_id.as_ref() != Some(rid) { + return false; + } + } + true +} + +/// Produce a simple content hash for deduplication. +fn simple_hash(content: &str) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + content.hash(&mut hasher); + format!("{:x}", hasher.finish()) +} diff --git a/core/src/memory/store.rs b/core/src/memory/store.rs new file mode 100644 index 00000000..d64fd603 --- /dev/null +++ b/core/src/memory/store.rs @@ -0,0 +1,353 @@ +//! SQLite-backed metadata store for persistent memory storage. + +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::errors::{GraphBitError, GraphBitResult}; + +use super::types::{Memory, MemoryAction, MemoryHistory, MemoryId, MemoryScope}; + +/// Persistent metadata store backed by SQLite. +/// +/// The connection is wrapped in `Arc` so that it can be shared across +/// async tasks and moved into `spawn_blocking` closures. +pub struct MetadataStore { + conn: Arc>, +} + +impl MetadataStore { + /// Open (or create) the database at `db_path`. + /// Pass `":memory:"` for an in-memory database. + pub fn new(db_path: &str) -> GraphBitResult { + let conn = rusqlite::Connection::open(db_path)?; + // Enable foreign key enforcement (must be set per-connection). + conn.execute("PRAGMA foreign_keys = ON", [])?; + let store = Self { + conn: Arc::new(Mutex::new(conn)), + }; + store.init_schema_sync()?; + Ok(store) + } + + /// Create the required tables if they do not already exist. + fn init_schema_sync(&self) -> GraphBitResult<()> { + let conn = self + .conn + .try_lock() + .map_err(|_| GraphBitError::memory("Failed to acquire database lock during init"))?; + + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + user_id TEXT, + agent_id TEXT, + run_id TEXT, + hash TEXT NOT NULL, + metadata TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS memory_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id TEXT NOT NULL, + old_content TEXT NOT NULL DEFAULT '', + new_content TEXT NOT NULL DEFAULT '', + action TEXT NOT NULL, + timestamp TEXT NOT NULL, + FOREIGN KEY (memory_id) REFERENCES memories(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id); + CREATE INDEX IF NOT EXISTS idx_memories_agent_id ON memories(agent_id); + CREATE INDEX IF NOT EXISTS idx_memories_run_id ON memories(run_id); + CREATE INDEX IF NOT EXISTS idx_memories_hash ON memories(hash); + CREATE INDEX IF NOT EXISTS idx_history_memory_id ON memory_history(memory_id);", + )?; + + Ok(()) + } + + /// Insert a new memory. + pub async fn insert_memory(&self, memory: &Memory) -> GraphBitResult<()> { + let id = memory.id.to_string(); + let content = memory.content.clone(); + let user_id = memory.scope.user_id.clone(); + let agent_id = memory.scope.agent_id.clone(); + let run_id = memory.scope.run_id.clone(); + let hash = memory.hash.clone(); + let metadata = serde_json::to_string(&memory.metadata)?; + let created_at = memory.created_at.to_rfc3339(); + let updated_at = memory.updated_at.to_rfc3339(); + + let conn_arc = Arc::clone(&self.conn); + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + conn.execute( + "INSERT INTO memories (id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + rusqlite::params![id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at], + )?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Get a single memory by ID. + pub async fn get_memory(&self, memory_id: &MemoryId) -> GraphBitResult> { + let id = memory_id.to_string(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult> { + let conn = conn_arc.blocking_lock(); + let mut stmt = + conn.prepare("SELECT id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at FROM memories WHERE id = ?1")?; + let mut rows = stmt.query(rusqlite::params![id])?; + if let Some(row) = rows.next()? { + Ok(Some(row_to_memory(row)?)) + } else { + Ok(None) + } + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Get all memories matching the given scope. + pub async fn get_all_memories(&self, scope: &MemoryScope) -> GraphBitResult> { + let user_id = scope.user_id.clone(); + let agent_id = scope.agent_id.clone(); + let run_id = scope.run_id.clone(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult> { + let conn = conn_arc.blocking_lock(); + let (where_clause, params) = build_scope_filter(&user_id, &agent_id, &run_id); + let sql = format!( + "SELECT id, content, user_id, agent_id, run_id, hash, metadata, created_at, updated_at FROM memories{}", + where_clause + ); + let mut stmt = conn.prepare(&sql)?; + let param_refs: Vec<&dyn rusqlite::types::ToSql> = + params.iter().map(|p| p as &dyn rusqlite::types::ToSql).collect(); + let mut rows = stmt.query(param_refs.as_slice())?; + let mut memories = Vec::new(); + while let Some(row) = rows.next()? { + memories.push(row_to_memory(row)?); + } + Ok(memories) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Update content and metadata of an existing memory. + pub async fn update_memory( + &self, + memory_id: &MemoryId, + content: &str, + hash: &str, + ) -> GraphBitResult<()> { + let id = memory_id.to_string(); + let content = content.to_string(); + let hash = hash.to_string(); + let updated_at = Utc::now().to_rfc3339(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + let changed = conn.execute( + "UPDATE memories SET content = ?1, hash = ?2, updated_at = ?3 WHERE id = ?4", + rusqlite::params![content, hash, updated_at, id], + )?; + if changed == 0 { + return Err(GraphBitError::memory(format!( + "Memory not found: {id}" + ))); + } + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Delete a single memory by ID. + pub async fn delete_memory(&self, memory_id: &MemoryId) -> GraphBitResult<()> { + let id = memory_id.to_string(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + conn.execute("DELETE FROM memories WHERE id = ?1", rusqlite::params![id])?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Delete all memories matching the given scope. + pub async fn delete_all_memories(&self, scope: &MemoryScope) -> GraphBitResult<()> { + let user_id = scope.user_id.clone(); + let agent_id = scope.agent_id.clone(); + let run_id = scope.run_id.clone(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + let (where_clause, params) = build_scope_filter(&user_id, &agent_id, &run_id); + let sql = format!("DELETE FROM memories{where_clause}"); + let param_refs: Vec<&dyn rusqlite::types::ToSql> = + params.iter().map(|p| p as &dyn rusqlite::types::ToSql).collect(); + conn.execute(&sql, param_refs.as_slice())?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Record a history entry for a memory mutation. + pub async fn insert_history(&self, history: &MemoryHistory) -> GraphBitResult<()> { + let memory_id = history.memory_id.to_string(); + let old_content = history.old_content.clone(); + let new_content = history.new_content.clone(); + let action = history.action.to_string(); + let timestamp = history.timestamp.to_rfc3339(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult<()> { + let conn = conn_arc.blocking_lock(); + conn.execute( + "INSERT INTO memory_history (memory_id, old_content, new_content, action, timestamp) + VALUES (?1, ?2, ?3, ?4, ?5)", + rusqlite::params![memory_id, old_content, new_content, action, timestamp], + )?; + Ok(()) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } + + /// Get the full history for a specific memory. + pub async fn get_history(&self, memory_id: &MemoryId) -> GraphBitResult> { + let id = memory_id.to_string(); + let conn_arc = Arc::clone(&self.conn); + + tokio::task::spawn_blocking(move || -> GraphBitResult> { + let conn = conn_arc.blocking_lock(); + let mut stmt = conn.prepare( + "SELECT memory_id, old_content, new_content, action, timestamp + FROM memory_history WHERE memory_id = ?1 ORDER BY timestamp ASC", + )?; + let mut rows = stmt.query(rusqlite::params![id])?; + let mut entries = Vec::new(); + while let Some(row) = rows.next()? { + entries.push(row_to_history(row)?); + } + Ok(entries) + }) + .await + .map_err(|e| GraphBitError::memory(format!("Join error: {e}")))? + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn row_to_memory(row: &rusqlite::Row<'_>) -> GraphBitResult { + let id_str: String = row.get(0)?; + let content: String = row.get(1)?; + let user_id: Option = row.get(2)?; + let agent_id: Option = row.get(3)?; + let run_id: Option = row.get(4)?; + let hash: String = row.get(5)?; + let metadata_json: String = row.get(6)?; + let created_at_str: String = row.get(7)?; + let updated_at_str: String = row.get(8)?; + + let id = MemoryId(Uuid::parse_str(&id_str).map_err(|e| { + GraphBitError::memory(format!("Invalid UUID in database: {e}")) + })?); + + let metadata: HashMap = + serde_json::from_str(&metadata_json).unwrap_or_default(); + + let created_at = DateTime::parse_from_rfc3339(&created_at_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + + let updated_at = DateTime::parse_from_rfc3339(&updated_at_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + + Ok(Memory { + id, + content, + scope: MemoryScope { + user_id, + agent_id, + run_id, + }, + metadata, + created_at, + updated_at, + hash, + }) +} + +fn row_to_history(row: &rusqlite::Row<'_>) -> GraphBitResult { + let memory_id_str: String = row.get(0)?; + let old_content: String = row.get(1)?; + let new_content: String = row.get(2)?; + let action_str: String = row.get(3)?; + let timestamp_str: String = row.get(4)?; + + let memory_id = MemoryId(Uuid::parse_str(&memory_id_str).map_err(|e| { + GraphBitError::memory(format!("Invalid UUID in history: {e}")) + })?); + + let timestamp = DateTime::parse_from_rfc3339(×tamp_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()); + + Ok(MemoryHistory { + memory_id, + old_content, + new_content, + action: MemoryAction::from_str_lossy(&action_str), + timestamp, + }) +} + +/// Build a SQL WHERE clause + params from optional scope fields. +fn build_scope_filter( + user_id: &Option, + agent_id: &Option, + run_id: &Option, +) -> (String, Vec) { + let mut conditions = Vec::new(); + let mut params = Vec::new(); + + if let Some(uid) = user_id { + params.push(uid.clone()); + conditions.push(format!("user_id = ?{}", params.len())); + } + if let Some(aid) = agent_id { + params.push(aid.clone()); + conditions.push(format!("agent_id = ?{}", params.len())); + } + if let Some(rid) = run_id { + params.push(rid.clone()); + conditions.push(format!("run_id = ?{}", params.len())); + } + + if conditions.is_empty() { + (String::new(), params) + } else { + (format!(" WHERE {}", conditions.join(" AND ")), params) + } +} diff --git a/core/src/memory/types.rs b/core/src/memory/types.rs new file mode 100644 index 00000000..a75f4ab0 --- /dev/null +++ b/core/src/memory/types.rs @@ -0,0 +1,206 @@ +//! Core data types for the memory layer. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; +use uuid::Uuid; + +use crate::embeddings::EmbeddingConfig; +use crate::llm::LlmConfig; + +/// Unique identifier for memories, following the `AgentId` pattern. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct MemoryId(pub Uuid); + +impl MemoryId { + /// Create a new random memory ID. + #[inline] + pub fn new() -> Self { + Self(Uuid::new_v4()) + } + + /// Parse a memory ID from a string. + pub fn from_string(s: &str) -> Result { + Ok(Self(Uuid::parse_str(s)?)) + } + + /// Get the underlying UUID. + #[inline] + pub fn as_uuid(&self) -> &Uuid { + &self.0 + } +} + +impl Default for MemoryId { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl fmt::Display for MemoryId { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Scoping information for memory isolation. +/// +/// Memories can be scoped to a specific user, agent, or run. All fields are +/// optional; omitted fields act as wildcards when filtering. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct MemoryScope { + /// User-level scope. + pub user_id: Option, + /// Agent-level scope. + pub agent_id: Option, + /// Run-level scope. + pub run_id: Option, +} + +/// A single stored memory fact. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Memory { + /// Unique identifier. + pub id: MemoryId, + /// The fact / content of the memory. + pub content: String, + /// Scoping information. + pub scope: MemoryScope, + /// Arbitrary key-value metadata. + pub metadata: HashMap, + /// Creation timestamp. + pub created_at: DateTime, + /// Last update timestamp. + pub updated_at: DateTime, + /// Content hash for deduplication. + pub hash: String, +} + +/// A memory paired with its similarity score from a search. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScoredMemory { + /// The memory. + pub memory: Memory, + /// Cosine similarity score (0.0 .. 1.0). + pub score: f64, +} + +/// A historical record of a memory mutation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryHistory { + /// Which memory was affected. + pub memory_id: MemoryId, + /// Content before the change (empty for `Add`). + pub old_content: String, + /// Content after the change (empty for `Delete`). + pub new_content: String, + /// What kind of mutation. + pub action: MemoryAction, + /// When the mutation occurred. + pub timestamp: DateTime, +} + +/// Describes the kind of memory mutation. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum MemoryAction { + /// A brand-new fact was added. + Add, + /// An existing fact was updated / refined. + Update, + /// A fact was removed. + Delete, + /// No change required (duplicate / irrelevant). + Noop, +} + +impl fmt::Display for MemoryAction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Add => write!(f, "ADD"), + Self::Update => write!(f, "UPDATE"), + Self::Delete => write!(f, "DELETE"), + Self::Noop => write!(f, "NOOP"), + } + } +} + +impl MemoryAction { + /// Parse an action from its string representation. + pub fn from_str_lossy(s: &str) -> Self { + match s.to_uppercase().as_str() { + "ADD" => Self::Add, + "UPDATE" => Self::Update, + "DELETE" => Self::Delete, + _ => Self::Noop, + } + } +} + +/// An LLM decision about what to do with an extracted fact. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryDecision { + /// The extracted fact text. + pub fact: String, + /// The decided action. + pub action: MemoryAction, + /// If updating or deleting, the target memory's ID (as a string). + pub target_memory_id: Option, +} + +/// Configuration for the memory subsystem. +#[derive(Debug, Clone)] +pub struct MemoryConfig { + /// LLM provider configuration for fact extraction. + pub llm_config: LlmConfig, + /// Embedding provider configuration for vector search. + pub embedding_config: EmbeddingConfig, + /// Path to the SQLite database file (`:memory:` for in-memory). + pub db_path: String, + /// Minimum cosine-similarity threshold for search results (0.0 .. 1.0). + pub similarity_threshold: f64, + /// Maximum tokens for the extraction LLM call. + pub max_extraction_tokens: u32, + /// Temperature for the extraction LLM call. + pub extraction_temperature: f32, +} + +impl MemoryConfig { + /// Create a new config with sensible defaults. + pub fn new(llm_config: LlmConfig, embedding_config: EmbeddingConfig) -> Self { + Self { + llm_config, + embedding_config, + db_path: "graphbit_memory.db".to_string(), + similarity_threshold: 0.7, + max_extraction_tokens: 1500, + extraction_temperature: 0.1, + } + } + + /// Override the database path. + pub fn with_db_path(mut self, db_path: impl Into) -> Self { + self.db_path = db_path.into(); + self + } + + /// Override the similarity threshold. + pub fn with_similarity_threshold(mut self, threshold: f64) -> Self { + self.similarity_threshold = threshold.clamp(0.0, 1.0); + self + } + + /// Override the max extraction tokens. + pub fn with_max_extraction_tokens(mut self, tokens: u32) -> Self { + self.max_extraction_tokens = tokens; + self + } + + /// Override the extraction temperature. + pub fn with_extraction_temperature(mut self, temperature: f32) -> Self { + self.extraction_temperature = temperature.clamp(0.0, 1.0); + self + } +} diff --git a/core/src/memory/vector.rs b/core/src/memory/vector.rs new file mode 100644 index 00000000..ea89e582 --- /dev/null +++ b/core/src/memory/vector.rs @@ -0,0 +1,97 @@ +//! In-memory vector index for semantic search over memories. + +use tokio::sync::RwLock; + +use crate::embeddings::EmbeddingService; +use crate::errors::GraphBitResult; + +use super::types::MemoryId; + +/// A single entry in the vector index. +#[derive(Debug, Clone)] +struct VectorEntry { + memory_id: MemoryId, + embedding: Vec, +} + +/// In-memory vector index backed by brute-force cosine similarity. +/// +/// Suitable for moderate memory counts (thousands). For larger datasets a +/// purpose-built ANN index should replace this implementation. +pub struct VectorIndex { + entries: RwLock>, +} + +impl VectorIndex { + /// Create a new, empty vector index. + pub fn new() -> Self { + Self { + entries: RwLock::new(Vec::new()), + } + } + + /// Insert an embedding for the given memory. + pub async fn insert(&self, memory_id: MemoryId, embedding: Vec) { + let mut entries = self.entries.write().await; + entries.push(VectorEntry { + memory_id, + embedding, + }); + } + + /// Search for the `top_k` most similar entries to `query_embedding`, + /// returning `(MemoryId, similarity_score)` pairs above `threshold`. + pub async fn search( + &self, + query_embedding: &[f32], + top_k: usize, + threshold: f64, + ) -> GraphBitResult> { + let entries = self.entries.read().await; + + let mut scored: Vec<(MemoryId, f64)> = entries + .iter() + .filter_map(|entry| { + let sim = EmbeddingService::cosine_similarity(query_embedding, &entry.embedding) + .ok()?; + let sim_f64 = f64::from(sim); + if sim_f64 >= threshold { + Some((entry.memory_id.clone(), sim_f64)) + } else { + None + } + }) + .collect(); + + // Sort descending by score. + scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.truncate(top_k); + + Ok(scored) + } + + /// Remove entries for a specific memory. + pub async fn remove(&self, memory_id: &MemoryId) { + let mut entries = self.entries.write().await; + entries.retain(|e| &e.memory_id != memory_id); + } + + /// Replace the embedding for an existing memory. + pub async fn update(&self, memory_id: &MemoryId, embedding: Vec) { + let mut entries = self.entries.write().await; + if let Some(entry) = entries.iter_mut().find(|e| &e.memory_id == memory_id) { + entry.embedding = embedding; + } else { + entries.push(VectorEntry { + memory_id: memory_id.clone(), + embedding, + }); + } + } + + /// Remove all entries from the index. + pub async fn clear(&self) { + let mut entries = self.entries.write().await; + entries.clear(); + } +} diff --git a/tests/rust_unit_tests/memory_store_tests.rs b/tests/rust_unit_tests/memory_store_tests.rs new file mode 100644 index 00000000..a3e25455 --- /dev/null +++ b/tests/rust_unit_tests/memory_store_tests.rs @@ -0,0 +1,277 @@ +use std::collections::HashMap; + +use graphbit_core::memory::{ + store::MetadataStore, Memory, MemoryAction, MemoryHistory, MemoryId, MemoryScope, +}; + +fn make_memory(content: &str, user_id: Option<&str>) -> Memory { + Memory { + id: MemoryId::new(), + content: content.to_string(), + scope: MemoryScope { + user_id: user_id.map(String::from), + agent_id: None, + run_id: None, + }, + metadata: HashMap::new(), + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + hash: format!("hash_{content}"), + } +} + +#[tokio::test] +async fn test_metadata_store_insert_and_get() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let memory = make_memory("User lives in Munich", Some("user1")); + let id = memory.id.clone(); + + store + .insert_memory(&memory) + .await + .expect("insert should succeed"); + + let fetched = store + .get_memory(&id) + .await + .expect("get should succeed") + .expect("memory should exist"); + assert_eq!(fetched.content, "User lives in Munich"); + assert_eq!(fetched.scope.user_id.as_deref(), Some("user1")); +} + +#[tokio::test] +async fn test_metadata_store_update() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let memory = make_memory("User lives in Munich", Some("user1")); + let id = memory.id.clone(); + store.insert_memory(&memory).await.expect("insert ok"); + + store + .update_memory(&id, "User lives in Berlin", "def456") + .await + .expect("update should succeed"); + + let updated = store + .get_memory(&id) + .await + .expect("get ok") + .expect("should exist"); + assert_eq!(updated.content, "User lives in Berlin"); + assert_eq!(updated.hash, "def456"); +} + +#[tokio::test] +async fn test_metadata_store_update_nonexistent() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let result = store + .update_memory(&MemoryId::new(), "content", "hash") + .await; + assert!(result.is_err(), "Updating non-existent memory should fail"); +} + +#[tokio::test] +async fn test_metadata_store_delete() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let memory = make_memory("To be deleted", None); + let id = memory.id.clone(); + store.insert_memory(&memory).await.expect("insert ok"); + + store.delete_memory(&id).await.expect("delete ok"); + + let gone = store.get_memory(&id).await.expect("get ok"); + assert!(gone.is_none()); +} + +#[tokio::test] +async fn test_metadata_store_get_nonexistent() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let result = store + .get_memory(&MemoryId::new()) + .await + .expect("get should not error"); + assert!(result.is_none()); +} + +#[tokio::test] +async fn test_metadata_store_scope_filtering() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + for (user, content) in &[("alice", "Fact A"), ("bob", "Fact B"), ("alice", "Fact C")] { + let memory = make_memory(content, Some(user)); + store.insert_memory(&memory).await.expect("insert ok"); + } + + // Filter by Alice + let alice_scope = MemoryScope { + user_id: Some("alice".to_string()), + ..Default::default() + }; + let alice_memories = store + .get_all_memories(&alice_scope) + .await + .expect("get_all ok"); + assert_eq!(alice_memories.len(), 2); + + // Filter by Bob + let bob_scope = MemoryScope { + user_id: Some("bob".to_string()), + ..Default::default() + }; + let bob_memories = store + .get_all_memories(&bob_scope) + .await + .expect("get_all ok"); + assert_eq!(bob_memories.len(), 1); + + // No filter (all) + let all = store + .get_all_memories(&MemoryScope::default()) + .await + .expect("get_all ok"); + assert_eq!(all.len(), 3); +} + +#[tokio::test] +async fn test_metadata_store_delete_all_by_scope() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + for (user, content) in &[("alice", "Fact A"), ("bob", "Fact B"), ("alice", "Fact C")] { + let memory = make_memory(content, Some(user)); + store.insert_memory(&memory).await.expect("insert ok"); + } + + let alice_scope = MemoryScope { + user_id: Some("alice".to_string()), + ..Default::default() + }; + store + .delete_all_memories(&alice_scope) + .await + .expect("delete_all ok"); + + let remaining = store + .get_all_memories(&MemoryScope::default()) + .await + .expect("get_all ok"); + assert_eq!(remaining.len(), 1); + assert_eq!(remaining[0].scope.user_id.as_deref(), Some("bob")); +} + +#[tokio::test] +async fn test_metadata_store_history() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let memory = make_memory("Initial", None); + let id = memory.id.clone(); + store.insert_memory(&memory).await.expect("insert ok"); + + store + .insert_history(&MemoryHistory { + memory_id: id.clone(), + old_content: String::new(), + new_content: "First version".to_string(), + action: MemoryAction::Add, + timestamp: chrono::Utc::now(), + }) + .await + .expect("insert_history ok"); + + store + .insert_history(&MemoryHistory { + memory_id: id.clone(), + old_content: "First version".to_string(), + new_content: "Second version".to_string(), + action: MemoryAction::Update, + timestamp: chrono::Utc::now(), + }) + .await + .expect("insert_history ok"); + + let history = store.get_history(&id).await.expect("get_history ok"); + assert_eq!(history.len(), 2); + assert_eq!(history[0].action, MemoryAction::Add); + assert_eq!(history[1].action, MemoryAction::Update); + assert_eq!(history[0].new_content, "First version"); + assert_eq!(history[1].old_content, "First version"); + assert_eq!(history[1].new_content, "Second version"); +} + +#[tokio::test] +async fn test_metadata_store_history_empty() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let memory = make_memory("No history", None); + let id = memory.id.clone(); + store.insert_memory(&memory).await.expect("insert ok"); + + let history = store.get_history(&id).await.expect("get_history ok"); + assert!(history.is_empty()); +} + +#[tokio::test] +async fn test_metadata_store_delete_cascades_history() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let memory = make_memory("Will be deleted", None); + let id = memory.id.clone(); + store.insert_memory(&memory).await.expect("insert ok"); + + store + .insert_history(&MemoryHistory { + memory_id: id.clone(), + old_content: String::new(), + new_content: "First version".to_string(), + action: MemoryAction::Add, + timestamp: chrono::Utc::now(), + }) + .await + .expect("insert_history ok"); + + // Delete the memory — history should cascade + store.delete_memory(&id).await.expect("delete ok"); + + // History should be gone due to ON DELETE CASCADE + let history = store.get_history(&id).await.expect("get_history ok"); + assert!(history.is_empty()); +} + +#[tokio::test] +async fn test_metadata_store_metadata_roundtrip() { + let store = MetadataStore::new(":memory:").expect("should create in-memory store"); + + let mut metadata = HashMap::new(); + metadata.insert("source".to_string(), serde_json::json!("conversation")); + metadata.insert("confidence".to_string(), serde_json::json!(0.95)); + + let memory = Memory { + id: MemoryId::new(), + content: "With metadata".to_string(), + scope: MemoryScope::default(), + metadata, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + hash: "h".to_string(), + }; + let id = memory.id.clone(); + store.insert_memory(&memory).await.expect("insert ok"); + + let fetched = store + .get_memory(&id) + .await + .expect("get ok") + .expect("should exist"); + assert_eq!( + fetched.metadata.get("source"), + Some(&serde_json::json!("conversation")) + ); + assert_eq!( + fetched.metadata.get("confidence"), + Some(&serde_json::json!(0.95)) + ); +} diff --git a/tests/rust_unit_tests/memory_types_tests.rs b/tests/rust_unit_tests/memory_types_tests.rs new file mode 100644 index 00000000..9e8d094c --- /dev/null +++ b/tests/rust_unit_tests/memory_types_tests.rs @@ -0,0 +1,205 @@ +use graphbit_core::memory::{ + MemoryAction, MemoryConfig, MemoryId, MemoryScope, +}; +use std::collections::HashMap; + +#[test] +fn test_memory_id_creation() { + let id1 = MemoryId::new(); + let id2 = MemoryId::new(); + assert_ne!(id1, id2, "Two fresh IDs should be unique"); + + let id_str = id1.to_string(); + let parsed = MemoryId::from_string(&id_str).expect("should parse valid UUID"); + assert_eq!(id1, parsed); +} + +#[test] +fn test_memory_id_default() { + let id = MemoryId::default(); + // Default should produce a valid UUID + let _ = id.as_uuid(); + assert!(!id.to_string().is_empty()); +} + +#[test] +fn test_memory_id_from_invalid_string() { + let result = MemoryId::from_string("not-a-uuid"); + assert!(result.is_err()); +} + +#[test] +fn test_memory_id_display() { + let id = MemoryId::new(); + let display = format!("{id}"); + // UUID format: 8-4-4-4-12 hex chars + assert_eq!(display.len(), 36); + assert_eq!(display.chars().filter(|c| *c == '-').count(), 4); +} + +#[test] +fn test_memory_id_equality_and_hash() { + let id = MemoryId::new(); + let cloned = id.clone(); + assert_eq!(id, cloned); + + // Test HashMap usage (requires Hash + Eq) + let mut map = HashMap::new(); + map.insert(id.clone(), "value"); + assert_eq!(map.get(&id), Some(&"value")); +} + +#[test] +fn test_memory_scope_default() { + let scope = MemoryScope::default(); + assert!(scope.user_id.is_none()); + assert!(scope.agent_id.is_none()); + assert!(scope.run_id.is_none()); +} + +#[test] +fn test_memory_scope_with_fields() { + let scope = MemoryScope { + user_id: Some("user1".to_string()), + agent_id: Some("agent1".to_string()), + run_id: Some("run1".to_string()), + }; + assert_eq!(scope.user_id.as_deref(), Some("user1")); + assert_eq!(scope.agent_id.as_deref(), Some("agent1")); + assert_eq!(scope.run_id.as_deref(), Some("run1")); +} + +#[test] +fn test_memory_action_display() { + assert_eq!(MemoryAction::Add.to_string(), "ADD"); + assert_eq!(MemoryAction::Update.to_string(), "UPDATE"); + assert_eq!(MemoryAction::Delete.to_string(), "DELETE"); + assert_eq!(MemoryAction::Noop.to_string(), "NOOP"); +} + +#[test] +fn test_memory_action_from_str_lossy() { + assert_eq!(MemoryAction::from_str_lossy("ADD"), MemoryAction::Add); + assert_eq!(MemoryAction::from_str_lossy("add"), MemoryAction::Add); + assert_eq!(MemoryAction::from_str_lossy("Add"), MemoryAction::Add); + assert_eq!(MemoryAction::from_str_lossy("UPDATE"), MemoryAction::Update); + assert_eq!(MemoryAction::from_str_lossy("update"), MemoryAction::Update); + assert_eq!(MemoryAction::from_str_lossy("DELETE"), MemoryAction::Delete); + assert_eq!(MemoryAction::from_str_lossy("delete"), MemoryAction::Delete); + assert_eq!(MemoryAction::from_str_lossy("unknown"), MemoryAction::Noop); + assert_eq!(MemoryAction::from_str_lossy(""), MemoryAction::Noop); +} + +#[test] +fn test_memory_action_equality() { + assert_eq!(MemoryAction::Add, MemoryAction::Add); + assert_ne!(MemoryAction::Add, MemoryAction::Update); + assert_ne!(MemoryAction::Delete, MemoryAction::Noop); +} + +#[test] +fn test_memory_config_builder() { + let llm_config = graphbit_core::llm::LlmConfig::openai("test-key", "gpt-4o-mini"); + let embedding_config = graphbit_core::embeddings::EmbeddingConfig { + provider: graphbit_core::embeddings::EmbeddingProvider::OpenAI, + api_key: "test-key".to_string(), + model: "text-embedding-3-small".to_string(), + base_url: None, + timeout_seconds: None, + max_batch_size: None, + extra_params: HashMap::new(), + python_instance: None, + }; + + let config = MemoryConfig::new(llm_config, embedding_config) + .with_db_path(":memory:") + .with_similarity_threshold(0.8) + .with_max_extraction_tokens(2000) + .with_extraction_temperature(0.2); + + assert_eq!(config.db_path, ":memory:"); + assert!((config.similarity_threshold - 0.8).abs() < f64::EPSILON); + assert_eq!(config.max_extraction_tokens, 2000); + assert!((config.extraction_temperature - 0.2).abs() < f32::EPSILON); +} + +#[test] +fn test_memory_config_defaults() { + let llm_config = graphbit_core::llm::LlmConfig::openai("key", "model"); + let embedding_config = graphbit_core::embeddings::EmbeddingConfig { + provider: graphbit_core::embeddings::EmbeddingProvider::OpenAI, + api_key: "key".to_string(), + model: "model".to_string(), + base_url: None, + timeout_seconds: None, + max_batch_size: None, + extra_params: HashMap::new(), + python_instance: None, + }; + + let config = MemoryConfig::new(llm_config, embedding_config); + assert_eq!(config.db_path, "graphbit_memory.db"); + assert!((config.similarity_threshold - 0.7).abs() < f64::EPSILON); + assert_eq!(config.max_extraction_tokens, 1500); + assert!((config.extraction_temperature - 0.1).abs() < f32::EPSILON); +} + +#[test] +fn test_memory_config_threshold_clamping() { + let llm_config = graphbit_core::llm::LlmConfig::openai("key", "model"); + let embedding_config = graphbit_core::embeddings::EmbeddingConfig { + provider: graphbit_core::embeddings::EmbeddingProvider::OpenAI, + api_key: "key".to_string(), + model: "model".to_string(), + base_url: None, + timeout_seconds: None, + max_batch_size: None, + extra_params: HashMap::new(), + python_instance: None, + }; + + // Threshold above 1.0 should be clamped + let config = MemoryConfig::new(llm_config.clone(), embedding_config.clone()) + .with_similarity_threshold(1.5); + assert!((config.similarity_threshold - 1.0).abs() < f64::EPSILON); + + // Threshold below 0.0 should be clamped + let config = MemoryConfig::new(llm_config.clone(), embedding_config.clone()) + .with_similarity_threshold(-0.5); + assert!(config.similarity_threshold.abs() < f64::EPSILON); + + // Temperature above 1.0 should be clamped + let config = MemoryConfig::new(llm_config, embedding_config) + .with_extraction_temperature(2.0); + assert!((config.extraction_temperature - 1.0).abs() < f32::EPSILON); +} + +#[test] +fn test_memory_action_serialization() { + let action = MemoryAction::Add; + let json = serde_json::to_string(&action).expect("serialize ok"); + let deserialized: MemoryAction = serde_json::from_str(&json).expect("deserialize ok"); + assert_eq!(action, deserialized); +} + +#[test] +fn test_memory_scope_serialization() { + let scope = MemoryScope { + user_id: Some("user1".to_string()), + agent_id: None, + run_id: Some("run1".to_string()), + }; + let json = serde_json::to_string(&scope).expect("serialize ok"); + let deserialized: MemoryScope = serde_json::from_str(&json).expect("deserialize ok"); + assert_eq!(deserialized.user_id.as_deref(), Some("user1")); + assert!(deserialized.agent_id.is_none()); + assert_eq!(deserialized.run_id.as_deref(), Some("run1")); +} + +#[test] +fn test_memory_id_serialization() { + let id = MemoryId::new(); + let json = serde_json::to_string(&id).expect("serialize ok"); + let deserialized: MemoryId = serde_json::from_str(&json).expect("deserialize ok"); + assert_eq!(id, deserialized); +} diff --git a/tests/rust_unit_tests/memory_vector_tests.rs b/tests/rust_unit_tests/memory_vector_tests.rs new file mode 100644 index 00000000..c3342f5d --- /dev/null +++ b/tests/rust_unit_tests/memory_vector_tests.rs @@ -0,0 +1,165 @@ +use graphbit_core::memory::{vector::VectorIndex, MemoryId}; + +#[tokio::test] +async fn test_vector_index_insert_and_search() { + let index = VectorIndex::new(); + + let id1 = MemoryId::new(); + let id2 = MemoryId::new(); + + index.insert(id1.clone(), vec![1.0, 0.0, 0.0]).await; + index.insert(id2.clone(), vec![0.0, 1.0, 0.0]).await; + + // Search with a vector close to id1 + let results = index + .search(&[0.9, 0.1, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert_eq!(results.len(), 2); + // The first result should be closer to id1 + assert_eq!(results[0].0, id1); + assert!(results[0].1 > results[1].1); +} + +#[tokio::test] +async fn test_vector_index_remove() { + let index = VectorIndex::new(); + + let id1 = MemoryId::new(); + let id2 = MemoryId::new(); + + index.insert(id1.clone(), vec![1.0, 0.0, 0.0]).await; + index.insert(id2.clone(), vec![0.0, 1.0, 0.0]).await; + + index.remove(&id1).await; + let results = index + .search(&[1.0, 0.0, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, id2); +} + +#[tokio::test] +async fn test_vector_index_threshold() { + let index = VectorIndex::new(); + + let id1 = MemoryId::new(); + index.insert(id1.clone(), vec![1.0, 0.0, 0.0]).await; + + // Orthogonal vector should have ~0 similarity + let results = index + .search(&[0.0, 1.0, 0.0], 10, 0.5) + .await + .expect("search ok"); + assert!( + results.is_empty(), + "Orthogonal vector should be below threshold 0.5" + ); + + // Identical vector should have similarity 1.0 + let results = index + .search(&[1.0, 0.0, 0.0], 10, 0.99) + .await + .expect("search ok"); + assert_eq!(results.len(), 1); + assert!((results[0].1 - 1.0).abs() < 0.01); +} + +#[tokio::test] +async fn test_vector_index_update() { + let index = VectorIndex::new(); + let id = MemoryId::new(); + + index.insert(id.clone(), vec![1.0, 0.0, 0.0]).await; + + // Update embedding + index.update(&id, vec![0.0, 1.0, 0.0]).await; + + // Now id should be similar to [0, 1, 0] rather than [1, 0, 0] + let results = index + .search(&[0.0, 1.0, 0.0], 10, 0.5) + .await + .expect("search ok"); + assert_eq!(results.len(), 1); + assert!((results[0].1 - 1.0).abs() < 0.01); +} + +#[tokio::test] +async fn test_vector_index_update_nonexistent_inserts() { + let index = VectorIndex::new(); + let id = MemoryId::new(); + + // Update on a non-existent ID should insert + index.update(&id, vec![1.0, 0.0, 0.0]).await; + + let results = index + .search(&[1.0, 0.0, 0.0], 10, 0.5) + .await + .expect("search ok"); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, id); +} + +#[tokio::test] +async fn test_vector_index_clear() { + let index = VectorIndex::new(); + index.insert(MemoryId::new(), vec![1.0, 0.0]).await; + index.insert(MemoryId::new(), vec![0.0, 1.0]).await; + + index.clear().await; + + let results = index + .search(&[1.0, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert!(results.is_empty()); +} + +#[tokio::test] +async fn test_vector_index_top_k_limit() { + let index = VectorIndex::new(); + + // Insert 5 similar vectors + for i in 0..5 { + let mut v = vec![0.0; 3]; + v[0] = 1.0 - (i as f32 * 0.1); + v[1] = i as f32 * 0.1; + index.insert(MemoryId::new(), v).await; + } + + // Request top_k=2 + let results = index + .search(&[1.0, 0.0, 0.0], 2, 0.0) + .await + .expect("search ok"); + assert_eq!(results.len(), 2); + // Results should be sorted descending by score + assert!(results[0].1 >= results[1].1); +} + +#[tokio::test] +async fn test_vector_index_empty_search() { + let index = VectorIndex::new(); + + let results = index + .search(&[1.0, 0.0, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert!(results.is_empty()); +} + +#[tokio::test] +async fn test_vector_index_remove_nonexistent() { + let index = VectorIndex::new(); + let id = MemoryId::new(); + + // Removing a non-existent ID should not panic + index.remove(&id).await; + + let results = index + .search(&[1.0, 0.0, 0.0], 10, 0.0) + .await + .expect("search ok"); + assert!(results.is_empty()); +} diff --git a/tests/rust_unit_tests/mod.rs b/tests/rust_unit_tests/mod.rs index 8d4ba757..3e54f43a 100644 --- a/tests/rust_unit_tests/mod.rs +++ b/tests/rust_unit_tests/mod.rs @@ -10,6 +10,9 @@ mod error_tests; mod graph_advanced_tests; mod llm_provider_tests; mod llm_tests; +mod memory_store_tests; +mod memory_types_tests; +mod memory_vector_tests; mod python_bindings_tests; mod serialization_comprehensive_tests; mod text_splitter_tests;