From 8780970e1eaa6ead0f225f7f001a3ca021029ab8 Mon Sep 17 00:00:00 2001 From: Chris Cage Date: Sun, 12 Apr 2026 20:50:21 -0500 Subject: [PATCH 1/5] Add embedding provider seam --- src/llm.ts | 49 +++++++++++++++++++++++++++++++++++++++-------- src/store.ts | 50 +++++++++++++++++++++++++++++++++++------------- test/llm.test.ts | 14 ++++++++++++++ 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/src/llm.ts b/src/llm.ts index 7cccc3fa8..a67d86742 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -151,11 +151,25 @@ export type LLMSessionOptions = { }; /** - * Session interface for scoped LLM access with lifecycle guarantees + * Embedding-specific provider seam. + * This lets store/indexing code depend on embedding capabilities without + * coupling to generation or reranking details. */ -export interface ILLMSession { +export interface EmbeddingProvider { + /** Stable provider identifier (e.g. "llama.cpp", "openai") */ + readonly providerId: string; + /** Provider-specific embedding model identifier */ + readonly modelId: string; + /** Canonical key used to reason about embedding compatibility */ + readonly compatibilityKey: string; embed(text: string, options?: EmbedOptions): Promise; embedBatch(texts: string[], options?: EmbedOptions): Promise<(EmbeddingResult | null)[]>; +} + +/** + * Session interface for scoped LLM access with lifecycle guarantees + */ +export interface ILLMSession extends EmbeddingProvider { expandQuery(query: string, options?: { context?: string; includeLexical?: boolean }): Promise; rerank(query: string, documents: RerankDocument[], options?: RerankOptions): Promise; /** Whether this session is still valid (not released or aborted) */ @@ -366,12 +380,7 @@ export async function pullModels( /** * Abstract LLM interface - implement this for different backends */ -export interface LLM { - /** - * Get embeddings for text - */ - embed(text: string, options?: EmbedOptions): Promise; - +export interface LLM extends EmbeddingProvider { /** * Generate text completion */ @@ -514,6 +523,18 @@ export class LlamaCpp implements LLM { return this.embedModelUri; } + get providerId(): string { + return "llama.cpp"; + } + + get modelId(): string { + return this.embedModelUri; + } + + get compatibilityKey(): string { + return `${this.providerId}:${this.modelId}`; + } + /** * Reset the inactivity timer. Called after each model operation. * When timer fires, models are unloaded to free memory (if no active sessions). @@ -1497,6 +1518,18 @@ class LLMSession implements ILLMSession { return this.abortController.signal; } + get providerId(): string { + return this.manager.getLlamaCpp().providerId; + } + + get modelId(): string { + return this.manager.getLlamaCpp().modelId; + } + + get compatibilityKey(): string { + return this.manager.getLlamaCpp().compatibilityKey; + } + /** * Release the session and decrement ref count. * Called automatically by withLLMSession when the callback completes. diff --git a/src/store.ts b/src/store.ts index 16a55b7df..824f97700 100644 --- a/src/store.ts +++ b/src/store.ts @@ -25,7 +25,7 @@ import { formatDocForEmbedding, withLLMSessionForLlm, type RerankDocument, - type ILLMSession, + type EmbeddingProvider, } from "./llm.js"; import type { NamedCollection, @@ -66,6 +66,15 @@ function getLlm(store: Store): LlamaCpp { return store.llm ?? getDefaultLlamaCpp(); } +/** + * Get the embedding-capable provider for a store. + * This is intentionally narrower than getLlm(): callers that only need + * embed/embedBatch should not depend on generation or reranking capabilities. + */ +function getEmbeddingProvider(store: Store): EmbeddingProvider { + return store.llm ?? getDefaultLlamaCpp(); +} + // ============================================================================= // Smart Chunking - Break Point Detection // ============================================================================= @@ -1126,7 +1135,7 @@ export type Store = { // Search searchFTS: (query: string, limit?: number, collectionName?: string) => SearchResult[]; - searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => Promise; + searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: EmbeddingProvider, precomputedEmbedding?: number[]) => Promise; // Query expansion & reranking expandQuery: (query: string, model?: string, intent?: string) => Promise; @@ -1431,7 +1440,8 @@ export async function generateEmbeddings( // Use store's LlamaCpp or global singleton, wrapped in a session const llm = getLlm(store); - const embedModelUri = llm.embedModelName; + const embedProvider = getEmbeddingProvider(store); + const embedModelUri = embedProvider.modelId; // Create a session manager for this llm instance const result = await withLLMSessionForLlm(llm, async (session) => { @@ -1640,7 +1650,7 @@ export function createStore(dbPath?: string): Store { // Search searchFTS: (query: string, limit?: number, collectionName?: string) => searchFTS(db, query, limit, collectionName), - searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding), + searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: EmbeddingProvider, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding), // Query expansion & reranking expandQuery: (query: string, model?: string, intent?: string) => expandQuery(query, model, db, intent, store.llm), @@ -3096,7 +3106,15 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle // Vector Search // ============================================================================= -export async function searchVec(db: Database, query: string, model: string, limit: number = 20, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]): Promise { +export async function searchVec( + db: Database, + query: string, + model: string, + limit: number = 20, + collectionName?: string, + session?: EmbeddingProvider, + precomputedEmbedding?: number[], +): Promise { const tableExists = db.prepare(`SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'`).get(); if (!tableExists) return []; @@ -3186,12 +3204,18 @@ export async function searchVec(db: Database, query: string, model: string, limi // Embeddings // ============================================================================= -async function getEmbedding(text: string, model: string, isQuery: boolean, session?: ILLMSession, llmOverride?: LlamaCpp): Promise { +async function getEmbedding( + text: string, + model: string, + isQuery: boolean, + session?: EmbeddingProvider, + providerOverride?: EmbeddingProvider, +): Promise { // Format text using the appropriate prompt template const formattedText = isQuery ? formatQueryForEmbedding(text, model) : formatDocForEmbedding(text, undefined, model); const result = session ? await session.embed(formattedText, { model, isQuery }) - : await (llmOverride ?? getDefaultLlamaCpp()).embed(formattedText, { model, isQuery }); + : await (providerOverride ?? getDefaultLlamaCpp()).embed(formattedText, { model, isQuery }); return result?.embedding || null; } @@ -4087,11 +4111,11 @@ export async function hybridQuery( } // Batch embed all vector queries in a single call - const llm = getLlm(store); - const textsToEmbed = vecQueries.map(q => formatQueryForEmbedding(q.text, llm.embedModelName)); + const embedProvider = getEmbeddingProvider(store); + const textsToEmbed = vecQueries.map(q => formatQueryForEmbedding(q.text, embedProvider.modelId)); hooks?.onEmbedStart?.(textsToEmbed.length); const embedStart = Date.now(); - const embeddings = await llm.embedBatch(textsToEmbed); + const embeddings = await embedProvider.embedBatch(textsToEmbed); hooks?.onEmbedDone?.(Date.now() - embedStart); // Run sqlite-vec lookups with pre-computed embeddings @@ -4470,11 +4494,11 @@ export async function structuredSearch( s.type === 'vec' || s.type === 'hyde' ); if (vecSearches.length > 0) { - const llm = getLlm(store); - const textsToEmbed = vecSearches.map(s => formatQueryForEmbedding(s.query, llm.embedModelName)); + const embedProvider = getEmbeddingProvider(store); + const textsToEmbed = vecSearches.map(s => formatQueryForEmbedding(s.query, embedProvider.modelId)); hooks?.onEmbedStart?.(textsToEmbed.length); const embedStart = Date.now(); - const embeddings = await llm.embedBatch(textsToEmbed); + const embeddings = await embedProvider.embedBatch(textsToEmbed); hooks?.onEmbedDone?.(Date.now() - embedStart); for (let i = 0; i < vecSearches.length; i++) { diff --git a/test/llm.test.ts b/test/llm.test.ts index 74b643015..1dc26c919 100644 --- a/test/llm.test.ts +++ b/test/llm.test.ts @@ -194,6 +194,17 @@ describe("LlamaCpp model resolution (config > env > default)", () => { }); }); +describe("LlamaCpp embedding provider seam", () => { + test("exposes provider metadata for embeddings", () => { + const llm = new LlamaCpp({ embedModel: "hf:config/model.gguf" }); + + expect(llm.providerId).toBe("llama.cpp"); + expect(llm.modelId).toBe("hf:config/model.gguf"); + expect(llm.compatibilityKey).toBe("llama.cpp:hf:config/model.gguf"); + expect(llm.embedModelName).toBe("hf:config/model.gguf"); + }); +}); + describe("LlamaCpp embedding truncation", () => { test("truncates against the active embedding context limit, not the model train context", async () => { const llm = new LlamaCpp({}) as any; @@ -690,6 +701,9 @@ describe.skipIf(!!process.env.CI)("LLM Session Management", () => { test("session provides access to LLM operations", async () => { const result = await withLLMSession(async (session) => { expect(session.isValid).toBe(true); + expect(session.providerId).toBe("llama.cpp"); + expect(session.modelId).toBe(getDefaultLlamaCpp().embedModelName); + expect(session.compatibilityKey).toBe(`llama.cpp:${getDefaultLlamaCpp().embedModelName}`); const embedding = await session.embed("test text"); expect(embedding).not.toBeNull(); expect(embedding!.embedding.length).toBe(768); From 53bf40d9cea1cfbb93d1c39d0a785d8922354517 Mon Sep 17 00:00:00 2001 From: Chris Cage Date: Sun, 12 Apr 2026 21:10:04 -0500 Subject: [PATCH 2/5] Add embedding model compatibility checks --- src/store.ts | 90 +++++++++++++++++++++++++++++++--------------- test/store.test.ts | 56 ++++++++++++++++++++++++++--- 2 files changed, 114 insertions(+), 32 deletions(-) diff --git a/src/store.ts b/src/store.ts index 824f97700..63cf541f9 100644 --- a/src/store.ts +++ b/src/store.ts @@ -75,6 +75,29 @@ function getEmbeddingProvider(store: Store): EmbeddingProvider { return store.llm ?? getDefaultLlamaCpp(); } +/** + * Resolve the key stored in vector metadata and used for compatibility checks. + * Explicit model overrides win. Otherwise prefer the provider's compatibility + * key, falling back to the provider model id or the legacy default. + */ +function getEmbeddingModelKey( + provider?: Partial, + explicitModel?: string, +): string { + return explicitModel ?? provider?.compatibilityKey ?? provider?.modelId ?? DEFAULT_EMBED_MODEL; +} + +/** + * Resolve the model identifier used for embedding prompt formatting. + * This should prefer the provider's actual model id when available. + */ +function getEmbeddingFormatModel( + provider?: Partial, + explicitModel?: string, +): string { + return provider?.modelId ?? explicitModel ?? DEFAULT_EMBED_MODEL; +} + // ============================================================================= // Smart Chunking - Break Point Detection // ============================================================================= @@ -1350,16 +1373,16 @@ function resolveEmbedOptions(options?: EmbedOptions): Required { const db = store.db; - const model = options?.model ?? DEFAULT_EMBED_MODEL; + const embedProvider = getEmbeddingProvider(store); + const model = getEmbeddingModelKey(embedProvider, options?.model); const now = new Date().toISOString(); const { maxDocsPerBatch, maxBatchBytes } = resolveEmbedOptions(options); const encoder = new TextEncoder(); @@ -1429,7 +1453,7 @@ export async function generateEmbeddings( clearAllEmbeddings(db); } - const docsToEmbed = getPendingEmbeddingDocs(db); + const docsToEmbed = getPendingEmbeddingDocs(db, model); if (docsToEmbed.length === 0) { return { docsProcessed: 0, chunksEmbedded: 0, errors: 0, durationMs: 0 }; @@ -1440,8 +1464,7 @@ export async function generateEmbeddings( // Use store's LlamaCpp or global singleton, wrapped in a session const llm = getLlm(store); - const embedProvider = getEmbeddingProvider(store); - const embedModelUri = embedProvider.modelId; + const embedModelUri = getEmbeddingFormatModel(embedProvider, options?.model); // Create a session manager for this llm instance const result = await withLLMSessionForLlm(llm, async (session) => { @@ -1609,6 +1632,7 @@ export function createStore(dbPath?: string): Store { const resolvedPath = dbPath || getDefaultDbPath(); const db = openDatabase(resolvedPath); initializeDatabase(db); + const resolveActiveEmbedModel = () => getEmbeddingModelKey(getEmbeddingProvider(store)); const store: Store = { db, @@ -1617,9 +1641,9 @@ export function createStore(dbPath?: string): Store { ensureVecTable: (dimensions: number) => ensureVecTableInternal(db, dimensions), // Index health - getHashesNeedingEmbedding: () => getHashesNeedingEmbedding(db), - getIndexHealth: () => getIndexHealth(db), - getStatus: () => getStatus(db), + getHashesNeedingEmbedding: () => getHashesNeedingEmbedding(db, resolveActiveEmbedModel()), + getIndexHealth: () => getIndexHealth(db, resolveActiveEmbedModel()), + getStatus: () => getStatus(db, resolveActiveEmbedModel()), // Caching getCacheKey, @@ -1677,7 +1701,7 @@ export function createStore(dbPath?: string): Store { getActiveDocumentPaths: (collectionName: string) => getActiveDocumentPaths(db, collectionName), // Vector/embedding operations - getHashesForEmbedding: () => getHashesForEmbedding(db), + getHashesForEmbedding: () => getHashesForEmbedding(db, resolveActiveEmbedModel()), clearAllEmbeddings: () => clearAllEmbeddings(db), insertEmbedding: (hash: string, seq: number, pos: number, embedding: Float32Array, model: string, embeddedAt: string) => insertEmbedding(db, hash, seq, pos, embedding, model, embeddedAt), }; @@ -1878,13 +1902,13 @@ export type IndexStatus = { // Index health // ============================================================================= -export function getHashesNeedingEmbedding(db: Database): number { +export function getHashesNeedingEmbedding(db: Database, model: string = DEFAULT_EMBED_MODEL): number { const result = db.prepare(` SELECT COUNT(DISTINCT d.hash) as count FROM documents d - LEFT JOIN content_vectors v ON d.hash = v.hash AND v.seq = 0 + LEFT JOIN content_vectors v ON d.hash = v.hash AND v.seq = 0 AND v.model = ? WHERE d.active = 1 AND v.hash IS NULL - `).get() as { count: number }; + `).get(model) as { count: number }; return result.count; } @@ -1894,8 +1918,8 @@ export type IndexHealthInfo = { daysStale: number | null; }; -export function getIndexHealth(db: Database): IndexHealthInfo { - const needsEmbedding = getHashesNeedingEmbedding(db); +export function getIndexHealth(db: Database, model: string = DEFAULT_EMBED_MODEL): IndexHealthInfo { + const needsEmbedding = getHashesNeedingEmbedding(db, model); const totalDocs = (db.prepare(`SELECT COUNT(*) as count FROM documents WHERE active = 1`).get() as { count: number }).count; const mostRecent = db.prepare(`SELECT MAX(modified_at) as latest FROM documents WHERE active = 1`).get() as { latest: string | null }; @@ -3118,6 +3142,7 @@ export async function searchVec( const tableExists = db.prepare(`SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'`).get(); if (!tableExists) return []; + const activeModelKey = getEmbeddingModelKey(session, model); const embedding = precomputedEmbedding ?? await getEmbedding(query, model, true, session); if (!embedding) return []; @@ -3154,8 +3179,9 @@ export async function searchVec( JOIN documents d ON d.hash = cv.hash AND d.active = 1 JOIN content ON content.hash = d.hash WHERE cv.hash || '_' || cv.seq IN (${placeholders}) + AND cv.model = ? `; - const params: string[] = [...hashSeqs]; + const params: string[] = [...hashSeqs, activeModelKey]; if (collectionName) { docSql += ` AND d.collection = ?`; @@ -3211,11 +3237,16 @@ async function getEmbedding( session?: EmbeddingProvider, providerOverride?: EmbeddingProvider, ): Promise { + const provider = session ?? providerOverride; + const formatModel = getEmbeddingFormatModel(provider, model); + const modelKey = getEmbeddingModelKey(provider, model); // Format text using the appropriate prompt template - const formattedText = isQuery ? formatQueryForEmbedding(text, model) : formatDocForEmbedding(text, undefined, model); + const formattedText = isQuery + ? formatQueryForEmbedding(text, formatModel) + : formatDocForEmbedding(text, undefined, formatModel); const result = session - ? await session.embed(formattedText, { model, isQuery }) - : await (providerOverride ?? getDefaultLlamaCpp()).embed(formattedText, { model, isQuery }); + ? await session.embed(formattedText, { model: modelKey, isQuery }) + : await (providerOverride ?? getDefaultLlamaCpp()).embed(formattedText, { model: modelKey, isQuery }); return result?.embedding || null; } @@ -3223,15 +3254,15 @@ async function getEmbedding( * Get all unique content hashes that need embeddings (from active documents). * Returns hash, document body, and a sample path for display purposes. */ -export function getHashesForEmbedding(db: Database): { hash: string; body: string; path: string }[] { +export function getHashesForEmbedding(db: Database, model: string = DEFAULT_EMBED_MODEL): { hash: string; body: string; path: string }[] { return db.prepare(` SELECT d.hash, c.doc as body, MIN(d.path) as path FROM documents d JOIN content c ON d.hash = c.hash - LEFT JOIN content_vectors v ON d.hash = v.hash AND v.seq = 0 + LEFT JOIN content_vectors v ON d.hash = v.hash AND v.seq = 0 AND v.model = ? WHERE d.active = 1 AND v.hash IS NULL GROUP BY d.hash - `).all() as { hash: string; body: string; path: string }[]; + `).all(model) as { hash: string; body: string; path: string }[]; } /** @@ -3773,7 +3804,7 @@ export function findDocuments( // Status // ============================================================================= -export function getStatus(db: Database): IndexStatus { +export function getStatus(db: Database, model: string = DEFAULT_EMBED_MODEL): IndexStatus { // DB is source of truth for collections — config provides supplementary metadata const dbCollections = db.prepare(` SELECT @@ -3808,7 +3839,7 @@ export function getStatus(db: Database): IndexStatus { }); const totalDocs = (db.prepare(`SELECT COUNT(*) as c FROM documents WHERE active = 1`).get() as { c: number }).c; - const needsEmbedding = getHashesNeedingEmbedding(db); + const needsEmbedding = getHashesNeedingEmbedding(db, model); const hasVectors = !!db.prepare(`SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'`).get(); return { @@ -4101,6 +4132,7 @@ export async function hybridQuery( // 3b: Collect all texts that need vector search (original query + vec/hyde expansions) if (hasVectors) { + const activeEmbedModel = getEmbeddingModelKey(getEmbeddingProvider(store)); const vecQueries: { text: string; queryType: "original" | "vec" | "hyde" }[] = [ { text: query, queryType: "original" }, ]; @@ -4124,7 +4156,7 @@ export async function hybridQuery( if (!embedding) continue; const vecResults = await store.searchVec( - vecQueries[i]!.text, DEFAULT_EMBED_MODEL, 20, collection, + vecQueries[i]!.text, activeEmbedModel, 20, collection, undefined, embedding ); if (vecResults.length > 0) { @@ -4346,6 +4378,7 @@ export async function vectorSearchQuery( `SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'` ).get(); if (!hasVectors) return []; + const activeEmbedModel = getEmbeddingModelKey(getEmbeddingProvider(store)); // Expand query — filter to vec/hyde only (lex queries target FTS, not vector) const expandStart = Date.now(); @@ -4357,7 +4390,7 @@ export async function vectorSearchQuery( const queryTexts = [query, ...vecExpanded.map(q => q.query)]; const allResults = new Map(); for (const q of queryTexts) { - const vecResults = await store.searchVec(q, DEFAULT_EMBED_MODEL, limit, collection); + const vecResults = await store.searchVec(q, activeEmbedModel, limit, collection); for (const r of vecResults) { const existing = allResults.get(r.filepath); if (!existing || r.score > existing.score) { @@ -4489,6 +4522,7 @@ export async function structuredSearch( // Step 2: Batch embed and run vector searches for vec/hyde if (hasVectors) { + const activeEmbedModel = getEmbeddingModelKey(getEmbeddingProvider(store)); const vecSearches = searches.filter( (s): s is ExpandedQuery & { type: 'vec' | 'hyde' } => s.type === 'vec' || s.type === 'hyde' @@ -4507,7 +4541,7 @@ export async function structuredSearch( for (const coll of collectionList) { const vecResults = await store.searchVec( - vecSearches[i]!.query, DEFAULT_EMBED_MODEL, 20, coll, + vecSearches[i]!.query, activeEmbedModel, 20, coll, undefined, embedding ); if (vecResults.length > 0) { diff --git a/test/store.test.ts b/test/store.test.ts index 848ec9683..4b9569a90 100644 --- a/test/store.test.ts +++ b/test/store.test.ts @@ -48,6 +48,7 @@ import { syncConfigToDb, STRONG_SIGNAL_MIN_SCORE, STRONG_SIGNAL_MIN_GAP, + DEFAULT_EMBED_MODEL, generateEmbeddings, type Store, type DocumentResult, @@ -2082,6 +2083,31 @@ describe("Index Status", () => { await cleanupTestDb(store); }); + test("getHashesNeedingEmbedding uses the active embedding model", async () => { + const store = await createTestStore(); + const collectionName = await createTestCollection(); + store.llm = { + providerId: "test", + modelId: DEFAULT_EMBED_MODEL, + compatibilityKey: DEFAULT_EMBED_MODEL, + } as any; + + await insertTestDocument(store.db, collectionName, { name: "doc1", hash: "hash1" }); + await insertTestDocument(store.db, collectionName, { name: "doc2", hash: "hash2" }); + + const now = new Date().toISOString(); + store.db.prepare( + `INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, ?, ?)` + ).run("hash1", DEFAULT_EMBED_MODEL, now); + store.db.prepare( + `INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, ?, ?)` + ).run("hash2", "other-model", now); + + expect(store.getHashesNeedingEmbedding()).toBe(1); + + await cleanupTestDb(store); + }); + test("getIndexHealth returns health info", async () => { const store = await createTestStore(); const collectionName = await createTestCollection(); @@ -2392,7 +2418,7 @@ describe.skipIf(!!process.env.CI)("LlamaCpp Integration", () => { // Create vector table and insert a vector store.ensureVecTable(768); const embedding = Array(768).fill(0).map(() => Math.random()); - store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, 'test', ?)`).run(hash, new Date().toISOString()); + store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, ?, ?)`).run(hash, DEFAULT_EMBED_MODEL, new Date().toISOString()); store.db.prepare(`INSERT INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`).run(`${hash}_0`, new Float32Array(embedding)); const results = await store.searchVec("test query", "embeddinggemma", 10); @@ -2428,8 +2454,8 @@ describe.skipIf(!!process.env.CI)("LlamaCpp Integration", () => { store.ensureVecTable(768); const embedding1 = Array(768).fill(0).map(() => Math.random()); const embedding2 = Array(768).fill(0).map(() => Math.random()); - store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, 'test', ?)`).run(hash1, new Date().toISOString()); - store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, 'test', ?)`).run(hash2, new Date().toISOString()); + store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, ?, ?)`).run(hash1, DEFAULT_EMBED_MODEL, new Date().toISOString()); + store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, ?, ?)`).run(hash2, DEFAULT_EMBED_MODEL, new Date().toISOString()); store.db.prepare(`INSERT INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`).run(`${hash1}_0`, new Float32Array(embedding1)); store.db.prepare(`INSERT INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`).run(`${hash2}_0`, new Float32Array(embedding2)); @@ -2464,7 +2490,7 @@ describe.skipIf(!!process.env.CI)("LlamaCpp Integration", () => { // Create vector table and insert a test vector store.ensureVecTable(768); const embedding = Array(768).fill(0).map(() => Math.random()); - store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, 'test', ?)`).run(hash, new Date().toISOString()); + store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, ?, ?)`).run(hash, DEFAULT_EMBED_MODEL, new Date().toISOString()); store.db.prepare(`INSERT INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`).run(`${hash}_0`, new Float32Array(embedding)); // This should complete quickly (not hang) due to the two-step fix @@ -2481,6 +2507,28 @@ describe.skipIf(!!process.env.CI)("LlamaCpp Integration", () => { await cleanupTestDb(store); }); + test("searchVec ignores vectors from an incompatible model", async () => { + const store = await createTestStore(); + const collectionName = await createTestCollection(); + + const hash = "wrongmodelhash"; + await insertTestDocument(store.db, collectionName, { + name: "doc1", + hash, + body: "Some content about testing", + }); + + store.ensureVecTable(768); + const embedding = Array(768).fill(0).map(() => Math.random()); + store.db.prepare(`INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) VALUES (?, 0, 0, ?, ?)`).run(hash, "other-model", new Date().toISOString()); + store.db.prepare(`INSERT INTO vectors_vec (hash_seq, embedding) VALUES (?, ?)`).run(`${hash}_0`, new Float32Array(embedding)); + + const results = await store.searchVec("test query", DEFAULT_EMBED_MODEL, 10); + expect(results).toHaveLength(0); + + await cleanupTestDb(store); + }); + test("expandQuery returns typed expansions (no original query)", async () => { const store = await createTestStore(); From 485e64ab9a413e3c1627555612fc4deab1c38224 Mon Sep 17 00:00:00 2001 From: Chris Cage Date: Sun, 12 Apr 2026 21:39:56 -0500 Subject: [PATCH 3/5] Add OpenAI embedding provider --- src/cli/qmd.ts | 9 ++- src/index.ts | 5 ++ src/llm.ts | 199 +++++++++++++++++++++++++++++++++++++++++++++++ src/store.ts | 40 ++++++++-- test/llm.test.ts | 76 ++++++++++++++++++ 5 files changed, 322 insertions(+), 7 deletions(-) diff --git a/src/cli/qmd.ts b/src/cli/qmd.ts index 50ae76486..343e3072a 100755 --- a/src/cli/qmd.ts +++ b/src/cli/qmd.ts @@ -78,7 +78,7 @@ import { type ReindexResult, type ChunkStrategy, } from "../store.js"; -import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, setDefaultLlamaCpp, LlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js"; +import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, setDefaultLlamaCpp, setDefaultEmbeddingProvider, createEmbeddingProvider, LlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js"; import { formatSearchResults, formatDocuments, @@ -121,10 +121,15 @@ function getStore(): ReturnType { const config = loadConfig(); syncConfigToDb(store.db, config); if (config.models) { - setDefaultLlamaCpp(new LlamaCpp({ + const llm = new LlamaCpp({ embedModel: config.models.embed, generateModel: config.models.generate, rerankModel: config.models.rerank, + }); + setDefaultLlamaCpp(llm); + setDefaultEmbeddingProvider(createEmbeddingProvider({ + embedModel: config.models.embed, + localProvider: llm, })); } } catch { diff --git a/src/index.ts b/src/index.ts index 677234743..2dd2863dc 100644 --- a/src/index.ts +++ b/src/index.ts @@ -66,6 +66,7 @@ import { } from "./store.js"; import { LlamaCpp, + createEmbeddingProvider, } from "./llm.js"; import { setConfigSource, @@ -375,6 +376,10 @@ export async function createStore(options: StoreOptions): Promise { disposeModelsOnInactivity: true, }); internal.llm = llm; + internal.embeddingProvider = createEmbeddingProvider({ + embedModel: config?.models?.embed, + localProvider: llm, + }); const store: QMDStore = { internal, diff --git a/src/llm.ts b/src/llm.ts index a67d86742..b6cb27786 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -122,6 +122,8 @@ export type EmbedOptions = { title?: string; }; +export type EmbeddingProviderName = "llama.cpp" | "openai"; + /** * Options for text generation */ @@ -221,6 +223,10 @@ export const LFM2_INSTRUCT_MODEL = "hf:LiquidAI/LFM2.5-1.2B-Instruct-GGUF/LFM2.5 export const DEFAULT_EMBED_MODEL_URI = DEFAULT_EMBED_MODEL; export const DEFAULT_RERANK_MODEL_URI = DEFAULT_RERANK_MODEL; export const DEFAULT_GENERATE_MODEL_URI = DEFAULT_GENERATE_MODEL; +export const DEFAULT_OPENAI_EMBED_MODEL = "text-embedding-3-small"; +export const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"; +const DEFAULT_OPENAI_TIMEOUT_MS = 30_000; +const DEFAULT_OPENAI_MAX_BATCH_SIZE = 128; // Local model cache directory const MODEL_CACHE_DIR = process.env.XDG_CACHE_HOME @@ -409,6 +415,179 @@ export interface LLM extends EmbeddingProvider { dispose(): Promise; } +export type OpenAIEmbeddingProviderConfig = { + model?: string; + apiKey?: string; + baseUrl?: string; + organization?: string; + project?: string; + dimensions?: number; + timeoutMs?: number; + maxBatchSize?: number; +}; + +export type EmbeddingProviderConfig = { + provider?: EmbeddingProviderName | string; + embedModel?: string; + localProvider?: EmbeddingProvider; + openai?: OpenAIEmbeddingProviderConfig; +}; + +export function resolveEmbeddingProviderName( + provider = process.env.QMD_EMBED_PROVIDER, +): EmbeddingProviderName { + const normalized = provider?.trim().toLowerCase() ?? ""; + if (!normalized) return "llama.cpp"; + if (normalized === "openai") return "openai"; + if (normalized === "llama.cpp" || normalized === "llama" || normalized === "local") return "llama.cpp"; + throw new Error(`Unsupported embedding provider "${provider}". Expected "llama.cpp" or "openai".`); +} + +function resolveOpenAIBaseUrl(baseUrl = process.env.QMD_OPENAI_BASE_URL || process.env.OPENAI_BASE_URL): string { + const resolved = baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL; + return resolved.replace(/\/+$/, ""); +} + +function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); +} + +type OpenAIEmbeddingsResponse = { + data: { embedding: number[]; index: number }[]; + model: string; +}; + +export class OpenAIEmbeddingProvider implements EmbeddingProvider { + readonly providerId = "openai"; + readonly modelId: string; + readonly compatibilityKey: string; + + private readonly apiKey: string; + private readonly baseUrl: string; + private readonly organization?: string; + private readonly project?: string; + private readonly dimensions?: number; + private readonly timeoutMs: number; + private readonly maxBatchSize: number; + + constructor(config: OpenAIEmbeddingProviderConfig = {}) { + this.modelId = config.model || process.env.QMD_EMBED_MODEL || DEFAULT_OPENAI_EMBED_MODEL; + this.apiKey = config.apiKey || process.env.OPENAI_API_KEY || ""; + this.baseUrl = resolveOpenAIBaseUrl(config.baseUrl); + this.organization = config.organization || process.env.OPENAI_ORG_ID || process.env.OPENAI_ORGANIZATION; + this.project = config.project || process.env.OPENAI_PROJECT_ID || process.env.OPENAI_PROJECT; + this.dimensions = config.dimensions ?? (process.env.QMD_OPENAI_EMBED_DIMENSIONS + ? Number.parseInt(process.env.QMD_OPENAI_EMBED_DIMENSIONS, 10) + : undefined); + this.timeoutMs = config.timeoutMs ?? DEFAULT_OPENAI_TIMEOUT_MS; + this.maxBatchSize = Math.max(1, config.maxBatchSize ?? DEFAULT_OPENAI_MAX_BATCH_SIZE); + this.compatibilityKey = this.dimensions + ? `${this.providerId}:${this.modelId}:${this.dimensions}` + : `${this.providerId}:${this.modelId}`; + } + + private buildHeaders(): HeadersInit { + const headers: Record = { + "Authorization": `Bearer ${this.apiKey}`, + "Content-Type": "application/json", + }; + if (this.organization) headers["OpenAI-Organization"] = this.organization; + if (this.project) headers["OpenAI-Project"] = this.project; + return headers; + } + + private async requestEmbeddings(inputs: string[]): Promise { + if (!this.apiKey) { + throw new Error("OPENAI_API_KEY is required when QMD_EMBED_PROVIDER=openai."); + } + + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(new Error("OpenAI embeddings request timed out")), this.timeoutMs); + try { + let lastError: Error | null = null; + for (let attempt = 0; attempt < 3; attempt++) { + try { + const response = await fetch(`${this.baseUrl}/embeddings`, { + method: "POST", + headers: this.buildHeaders(), + body: JSON.stringify({ + input: inputs, + model: this.modelId, + encoding_format: "float", + ...(this.dimensions ? { dimensions: this.dimensions } : {}), + }), + signal: controller.signal, + }); + + if (!response.ok) { + const text = await response.text(); + if ((response.status === 429 || response.status >= 500) && attempt < 2) { + await sleep(250 * (attempt + 1)); + continue; + } + throw new Error(`OpenAI embeddings request failed (${response.status}): ${text}`); + } + + return await response.json() as OpenAIEmbeddingsResponse; + } catch (error) { + lastError = error instanceof Error ? error : new Error(String(error)); + if (attempt < 2) { + await sleep(250 * (attempt + 1)); + continue; + } + } + } + throw lastError ?? new Error("OpenAI embeddings request failed"); + } finally { + clearTimeout(timer); + } + } + + async embed(text: string, options: EmbedOptions = {}): Promise { + const [result] = await this.embedBatch([text], options); + return result; + } + + async embedBatch(texts: string[], options: EmbedOptions = {}): Promise<(EmbeddingResult | null)[]> { + if (texts.length === 0) return []; + + const batches: string[][] = []; + for (let i = 0; i < texts.length; i += this.maxBatchSize) { + batches.push(texts.slice(i, i + this.maxBatchSize)); + } + + const results: (EmbeddingResult | null)[] = []; + for (const batch of batches) { + try { + const response = await this.requestEmbeddings(batch); + const byIndex = new Map(response.data.map(item => [item.index, item.embedding])); + for (let i = 0; i < batch.length; i++) { + const embedding = byIndex.get(i); + results.push(embedding ? { + embedding, + model: options.model ?? this.compatibilityKey, + } : null); + } + } catch (error) { + console.error("OpenAI embedding batch error:", error); + results.push(...batch.map(() => null)); + } + } + return results; + } +} + +export function createEmbeddingProvider(config: EmbeddingProviderConfig = {}): EmbeddingProvider { + const providerName = resolveEmbeddingProviderName(config.provider); + if (providerName === "openai") { + return new OpenAIEmbeddingProvider({ + model: config.embedModel, + ...config.openai, + }); + } + return config.localProvider ?? new LlamaCpp({ embedModel: config.embedModel }); +} + // ============================================================================= // node-llama-cpp Implementation // ============================================================================= @@ -1668,6 +1847,7 @@ export function canUnloadLLM(): boolean { // ============================================================================= let defaultLlamaCpp: LlamaCpp | null = null; +let defaultEmbeddingProvider: EmbeddingProvider | null = null; /** * Get the default LlamaCpp instance (creates one if needed) @@ -1679,11 +1859,27 @@ export function getDefaultLlamaCpp(): LlamaCpp { return defaultLlamaCpp; } +export function getDefaultEmbeddingProvider(): EmbeddingProvider { + if (!defaultEmbeddingProvider) { + defaultEmbeddingProvider = createEmbeddingProvider({ + localProvider: getDefaultLlamaCpp(), + }); + } + return defaultEmbeddingProvider; +} + /** * Set a custom default LlamaCpp instance (useful for testing) */ export function setDefaultLlamaCpp(llm: LlamaCpp | null): void { defaultLlamaCpp = llm; + if (resolveEmbeddingProviderName() === "llama.cpp") { + defaultEmbeddingProvider = llm; + } +} + +export function setDefaultEmbeddingProvider(provider: EmbeddingProvider | null): void { + defaultEmbeddingProvider = provider; } /** @@ -1695,4 +1891,7 @@ export async function disposeDefaultLlamaCpp(): Promise { await defaultLlamaCpp.dispose(); defaultLlamaCpp = null; } + if (defaultEmbeddingProvider instanceof LlamaCpp) { + defaultEmbeddingProvider = null; + } } diff --git a/src/store.ts b/src/store.ts index 63cf541f9..ad3d9d0c9 100644 --- a/src/store.ts +++ b/src/store.ts @@ -21,11 +21,13 @@ import fastGlob from "fast-glob"; import { LlamaCpp, getDefaultLlamaCpp, + getDefaultEmbeddingProvider, formatQueryForEmbedding, formatDocForEmbedding, withLLMSessionForLlm, type RerankDocument, type EmbeddingProvider, + type ILLMSession, } from "./llm.js"; import type { NamedCollection, @@ -72,7 +74,7 @@ function getLlm(store: Store): LlamaCpp { * embed/embedBatch should not depend on generation or reranking capabilities. */ function getEmbeddingProvider(store: Store): EmbeddingProvider { - return store.llm ?? getDefaultLlamaCpp(); + return store.embeddingProvider ?? store.llm ?? getDefaultEmbeddingProvider(); } /** @@ -1121,6 +1123,8 @@ export type Store = { dbPath: string; /** Optional LlamaCpp instance for this store (overrides the global singleton) */ llm?: LlamaCpp; + /** Optional embedding provider for this store (can be remote) */ + embeddingProvider?: EmbeddingProvider; close: () => void; ensureVecTable: (dimensions: number) => void; @@ -1433,6 +1437,34 @@ function getEmbeddingDocsForBatch(db: Database, batch: PendingEmbeddingDoc[]): E })); } +type EmbeddingWorkSession = EmbeddingProvider & Pick; + +async function withEmbeddingSession( + provider: EmbeddingProvider, + fn: (session: EmbeddingWorkSession) => Promise, +): Promise { + if (provider instanceof LlamaCpp) { + return withLLMSessionForLlm(provider, fn); + } + + const controller = new AbortController(); + const session: EmbeddingWorkSession = { + providerId: provider.providerId, + modelId: provider.modelId, + compatibilityKey: provider.compatibilityKey, + get isValid() { return !controller.signal.aborted; }, + signal: controller.signal, + embed: provider.embed.bind(provider), + embedBatch: provider.embedBatch.bind(provider), + }; + + try { + return await fn(session); + } finally { + controller.abort(new Error("Embedding session released")); + } +} + /** * Generate vector embeddings for documents that need them. * Pure function — no console output, no db lifecycle management. @@ -1462,12 +1494,10 @@ export async function generateEmbeddings( const totalDocs = docsToEmbed.length; const startTime = Date.now(); - // Use store's LlamaCpp or global singleton, wrapped in a session - const llm = getLlm(store); + // Use the provider's actual embedding model for prompt formatting. const embedModelUri = getEmbeddingFormatModel(embedProvider, options?.model); - // Create a session manager for this llm instance - const result = await withLLMSessionForLlm(llm, async (session) => { + const result = await withEmbeddingSession(embedProvider, async (session) => { let chunksEmbedded = 0; let errors = 0; let bytesProcessed = 0; diff --git a/test/llm.test.ts b/test/llm.test.ts index 1dc26c919..22b112b65 100644 --- a/test/llm.test.ts +++ b/test/llm.test.ts @@ -10,9 +10,12 @@ import { describe, test, expect, beforeAll, afterAll, vi } from "vitest"; import { LlamaCpp, + OpenAIEmbeddingProvider, + createEmbeddingProvider, getDefaultLlamaCpp, disposeDefaultLlamaCpp, resolveLlamaGpuMode, + resolveEmbeddingProviderName, withLLMSession, canUnloadLLM, SessionReleasedError, @@ -205,6 +208,79 @@ describe("LlamaCpp embedding provider seam", () => { }); }); +describe("Embedding provider resolution", () => { + test("uses llama.cpp by default", () => { + const prevProvider = process.env.QMD_EMBED_PROVIDER; + try { + delete process.env.QMD_EMBED_PROVIDER; + expect(resolveEmbeddingProviderName()).toBe("llama.cpp"); + } finally { + if (prevProvider === undefined) delete process.env.QMD_EMBED_PROVIDER; + else process.env.QMD_EMBED_PROVIDER = prevProvider; + } + }); + + test("uses openai when explicitly configured", () => { + expect(resolveEmbeddingProviderName("openai")).toBe("openai"); + }); + + test("createEmbeddingProvider returns OpenAI provider when QMD_EMBED_PROVIDER=openai", () => { + const prevProvider = process.env.QMD_EMBED_PROVIDER; + const prevModel = process.env.QMD_EMBED_MODEL; + const prevKey = process.env.OPENAI_API_KEY; + try { + process.env.QMD_EMBED_PROVIDER = "openai"; + process.env.QMD_EMBED_MODEL = "text-embedding-3-small"; + process.env.OPENAI_API_KEY = "test-key"; + const provider = createEmbeddingProvider(); + expect(provider).toBeInstanceOf(OpenAIEmbeddingProvider); + expect(provider.compatibilityKey).toBe("openai:text-embedding-3-small"); + } finally { + if (prevProvider === undefined) delete process.env.QMD_EMBED_PROVIDER; + else process.env.QMD_EMBED_PROVIDER = prevProvider; + if (prevModel === undefined) delete process.env.QMD_EMBED_MODEL; + else process.env.QMD_EMBED_MODEL = prevModel; + if (prevKey === undefined) delete process.env.OPENAI_API_KEY; + else process.env.OPENAI_API_KEY = prevKey; + } + }); +}); + +describe("OpenAIEmbeddingProvider", () => { + test("batches embeddings through the OpenAI API and preserves order", async () => { + const fetchMock = vi.fn(async () => ({ + ok: true, + json: async () => ({ + data: [ + { index: 0, embedding: [0.1, 0.2] }, + { index: 1, embedding: [0.3, 0.4] }, + ], + model: "text-embedding-3-small", + }), + })); + + const originalFetch = globalThis.fetch; + vi.stubGlobal("fetch", fetchMock as any); + + try { + const provider = new OpenAIEmbeddingProvider({ + apiKey: "test-key", + model: "text-embedding-3-small", + }); + const results = await provider.embedBatch(["alpha", "beta"]); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(results).toEqual([ + { embedding: [0.1, 0.2], model: "openai:text-embedding-3-small" }, + { embedding: [0.3, 0.4], model: "openai:text-embedding-3-small" }, + ]); + } finally { + vi.unstubAllGlobals(); + globalThis.fetch = originalFetch; + } + }); +}); + describe("LlamaCpp embedding truncation", () => { test("truncates against the active embedding context limit, not the model train context", async () => { const llm = new LlamaCpp({}) as any; From 31e0a22a9ab442e2035967f25c1715015c2429ab Mon Sep 17 00:00:00 2001 From: Chris Cage Date: Mon, 13 Apr 2026 08:33:31 -0500 Subject: [PATCH 4/5] Fix OpenAI embedding runtime integration --- src/cli/qmd.ts | 121 ++++++++++++++++++++++++++++------------------- src/index.ts | 15 +++++- src/llm.ts | 9 ++-- src/store.ts | 53 ++++++++++++++------- test/llm.test.ts | 34 +++++++++++++ test/sdk.test.ts | 63 +++++++++++++++++++++++- 6 files changed, 222 insertions(+), 73 deletions(-) diff --git a/src/cli/qmd.ts b/src/cli/qmd.ts index 343e3072a..535d180ad 100755 --- a/src/cli/qmd.ts +++ b/src/cli/qmd.ts @@ -78,7 +78,7 @@ import { type ReindexResult, type ChunkStrategy, } from "../store.js"; -import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, setDefaultLlamaCpp, setDefaultEmbeddingProvider, createEmbeddingProvider, LlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js"; +import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, getDefaultEmbeddingProvider, setDefaultLlamaCpp, setDefaultEmbeddingProvider, createEmbeddingProvider, LlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "../llm.js"; import { formatSearchResults, formatDocuments, @@ -239,10 +239,22 @@ function formatETA(seconds: number): string { return `${Math.floor(seconds / 3600)}h ${Math.floor((seconds % 3600) / 60)}m`; } +function resolveActiveEmbeddingModelForStore(store: ReturnType): string { + const provider = store.embeddingProvider ?? getDefaultEmbeddingProvider(); + return provider.compatibilityKey ?? provider.modelId; +} + +function resolveEmbeddingDisplayModelForStore(store: ReturnType): string { + const provider = store.embeddingProvider ?? getDefaultEmbeddingProvider(); + return provider.modelId; +} // Check index health and print warnings/tips -function checkIndexHealth(db: Database): void { - const { needsEmbedding, totalDocs, daysStale } = getIndexHealth(db); +function checkIndexHealth(store: ReturnType): void { + const { needsEmbedding, totalDocs, daysStale } = getIndexHealth( + store.db, + resolveActiveEmbeddingModelForStore(store), + ); // Warn if many docs need embedding if (needsEmbedding > 0) { @@ -322,7 +334,10 @@ function formatBytes(bytes: number): string { async function showStatus(): Promise { const dbPath = getDbPath(); - const db = getDb(); + const store = getStore(); + const db = store.db; + const activeEmbeddingModel = resolveActiveEmbeddingModelForStore(store); + const embeddingDisplayModel = resolveEmbeddingDisplayModelForStore(store); // Collections are defined in YAML; no duplicate cleanup needed. // Collections are defined in YAML; no duplicate cleanup needed. @@ -340,7 +355,7 @@ async function showStatus(): Promise { // Overall stats const totalDocs = db.prepare(`SELECT COUNT(*) as count FROM documents WHERE active = 1`).get() as { count: number }; const vectorCount = db.prepare(`SELECT COUNT(*) as count FROM content_vectors`).get() as { count: number }; - const needsEmbedding = getHashesNeedingEmbedding(db); + const needsEmbedding = getHashesNeedingEmbedding(db, activeEmbeddingModel); // Most recent update across all collections const mostRecent = db.prepare(`SELECT MAX(modified_at) as latest FROM documents WHERE active = 1`).get() as { latest: string | null }; @@ -467,7 +482,7 @@ async function showStatus(): Promise { return match ? `https://huggingface.co/${match[1]}` : uri; }; console.log(`\n${c.bold}Models${c.reset}`); - console.log(` Embedding: ${hfLink(DEFAULT_EMBED_MODEL_URI)}`); + console.log(` Embedding: ${hfLink(embeddingDisplayModel)}`); console.log(` Reranking: ${hfLink(DEFAULT_RERANK_MODEL_URI)}`); console.log(` Generation: ${hfLink(DEFAULT_GENERATE_MODEL_URI)}`); } @@ -627,7 +642,7 @@ async function updateCollections(): Promise { } // Check if any documents need embedding (show once at end) - const needsEmbedding = getHashesNeedingEmbedding(db); + const needsEmbedding = getHashesNeedingEmbedding(db, resolveActiveEmbeddingModelForStore(storeInstance)); closeDb(); console.log(`${c.green}✓ All collections updated.${c.reset}`); @@ -1519,6 +1534,8 @@ function collectionRename(oldName: string, newName: string): void { async function indexFiles(pwd?: string, globPattern: string = DEFAULT_GLOB, collectionName?: string, suppressEmbedNotice: boolean = false, ignorePatterns?: string[]): Promise { const db = getDb(); + const storeInstance = getStore(); + const activeModelKey = resolveActiveEmbeddingModelForStore(storeInstance); const resolvedPwd = pwd || getPwd(); const now = new Date().toISOString(); const excludeDirs = ["node_modules", ".git", ".cache", "vendor", "dist", "build"]; @@ -1640,7 +1657,7 @@ async function indexFiles(pwd?: string, globPattern: string = DEFAULT_GLOB, coll const orphanedContent = cleanupOrphanedContent(db); // Check if vector index needs updating - const needsEmbedding = getHashesNeedingEmbedding(db); + const needsEmbedding = getHashesNeedingEmbedding(db, activeModelKey); progress.clear(); console.log(`\nIndexed: ${indexed} new, ${updated} updated, ${unchanged} unchanged, ${removed} removed`); @@ -1679,26 +1696,29 @@ function parseChunkStrategy(value: unknown): ChunkStrategy | undefined { } async function vectorIndex( - model: string = DEFAULT_EMBED_MODEL_URI, + model?: string, force: boolean = false, batchOptions?: { maxDocsPerBatch?: number; maxBatchBytes?: number; chunkStrategy?: ChunkStrategy }, ): Promise { const storeInstance = getStore(); const db = storeInstance.db; + const embedProvider = storeInstance.embeddingProvider ?? getDefaultEmbeddingProvider(); + const activeModelKey = model ?? embedProvider.compatibilityKey ?? embedProvider.modelId; + const displayModel = model ?? embedProvider.modelId; if (force) { console.log(`${c.yellow}Force re-indexing: clearing all vectors...${c.reset}`); } // Check if there's work to do before starting - const hashesToEmbed = getHashesNeedingEmbedding(db); + const hashesToEmbed = getHashesNeedingEmbedding(db, activeModelKey); if (hashesToEmbed === 0 && !force) { console.log(`${c.green}✓ All content hashes already have embeddings.${c.reset}`); closeDb(); return; } - console.log(`${c.dim}Model: ${model}${c.reset}\n`); + console.log(`${c.dim}Model: ${displayModel}${c.reset}\n`); if (batchOptions?.maxDocsPerBatch !== undefined || batchOptions?.maxBatchBytes !== undefined) { const maxDocsPerBatch = batchOptions.maxDocsPerBatch ?? DEFAULT_EMBED_MAX_DOCS_PER_BATCH; const maxBatchBytes = batchOptions.maxBatchBytes ?? DEFAULT_EMBED_MAX_BATCH_BYTES; @@ -1808,6 +1828,7 @@ type OutputOptions = { candidateLimit?: number; // Max candidates to rerank (default: 40) intent?: string; // Domain intent for disambiguation skipRerank?: boolean; // Skip LLM reranking, use RRF scores only + skipExpand?: boolean; // Skip query expansion, search only the original query chunkStrategy?: ChunkStrategy; // "auto" (default) or "regex" }; @@ -2289,47 +2310,46 @@ async function vectorSearch(query: string, opts: OutputOptions, _model: string = const collectionNames = resolveCollectionFilter(opts.collection, true); const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined; - checkIndexHealth(store.db); - - await withLLMSession(async () => { - let results = await vectorSearchQuery(store, query, { - collection: singleCollection, - limit: opts.all ? 500 : (opts.limit || 10), - minScore: opts.minScore || 0.3, - intent: opts.intent, - hooks: { - onExpand: (original, expanded) => { - logExpansionTree(original, expanded); - process.stderr.write(`${c.dim}Searching ${expanded.length + 1} vector queries...${c.reset}\n`); - }, + checkIndexHealth(store); + + let results = await vectorSearchQuery(store, query, { + collection: singleCollection, + limit: opts.all ? 500 : (opts.limit || 10), + minScore: opts.minScore || 0.3, + intent: opts.intent, + skipExpand: opts.skipExpand, + hooks: opts.skipExpand ? undefined : { + onExpand: (original, expanded) => { + logExpansionTree(original, expanded); + process.stderr.write(`${c.dim}Searching ${expanded.length + 1} vector queries...${c.reset}\n`); }, - }); + }, + }); - // Post-filter for multi-collection - if (collectionNames.length > 1) { - results = results.filter(r => { - const prefixes = collectionNames.map(n => `qmd://${n}/`); - return prefixes.some(p => r.file.startsWith(p)); - }); - } + // Post-filter for multi-collection + if (collectionNames.length > 1) { + results = results.filter(r => { + const prefixes = collectionNames.map(n => `qmd://${n}/`); + return prefixes.some(p => r.file.startsWith(p)); + }); + } - closeDb(); + closeDb(); - if (results.length === 0) { - printEmptySearchResults(opts.format); - return; - } + if (results.length === 0) { + printEmptySearchResults(opts.format); + return; + } - outputResults(results.map(r => ({ - file: r.file, - displayPath: r.displayPath, - title: r.title, - body: r.body, - score: r.score, - context: r.context, - docid: r.docid, - })), query, { ...opts, limit: results.length }); - }, { maxDuration: 10 * 60 * 1000, name: 'vectorSearch' }); + outputResults(results.map(r => ({ + file: r.file, + displayPath: r.displayPath, + title: r.title, + body: r.body, + score: r.score, + context: r.context, + docid: r.docid, + })), query, { ...opts, limit: results.length }); } async function querySearch(query: string, opts: OutputOptions, _embedModel: string = DEFAULT_EMBED_MODEL, _rerankModel: string = DEFAULT_RERANK_MODEL): Promise { @@ -2340,7 +2360,7 @@ async function querySearch(query: string, opts: OutputOptions, _embedModel: stri const collectionNames = resolveCollectionFilter(opts.collection, true); const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined; - checkIndexHealth(store.db); + checkIndexHealth(store); // Check for structured query syntax (lex:/vec:/hyde:/intent: prefixes) const parsed = parseStructuredQuery(query); @@ -2517,6 +2537,7 @@ function parseCLI() { // Query options "candidate-limit": { type: "string", short: "C" }, "no-rerank": { type: "boolean", default: false }, + "no-expand": { type: "boolean", default: false }, intent: { type: "string" }, // Chunking options "chunk-strategy": { type: "string" }, // "regex" (default) or "auto" (AST for code files) @@ -2559,6 +2580,7 @@ function parseCLI() { lineNumbers: !!values["line-numbers"], candidateLimit: values["candidate-limit"] ? parseInt(String(values["candidate-limit"]), 10) : undefined, skipRerank: !!values["no-rerank"], + skipExpand: !!values["no-expand"], explain: !!values.explain, intent: values.intent as string | undefined, chunkStrategy: parseChunkStrategy(values["chunk-strategy"]), @@ -2781,6 +2803,7 @@ function showHelp(): void { console.log(" --full - Output full document instead of snippet"); console.log(" -C, --candidate-limit - Max candidates to rerank (default 40, lower = faster)"); console.log(" --no-rerank - Skip LLM reranking (use RRF scores only, much faster on CPU)"); + console.log(" --no-expand - Skip local query expansion (search only the original query)"); console.log(" --line-numbers - Include line numbers in output"); console.log(" --explain - Include retrieval score traces (query --json/CLI)"); console.log(" --files | --json | --csv | --md | --xml - Output format"); @@ -3112,7 +3135,7 @@ if (isMain) { const maxDocsPerBatch = parseEmbedBatchOption("maxDocsPerBatch", cli.values["max-docs-per-batch"]); const maxBatchMb = parseEmbedBatchOption("maxBatchBytes", cli.values["max-batch-mb"]); const embedChunkStrategy = parseChunkStrategy(cli.values["chunk-strategy"]); - await vectorIndex(DEFAULT_EMBED_MODEL_URI, !!cli.values.force, { + await vectorIndex(undefined, !!cli.values.force, { maxDocsPerBatch, maxBatchBytes: maxBatchMb === undefined ? undefined : maxBatchMb * 1024 * 1024, chunkStrategy: embedChunkStrategy, diff --git a/src/index.ts b/src/index.ts index 2dd2863dc..3a9d9a113 100644 --- a/src/index.ts +++ b/src/index.ts @@ -23,7 +23,6 @@ import { structuredSearch, extractSnippet, addLineNumbers, - DEFAULT_EMBED_MODEL, DEFAULT_MULTI_GET_MAX_BYTES, reindexCollection, generateEmbeddings, @@ -380,6 +379,10 @@ export async function createStore(options: StoreOptions): Promise { embedModel: config?.models?.embed, localProvider: llm, }); + const resolveActiveEmbedModel = () => + internal.embeddingProvider?.compatibilityKey ?? + internal.embeddingProvider?.modelId ?? + llm.compatibilityKey; const store: QMDStore = { internal, @@ -422,7 +425,15 @@ export async function createStore(options: StoreOptions): Promise { }); }, searchLex: async (q, opts) => internal.searchFTS(q, opts?.limit, opts?.collection), - searchVector: async (q, opts) => internal.searchVec(q, DEFAULT_EMBED_MODEL, opts?.limit, opts?.collection), + searchVector: async (q, opts) => internal.searchVec( + q, + resolveActiveEmbedModel(), + opts?.limit, + opts?.collection, + undefined, + undefined, + internal.embeddingProvider, + ), expandQuery: async (q, opts) => internal.expandQuery(q, undefined, opts?.intent), get: async (pathOrDocid, opts) => internal.findDocument(pathOrDocid, opts), getDocumentBody: async (pathOrDocid, opts) => { diff --git a/src/llm.ts b/src/llm.ts index b6cb27786..5a5dcb4c5 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -1861,9 +1861,12 @@ export function getDefaultLlamaCpp(): LlamaCpp { export function getDefaultEmbeddingProvider(): EmbeddingProvider { if (!defaultEmbeddingProvider) { - defaultEmbeddingProvider = createEmbeddingProvider({ - localProvider: getDefaultLlamaCpp(), - }); + const providerName = resolveEmbeddingProviderName(); + defaultEmbeddingProvider = providerName === "llama.cpp" + ? createEmbeddingProvider({ + localProvider: getDefaultLlamaCpp(), + }) + : createEmbeddingProvider(); } return defaultEmbeddingProvider; } diff --git a/src/store.ts b/src/store.ts index ad3d9d0c9..0213613c2 100644 --- a/src/store.ts +++ b/src/store.ts @@ -1162,7 +1162,7 @@ export type Store = { // Search searchFTS: (query: string, limit?: number, collectionName?: string) => SearchResult[]; - searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: EmbeddingProvider, precomputedEmbedding?: number[]) => Promise; + searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: EmbeddingProvider, precomputedEmbedding?: number[], providerOverride?: EmbeddingProvider) => Promise; // Query expansion & reranking expandQuery: (query: string, model?: string, intent?: string) => Promise; @@ -1521,13 +1521,24 @@ export async function generateEmbeddings( if (!doc.body.trim()) continue; const title = extractTitle(doc.body, doc.path); - const chunks = await chunkDocumentByTokens( - doc.body, - undefined, undefined, undefined, - doc.path, - options?.chunkStrategy, - session.signal, - ); + const chunks = embedProvider instanceof LlamaCpp + ? await chunkDocumentByTokens( + doc.body, + undefined, undefined, undefined, + doc.path, + options?.chunkStrategy, + session.signal, + ) + : (await chunkDocumentAsync( + doc.body, + undefined, undefined, undefined, + doc.path, + options?.chunkStrategy, + )).map(chunk => ({ + text: chunk.text, + pos: chunk.pos, + tokens: 0, + })); for (let seq = 0; seq < chunks.length; seq++) { batchChunks.push({ @@ -1704,7 +1715,7 @@ export function createStore(dbPath?: string): Store { // Search searchFTS: (query: string, limit?: number, collectionName?: string) => searchFTS(db, query, limit, collectionName), - searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: EmbeddingProvider, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding), + searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: EmbeddingProvider, precomputedEmbedding?: number[], providerOverride?: EmbeddingProvider) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding, providerOverride), // Query expansion & reranking expandQuery: (query: string, model?: string, intent?: string) => expandQuery(query, model, db, intent, store.llm), @@ -3168,12 +3179,14 @@ export async function searchVec( collectionName?: string, session?: EmbeddingProvider, precomputedEmbedding?: number[], + providerOverride?: EmbeddingProvider, ): Promise { const tableExists = db.prepare(`SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'`).get(); if (!tableExists) return []; - const activeModelKey = getEmbeddingModelKey(session, model); - const embedding = precomputedEmbedding ?? await getEmbedding(query, model, true, session); + const provider = session ?? providerOverride; + const activeModelKey = getEmbeddingModelKey(provider, model); + const embedding = precomputedEmbedding ?? await getEmbedding(query, model, true, session, providerOverride); if (!embedding) return []; // IMPORTANT: We use a two-step query approach here because sqlite-vec virtual tables @@ -4372,6 +4385,7 @@ export interface VectorSearchOptions { limit?: number; // default 10 minScore?: number; // default 0.3 intent?: string; // domain intent hint for disambiguation + skipExpand?: boolean; // skip local query expansion, search only the original query hooks?: Pick; } @@ -4403,24 +4417,29 @@ export async function vectorSearchQuery( const minScore = options?.minScore ?? 0.3; const collection = options?.collection; const intent = options?.intent; + const skipExpand = options?.skipExpand ?? false; const hasVectors = !!store.db.prepare( `SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'` ).get(); if (!hasVectors) return []; - const activeEmbedModel = getEmbeddingModelKey(getEmbeddingProvider(store)); + const embedProvider = getEmbeddingProvider(store); + const activeEmbedModel = getEmbeddingModelKey(embedProvider); // Expand query — filter to vec/hyde only (lex queries target FTS, not vector) - const expandStart = Date.now(); - const allExpanded = await store.expandQuery(query, undefined, intent); - const vecExpanded = allExpanded.filter(q => q.type !== 'lex'); - options?.hooks?.onExpand?.(query, vecExpanded, Date.now() - expandStart); + let vecExpanded: ExpandedQuery[] = []; + if (!skipExpand) { + const expandStart = Date.now(); + const allExpanded = await store.expandQuery(query, undefined, intent); + vecExpanded = allExpanded.filter(q => q.type !== 'lex'); + options?.hooks?.onExpand?.(query, vecExpanded, Date.now() - expandStart); + } // Run original + vec/hyde expanded through vector, sequentially — concurrent embed() hangs const queryTexts = [query, ...vecExpanded.map(q => q.query)]; const allResults = new Map(); for (const q of queryTexts) { - const vecResults = await store.searchVec(q, activeEmbedModel, limit, collection); + const vecResults = await store.searchVec(q, activeEmbedModel, limit, collection, undefined, undefined, embedProvider); for (const r of vecResults) { const existing = allResults.get(r.filepath); if (!existing || r.score > existing.score) { diff --git a/test/llm.test.ts b/test/llm.test.ts index 22b112b65..ba95d661b 100644 --- a/test/llm.test.ts +++ b/test/llm.test.ts @@ -12,16 +12,20 @@ import { LlamaCpp, OpenAIEmbeddingProvider, createEmbeddingProvider, + getDefaultEmbeddingProvider, getDefaultLlamaCpp, disposeDefaultLlamaCpp, resolveLlamaGpuMode, resolveEmbeddingProviderName, + setDefaultEmbeddingProvider, + setDefaultLlamaCpp, withLLMSession, canUnloadLLM, SessionReleasedError, type RerankDocument, type ILLMSession, } from "../src/llm.js"; +import * as llmModule from "../src/llm.js"; // ============================================================================= // Singleton Tests (no model loading required) @@ -244,6 +248,36 @@ describe("Embedding provider resolution", () => { else process.env.OPENAI_API_KEY = prevKey; } }); + + test("getDefaultEmbeddingProvider does not construct LlamaCpp when openai is active", () => { + const prevProvider = process.env.QMD_EMBED_PROVIDER; + const prevModel = process.env.QMD_EMBED_MODEL; + const prevKey = process.env.OPENAI_API_KEY; + const llamaSpy = vi.spyOn(llmModule, "getDefaultLlamaCpp"); + + try { + setDefaultEmbeddingProvider(null); + setDefaultLlamaCpp(null); + process.env.QMD_EMBED_PROVIDER = "openai"; + process.env.QMD_EMBED_MODEL = "text-embedding-3-small"; + process.env.OPENAI_API_KEY = "test-key"; + + const provider = getDefaultEmbeddingProvider(); + + expect(provider).toBeInstanceOf(OpenAIEmbeddingProvider); + expect(llamaSpy).not.toHaveBeenCalled(); + } finally { + llamaSpy.mockRestore(); + setDefaultEmbeddingProvider(null); + setDefaultLlamaCpp(null); + if (prevProvider === undefined) delete process.env.QMD_EMBED_PROVIDER; + else process.env.QMD_EMBED_PROVIDER = prevProvider; + if (prevModel === undefined) delete process.env.QMD_EMBED_MODEL; + else process.env.QMD_EMBED_MODEL = prevModel; + if (prevKey === undefined) delete process.env.OPENAI_API_KEY; + else process.env.OPENAI_API_KEY = prevKey; + } + }); }); describe("OpenAIEmbeddingProvider", () => { diff --git a/test/sdk.test.ts b/test/sdk.test.ts index 689da27b9..c2e0f2f8f 100644 --- a/test/sdk.test.ts +++ b/test/sdk.test.ts @@ -5,7 +5,7 @@ * Uses inline config (no YAML files) to verify the SDK works self-contained. */ -import { describe, test, expect, beforeAll, afterAll, beforeEach, afterEach } from "vitest"; +import { describe, test, expect, beforeAll, afterAll, beforeEach, afterEach, vi } from "vitest"; import { mkdtemp, writeFile, mkdir, rm } from "node:fs/promises"; import { tmpdir } from "node:os"; import { join } from "node:path"; @@ -22,7 +22,7 @@ import { type VectorSearchOptions, type ExpandQueryOptions, } from "../src/index.js"; -import { setDefaultLlamaCpp } from "../src/llm.js"; +import { LlamaCpp, OpenAIEmbeddingProvider, setDefaultLlamaCpp } from "../src/llm.js"; // ============================================================================= // Test Helpers @@ -146,6 +146,65 @@ describe("createStore", () => { expect(store.dbPath).toBe(dbPath); await store.close(); }); + test("uses the local llama.cpp embedding provider by default", async () => { + const store = await createStore({ + dbPath: freshDbPath(), + config: { + collections: {}, + models: { + embed: "hf:custom/embed-model.gguf", + }, + }, + }); + + expect(store.internal.llm).toBeInstanceOf(LlamaCpp); + expect(store.internal.embeddingProvider).toBe(store.internal.llm); + expect(store.internal.embeddingProvider?.compatibilityKey).toBe("llama.cpp:hf:custom/embed-model.gguf"); + await store.close(); + }); + + test("searchVector uses the active OpenAI embedding compatibility key", async () => { + const prevProvider = process.env.QMD_EMBED_PROVIDER; + const prevKey = process.env.OPENAI_API_KEY; + process.env.QMD_EMBED_PROVIDER = "openai"; + process.env.OPENAI_API_KEY = "test-key"; + + try { + const store = await createStore({ + dbPath: freshDbPath(), + config: { + collections: {}, + models: { + embed: "text-embedding-3-small", + }, + }, + }); + + const searchVec = vi.fn(async () => []); + store.internal.searchVec = searchVec as typeof store.internal.searchVec; + + await store.searchVector("authentication"); + + expect(store.internal.llm).toBeInstanceOf(LlamaCpp); + expect(store.internal.embeddingProvider).toBeInstanceOf(OpenAIEmbeddingProvider); + expect(searchVec).toHaveBeenCalledWith( + "authentication", + "openai:text-embedding-3-small", + undefined, + undefined, + undefined, + undefined, + store.internal.embeddingProvider, + ); + + await store.close(); + } finally { + if (prevProvider === undefined) delete process.env.QMD_EMBED_PROVIDER; + else process.env.QMD_EMBED_PROVIDER = prevProvider; + if (prevKey === undefined) delete process.env.OPENAI_API_KEY; + else process.env.OPENAI_API_KEY = prevKey; + } + }); }); // ============================================================================= From 1f19d5144733a8a312044293ce1273cf6e059886 Mon Sep 17 00:00:00 2001 From: Chris Cage Date: Sun, 12 Apr 2026 21:47:58 -0500 Subject: [PATCH 5/5] Document and test OpenAI embedding provider --- README.md | 29 +++++++++++++++++++++++++++++ test/llm.test.ts | 29 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/README.md b/README.md index 6f318446b..39910205b 100644 --- a/README.md +++ b/README.md @@ -515,6 +515,35 @@ Supported model families: > since vectors are not cross-compatible between models. The prompt format is > automatically adjusted for each model family. +### OpenAI Embedding Provider (Prototype) + +QMD can also use OpenAI for embeddings while keeping query expansion and reranking +on the local GGUF models. Provider selection is env-driven for now. + +```sh +export QMD_EMBED_PROVIDER="openai" +export OPENAI_API_KEY="sk-..." +export QMD_EMBED_MODEL="text-embedding-3-small" + +# Optional: +export QMD_OPENAI_BASE_URL="https://api.openai.com/v1" +export QMD_OPENAI_EMBED_DIMENSIONS="1024" +export OPENAI_ORG_ID="org_..." +export OPENAI_PROJECT_ID="proj_..." + +qmd embed -f +``` + +Notes: +- `QMD_EMBED_PROVIDER=openai` switches only the embedding path. Reranking and query expansion remain local. +- `QMD_EMBED_MODEL` should be the OpenAI embedding model name when using the OpenAI provider. +- `QMD_OPENAI_BASE_URL` overrides the embeddings endpoint base URL. `OPENAI_BASE_URL` is also accepted. +- `QMD_OPENAI_EMBED_DIMENSIONS` becomes part of the embedding compatibility key, so changing it also requires `qmd embed -f`. +- Switching provider, model, or dimensions requires a full re-embed because stored vectors are only compatible with the active embedding configuration. + +If you use YAML or SDK config, `models.embed` can hold the embedding model string, +but provider selection still comes from the environment in this prototype. + ## Installation ```sh diff --git a/test/llm.test.ts b/test/llm.test.ts index ba95d661b..e3a35741a 100644 --- a/test/llm.test.ts +++ b/test/llm.test.ts @@ -281,6 +281,16 @@ describe("Embedding provider resolution", () => { }); describe("OpenAIEmbeddingProvider", () => { + test("includes dimensions in the compatibility key when configured", () => { + const provider = new OpenAIEmbeddingProvider({ + apiKey: "test-key", + model: "text-embedding-3-large", + dimensions: 1024, + }); + + expect(provider.compatibilityKey).toBe("openai:text-embedding-3-large:1024"); + }); + test("batches embeddings through the OpenAI API and preserves order", async () => { const fetchMock = vi.fn(async () => ({ ok: true, @@ -313,6 +323,25 @@ describe("OpenAIEmbeddingProvider", () => { globalThis.fetch = originalFetch; } }); + + test("returns null results and logs when the OpenAI API key is missing", async () => { + const errorSpy = vi.spyOn(console, "error").mockImplementation(() => {}); + const prevKey = process.env.OPENAI_API_KEY; + delete process.env.OPENAI_API_KEY; + + try { + const provider = new OpenAIEmbeddingProvider({ + model: "text-embedding-3-small", + }); + + await expect(provider.embedBatch(["alpha", "beta"])).resolves.toEqual([null, null]); + expect(errorSpy).toHaveBeenCalled(); + } finally { + errorSpy.mockRestore(); + if (prevKey === undefined) delete process.env.OPENAI_API_KEY; + else process.env.OPENAI_API_KEY = prevKey; + } + }); }); describe("LlamaCpp embedding truncation", () => {