Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ import {
} from "./providers"
import { NativeOllamaHandler } from "./providers/native-ollama"

export interface CompletePromptOptions {
/** Abort signal for cancelling the request mid-flight */
signal?: AbortSignal
/** Optional timeout override (ms) — falls back to provider default if omitted */
timeoutMs?: number
}

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
completePrompt(prompt: string, options?: CompletePromptOptions): Promise<string>
}

export interface ApiHandlerCreateMessageMetadata {
Expand Down
64 changes: 52 additions & 12 deletions src/api/providers/__tests__/anthropic-vertex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -834,18 +834,22 @@ describe("VertexHandler", () => {

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(handler["client"].messages.create).toHaveBeenCalledWith({
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
messages: [
{
role: "user",
content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }],
},
],
stream: false,
})
expect(handler["client"].messages.create).toHaveBeenCalledWith(
{
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
messages: [
{
role: "user",
content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }],
},
],
stream: false,
thinking: undefined,
},
undefined,
)
})

it("should handle API errors for Claude", async () => {
Expand Down Expand Up @@ -895,6 +899,42 @@ describe("VertexHandler", () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})

it("should pass abort signal through to client", async () => {
handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const controller = new AbortController()
const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "response" }],
})
;(handler["client"].messages as any).create = mockCreate

await handler.completePrompt("test prompt", { signal: controller.signal })
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), {
signal: controller.signal,
})
})

it("should work without options (backward compatible)", async () => {
handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "response" }],
})
;(handler["client"].messages as any).create = mockCreate

const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), undefined)
})
})

describe("getModel", () => {
Expand Down
70 changes: 62 additions & 8 deletions src/api/providers/__tests__/anthropic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,17 @@ describe("AnthropicHandler", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
})
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
undefined,
)
})

it("should handle API errors", async () => {
Expand All @@ -464,6 +467,57 @@ describe("AnthropicHandler", () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})

it("should pass abort signal through to client", async () => {
const controller = new AbortController()
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
await handler.completePrompt("test prompt", { signal: controller.signal })
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
{ signal: controller.signal },
)
})

it("should work without options (backward compatible)", async () => {
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
undefined,
)
})

it("should merge signal and timeout together", async () => {
const controller = new AbortController()
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
await handler.completePrompt("test prompt", { signal: controller.signal, timeoutMs: 10000 })
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
{ signal: controller.signal },
)
})
})

describe("getModel", () => {
Expand Down
42 changes: 42 additions & 0 deletions src/api/providers/__tests__/bedrock.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,48 @@ describe("AwsBedrockHandler", () => {
expect(isAdaptiveThinkingModel("anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe(false)
expect(isAdaptiveThinkingModel("amazon.nova-lite-v1:0")).toBe(false)
})

it("should pass abort signal through to client.send", async () => {
const mockConverseCommand = vi.mocked(ConverseCommand)
const mockSend = BedrockRuntimeClient.prototype.send as any

const handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})

const controller = new AbortController()
mockSend.mockResolvedValueOnce({
output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null },
})

await handler.completePrompt("test prompt", { signal: controller.signal })

expect(mockSend).toHaveBeenCalledWith(expect.any(Object), { abortSignal: controller.signal })
})

it("should work without options (backward compatible)", async () => {
const mockConverseCommand = vi.mocked(ConverseCommand)
const mockSend = BedrockRuntimeClient.prototype.send as any

const handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})

mockSend.mockResolvedValueOnce({
output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null },
})

const result = await handler.completePrompt("test prompt")

expect(result).toBe("response")
expect(mockSend).toHaveBeenCalledWith(expect.any(Object), undefined)
})
})
})
})
29 changes: 29 additions & 0 deletions src/api/providers/__tests__/complete-prompt-options.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { describe, it, expect } from "vitest"

import type { CompletePromptOptions } from "../../index"

describe("CompletePromptOptions", () => {
it("should allow signal property", () => {
const controller = new AbortController()
const options: CompletePromptOptions = { signal: controller.signal }
expect(options.signal).toBe(controller.signal)
})

it("should allow timeoutMs property", () => {
const options: CompletePromptOptions = { timeoutMs: 5000 }
expect(options.timeoutMs).toBe(5000)
})

it("should allow both signal and timeoutMs together", () => {
const controller = new AbortController()
const options: CompletePromptOptions = { signal: controller.signal, timeoutMs: 10000 }
expect(options.signal).toBe(controller.signal)
expect(options.timeoutMs).toBe(10000)
})

it("should allow empty options object", () => {
const options: CompletePromptOptions = {}
expect(options.signal).toBeUndefined()
expect(options.timeoutMs).toBeUndefined()
})
})
41 changes: 41 additions & 0 deletions src/api/providers/__tests__/deepseek.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -710,4 +710,45 @@ describe("DeepSeekHandler", () => {
expect(toolCallChunks[0].name).toBe("get_weather")
})
})

describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "response" } }],
})
const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
})

it("should pass abort signal through to client", async () => {
const controller = new AbortController()
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "response" } }],
})
await handler.completePrompt("test prompt", { signal: controller.signal })
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: expect.any(String) }),
expect.objectContaining({ signal: controller.signal }),
)
})

it("should pass timeout through to client", async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "response" } }],
})
await handler.completePrompt("test prompt", { timeoutMs: 5000 })
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: expect.any(String) }),
expect.objectContaining({ timeout: 5000 }),
)
})

it("should work without options (backward compatible)", async () => {
mockCreate.mockResolvedValueOnce({
choices: [{ message: { content: "response" } }],
})
const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
})
})
})
25 changes: 25 additions & 0 deletions src/api/providers/__tests__/fireworks.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,31 @@ describe("FireworksHandler", () => {
expect(result).toBe("")
})

it("completePrompt should pass abort signal through to client", async () => {
const controller = new AbortController()
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] })
await handler.completePrompt("test prompt", { signal: controller.signal })
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: expect.any(String) }),
expect.objectContaining({ signal: controller.signal }),
)
})

it("completePrompt should pass timeout through to client", async () => {
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] })
await handler.completePrompt("test prompt", { timeoutMs: 5000 })
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: expect.any(String) }),
expect.objectContaining({ timeout: 5000 }),
)
})

it("completePrompt should work without options (backward compatible)", async () => {
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "response" } }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
})

it("createMessage should handle stream with multiple chunks", async () => {
mockCreate.mockImplementationOnce(async () => ({
[Symbol.asyncIterator]: async function* () {
Expand Down
38 changes: 38 additions & 0 deletions src/api/providers/__tests__/gemini-handler.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,44 @@ describe("GeminiHandler backend support", () => {
expect(promptConfig.tools).toBeUndefined()
})

it("completePrompt should pass abort signal through to client via httpOptions", async () => {
const options = {
apiProvider: "gemini",
enableUrlContext: false,
enableGrounding: false,
} as ApiHandlerOptions
const handler = new GeminiHandler(options)

const controller = new AbortController()
const stub = vi.fn().mockResolvedValue({ text: "response" })
handler["client"].models.generateContent = stub

await handler.completePrompt("test prompt", { signal: controller.signal })

expect(stub).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
httpOptions: { signal: controller.signal },
}),
}),
)
})

it("completePrompt should work without options (backward compatible)", async () => {
const options = {
apiProvider: "gemini",
enableUrlContext: false,
enableGrounding: false,
} as ApiHandlerOptions
const handler = new GeminiHandler(options)

const stub = vi.fn().mockResolvedValue({ text: "response" })
handler["client"].models.generateContent = stub

const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
})

describe("error scenarios", () => {
it("should handle grounding metadata extraction failure gracefully", async () => {
const options = {
Expand Down
Loading
Loading