From d00783635443eaa65bf4a92af90db0a3351f5992 Mon Sep 17 00:00:00 2001 From: Dimon Date: Sun, 17 May 2026 02:44:16 +0800 Subject: [PATCH] fix(core): scope L1 memory search in stores Pass sessionKey and sessionId filters through L1 search APIs so shared backends apply scope before topK truncation. Keep a caller-side guard as a final defense and cover both vector and FTS recall paths with regression tests. Signed-off-by: Dimon --- CHANGELOG.md | 8 ++ src/core/store/sqlite.ts | 72 +++++++++++--- src/core/store/tcvdb.ts | 59 +++++++++--- src/core/store/types.ts | 11 ++- src/core/tdai-core.ts | 2 + src/core/tools/memory-search.test.ts | 139 +++++++++++++++++++++++++++ src/core/tools/memory-search.ts | 41 +++++++- src/core/types.ts | 2 + src/gateway/server.ts | 2 + src/gateway/types.ts | 2 + 10 files changed, 302 insertions(+), 36 deletions(-) create mode 100644 src/core/tools/memory-search.test.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f2d346..6edd511 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,14 @@ --- +## [Unreleased] + +### ✨ 改进 + +- **L1 作用域过滤**:`searchMemories()` 支持 `sessionKey` / `sessionId`,Gateway `/search/memories` 支持 `session_key` / `session_id`,避免共享后端里的结构化记忆跨 session 召回。 + +--- + ## [0.3.4] - 2026-05-12 ### 🐛 修复 diff --git a/src/core/store/sqlite.ts b/src/core/store/sqlite.ts index 6252419..55d967a 100644 --- a/src/core/store/sqlite.ts +++ b/src/core/store/sqlite.ts @@ -28,6 +28,7 @@ import type { IMemoryStore, StoreCapabilities, L0Record, + L1QueryFilter, L1SearchResult, L1FtsResult, L0SearchResult, @@ -96,16 +97,6 @@ export interface L0RecordRow { timestamp: number; } -/** Filter options for querying L1 records from SQLite. */ -export interface L1QueryFilter { - /** If provided, only return records for this session key (conversation channel). */ - sessionKey?: string; - /** If provided, only return records for this session ID (single conversation instance). */ - sessionId?: string; - /** If provided, only return records with updated_time strictly after this ISO 8601 UTC timestamp. */ - updatedAfter?: string; -} - interface Logger { debug?: (message: string) => void; info: (message: string) => void; @@ -302,6 +293,36 @@ export function bm25RankToScore(rank: number): number { return 1 / (1 + rank); } +function hasL1ScopeFilter(filter?: L1QueryFilter): boolean { + return Boolean(filter?.sessionKey || filter?.sessionId); +} + +function matchesL1ScopeFilter( + row: { session_key: string; session_id: string }, + filter?: L1QueryFilter, +): boolean { + if (filter?.sessionKey && row.session_key !== filter.sessionKey) return false; + if (filter?.sessionId && row.session_id !== filter.sessionId) return false; + return true; +} + +function buildL1ScopeSql(filter?: L1QueryFilter): { + clauses: string[]; + values: string[]; +} { + const clauses: string[] = []; + const values: string[] = []; + if (filter?.sessionKey) { + clauses.push("session_key = ?"); + values.push(filter.sessionKey); + } + if (filter?.sessionId) { + clauses.push("session_id = ?"); + values.push(filter.sessionId); + } + return { clauses, values }; +} + /** FTS5 search result for L1 records. */ export interface FtsSearchResult { record_id: string; @@ -1109,7 +1130,12 @@ export class VectorStore implements IMemoryStore { * **Fault-tolerant**: returns an empty array on any error (e.g. dimension * mismatch, corrupted DB) so callers can fall back to keyword search. */ - searchL1Vector(queryEmbedding: Float32Array, topK = 5): VectorSearchResult[] { + searchL1Vector( + queryEmbedding: Float32Array, + topK = 5, + _queryText?: string, + filter?: L1QueryFilter, + ): VectorSearchResult[] { if (this.degraded || !this.vecTablesReady) { if (this.degraded) this.logger?.warn(`${TAG} [L1-search] SKIPPED (degraded mode)`); return []; @@ -1123,7 +1149,9 @@ export class VectorStore implements IMemoryStore { // NOTE: "AND distance IS NOT NULL" is NOT usable because vec0 does not // support that constraint — it causes an empty result set. const ZERO_VEC_BUFFER = 10; - const retrieveCount = topK + ZERO_VEC_BUFFER; + const retrieveCount = hasL1ScopeFilter(filter) + ? Math.max(this.countL1(), topK + ZERO_VEC_BUFFER) + : topK + ZERO_VEC_BUFFER; this.logger?.debug?.( `${TAG} [L1-search] START topK=${topK}, retrieveCount=${retrieveCount}, ` + @@ -1171,6 +1199,7 @@ export class VectorStore implements IMemoryStore { this.logger?.warn(`${TAG} [L1-search] record_id=${record_id} has vector but NO metadata (orphan)`); continue; } + if (!matchesL1ScopeFilter(meta, filter)) continue; const score = 1.0 - distance; this.logger?.debug?.( @@ -2026,10 +2055,25 @@ export class VectorStore implements IMemoryStore { * * **Fault-tolerant**: returns an empty array on any error. */ - searchL1Fts(ftsQuery: string, limit = 20): FtsSearchResult[] { + searchL1Fts(ftsQuery: string, limit = 20, filter?: L1QueryFilter): FtsSearchResult[] { if (this.degraded || !this.ftsAvailable) return []; try { - const rows = this.stmtL1FtsSearch.all(ftsQuery, limit) as Array<{ + const scopeSql = buildL1ScopeSql(filter); + const rows = (scopeSql.clauses.length === 0 + ? this.stmtL1FtsSearch.all(ftsQuery, limit) + : this.db + .prepare(` + SELECT record_id, content_original AS content, type, priority, scene_name, + session_key, session_id, timestamp_str, timestamp_start, timestamp_end, + metadata_json, + bm25(l1_fts) AS rank + FROM l1_fts + WHERE l1_fts MATCH ? + AND ${scopeSql.clauses.join(" AND ")} + ORDER BY rank ASC + LIMIT ? + `) + .all(ftsQuery, ...scopeSql.values, limit)) as Array<{ record_id: string; content: string; type: string; diff --git a/src/core/store/tcvdb.ts b/src/core/store/tcvdb.ts index 35fe6a9..954996a 100644 --- a/src/core/store/tcvdb.ts +++ b/src/core/store/tcvdb.ts @@ -98,6 +98,25 @@ function epochMsToIso(ms: number): string { return new Date(ms).toISOString(); } +function escapeFilterString(value: string): string { + return value.replace(/\\/g, "\\\\").replace(/"/g, '\\"'); +} + +function buildL1FilterExpression(filter?: L1QueryFilter): string | undefined { + const conditions: string[] = []; + if (filter?.sessionKey) { + conditions.push(`session_key = "${escapeFilterString(filter.sessionKey)}"`); + } + if (filter?.sessionId) { + conditions.push(`session_id = "${escapeFilterString(filter.sessionId)}"`); + } + if (filter?.updatedAfter) { + const afterMs = isoToEpochMs(filter.updatedAfter); + if (afterMs > 0) conditions.push(`updated_time_ms > ${afterMs}`); + } + return conditions.length > 0 ? conditions.join(" and ") : undefined; +} + /** * Extract agent ID from a sessionKey like `agent::`. * Returns empty string if the format doesn't match. @@ -552,15 +571,7 @@ export class TcvdbMemoryStore implements IMemoryStore { await this._ensureInit(); if (this.degraded) return []; - // Build TCVDB filter expression from L1QueryFilter - const conditions: string[] = []; - if (filter?.sessionKey) conditions.push(`session_key = "${filter.sessionKey}"`); - if (filter?.sessionId) conditions.push(`session_id = "${filter.sessionId}"`); - if (filter?.updatedAfter) { - const afterMs = isoToEpochMs(filter.updatedAfter); - if (afterMs > 0) conditions.push(`updated_time_ms > ${afterMs}`); - } - const filterExpr = conditions.length > 0 ? conditions.join(" and ") : undefined; + const filterExpr = buildL1FilterExpression(filter); const docs = await this._queryAllDocs( this.l1Collection, @@ -615,21 +626,26 @@ export class TcvdbMemoryStore implements IMemoryStore { // ── L1 Search Operations ───────────────────────────────── - async searchL1Vector(_queryEmbedding: Float32Array, topK?: number, queryText?: string): Promise { + async searchL1Vector( + _queryEmbedding: Float32Array, + topK?: number, + queryText?: string, + filter?: L1QueryFilter, + ): Promise { // TCVDB uses server-side embedding — delegate to hybrid search with text if (queryText) { - return this.searchL1HybridAsync({ queryText, topK }); + return this.searchL1HybridAsync({ queryText, topK, filter }); } // No queryText and TCVDB can't use client embeddings directly via embeddingItems // Return empty — callers should pass queryText for TCVDB return []; } - async searchL1Fts(ftsQuery: string, limit?: number): Promise { + async searchL1Fts(ftsQuery: string, limit?: number, filter?: L1QueryFilter): Promise { // TCVDB has no pure FTS — use hybrid search with sparse-only path // The ftsQuery is raw text, use it as queryText for hybrid if (!ftsQuery) return []; - const results = await this.searchL1HybridAsync({ queryText: ftsQuery, topK: limit }); + const results = await this.searchL1HybridAsync({ queryText: ftsQuery, topK: limit, filter }); // L1SearchResult and L1FtsResult have identical shapes return results; } @@ -637,12 +653,21 @@ export class TcvdbMemoryStore implements IMemoryStore { async searchL1Hybrid(params: { query?: string; queryEmbedding?: Float32Array; + sessionId?: string; + sessionKey?: string; sparseVector?: SparseVector; topK?: number; }): Promise { const queryText = params.query; if (!queryText) return []; - return this.searchL1HybridAsync({ queryText, topK: params.topK }); + return this.searchL1HybridAsync({ + queryText, + topK: params.topK, + filter: { + sessionId: params.sessionId, + sessionKey: params.sessionKey, + }, + }); } /** @@ -650,10 +675,11 @@ export class TcvdbMemoryStore implements IMemoryStore { * Call this directly from async contexts (hooks, tools). */ async searchL1HybridAsync(params: { + filter?: L1QueryFilter; queryText: string; topK?: number; }): Promise { - const { queryText, topK = 10 } = params; + const { filter, queryText, topK = 10 } = params; if (!queryText) return []; try { @@ -665,6 +691,8 @@ export class TcvdbMemoryStore implements IMemoryStore { limit: topK, outputFields: L1_OUTPUT_FIELDS, }; + const filterExpr = buildL1FilterExpression(filter); + if (filterExpr) searchParams.filter = filterExpr; // ann: use embedding field name "text" for server-side embedding // (per SDK: AnnSearch(field_name="text", data='query string')) @@ -702,6 +730,7 @@ export class TcvdbMemoryStore implements IMemoryStore { retrieveVector: false, outputFields: L1_OUTPUT_FIELDS, }; + if (filterExpr) denseSearch.filter = filterExpr; const resp = await this.client.search(this.l1Collection, denseSearch); return this._parseL1SearchResults(resp.documents); } diff --git a/src/core/store/types.ts b/src/core/store/types.ts index cfcb50a..9195f1c 100644 --- a/src/core/store/types.ts +++ b/src/core/store/types.ts @@ -270,11 +270,18 @@ export interface IMemoryStore { // ── L1 Search ──────────────────────────────────────────── - searchL1Vector(queryEmbedding: Float32Array, topK?: number, queryText?: string): MaybePromise; - searchL1Fts(ftsQuery: string, limit?: number): MaybePromise; + searchL1Vector( + queryEmbedding: Float32Array, + topK?: number, + queryText?: string, + filter?: L1QueryFilter, + ): MaybePromise; + searchL1Fts(ftsQuery: string, limit?: number, filter?: L1QueryFilter): MaybePromise; searchL1Hybrid?(params: { query?: string; queryEmbedding?: Float32Array; + sessionId?: string; + sessionKey?: string; sparseVector?: Array<[number, number]>; topK?: number; }): MaybePromise; diff --git a/src/core/tdai-core.ts b/src/core/tdai-core.ts index 977d4a2..b2c2708 100644 --- a/src/core/tdai-core.ts +++ b/src/core/tdai-core.ts @@ -291,6 +291,8 @@ export class TdaiCore { const result = await executeMemorySearch({ query: params.query, limit: params.limit ?? 5, + sessionKey: params.sessionKey, + sessionId: params.sessionId, type: params.type, scene: params.scene, vectorStore: this.vectorStore, diff --git a/src/core/tools/memory-search.test.ts b/src/core/tools/memory-search.test.ts new file mode 100644 index 0000000..1cc6b28 --- /dev/null +++ b/src/core/tools/memory-search.test.ts @@ -0,0 +1,139 @@ +import { describe, expect, it } from "vitest"; +import type { EmbeddingService } from "../store/embedding.js"; +import type { IMemoryStore, L1FtsResult, L1SearchResult } from "../store/types.js"; +import { executeMemorySearch } from "./memory-search.js"; + +function l1Result(params: { + content: string; + recordId: string; + sessionId?: string; + sessionKey: string; +}): L1FtsResult { + return { + content: params.content, + metadata_json: "{}", + priority: 50, + record_id: params.recordId, + scene_name: "review", + score: 0.9, + session_id: params.sessionId ?? `${params.sessionKey}:sub`, + session_key: params.sessionKey, + timestamp_end: "2026-05-16T00:00:00.000Z", + timestamp_start: "2026-05-16T00:00:00.000Z", + timestamp_str: "2026-05-16", + type: "preference", + }; +} + +describe("memory search scope", () => { + it("filters L1 FTS results by sessionKey before formatting", async () => { + const vectorStore = { + isFtsAvailable: () => true, + searchL1Fts: async () => [ + l1Result({ + content: "Refresh project prefers strict code review.", + recordId: "refresh-review-style", + sessionKey: "refresh-project", + }), + l1Result({ + content: "Other project prefers loose review.", + recordId: "other-review-style", + sessionKey: "other-project", + }), + ], + } as Partial as IMemoryStore; + + const result = await executeMemorySearch({ + limit: 5, + query: "review style", + sessionKey: "refresh-project", + vectorStore, + }); + + expect(result.total).toBe(1); + expect(result.results.map((item) => item.content)).toEqual([ + "Refresh project prefers strict code review.", + ]); + }); + + it("passes scope to the L1 store so filtering happens before topK truncation", async () => { + let observedFilter: { sessionKey?: string; sessionId?: string } | undefined; + const vectorStore = { + isFtsAvailable: () => true, + searchL1Fts: async ( + _ftsQuery: string, + _limit?: number, + filter?: { sessionKey?: string; sessionId?: string }, + ) => { + observedFilter = filter; + if (filter?.sessionKey === "refresh-project") { + return [ + l1Result({ + content: "Refresh project keeps in-scope memory after store filtering.", + recordId: "refresh-store-scoped", + sessionKey: "refresh-project", + }), + ]; + } + return Array.from({ length: 20 }, (_, index) => + l1Result({ + content: `Other project memory ${index}`, + recordId: `other-${index}`, + sessionKey: "other-project", + }), + ); + }, + } as Partial as IMemoryStore; + + const result = await executeMemorySearch({ + limit: 5, + query: "store scoped", + sessionKey: "refresh-project", + vectorStore, + }); + + expect(observedFilter).toEqual({ sessionKey: "refresh-project" }); + expect(result.total).toBe(1); + expect(result.results[0]?.id).toBe("refresh-store-scoped"); + }); + + it("passes scope to vector L1 search before vector topK truncation", async () => { + let observedFilter: { sessionKey?: string; sessionId?: string } | undefined; + const vectorStore = { + isFtsAvailable: () => false, + searchL1Vector: async ( + _embedding: Float32Array, + _limit?: number, + _queryText?: string, + filter?: { sessionKey?: string; sessionId?: string }, + ): Promise => { + observedFilter = filter; + return filter?.sessionId === "sub-1" + ? [ + l1Result({ + content: "Refresh vector memory stays scoped by session id.", + recordId: "refresh-vector-scoped", + sessionId: "sub-1", + sessionKey: "refresh-project", + }), + ] + : []; + }, + } as Partial as IMemoryStore; + const embeddingService = { + embed: async () => new Float32Array([0.1, 0.2]), + } as Partial as EmbeddingService; + + const result = await executeMemorySearch({ + embeddingService, + limit: 5, + query: "vector scoped", + sessionId: "sub-1", + vectorStore, + }); + + expect(observedFilter).toEqual({ sessionId: "sub-1" }); + expect(result.total).toBe(1); + expect(result.results[0]?.id).toBe("refresh-vector-scoped"); + }); +}); diff --git a/src/core/tools/memory-search.ts b/src/core/tools/memory-search.ts index dc9d2c2..38b7244 100644 --- a/src/core/tools/memory-search.ts +++ b/src/core/tools/memory-search.ts @@ -10,7 +10,7 @@ * The tool is registered via `api.registerTool()` in index.ts. */ -import type { IMemoryStore, L1SearchResult } from "../store/types.js"; +import type { IMemoryStore, L1QueryFilter, L1SearchResult } from "../store/types.js"; import { buildFtsQuery } from "../store/sqlite.js"; import type { EmbeddingService } from "../store/embedding.js"; @@ -46,6 +46,26 @@ export interface MemorySearchResult { const TAG = "[memory-tdai][tdai_memory_search]"; +function matchesL1Scope( + item: { session_key?: string; session_id?: string }, + sessionKey?: string, + sessionId?: string, +): boolean { + if (sessionKey && item.session_key !== sessionKey) return false; + if (sessionId && item.session_id !== sessionId) return false; + return true; +} + +function buildL1ScopeFilter(params: { + sessionKey?: string; + sessionId?: string; +}): L1QueryFilter | undefined { + const filter: L1QueryFilter = {}; + if (params.sessionKey) filter.sessionKey = params.sessionKey; + if (params.sessionId) filter.sessionId = params.sessionId; + return Object.keys(filter).length > 0 ? filter : undefined; +} + // ============================ // RRF (Reciprocal Rank Fusion) // ============================ @@ -88,6 +108,8 @@ function rrfMergeL1(...lists: MemorySearchResultItem[][]): MemorySearchResultIte export async function executeMemorySearch(params: { query: string; limit: number; + sessionKey?: string; + sessionId?: string; type?: string; scene?: string; vectorStore?: IMemoryStore; @@ -97,6 +119,8 @@ export async function executeMemorySearch(params: { const { query, limit, + sessionKey, + sessionId, type: typeFilter, scene: sceneFilter, vectorStore, @@ -106,6 +130,7 @@ export async function executeMemorySearch(params: { logger?.debug?.( `${TAG} CALLED: query="${query.slice(0, 100)}", limit=${limit}, ` + + `sessionKey=${sessionKey ?? "(all)"}, sessionId=${sessionId ?? "(all)"}, ` + `typeFilter=${typeFilter ?? "(none)"}, sceneFilter=${sceneFilter ?? "(none)"}, ` + `vectorStore=${vectorStore ? "available" : "UNAVAILABLE"}, ` + `embeddingService=${embeddingService ? "available" : "UNAVAILABLE"}`, @@ -124,6 +149,7 @@ export async function executeMemorySearch(params: { // ── Determine available capabilities ── const hasEmbedding = !!embeddingService; const hasFts = vectorStore.isFtsAvailable(); + const scopeFilter = buildL1ScopeFilter({ sessionKey, sessionId }); if (!hasEmbedding && !hasFts) { logger?.warn?.(`${TAG} Neither EmbeddingService nor FTS5 available — cannot search`); @@ -153,9 +179,9 @@ export async function executeMemorySearch(params: { return []; } logger?.debug?.(`${TAG} [hybrid-fts] FTS5 query: "${ftsQuery}"`); - const ftsResults = await vectorStore.searchL1Fts(ftsQuery, candidateK); + const ftsResults = await vectorStore.searchL1Fts(ftsQuery, candidateK, scopeFilter); logger?.debug?.(`${TAG} [hybrid-fts] FTS5 returned ${ftsResults.length} candidates`); - return ftsResults.map((r) => ({ + return ftsResults.filter((r) => matchesL1Scope(r, sessionKey, sessionId)).map((r) => ({ id: r.record_id, content: r.content, type: r.type, @@ -182,9 +208,14 @@ export async function executeMemorySearch(params: { logger?.debug?.( `${TAG} [hybrid-vec] Embedding OK, dims=${queryEmbedding.length}, searching top-${candidateK}...`, ); - const vecResults: L1SearchResult[] = await vectorStore.searchL1Vector(queryEmbedding, candidateK, query); + const vecResults: L1SearchResult[] = await vectorStore.searchL1Vector( + queryEmbedding, + candidateK, + query, + scopeFilter, + ); logger?.debug?.(`${TAG} [hybrid-vec] Vector search returned ${vecResults.length} candidates`); - return vecResults.map((r) => ({ + return vecResults.filter((r) => matchesL1Scope(r, sessionKey, sessionId)).map((r) => ({ id: r.record_id, content: r.content, type: r.type, diff --git a/src/core/types.ts b/src/core/types.ts index 8585b50..8cbaba0 100644 --- a/src/core/types.ts +++ b/src/core/types.ts @@ -229,6 +229,8 @@ export interface CaptureResult { export interface MemorySearchParams { query: string; limit?: number; + sessionKey?: string; + sessionId?: string; type?: string; scene?: string; } diff --git a/src/gateway/server.ts b/src/gateway/server.ts index bd7d0a0..f409842 100644 --- a/src/gateway/server.ts +++ b/src/gateway/server.ts @@ -290,6 +290,8 @@ export class TdaiGateway { const result = await this.core.searchMemories({ query: body.query, limit: body.limit, + sessionKey: body.session_key, + sessionId: body.session_id, type: body.type, scene: body.scene, }); diff --git a/src/gateway/types.ts b/src/gateway/types.ts index 50b2ff4..cce01e4 100644 --- a/src/gateway/types.ts +++ b/src/gateway/types.ts @@ -66,6 +66,8 @@ export interface CaptureResponse { export interface MemorySearchRequest { query: string; limit?: number; + session_key?: string; + session_id?: string; type?: string; scene?: string; }