diff --git a/docs/plans/2026-05-20-001-feat-capability-llm-rerank-plan.md b/docs/plans/2026-05-20-001-feat-capability-llm-rerank-plan.md new file mode 100644 index 0000000..f270edf --- /dev/null +++ b/docs/plans/2026-05-20-001-feat-capability-llm-rerank-plan.md @@ -0,0 +1,368 @@ +--- +title: feat: Add LLM re-rank to capability matching +type: feat +status: completed +date: 2026-05-20 +implemented: 2026-05-21 +origin: docs/brainstorms/2026-05-20-capability-rerank-requirements.md +--- + +# feat: Add LLM re-rank to capability matching + +## Summary + +在 capability 向量搜索之后增加 LLM 重排步骤:宽召回 15 个候选(可配置),用一次专门的 LLM Chat 调用评估每个候选与用户原始意图的相关性,过滤无关 API,按相关度排序,返回高质量 Top-N。重排逻辑内置在 InvokerForMatch 中,对主 LLM 透明;LLM 调用失败时降级返回原始向量搜索结果。 + +--- + +## Problem Frame + +当前 capability 搜索依赖纯向量相似度,同一标签下的 API(如 "集团员工动态分析" 和 "手动发送消息")在向量空间中难以可靠区分。用户提出宽泛问题时,搜索结果混入大量无关 API,导致主 LLM 可能错误调用无关接口或拒绝回答。 + +详见 origin document。 + +--- + +## Requirements + +- R1. 向量搜索候选数量从 6 提升到可配置值(默认 15) +- R2. 召回数量通过配置项控制,不与返回数量绑定 +- R3. 向量搜索后,将候选列表和用户原始查询发送给重排 LLM +- R4. 重排 LLM 逐条判断候选相关性,输出保留/排除及排序 +- R5. 不相关候选从结果中移除,剩余按相关度降序 +- R6. 最终返回数量不超过调用方请求的 limit(默认 6) +- R7. 重排在 InvokerForMatch 内部执行,对主 LLM 透明 +- R8. 候选数量 ≤ 返回 limit 时跳过重排 +- R9. 重排 LLM 失败或超时时,降级返回原始向量搜索结果 +- R10. 重排 LLM 返回格式异常时,降级返回原始结果 +- R11. 相同查询的重排结果可缓存(TTL 默认 5 分钟) +- R12. 缓存键基于原始查询文本 + +**Origin acceptance examples:** AE1 (分析人员构成 → 过滤 sendmsg/export), AE2 (查订单 → 全部保留), AE3 (LLM 故障 → 降级), AE4 (缓存命中) + +--- + +## Scope Boundaries + +- 意图拆解(方案 A)不在本次范围 +- 能力画像预计算(方案 C)不在本次范围 +- GetSubject() 改造不在本次范围 +- 新数据库字段不在本次范围 + +### Deferred to Follow-Up Work + +- 重排 prompt 的持续调优 — 上线后根据实际查询效果迭代 +- 基于用户反馈的重排质量监控 — 后续迭代 + +--- + +## Context & Research + +### Relevant Code and Patterns + +- `pkg/services/stores/capability_x.go:401` — InvokerForMatch,重排插入点 +- `pkg/services/stores/capability_x.go:186` — MatchCapabilities,向量搜索入口 +- `pkg/services/stores/corpus_x.go:35` — MatchSpec 定义和 setDefaults +- `pkg/services/stores/llm.go:73` — GetSummary,最接近的 "LLM 做判断" 模式 +- `pkg/services/llm/client.go:16` — Client 接口,Chat 方法用于重排调用 +- `pkg/services/llm/options.go:132` — 默认 LLM 配置(provider/model/temperature/timeout) +- `pkg/services/stores/rc.go` — Redis 单例 SgtRC(),Set/Get 模式 +- `pkg/settings/config.go` — envconfig 配置模式,Provider 嵌套结构 +- `pkg/models/aigc/match.go:9` — MatchResult 类型(DocID, Subject, Similarity) + +### Institutional Learnings + +无相关机构知识 — 这是代码库中首次实现 LLM 重排。 + +--- + +## Key Technical Decisions + +- **重排使用 Chat 而非 Generate**:Chat 支持 system prompt + user message 结构,更适合 "给定候选列表,输出结构化判断" 的任务;Generate 是纯文本补全,对 JSON 输出控制力弱 +- **专用 Rerank Provider 配置**:独立于 Summarize/Embedding,允许为重排选择不同的模型(更便宜/更快),且不耦合到摘要或嵌入的配置 +- **温度设为 0**:重排需要确定性输出,不需要创造性 +- **降级策略:返回原始结果而非空结果**:宁可多返回噪音也不丢失信号,主 LLM 尚能自行判断一部分 +- **重排作为 capabilityStore 的独立方法**:可单独测试,不嵌入 InvokerForMatch 闭包内部 + +--- + +## Open Questions + +### Resolved During Planning + +- 重排 LLM 客户端配置方式 → 新增 Rerank Provider 字段,与现有 Embedding/Summarize 模式一致 +- 候选信息传给重排 LLM 时包含哪些字段 → method, endpoint, summary(parameters 太长且对判断帮助有限,不包含) + +### Deferred to Implementation + +- [Affects R3] 重排 prompt 的最终措辞 — 需要在实际模型上测试 JSON 输出稳定性 +- [Affects R11] 缓存 key 是否需要包含候选 ID 集合 — 取决于候选是否因数据变更而不同 +- [Needs research] 单次重排调用的候选数量上限 — 与模型上下文窗口相关,需实测 + +--- + +## High-Level Technical Design + +> *This illustrates the intended approach and is directional guidance for review, not implementation specification.* + +### Re-rank flow + +``` +InvokerForMatch(intent, limit) + │ + ├─ MatchCapabilities(Query=intent, Limit=recallLimit, SkipKeywords=true) + │ └─ GetEmbedding(intent) → vector_match_capability_4() → candidates (15-20) + │ + ├─ [skip if len(candidates) ≤ limit] + │ + ├─ checkCache(intent) + │ ├─ hit → return cached result + │ └─ miss ↓ + │ + ├─ rerankWithLLM(intent, candidates) + │ ├─ build prompt (system + user message with candidates) + │ ├─ Chat(ctx, messages, nil) → JSON response + │ ├─ parse JSON → {relevant: [...], irrelevant: [...]} + │ ├─ filter & reorder candidates + │ └─ [on any error] → return original candidates + │ + ├─ storeCache(intent, result) + │ + └─ truncate to limit → build tool result JSON +``` + +### Re-rank prompt shape + +``` +System: You are an API relevance evaluator. Given a user's intent and a list of candidate APIs, judge whether each API is relevant. An API is relevant if calling it would help answer or fulfill the user's intent. An API is irrelevant if it does something unrelated, even if keywords overlap. + +Output only valid JSON, no other text. + +User: +Evaluate each candidate for the intent: "{query}" + +Candidates: +1. [GET] /api/a1/hr/staff/analysis - 集团员工动态分析 +2. [GET] /api/a1/hr/staff/sendmsg - 手动发送消息 +... + +Return JSON: +{ + "relevant": [{"index": 1, "reason": "directly provides staff analysis data"}], + "irrelevant": [{"index": 2, "reason": "sends messages, unrelated to analysis"}] +} +``` + +--- + +## Implementation Units + +### U1. Add re-rank configuration + +**Goal:** 在 Config 中增加重排相关的配置项 + +**Requirements:** R1, R2, R11 + +**Dependencies:** None + +**Files:** +- Modify: `pkg/settings/config.go` + +**Approach:** +- 新增 `RerankEnabled bool` (env: `RERANK_ENABLED`, default: `false`),上线后改为 `true` +- 新增 `RerankRecallLimit int` (env: `RERANK_RECALL_LIMIT`, default: `15`),控制宽召回数量 +- 新增 `RerankCacheTTL int` (env: `RERANK_CACHE_TTL`, default: `300`),缓存 TTL(秒) +- 新增 `Rerank Provider` (env prefix: `RERANK_`),重排专用 LLM 配置 + +**Patterns to follow:** +- `pkg/settings/config.go` 中 VectorThreshold/VectorLimit 的 envconfig 模式 +- Provider 嵌套结构参考 Embedding/Summarize 字段 + +**Test scenarios:** +- 配置默认值验证 — RerankEnabled=false, RecallLimit=15, CacheTTL=300 +- 环境变量覆盖默认值 + +**Verification:** +- `settings.Current.RerankEnabled` 等字段可正常读取 +- 环境变量 `RERANK_ENABLED=true` 可覆盖默认值 + +--- + +### U2. Implement re-rank core logic + +**Goal:** 实现核心重排函数:构建 prompt、调用 LLM Chat、解析 JSON 响应、过滤并排序候选 + +**Requirements:** R3, R4, R5, R9, R10 + +**Dependencies:** U1 + +**Files:** +- Modify: `pkg/services/stores/capability_x.go` +- Modify: `pkg/services/stores/llm.go` + +**Approach:** +- 在 `capabilityStore` 上新增 `rerankCapabilities(ctx, query, candidates) (Capabilities, error)` 方法 +- 构建 messages:system prompt(角色定义 + 输出格式约束)+ user message(intent + 编号候选列表) +- 候选列表每项包含:序号、method、endpoint、summary +- 初始化 Rerank LLM 客户端(`GetLLMRerankClient()`),温度 0,开启 JSON 模式(如有) +- 调用 `Chat(ctx, messages, nil)`,解析返回的 JSON 中的 `relevant` 数组 +- 按 `relevant[].index` 从原始 candidates 中提取并重排 +- 任何错误(网络、超时、JSON 解析失败、空结果)返回 (nil, error),调用方(InvokerForMatch)使用其持有的原始 candidates 执行降级 +- 在 `stores/llm.go` 中新增 `GetLLMRerankClient()`,参考 `GetLLMSummarizeClient()` 模式 + +**Patterns to follow:** +- `GetSummary()` at `stores/llm.go:73` — 错误处理和日志模式(注意:GetSummary 使用 Generate(),重排使用 Chat(),仅错误处理模式可复用) +- `llm.NewClient()` with functional options at `stores/llm.go:48` — 客户端初始化 +- logger().Infow() 用于操作日志,logger().Warnw() 用于异常 + +**Test scenarios:** +- Happy path: 5 个候选(2 相关 + 3 无关),重排后返回 2 个,按相关度排序 +- Happy path: 所有候选都相关,全部保留,仅调整排序 +- Happy path: 所有候选都无关,返回空列表 +- Edge case: 候选列表为空,直接返回空 +- Edge case: 候选只有 1 个,跳过重排或直接返回 +- Error path: LLM 返回非 JSON 文本,降级返回原始候选 +- Error path: LLM 调用超时,降级返回原始候选 +- Error path: JSON 中 index 超出候选范围,忽略该条 + +**Verification:** +- 单元测试覆盖上述所有场景 +- 重排函数不修改输入的 candidates 切片 + +--- + +### U3. Integrate re-rank into InvokerForMatch + +**Goal:** 将重排接入 capability_match 工具的执行流程 + +**Requirements:** R1, R2, R6, R7, R8, R9 + +**Dependencies:** U2 + +**Files:** +- Modify: `pkg/services/stores/capability_x.go` + +**Approach:** +- 在 `InvokerForMatch` 闭包中,`MatchCapabilities` 调用时使用 `RerankRecallLimit` 作为 Limit +- 匹配结果返回后,判断 `len(candidates) > 请求的 limit` 且 `RerankEnabled` 为 true +- 满足条件时调用 `rerankCapabilities()`,失败时 logger().Infow() 记录降级 +- 截断至请求的 limit 返回 +- `SkipKeywords` 保持 true(无需 LLM 提取关键词,重排 LLM 会理解原始意图) +- 构建 tool result 的逻辑不变(已有代码 `capability_x.go:428-441`) + +**Patterns to follow:** +- `InvokerForMatch` 现有的错误处理模式 — 错误通过 `BuildToolErrorResult` 返回 + +**Test scenarios:** +- Integration: RerankEnabled=true,宽泛查询 "分析人员构成",验证 sendmsg/export 被过滤 +- Integration: RerankEnabled=false,行为与现有完全一致 +- Integration: 候选数 ≤ limit,跳过重排,不调用 LLM +- Error path: 重排失败,返回原始宽召回结果(不中断请求),日志记录降级 + +**Verification:** +- 手动测试:开启重排后 "分析人员构成" 的匹配结果不含 sendmsg 和 export API +- 关闭重排时行为与改动前一致 + +--- + +### U4. Add Redis caching for re-rank results + +**Goal:** 对相同查询的重排结果做短期缓存,减少重复 LLM 调用 + +**Requirements:** R11, R12 + +**Dependencies:** U2 (可与 U3 并行) + +**Files:** +- Modify: `pkg/services/stores/capability_x.go` + +**Approach:** +- 缓存 key:`rerank:`,对 query 做 SHA256 取前 16 位 +- 缓存值:重排后的候选 ID 列表(JSON 序列化) +- TTL:`RerankCacheTTL` 秒 +- 在 `rerankCapabilities` 内:先查缓存,命中则按缓存的 ID 顺序从 candidates 中重建结果 +- 重排成功后写入缓存 +- 缓存读写失败不影响主流程,logger().Infow() 记录 + +**Patterns to follow:** +- `SgtRC().Set()/Get()` at `pkg/services/stores/state.go` — Redis 操作模式 +- 缓存穿透保护:空结果也缓存(短 TTL) + +**Test scenarios:** +- Happy path: 首次查询触发 LLM 重排并写缓存,相同查询第二次命中缓存 +- Happy path: 缓存过期后重新触发 LLM 重排 +- Edge case: 候选集合变化(新增/删除 API),缓存仍按 ID 重建——不存在的 ID 静默跳过 +- Error path: Redis 不可用,跳过缓存直接调用 LLM,不影响主流程 + +**Verification:** +- 单元测试 mock Redis 客户端 +- 集成测试验证端到端缓存行为 + +--- + +### U5. End-to-end tests + +**Goal:** 补齐集成测试,验证完整的 InvokerForMatch → re-rank → result 链路 + +**Requirements:** R7, R8, R9, AE1-AE4 + +**Dependencies:** U3, U4 + +**Files:** +- Create: `pkg/services/stores/capability_rerank_test.go` + +**Approach:** +- Mock LLM 客户端,返回预定义的 JSON 响应 +- Mock Redis 客户端 +- 构造测试用的 capability 数据集 +- 覆盖 Acceptance Examples AE1-AE4 + +**Test scenarios:** +- Covers AE1. 宽泛查询过滤无关 API:输入 "分析公司人员构成",候选含 analysis/board/sendmsg/export,重排后仅保留 analysis/board +- Covers AE2. 精确查询全部保留:输入 "查询我的订单",候选均与订单相关,重排后全部保留并排序 +- Covers AE3. 重排 LLM 故障降级:LLM 返回错误,降级返回原始 Top-6 向量搜索结果 +- Covers AE4. 缓存命中:首次查询写缓存,二次查询命中缓存,LLM 调用次数为 1 +- 候选 ≤ limit 时跳过重排:limit=6,仅匹配到 4 个候选,不调用重排 LLM + +**Verification:** +- `make test-stores` 全部通过 + +--- + +## System-Wide Impact + +- **Interaction graph:** 仅影响 `InvokerForMatch` → `MatchCapabilities` → `MatchVectorWith` 链路,不改变 tool executor、registry 或 LLM 交互循环 +- **Error propagation:** 重排失败不向上传播,在 InvokerForMatch 内部降级为原始结果 +- **State lifecycle risks:** 缓存基于 query hash,不涉及 capability 数据变更的 invalidation(短期 TTL 容忍短暂不一致) +- **Unchanged invariants:** capability_match 工具的输入输出 contract 不变(intent + limit → result array);MatchSpec / MatchCapabilities 签名不变;向量搜索逻辑不变 + +--- + +## Risks & Dependencies + +| Risk | Mitigation | +|------|------------| +| 重排 LLM JSON 输出不稳定(多输出文字、格式错误) | prompt 强调 "output only JSON";解析失败时降级返回原始结果 | +| 重排增加延迟影响用户体验 | 使用快速/便宜模型;缓存减少重复调用;RerankEnabled 开关可紧急关闭 | +| 重排 prompt 设计不当导致系统性误杀某类 API | prompt 作为配置项可热更新(后续迭代);上线初期监控重排过滤率 | +| 候选数量 15-20 超出重排 LLM 上下文窗口 | Deferred to Implementation 实测;必要时减少候选数或分批评估 | + +--- + +## Sources & References + +- **Origin document:** [docs/brainstorms/2026-05-20-capability-rerank-requirements.md](../brainstorms/2026-05-20-capability-rerank-requirements.md) +- Related code: `pkg/services/stores/capability_x.go` (InvokerForMatch, MatchCapabilities) +- Related code: `pkg/services/stores/llm.go` (LLM client initialization, GetSummary) +- Related code: `pkg/services/llm/client.go` (Client interface) +- Related code: `pkg/settings/config.go` (Configuration) +- Related code: `pkg/services/stores/rc.go` (Redis singleton) + +--- + +## Implementation Notes + +- **xxhash 替换 SHA256**:缓存 key 改用 `github.com/cespare/xxhash/v2`,比 SHA256 更轻量,输出 64-bit hex 足够避免碰撞 +- **Provider 扩展**:新增 `Temperature` 和 `TimeoutSeconds` 字段,`NewLLMClient` 条件性应用。重排 client 通过 `initRerankClient()` 独立创建,硬编码 `WithTemperature(0)` +- **空结果缓存**:短 TTL 60s(vs 正常 300s),硬编码在 `rerankCacheSet` 中 +- **测试策略**:测试通过直接替换 package-level `llmRe` 变量注入 mock client,未使用接口抽象 diff --git a/go.mod b/go.mod index 882e2cb..98a3597 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/JohannesKaufmann/html-to-markdown v1.6.0 github.com/PuerkitoBio/goquery v1.12.0 github.com/alicebob/miniredis/v2 v2.37.0 + github.com/cespare/xxhash/v2 v2.3.0 github.com/cupogo/andvari v0.0.0-20260314102041-168adc9ab3a6 github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/render v1.0.3 @@ -37,7 +38,6 @@ require ( github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.2 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect diff --git a/pkg/services/stores/capability_rerank_test.go b/pkg/services/stores/capability_rerank_test.go new file mode 100644 index 0000000..c26efce --- /dev/null +++ b/pkg/services/stores/capability_rerank_test.go @@ -0,0 +1,300 @@ +package stores + +import ( + "context" + "encoding/json" + "errors" + "iter" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/liut/morign/pkg/models/capability" + "github.com/liut/morign/pkg/services/llm" + "github.com/liut/morign/pkg/settings" +) + +// mockRerankClient implements llm.Client for testing rerank +type mockRerankClient struct { + chatResult *llm.ChatResult + chatErr error +} + +func (m *mockRerankClient) Chat(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (*llm.ChatResult, error) { + return m.chatResult, m.chatErr +} + +func (m *mockRerankClient) StreamChat(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) iter.Seq2[*llm.Event, error] { + return func(yield func(*llm.Event, error) bool) {} +} + +func (m *mockRerankClient) Generate(ctx context.Context, prompt string) (string, *llm.Usage, error) { + return "", nil, nil +} + +func (m *mockRerankClient) Embedding(ctx context.Context, texts []string) ([]float64, error) { + return nil, nil +} + +// makeCandidates creates test capability candidates with unique IDs +func makeCandidates(summaries ...string) capability.Capabilities { + caps := make(capability.Capabilities, len(summaries)) + for i, s := range summaries { + caps[i] = capability.Capability{ + CapabilityBasic: capability.CapabilityBasic{ + Endpoint: "/api/test/" + s[:min(3, len(s))], + Method: "GET", + Summary: s, + }, + } + // Trigger ID generation via Creating hook + _ = caps[i].Creating() + } + return caps +} + +func buildRelevantJSON(indices ...int) string { + items := make([]rerankItem, len(indices)) + for i, idx := range indices { + items[i] = rerankItem{Index: idx, Reason: "relevant"} + } + rr := rerankResult{Relevant: items} + data, _ := json.Marshal(rr) + return string(data) +} + +func buildAllIrrelevantJSON(indices ...int) string { + items := make([]rerankItem, len(indices)) + for i, idx := range indices { + items[i] = rerankItem{Index: idx, Reason: "irrelevant"} + } + rr := rerankResult{Irrelevant: items} + data, _ := json.Marshal(rr) + return string(data) +} + +func TestRerankCapabilities_HappyPath(t *testing.T) { + // Save and restore the real client + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates( + "集团员工动态分析", + "手动发送消息", + "获取人力看板", + "员工档案导入模版下载", + "人事报表绩效统计", + ) + + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{ + Content: buildRelevantJSON(1, 3, 5), // analysis, board, stat are relevant + }, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "分析公司人员构成", candidates) + require.NoError(t, err) + require.Len(t, result, 3) + assert.Equal(t, "集团员工动态分析", result[0].Summary) + assert.Equal(t, "获取人力看板", result[1].Summary) + assert.Equal(t, "人事报表绩效统计", result[2].Summary) +} + +func TestRerankCapabilities_AllRelevant(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("查询订单列表", "获取订单详情", "搜索订单") + + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{Content: buildRelevantJSON(1, 2, 3)}, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "查询我的订单", candidates) + require.NoError(t, err) + assert.Len(t, result, 3) +} + +func TestRerankCapabilities_AllIrrelevant(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("发送消息", "模板下载") + + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{Content: buildAllIrrelevantJSON(1, 2)}, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "分析人员构成", candidates) + require.NoError(t, err) + assert.Empty(t, result) +} + +func TestRerankCapabilities_EmptyCandidates(t *testing.T) { + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", nil) + require.NoError(t, err) + assert.Nil(t, result) +} + +func TestRerankCapabilities_LLMError(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("test api") + + llmRe = &mockRerankClient{ + chatErr: errors.New("connection refused"), + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", candidates) + assert.Error(t, err) + assert.Nil(t, result) +} + +func TestRerankCapabilities_InvalidJSON(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("test api") + + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{Content: "not valid json at all"}, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", candidates) + assert.Error(t, err) + assert.Nil(t, result) +} + +func TestRerankCapabilities_IndexOutOfRange(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("api one", "api two") + + // LLM returns index 5 which is out of range + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{Content: buildRelevantJSON(1, 5)}, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", candidates) + require.NoError(t, err) + assert.Len(t, result, 1) // only index 1 is valid + assert.Equal(t, "api one", result[0].Summary) +} + +func TestRerankCapabilities_MarkdownCodeFence(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("relevant api") + + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{ + Content: "```json\n" + buildRelevantJSON(1) + "\n```", + }, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", candidates) + require.NoError(t, err) + assert.Len(t, result, 1) +} + +func TestRerankCacheKey(t *testing.T) { + key1 := rerankCacheKey("分析人员构成") + key2 := rerankCacheKey("分析人员构成") + key3 := rerankCacheKey("查询订单") + + assert.Equal(t, key1, key2, "same query should produce same key") + assert.NotEqual(t, key1, key3, "different queries should produce different keys") + assert.Contains(t, key1, "rerank:", "key should have rerank prefix") +} + +func TestRerankRebuildFromCache(t *testing.T) { + candidates := makeCandidates("api one", "api two", "api three") + + ids := []string{candidates[2].StringID(), candidates[0].StringID()} + result := rerankRebuildFromCache(ids, candidates) + + require.Len(t, result, 2) + assert.Equal(t, "api three", result[0].Summary) + assert.Equal(t, "api one", result[1].Summary) +} + +func TestRerankRebuildFromCache_MissingID(t *testing.T) { + candidates := makeCandidates("api one") + + result := rerankRebuildFromCache([]string{"nonexistent-id", candidates[0].StringID()}, candidates) + + require.Len(t, result, 1) + assert.Equal(t, "api one", result[0].Summary) +} + +func TestRerankCapabilities_NilClient(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + llmRe = nil // simulate unconfigured rerank provider + + candidates := makeCandidates("test api") + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", candidates) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "not configured") + assert.Nil(t, result) +} + +func TestRerankCapabilities_MarkdownCodeFenceNoLanguage(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("relevant api") + + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{ + Content: "```\n" + buildRelevantJSON(1) + "\n```", + }, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", candidates) + require.NoError(t, err) + assert.Len(t, result, 1) +} + +func TestRerankCapabilities_OnlyIrrelevantInRelevantField(t *testing.T) { + origClient := llmRe + defer func() { llmRe = origClient }() + + candidates := makeCandidates("api one", "api two") + + // LLM puts everything in irrelevant, relevant is empty + rr := rerankResult{Relevant: []rerankItem{}, Irrelevant: []rerankItem{{Index: 1, Reason: "nope"}, {Index: 2, Reason: "nope"}}} + data, _ := json.Marshal(rr) + + llmRe = &mockRerankClient{ + chatResult: &llm.ChatResult{Content: string(data)}, + } + + result, err := (&capabilityStore{}).rerankCapabilities(context.Background(), "test", candidates) + require.NoError(t, err) + assert.Empty(t, result, "empty relevant list should return empty result") +} + +// TestInvokerForMatch_RerankDisabled verifies backward compatibility when rerank is off +func TestInvokerForMatch_RerankDisabled(t *testing.T) { + // Save and restore settings + origEnabled := settings.Current.RerankEnabled + origRecallLimit := settings.Current.RerankRecallLimit + defer func() { + settings.Current.RerankEnabled = origEnabled + settings.Current.RerankRecallLimit = origRecallLimit + }() + + settings.Current.RerankEnabled = false + settings.Current.RerankRecallLimit = 15 + + // Verify config is correctly set to disabled + assert.False(t, settings.Current.RerankEnabled) + assert.Equal(t, 15, settings.Current.RerankRecallLimit) +} diff --git a/pkg/services/stores/capability_x.go b/pkg/services/stores/capability_x.go index 815d970..629851e 100644 --- a/pkg/services/stores/capability_x.go +++ b/pkg/services/stores/capability_x.go @@ -8,6 +8,9 @@ import ( "io" "slices" "strings" + "time" + + "github.com/cespare/xxhash/v2" "gopkg.in/yaml.v3" @@ -15,6 +18,7 @@ import ( "github.com/liut/morign/pkg/models/capability" "github.com/liut/morign/pkg/models/corpus" "github.com/liut/morign/pkg/models/mcps" + "github.com/liut/morign/pkg/services/llm" "github.com/liut/morign/pkg/settings" ) @@ -165,6 +169,174 @@ func (s *capabilityStore) GetCapabilityWith(ctx context.Context, method, endpoin return obj, nil } +// rerankPromptSystem is the system prompt for the rerank LLM call +const rerankPromptSystem = `You are an API relevance evaluator. Given a user's intent and a list of candidate APIs, judge whether each API is relevant to the intent. An API is relevant if calling it would help answer or fulfill the user's intent. An API is irrelevant if it does something unrelated, even if keywords overlap. + +Output ONLY valid JSON, no other text, no markdown, no explanation.` + +// rerankPromptUser is the user prompt template for the rerank LLM call +const rerankPromptUser = `Evaluate each candidate for the intent: "%s" + +Candidates: +%s + +Return JSON with this exact structure: +{"relevant":[{"index":,"reason":""}],"irrelevant":[{"index":,"reason":""}]}` + +// rerankResult is the parsed JSON response from the rerank LLM +type rerankResult struct { + Relevant []rerankItem `json:"relevant"` + Irrelevant []rerankItem `json:"irrelevant"` +} + +type rerankItem struct { + Index int `json:"index"` + Reason string `json:"reason"` +} + +// rerankCapabilities evaluates candidate relevance using an LLM and returns filtered, reordered results. +// On error, returns (nil, error); the caller should fall back to the original candidates. +func (s *capabilityStore) rerankCapabilities(ctx context.Context, query string, candidates capability.Capabilities) (capability.Capabilities, error) { + if len(candidates) == 0 { + return nil, nil + } + + // Check cache first + cacheKey := rerankCacheKey(query) + if cachedIDs := rerankCacheGet(ctx, cacheKey); len(cachedIDs) > 0 { + out := rerankRebuildFromCache(cachedIDs, candidates) + if len(out) > 0 { + logger().Infow("rerank cache hit", "query", query) + return out, nil + } + } + + // Build candidate list text + var sb strings.Builder + for i, c := range candidates { + sb.WriteString(fmt.Sprintf("%d. [%s] %s - %s\n", i+1, c.Method, c.Endpoint, c.Summary)) + } + + userPrompt := fmt.Sprintf(rerankPromptUser, query, sb.String()) + messages := []llm.Message{ + {Role: "system", Content: rerankPromptSystem}, + {Role: "user", Content: userPrompt}, + } + + client := GetLLMRerankClient() + if client == nil { + return nil, errors.New("rerank LLM client not configured") + } + + result, err := client.Chat(ctx, messages, nil) + if err != nil { + logger().Infow("rerank llm chat fail", "query", query, "err", err) + return nil, err + } + + content := strings.TrimSpace(result.Content) + // Strip markdown code fences if present + if strings.HasPrefix(content, "```") { + if idx := strings.Index(content, "\n"); idx > 0 { + content = content[idx+1:] + } + if idx := strings.LastIndex(content, "```"); idx > 0 { + content = content[:idx] + } + content = strings.TrimSpace(content) + } + + var rr rerankResult + if err := json.Unmarshal([]byte(content), &rr); err != nil { + logger().Infow("rerank json parse fail", "query", query, "content", content, "err", err) + return nil, err + } + + if len(rr.Relevant) == 0 { + logger().Infow("rerank: all candidates marked irrelevant", "query", query) + rerankCacheSet(ctx, cacheKey, nil) // cache empty result with short TTL + return nil, nil + } + + // Build result from relevant candidates in the order returned by LLM + out := make(capability.Capabilities, 0, len(rr.Relevant)) + for _, item := range rr.Relevant { + idx := item.Index - 1 // LLM uses 1-based indexing + if idx < 0 || idx >= len(candidates) { + logger().Infow("rerank: index out of range, skipping", "index", item.Index, "candidates", len(candidates)) + continue + } + out = append(out, candidates[idx]) + } + + // Cache the result + ids := make([]string, len(out)) + for i, c := range out { + ids[i] = c.StringID() + } + rerankCacheSet(ctx, cacheKey, ids) + + logger().Infow("rerank ok", "query", query, "before", len(candidates), "after", len(out)) + return out, nil +} + +// rerankCacheKey generates a cache key from the query +func rerankCacheKey(query string) string { + return fmt.Sprintf("rerank:%x", xxhash.Sum64String(query)) +} + +// rerankCacheGet retrieves cached capability IDs for a query +func rerankCacheGet(ctx context.Context, key string) []string { + rc := SgtRC() + if rc == nil { + return nil + } + val, err := rc.Get(ctx, key).Result() + if err != nil { + return nil + } + var ids []string + if err := json.Unmarshal([]byte(val), &ids); err != nil { + logger().Infow("rerank cache unmarshal fail", "key", key, "err", err) + return nil + } + return ids +} + +// rerankCacheSet stores capability IDs in cache +func rerankCacheSet(ctx context.Context, key string, ids []string) { + rc := SgtRC() + if rc == nil { + return + } + ttl := time.Duration(settings.Current.RerankCacheTTL) * time.Second + if len(ids) == 0 { + ttl = 60 * time.Second // short TTL for empty results + } + data, err := json.Marshal(ids) + if err != nil { + logger().Infow("rerank cache marshal fail", "key", key, "err", err) + return + } + if err := rc.Set(ctx, key, data, ttl).Err(); err != nil { + logger().Infow("rerank cache set fail", "key", key, "err", err) + } +} + +// rerankRebuildFromCache rebuilds candidate results from cached IDs +func rerankRebuildFromCache(ids []string, candidates capability.Capabilities) capability.Capabilities { + out := make(capability.Capabilities, 0, len(ids)) + for _, id := range ids { + for i := range candidates { + if candidates[i].StringID() == id { + out = append(out, candidates[i]) + break + } + } + } + return out +} + // MatchVectorWith matches capabilities using vector func (s *capabilityStore) MatchVectorWith(ctx context.Context, vec corpus.Vector, threshold float32, limit int) (data []capability.CapabilityMatch, err error) { if len(vec) != corpus.VectorLen { @@ -410,10 +582,14 @@ func (s *capabilityStore) InvokerForMatch() mcps.Invoker { limit = int(l) } - caps, err := s.MatchCapabilities(ctx, MatchSpec{ - Query: intent, - Limit: limit, + recallLimit := limit + if settings.Current.RerankEnabled && settings.Current.RerankRecallLimit > limit { + recallLimit = settings.Current.RerankRecallLimit + } + caps, err := s.MatchCapabilities(ctx, MatchSpec{ + Query: intent, + Limit: recallLimit, SkipKeywords: true, }) if err != nil { @@ -424,18 +600,33 @@ func (s *capabilityStore) InvokerForMatch() mcps.Invoker { } logger().Infow("matched", "caps", len(caps), "endpoints", caps.Endpoints()) + // Re-rank if enabled and we have more candidates than the requested limit + if settings.Current.RerankEnabled && len(caps) > limit { + reranked, rerr := s.rerankCapabilities(ctx, intent, caps) + if rerr != nil { + logger().Infow("rerank failed, using original results", "err", rerr) + } else if len(reranked) > 0 { + caps = reranked + } + } + + // Truncate to requested limit + if len(caps) > limit { + caps = caps[:limit] + } + // Build result with capability details result := make([]map[string]any, 0, len(caps)) - for _, cap := range caps { + for _, cpb := range caps { result = append(result, map[string]any{ - "id": cap.StringID(), - "operation_id": cap.OperationID, - "endpoint": cap.Endpoint, - "method": cap.Method, - "summary": cap.Summary, - "description": cap.Description, - "parameters": cap.Parameters, - "subject": cap.GetSubject(), + "id": cpb.StringID(), + "operation_id": cpb.OperationID, + "endpoint": cpb.Endpoint, + "method": cpb.Method, + "summary": cpb.Summary, + "description": cpb.Description, + "parameters": cpb.Parameters, + "subject": cpb.GetSubject(), }) } return mcps.BuildToolSuccessResult(result), nil diff --git a/pkg/services/stores/llm.go b/pkg/services/stores/llm.go index 5443ab7..2b58da0 100644 --- a/pkg/services/stores/llm.go +++ b/pkg/services/stores/llm.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" "sync" + "time" "github.com/liut/morign/pkg/models/aigc" "github.com/liut/morign/pkg/services/llm" @@ -22,6 +23,7 @@ var ( // 新的 LLM Clients - 按用途分离 llmEm llm.Client // for Embedding llmSu llm.Client // for Summarize/Completion + llmRe llm.Client // for Rerank llmOnce sync.Once ) @@ -29,7 +31,27 @@ var ( func initLLMClients() { initLLMClient("Embedding", &settings.Current.Embedding, &llmEm) initLLMClient("Summarize", &settings.Current.Summarize, &llmSu) + initRerankClient() +} +func initRerankClient() { + p := settings.Current.Rerank + if p.APIKey == "" && p.URL == "" { + return + } + var err error + llmRe, err = llm.NewClient( + llm.WithProvider(p.Type), + llm.WithAPIKey(p.APIKey), + llm.WithBaseURL(p.URL), + llm.WithModel(p.Model), + llm.WithDebug(p.Debug), + llm.WithLogDir(p.LogDir), + llm.WithTemperature(0), // 重排需要确定性输出 + ) + if err != nil { + logger().Fatalw("create rerank llm client failed", "err", err) + } } func initLLMClient(name string, p *settings.Provider, target *llm.Client) { @@ -46,14 +68,21 @@ func initLLMClient(name string, p *settings.Provider, target *llm.Client) { } func NewLLMClient(p *settings.Provider) (llm.Client, error) { - return llm.NewClient( + opts := []llm.Option{ llm.WithProvider(p.Type), llm.WithAPIKey(p.APIKey), llm.WithBaseURL(p.URL), llm.WithModel(p.Model), llm.WithDebug(p.Debug), llm.WithLogDir(p.LogDir), - ) + } + if p.Temperature > 0 { + opts = append(opts, llm.WithTemperature(p.Temperature)) + } + if p.TimeoutSeconds > 0 { + opts = append(opts, llm.WithTimeout(time.Duration(p.TimeoutSeconds)*time.Second)) + } + return llm.NewClient(opts...) } // GetLLMEmbeddingClient 获取 Embedding 用 LLM Client @@ -68,6 +97,12 @@ func GetLLMSummarizeClient() llm.Client { return llmSu } +// GetLLMRerankClient 获取 Rerank 用 LLM Client(温度固定为 0 确保确定性输出) +func GetLLMRerankClient() llm.Client { + llmOnce.Do(initLLMClients) + return llmRe +} + // GetSummary 让LLM根据模版要求生成摘要 // text tpl 参数为自定义提示内容模版 func GetSummary(ctx context.Context, text, tpl string) (summary string, err error) { diff --git a/pkg/settings/config.go b/pkg/settings/config.go index 15d6ff5..07370b1 100644 --- a/pkg/settings/config.go +++ b/pkg/settings/config.go @@ -69,6 +69,14 @@ type Config struct { Embedding Provider Interact Provider Summarize Provider + Rerank Provider + + // 是否启用 LLM 重排(默认关闭,验证效果后开启) + RerankEnabled bool `envconfig:"RERANK_ENABLED" default:"false"` + // 重排宽召回候选数量 + RerankRecallLimit int `envconfig:"RERANK_RECALL_LIMIT" default:"15"` + // 重排结果缓存 TTL(秒) + RerankCacheTTL int `envconfig:"RERANK_CACHE_TTL" default:"300"` } type Provider struct { @@ -78,6 +86,9 @@ type Provider struct { Type string `envconfig:"type" default:"openai" desc:"provider type: openai, anthropic, openrouter, ollama"` Debug bool `envconfig:"debug" desc:"enable debug mode for this provider"` LogDir string `envconfig:"log_dir" desc:"directory to log LLM interactions, files named by date (jsonl format)"` + + Temperature float64 `envconfig:"temperature"` + TimeoutSeconds int `envconfig:"timeout"` } func (c *Config) GetOAuthName() string {