Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ export interface ApiHandlerCreateMessageMetadata {
* Only applies to providers that support function calling restrictions (e.g., Gemini).
*/
allowedFunctionNames?: string[]
/**
* Abort signal for cancelling the HTTP request mid-stream.
* Passed through to AI SDK's streamText() so the underlying HTTP request is aborted
* when the user clicks stop, preventing wasted API tokens/compute on the provider side.
*/
abortSignal?: AbortSignal
}

export interface ApiHandler {
Expand Down
122 changes: 110 additions & 12 deletions src/api/providers/__tests__/anthropic-vertex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -763,18 +763,21 @@ describe("VertexHandler", () => {

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(handler["client"].messages.create).toHaveBeenCalledWith({
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
messages: [
{
role: "user",
content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }],
},
],
stream: false,
})
expect(handler["client"].messages.create).toHaveBeenCalledWith(
{
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
messages: [
{
role: "user",
content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }],
},
],
stream: false,
},
undefined,
)
})

it("should handle API errors for Claude", async () => {
Expand Down Expand Up @@ -824,6 +827,45 @@ describe("VertexHandler", () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})

it("should pass abortSignal to messages.create when provided in metadata", async () => {
handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const controller = new AbortController()
const mockAbortSignal = controller.signal

const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "Test response" }],
})
;(handler["client"].messages as any).create = mockCreate

await handler.completePrompt("Test prompt", { taskId: "test", abortSignal: mockAbortSignal })

const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBe(mockAbortSignal)
})

it("should pass undefined signal when abortSignal is not provided", async () => {
handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "Test response" }],
})
;(handler["client"].messages as any).create = mockCreate

await handler.completePrompt("Test prompt")

const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBeUndefined()
})
})

describe("getModel", () => {
Expand Down Expand Up @@ -1540,4 +1582,60 @@ describe("VertexHandler", () => {
})
})
})

describe("abortSignal support", () => {
it("should pass abortSignal to messages.create when provided in metadata", async () => {
const handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const mockCreate = handler["client"].messages.create as any
mockCreate.mockClear()

const controller = new AbortController()
const mockAbortSignal = controller.signal

const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: [{ type: "text" as const, text: "Hello!" }] },
]

for await (const _chunk of handler.createMessage(systemPrompt, messages, {
taskId: "test",
abortSignal: mockAbortSignal,
})) {
break
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBe(mockAbortSignal)
})

it("should pass undefined signal when abortSignal is not provided", async () => {
const handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const mockCreate = handler["client"].messages.create as any
mockCreate.mockClear()

const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: [{ type: "text" as const, text: "Hello!" }] },
]

for await (const _chunk of handler.createMessage(systemPrompt, messages)) {
break
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBeUndefined()
})
})
})
89 changes: 81 additions & 8 deletions src/api/providers/__tests__/anthropic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,17 @@ describe("AnthropicHandler", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
})
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
undefined,
)
})

it("should handle API errors", async () => {
Expand All @@ -462,6 +465,23 @@ describe("AnthropicHandler", () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})

it("should pass abortSignal to messages.create when provided in metadata", async () => {
const controller = new AbortController()
const mockAbortSignal = controller.signal

await handler.completePrompt("Test prompt", { taskId: "test", abortSignal: mockAbortSignal })

const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBe(mockAbortSignal)
})

it("should pass undefined signal when abortSignal is not provided", async () => {
await handler.completePrompt("Test prompt")

const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBeUndefined()
})
})

describe("getModel", () => {
Expand Down Expand Up @@ -1055,4 +1075,57 @@ describe("AnthropicHandler", () => {
})
})
})

describe("abortSignal support", () => {
it("should pass abortSignal to messages.create when provided in metadata", async () => {
const handler = new AnthropicHandler({
apiKey: "test-api-key",
apiModelId: "claude-3-5-sonnet-20241022",
})

const controller = new AbortController()
const mockAbortSignal = controller.signal

mockCreate.mockResolvedValueOnce({
[Symbol.asyncIterator]: async function* () {
yield { type: "message_start", message: { usage: { input_tokens: 10, output_tokens: 5 } } }
},
})

for await (const _chunk of handler.createMessage(
"system",
[{ role: "user", content: [{ type: "text", text: "Hello!" }] }],
{ taskId: "test", abortSignal: mockAbortSignal },
)) {
break
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBe(mockAbortSignal)
})

it("should pass undefined signal when abortSignal is not provided", async () => {
const handler = new AnthropicHandler({
apiKey: "test-api-key",
apiModelId: "claude-3-5-sonnet-20241022",
})

mockCreate.mockResolvedValueOnce({
[Symbol.asyncIterator]: async function* () {
yield { type: "message_start", message: { usage: { input_tokens: 10, output_tokens: 5 } } }
},
})

for await (const _chunk of handler.createMessage("system", [
{ role: "user", content: [{ type: "text", text: "Hello!" }] },
])) {
break
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][1]
expect(callArgs?.signal).toBeUndefined()
})
})
})
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ describe("BaseOpenAiCompatibleProvider", () => {
stream: true,
stream_options: { include_usage: true },
}),
undefined,
expect.any(Object),
)
})

Expand Down
72 changes: 72 additions & 0 deletions src/api/providers/__tests__/bedrock.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1576,4 +1576,76 @@ describe("AwsBedrockHandler", () => {
})
})
})

describe("abort signal", () => {
beforeEach(() => {
mockConverseStreamCommand.mockReset()
})

it("should handle pre-aborted signals by calling controller.abort() immediately", async () => {
const handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})

const controller = new AbortController()
controller.abort() // Pre-abort the signal

const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]
const generator = handler.createMessage("System prompt", messages, {
taskId: "test-task",
abortSignal: controller.signal,
})

// Verify the external signal is already aborted before consuming
expect(controller.signal.aborted).toBe(true)

// Consume the stream - pre-aborted signal should trigger internal abort
for await (const _ of generator) {
// consume
}
})
Comment on lines +1585 to +1609

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Abort tests currently pass without proving cancellation propagation.

At Line 1603 and Line 1627, both tests only consume/suppress iteration; they never assert that the request signal became aborted. With the current mock stream shape, these can pass even if abort wiring breaks.

Suggested assertion tightening
 it("should handle pre-aborted signals by calling controller.abort() immediately", async () => {
@@
-  // Consume the stream - pre-aborted signal should trigger internal abort
-  for await (const _ of generator) {
-    // consume
-  }
+  await generator.next()
+
+  const clientInstance =
+    vi.mocked(BedrockRuntimeClient).mock.results[vi.mocked(BedrockRuntimeClient).mock.results.length - 1]?.value
+  const mockSendFn = clientInstance?.send as ReturnType<typeof vi.fn>
+  const sendOptions = mockSendFn.mock.calls[0][1]
+  expect(sendOptions.abortSignal).toBeDefined()
+  expect(sendOptions.abortSignal.aborted).toBe(true)
 })
@@
 it("should use { once: true } listener for external abort signal", async () => {
@@
-  controller.abort()
+  const clientInstance =
+    vi.mocked(BedrockRuntimeClient).mock.results[vi.mocked(BedrockRuntimeClient).mock.results.length - 1]?.value
+  const mockSendFn = clientInstance?.send as ReturnType<typeof vi.fn>
+  const sendOptions = mockSendFn.mock.calls[0][1]
+  expect(sendOptions.abortSignal.aborted).toBe(false)
+
+  controller.abort()
+  expect(sendOptions.abortSignal.aborted).toBe(true)
@@
-  await consumePromise.catch(() => {})
+  await consumePromise.catch(() => {})
 })

Also applies to: 1608-1639

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/api/providers/__tests__/bedrock.spec.ts` around lines 1585 - 1606, The
tests for AwsBedrockHandler.createMessage currently only iterate the generator
and never assert cancellation propagation; update both tests (the pre-aborted
case and the mid-stream abort case) to explicitly assert that the request/abort
signal was triggered: for the pre-aborted test, assert controller.signal.aborted
is true immediately after consuming the generator or assert that the generator
iteration throws an AbortError/DOMException; for the mid-stream-abort test, set
up a mock that records whether the internal request was aborted (or assert the
thrown AbortError) and assert that the internal controller.signal (or the passed
controller.signal) becomes aborted after you call controller.abort(); reference
AwsBedrockHandler.createMessage and the test generator variable to locate where
to add these assertions.


it("should use { once: true } listener for external abort signal", async () => {
const handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})

const controller = new AbortController()

const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]
const generator = handler.createMessage("System prompt", messages, {
taskId: "test-task",
abortSignal: controller.signal,
})

// Start consuming and then abort
const consumePromise = (async () => {
try {
for await (const _ of generator) {
// consume
}
} catch {
// expected to fail due to mock
}
})()

await new Promise((r) => setTimeout(r, 10))

// Verify signal is not aborted before calling abort
expect(controller.signal.aborted).toBe(false)

controller.abort()

// Verify the signal becomes aborted after calling abort()
expect(controller.signal.aborted).toBe(true)

await consumePromise.catch(() => {})
})
})
})
40 changes: 40 additions & 0 deletions src/api/providers/__tests__/deepseek.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -639,4 +639,44 @@ describe("DeepSeekHandler", () => {
expect(toolCallChunks[0].name).toBe("get_weather")
})
})

describe("abortSignal support", () => {
it("should pass abortSignal to chat.completions.create when provided in metadata", async () => {
const handler = new DeepSeekHandler({ ...mockOptions, apiKey: "test-key" })
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: [{ type: "text" as const, text: "Hello!" }] },
]

const controller = new AbortController()
const mockAbortSignal = controller.signal

for await (const _chunk of handler.createMessage(systemPrompt, messages, {
taskId: "test",
abortSignal: mockAbortSignal,
})) {
break
}

expect(mockCreate).toHaveBeenCalled()
const requestOptions = mockCreate.mock.calls[0][1]
expect(requestOptions?.signal).toBe(mockAbortSignal)
})

it("should not include signal when abortSignal is not provided", async () => {
const handler = new DeepSeekHandler({ ...mockOptions, apiKey: "test-key" })
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: [{ type: "text" as const, text: "Hello!" }] },
]

for await (const _chunk of handler.createMessage(systemPrompt, messages)) {
break
}

expect(mockCreate).toHaveBeenCalled()
const requestOptions = mockCreate.mock.calls[0][1]
expect(requestOptions?.signal).toBeUndefined()
})
})
})
Loading
Loading