diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index 156df8e540..507a81d9e2 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -1572,10 +1572,55 @@ describe("AwsBedrockHandler", () => { it("returns false for older / non-adaptive models", () => { expect(isAdaptiveThinkingModel("anthropic.claude-opus-4-6-v1")).toBe(false) - expect(isAdaptiveThinkingModel("anthropic.claude-sonnet-4-6")).toBe(false) - expect(isAdaptiveThinkingModel("anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe(false) expect(isAdaptiveThinkingModel("amazon.nova-lite-v1:0")).toBe(false) }) }) }) + + describe("abortSignal pass-through", () => { + it("should use metadata.abortSignal when provided in completePrompt", async () => { + const mockConverseCommand = vi.fn() + vi.mock("@aws-sdk/client-bedrock-runtime", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + BedrockRuntimeClient: vi.fn().mockReturnValue({ send: vi.fn().mockResolvedValue({ stream: null }) }), + } + }) + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockConverseCommand).toHaveBeenCalled() + }) + + it("should use default controller signal when metadata.abortSignal not provided", async () => { + const mockConverseCommand = vi.fn() + vi.mock("@aws-sdk/client-bedrock-runtime", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + BedrockRuntimeClient: vi.fn().mockReturnValue({ send: vi.fn().mockResolvedValue({ stream: null }) }), + } + }) + + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + await handler.completePrompt("Test prompt") + + expect(mockConverseCommand).toHaveBeenCalled() + }) + }) }) diff --git a/src/api/providers/__tests__/native-ollama.spec.ts b/src/api/providers/__tests__/native-ollama.spec.ts index 200868022b..70cb217778 100644 --- a/src/api/providers/__tests__/native-ollama.spec.ts +++ b/src/api/providers/__tests__/native-ollama.spec.ts @@ -226,6 +226,34 @@ describe("NativeOllamaHandler", () => { }), ) }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + mockChat.mockResolvedValue({ + message: { content: "Response with abort signal" }, + }) + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockChat).toHaveBeenCalledWith( + expect.objectContaining({ + options: expect.any(Object), + }), + ) + // Verify the call was made (signal is passed via Ollama SDK options) + const callArgs = mockChat.mock.calls[0][0] as any + expect(callArgs).toBeDefined() + }) + + it("should work without metadata in completePrompt", async () => { + mockChat.mockResolvedValue({ + message: { content: "Response without metadata" }, + }) + + await handler.completePrompt("Test prompt") + + expect(mockChat).toHaveBeenCalled() + }) }) describe("error handling", () => { diff --git a/src/api/providers/__tests__/openai-codex.spec.ts b/src/api/providers/__tests__/openai-codex.spec.ts index dcc0c4d035..9f24e3c8ea 100644 --- a/src/api/providers/__tests__/openai-codex.spec.ts +++ b/src/api/providers/__tests__/openai-codex.spec.ts @@ -42,3 +42,52 @@ describe("OpenAiCodexHandler.getModel", () => { expect(model.info).toBeDefined() }) }) + +describe("OpenAiCodexHandler.completePrompt", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + output: [{ type: "message", content: [{ type: "output_text", text: "Codex response" }] }], + }), + }) + global.fetch = mockFetch as any + + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.3-codex" }) + const controller = new AbortController() + + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockFetch).toHaveBeenCalledWith( + expect.stringContaining("/responses"), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default abortController signal when metadata.abortSignal not provided", async () => { + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue({ + output: [{ type: "message", content: [{ type: "output_text", text: "Codex response" }] }], + }), + }) + global.fetch = mockFetch as any + + const handler = new OpenAiCodexHandler({ apiModelId: "gpt-5.3-codex" }) + + await handler.completePrompt("Test prompt") + + expect(mockFetch).toHaveBeenCalledWith( + expect.stringContaining("/responses"), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) +}) diff --git a/src/api/providers/__tests__/openai-native.spec.ts b/src/api/providers/__tests__/openai-native.spec.ts index 4d35387992..3d3026a646 100644 --- a/src/api/providers/__tests__/openai-native.spec.ts +++ b/src/api/providers/__tests__/openai-native.spec.ts @@ -216,7 +216,7 @@ describe("OpenAiNativeHandler", () => { ], }), expect.objectContaining({ - signal: expect.any(Object), + signal: expect.any(AbortSignal), }), ) }) @@ -245,6 +245,49 @@ describe("OpenAiNativeHandler", () => { expect(result).toBe("") }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "This is the completion response" }], + }, + ], + }) + + const controller = new AbortController() + const result = await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(result).toBe("This is the completion response") + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default abortController signal when metadata.abortSignal not provided", async () => { + mockResponsesCreate.mockResolvedValue({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: "This is the completion response" }], + }, + ], + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("This is the completion response") + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: handler["abortController"].signal, + }), + ) + }) }) describe("getModel", () => { @@ -1624,7 +1667,7 @@ describe("GPT-5 streaming event coverage (additional)", () => { store: false, }), expect.objectContaining({ - signal: expect.any(Object), + signal: expect.any(AbortSignal), }), ) }) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 5b009f40d0..369cd1f846 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -544,7 +544,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH const command = new ConverseStreamCommand(payload) const response = await this.client.send(command, { - abortSignal: controller.signal, + abortSignal: metadata?.abortSignal ?? controller.signal, }) if (!response.stream) { @@ -793,7 +793,8 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { + const controller = new AbortController() try { const modelConfig = this.getModel() @@ -835,7 +836,9 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } const command = new ConverseCommand(payload) - const response = await this.client.send(command) + const response = await this.client.send(command, { + abortSignal: metadata?.abortSignal ?? controller.signal, + }) if ( response?.output?.message?.content && diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index 99c1dc03cf..85f1f1d5d6 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -234,14 +234,14 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio chatOptions.num_ctx = this.options.ollamaNumCtx } - // Create the actual API request promise - const stream = await client.chat({ + const controller = new AbortController() + const stream = (await client.chat({ model: modelId, messages: ollamaMessages, stream: true, options: chatOptions, tools: this.convertToolsToOllama(metadata?.tools), - }) + })) as any let totalInputTokens = 0 let totalOutputTokens = 0 diff --git a/src/api/providers/openai-codex.ts b/src/api/providers/openai-codex.ts index b5891c0e47..1554747e7d 100644 --- a/src/api/providers/openai-codex.ts +++ b/src/api/providers/openai-codex.ts @@ -188,7 +188,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion // Make the request with retry on auth failure for (let attempt = 0; attempt < 2; attempt++) { try { - yield* this.executeRequest(requestBody, model, accessToken, metadata?.taskId) + yield* this.executeRequest(requestBody, model, accessToken, metadata?.taskId, metadata) return } catch (error) { const message = error instanceof Error ? error.message : String(error) @@ -345,6 +345,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion model: OpenAiCodexModel, accessToken: string, taskId?: string, + metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { // Create AbortController for cancellation this.abortController = new AbortController() @@ -374,7 +375,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion }) const stream = (await (client as any).responses.create(requestBody, { - signal: this.abortController.signal, + signal: metadata?.abortSignal ?? this.abortController.signal, // If the SDK supports per-request overrides, ensure headers are present. headers: codexHeaders, })) as AsyncIterable @@ -399,7 +400,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion } } catch (_sdkErr) { // Fallback to manual SSE via fetch (Codex backend). - yield* this.makeCodexRequest(requestBody, model, accessToken, taskId) + yield* this.makeCodexRequest(requestBody, model, accessToken, taskId, metadata) } } finally { this.abortController = undefined @@ -492,6 +493,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion model: OpenAiCodexModel, accessToken: string, taskId?: string, + metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { // Per the implementation guide: route to Codex backend with Bearer token const url = `${CODEX_API_BASE_URL}/responses` @@ -518,7 +520,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion method: "POST", headers, body: JSON.stringify(requestBody), - signal: this.abortController?.signal, + signal: metadata?.abortSignal ?? this.abortController?.signal, }) if (!response.ok) { @@ -1151,7 +1153,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion return this.lastResponseId } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { this.abortController = new AbortController() try { @@ -1213,7 +1215,7 @@ export class OpenAiCodexHandler extends BaseProvider implements SingleCompletion method: "POST", headers, body: JSON.stringify(requestBody), - signal: this.abortController.signal, + signal: metadata?.abortSignal ?? this.abortController.signal, }) if (!response.ok) { diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index 37545f9979..bb9e260571 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -424,8 +424,9 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio try { // Use the official SDK with per-request headers + const signal = metadata?.abortSignal ?? this.abortController.signal const stream = (await (this.client as any).responses.create(requestBody, { - signal: this.abortController.signal, + signal, headers: requestHeaders, })) as AsyncIterable @@ -576,7 +577,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio "User-Agent": userAgent, }, body: JSON.stringify(requestBody), - signal: this.abortController.signal, + signal: metadata?.abortSignal ?? this.abortController.signal, }) if (!response.ok) { @@ -1482,7 +1483,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio return this.lastResponseId } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { // Create AbortController for cancellation this.abortController = new AbortController() @@ -1544,9 +1545,9 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio requestBody.prompt_cache_retention = promptCacheRetention } - // Make the non-streaming request + // Make the non-streaming request with conditional signal pass-through const response = await (this.client as any).responses.create(requestBody, { - signal: this.abortController.signal, + signal: metadata?.abortSignal ?? this.abortController.signal, }) // Extract text from the response