diff --git a/src/api/providers/__tests__/deepseek.spec.ts b/src/api/providers/__tests__/deepseek.spec.ts index 4ea60a24ea..69ae2c6d82 100644 --- a/src/api/providers/__tests__/deepseek.spec.ts +++ b/src/api/providers/__tests__/deepseek.spec.ts @@ -494,7 +494,7 @@ describe("DeepSeekHandler", () => { expect.objectContaining({ thinking: { type: "enabled" }, }), - {}, // Empty path options for non-Azure URLs + undefined, // No signal, non-Azure URL ) const callArgs = mockCreate.mock.calls[0][0] expect(callArgs.reasoning_effort).toBeUndefined() @@ -517,7 +517,7 @@ describe("DeepSeekHandler", () => { reasoning_effort: "high", max_completion_tokens: 200_000, }), - {}, + undefined, ) }) @@ -554,10 +554,27 @@ describe("DeepSeekHandler", () => { thinking: { type: "enabled" }, reasoning_effort: "max", }), - {}, + undefined, ) }) + it("should pass Azure path even when abortSignal is not provided", async () => { + const azureHandler = new DeepSeekHandler({ + ...mockOptions, + deepSeekBaseUrl: "https://example.services.ai.azure.com", + }) + + const stream = azureHandler.createMessage(systemPrompt, messages) + for await (const _chunk of stream) { + // Consume the stream + } + + expect(mockCreate).toHaveBeenCalledTimes(1) + const [requestOptions, signalArg] = mockCreate.mock.calls[0] + expect(requestOptions.model).toBe("deepseek-chat") + expect(signalArg).toBeUndefined() + }) + it("should disable thinking for deepseek-v4 models when reasoning is disabled", async () => { const v4Handler = new DeepSeekHandler({ ...mockOptions, @@ -640,5 +657,48 @@ describe("DeepSeekHandler", () => { expect(toolCallChunks.length).toBeGreaterThan(0) expect(toolCallChunks[0].name).toBe("get_weather") }) + + it("should use metadata.abortSignal when provided in createMessage", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValueOnce({ + id: "test-completion", + choices: [{ message: { content: "Response with abort signal" } }], + }) + + const stream = handler.createMessage(systemPrompt, messages, { abortSignal: controller.signal }) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in createMessage", async () => { + mockCreate.mockResolvedValueOnce({ + id: "test-completion", + choices: [{ message: { content: "Response without metadata" } }], + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) }) }) diff --git a/src/api/providers/__tests__/lmstudio-native-tools.spec.ts b/src/api/providers/__tests__/lmstudio-native-tools.spec.ts index c6870e80ef..6244275c8c 100644 --- a/src/api/providers/__tests__/lmstudio-native-tools.spec.ts +++ b/src/api/providers/__tests__/lmstudio-native-tools.spec.ts @@ -83,6 +83,7 @@ describe("LmStudioHandler Native Tools", () => { }), ]), }), + undefined, ) // parallel_tool_calls should be true by default when not explicitly set const callArgs = mockCreate.mock.calls[0][0] @@ -109,6 +110,7 @@ describe("LmStudioHandler Native Tools", () => { expect.objectContaining({ tool_choice: "auto", }), + undefined, ) }) @@ -221,9 +223,9 @@ describe("LmStudioHandler Native Tools", () => { expect.objectContaining({ parallel_tool_calls: true, }), + undefined, ) }) - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { mockCreate.mockImplementationOnce(() => ({ [Symbol.asyncIterator]: async function* () { diff --git a/src/api/providers/__tests__/lmstudio.spec.ts b/src/api/providers/__tests__/lmstudio.spec.ts index c6ebd8a6e9..8bae7125d5 100644 --- a/src/api/providers/__tests__/lmstudio.spec.ts +++ b/src/api/providers/__tests__/lmstudio.spec.ts @@ -133,12 +133,15 @@ describe("LmStudioHandler", () => { it("should complete prompt successfully", async () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.lmStudioModelId, - messages: [{ role: "user", content: "Test prompt" }], - temperature: 0, - stream: false, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.lmStudioModelId, + messages: [{ role: "user", content: "Test prompt" }], + temperature: 0, + stream: false, + }, + undefined, + ) }) it("should handle API errors", async () => { @@ -155,6 +158,39 @@ describe("LmStudioHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "Response with abort signal" } }], + }) + + const controller = new AbortController() + const result = await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(result).toBe("Response with abort signal") + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in completePrompt", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ message: { content: "Response without metadata" } }], + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Response without metadata") + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/mimo.spec.ts b/src/api/providers/__tests__/mimo.spec.ts index 9e7ec97d28..7ec32dac13 100644 --- a/src/api/providers/__tests__/mimo.spec.ts +++ b/src/api/providers/__tests__/mimo.spec.ts @@ -376,6 +376,7 @@ describe("MimoHandler", () => { expect.objectContaining({ extra_body: { thinking: { type: "enabled" } }, }), + undefined, ) }) diff --git a/src/api/providers/__tests__/minimax.spec.ts b/src/api/providers/__tests__/minimax.spec.ts index d87ae1190b..28349b1ffc 100644 --- a/src/api/providers/__tests__/minimax.spec.ts +++ b/src/api/providers/__tests__/minimax.spec.ts @@ -220,6 +220,37 @@ describe("MiniMaxHandler", () => { await expect(handler.completePrompt("test prompt")).rejects.toThrow() }) + it("should use metadata.abortSignal when provided in completePrompt", async () => { + mockCreate.mockResolvedValueOnce({ + data: [{ type: "text", text: "Response with abort signal" }], + }) + + const controller = new AbortController() + const result = await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in completePrompt", async () => { + mockCreate.mockResolvedValueOnce({ + data: [{ type: "text", text: "Response without metadata" }], + }) + + await handler.completePrompt("test prompt") + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) + it("createMessage should yield text content from stream", async () => { const testContent = "This is test content from MiniMax stream" @@ -305,6 +336,7 @@ describe("MiniMaxHandler", () => { messages: expect.any(Array), stream: true, }), + undefined, ) }) @@ -324,6 +356,7 @@ describe("MiniMaxHandler", () => { expect.objectContaining({ temperature: 1, }), + undefined, ) }) diff --git a/src/api/providers/__tests__/mistral.spec.ts b/src/api/providers/__tests__/mistral.spec.ts index 96e42e356b..61f0af42af 100644 --- a/src/api/providers/__tests__/mistral.spec.ts +++ b/src/api/providers/__tests__/mistral.spec.ts @@ -497,5 +497,40 @@ describe("MistralHandler", () => { mockComplete.mockRejectedValueOnce(new Error("API Error")) await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Mistral completion error: API Error") }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + mockComplete.mockResolvedValueOnce({ + id: "test-completion", + choices: [{ message: { content: "Response with abort signal" } }], + }) + + const controller = new AbortController() + const result = await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(result).toBe("Response with abort signal") + expect(mockComplete).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in completePrompt", async () => { + mockComplete.mockResolvedValueOnce({ + id: "test-completion", + choices: [{ message: { content: "Response without metadata" } }], + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Response without metadata") + expect(mockComplete).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) }) }) diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index ed5e82496e..2bf6d6918c 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -600,6 +600,49 @@ describe("OpenAiHandler", () => { } }).rejects.toThrow("Rate limit exceeded") }) + + it("should use metadata.abortSignal when provided in createMessage", async () => { + const controller = new AbortController() + mockCreate.mockResolvedValue({ + id: "test-completion", + choices: [{ message: { content: "Response with abort signal" } }], + }) + + const stream = handler.createMessage("system prompt", testMessages, { abortSignal: controller.signal }) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in createMessage", async () => { + mockCreate.mockResolvedValue({ + id: "test-completion", + choices: [{ message: { content: "Response without metadata" } }], + }) + + const stream = handler.createMessage("system prompt", testMessages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) }) describe("completePrompt", () => { @@ -627,6 +670,41 @@ describe("OpenAiHandler", () => { const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + mockCreate.mockResolvedValue({ + id: "test-completion", + choices: [{ message: { content: "Response with abort signal" } }], + }) + + const controller = new AbortController() + const result = await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(result).toBe("Response with abort signal") + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in completePrompt", async () => { + mockCreate.mockResolvedValue({ + id: "test-completion", + choices: [{ message: { content: "Response without metadata" } }], + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Response without metadata") + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) }) describe("getModel", () => { diff --git a/src/api/providers/__tests__/poe.spec.ts b/src/api/providers/__tests__/poe.spec.ts index b22d42179c..4aaa4b4c9c 100644 --- a/src/api/providers/__tests__/poe.spec.ts +++ b/src/api/providers/__tests__/poe.spec.ts @@ -309,5 +309,20 @@ describe("PoeHandler", () => { }), ) }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + const handler = new PoeHandler({ poeApiKey: "key", apiModelId: "openai/gpt-4o" }) + mockGenerateText.mockResolvedValue({ text: "response with abort signal" }) + + const controller = new AbortController() + await handler.completePrompt("complete this", { abortSignal: controller.signal }) + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + model: mockLanguageModel, + prompt: "complete this", + }), + ) + }) }) }) diff --git a/src/api/providers/__tests__/requesty.spec.ts b/src/api/providers/__tests__/requesty.spec.ts index 7556aa58f6..2a4c320647 100644 --- a/src/api/providers/__tests__/requesty.spec.ts +++ b/src/api/providers/__tests__/requesty.spec.ts @@ -210,6 +210,7 @@ describe("RequestyHandler", () => { stream_options: { include_usage: true }, temperature: 0, }), + undefined, ) }) @@ -243,6 +244,7 @@ describe("RequestyHandler", () => { thinking: { type: "adaptive" }, temperature: undefined, }), + undefined, ) }) @@ -314,6 +316,7 @@ describe("RequestyHandler", () => { ]), tool_choice: "auto", }), + undefined, ) }) @@ -412,12 +415,15 @@ describe("RequestyHandler", () => { expect(result).toBe("test completion") - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.requestyModelId, - max_tokens: 8192, - messages: [{ role: "system", content: "test prompt" }], - temperature: 0, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: mockOptions.requestyModelId, + max_tokens: 8192, + messages: [{ role: "system", content: "test prompt" }], + temperature: 0, + }, + undefined, + ) }) it("omits temperature for Claude Fable 5 in completePrompt", async () => { @@ -429,12 +435,15 @@ describe("RequestyHandler", () => { await handler.completePrompt("test prompt") - expect(mockCreate).toHaveBeenCalledWith({ - model: "anthropic/claude-fable-5", - max_tokens: 8192, - messages: [{ role: "system", content: "test prompt" }], - temperature: undefined, - }) + expect(mockCreate).toHaveBeenCalledWith( + { + model: "anthropic/claude-fable-5", + max_tokens: 8192, + messages: [{ role: "system", content: "test prompt" }], + temperature: undefined, + }, + undefined, + ) }) it("handles API errors", async () => { @@ -445,11 +454,40 @@ describe("RequestyHandler", () => { await expect(handler.completePrompt("test prompt")).rejects.toThrow("API Error") }) - it("handles unexpected errors", async () => { + it("should handle unexpected errors", async () => { const handler = new RequestyHandler(mockOptions) mockCreate.mockRejectedValue(new Error("Unexpected error")) await expect(handler.completePrompt("test prompt")).rejects.toThrow("Unexpected error") }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + const handler = new RequestyHandler(mockOptions) + mockCreate.mockResolvedValue({ choices: [{ message: { content: "response" } }] }) + + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in completePrompt", async () => { + const handler = new RequestyHandler(mockOptions) + mockCreate.mockResolvedValue({ choices: [{ message: { content: "response" } }] }) + + await handler.completePrompt("test prompt") + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) }) }) diff --git a/src/api/providers/__tests__/unbound.spec.ts b/src/api/providers/__tests__/unbound.spec.ts index d741f2a371..4d30d51d71 100644 --- a/src/api/providers/__tests__/unbound.spec.ts +++ b/src/api/providers/__tests__/unbound.spec.ts @@ -90,6 +90,66 @@ describe("UnboundHandler", () => { mode: "architect", }, }), + undefined, + ) + }) + + it("should use metadata.abortSignal when provided in createMessage", async () => { + const mockCreate = (OpenAI as unknown as any)().chat.completions.create + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "ok" } }] } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 1, completion_tokens: 1 } } + }, + }) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + const controller = new AbortController() + const stream = handler.createMessage("system", [{ role: "user" as const, content: "hello" }], { + abortSignal: controller.signal, + }) + + for await (const _chunk of stream) { + // drain stream + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in createMessage", async () => { + const mockCreate = (OpenAI as unknown as any)().chat.completions.create + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "ok" } }] } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 1, completion_tokens: 1 } } + }, + }) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + const stream = handler.createMessage("system", [{ role: "user" as const, content: "hello" }]) + + for await (const _chunk of stream) { + // drain stream + } + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) }) }) diff --git a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts index 2342122e17..745adfaa28 100644 --- a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts +++ b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts @@ -263,9 +263,33 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ temperature: customTemp, }), + undefined, ) }) + it("forwards abortSignal to createMessage", async () => { + const controller = new AbortController() + const handler = new VercelAiGatewayHandler(mockOptions) + + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + await handler.createMessage(systemPrompt, messages, { abortSignal: controller.signal }).next() + + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({}), { signal: controller.signal }) + }) + + it("forwards abortSignal to completePrompt", async () => { + const controller = new AbortController() + const handler = new VercelAiGatewayHandler(mockOptions) + + mockCreate.mockResolvedValue({ choices: [{ message: { content: "test completion" } }] }) + + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({}), { signal: controller.signal }) + }) + it("uses default temperature when none provided", async () => { const handler = new VercelAiGatewayHandler(mockOptions) @@ -278,6 +302,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ temperature: VERCEL_AI_GATEWAY_DEFAULT_TEMPERATURE, }), + undefined, ) }) @@ -295,6 +320,7 @@ describe("VercelAiGatewayHandler", () => { temperature: undefined, max_completion_tokens: 128000, }), + undefined, ) }) @@ -325,6 +351,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ max_completion_tokens: 64000, // max tokens for sonnet 4 }), + undefined, ) }) @@ -403,6 +430,7 @@ describe("VercelAiGatewayHandler", () => { }), ]), }), + undefined, ) }) @@ -420,6 +448,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ tool_choice: "auto", }), + undefined, ) }) @@ -437,6 +466,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ parallel_tool_calls: true, }), + undefined, ) }) @@ -454,6 +484,7 @@ describe("VercelAiGatewayHandler", () => { tools: expect.any(Array), parallel_tool_calls: true, }), + undefined, ) }) @@ -553,6 +584,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ stream_options: { include_usage: true }, }), + undefined, ) }) }) @@ -591,6 +623,7 @@ describe("VercelAiGatewayHandler", () => { temperature: VERCEL_AI_GATEWAY_DEFAULT_TEMPERATURE, max_completion_tokens: 64000, }), + undefined, ) }) @@ -607,6 +640,7 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ temperature: customTemp, }), + undefined, ) }) @@ -655,6 +689,42 @@ describe("VercelAiGatewayHandler", () => { expect.objectContaining({ temperature: 0.9, }), + undefined, + ) + }) + + it("should use metadata.abortSignal when provided in completePrompt", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + mockCreate.mockResolvedValueOnce({ + id: "test-completion", + choices: [{ message: { content: "Response with abort signal" } }], + }) + + const controller = new AbortController() + await handler.completePrompt("Test prompt", { abortSignal: controller.signal }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in completePrompt", async () => { + const handler = new VercelAiGatewayHandler(mockOptions) + mockCreate.mockResolvedValueOnce({ + id: "test-completion", + choices: [{ message: { content: "Response without metadata" } }], + }) + + await handler.completePrompt("Test prompt") + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) }) }) diff --git a/src/api/providers/__tests__/zai.spec.ts b/src/api/providers/__tests__/zai.spec.ts index 19ecc45ec2..d4ca7883d7 100644 --- a/src/api/providers/__tests__/zai.spec.ts +++ b/src/api/providers/__tests__/zai.spec.ts @@ -385,6 +385,33 @@ describe("ZAiHandler", () => { ) }) + it("should use metadata.abortSignal when provided in completePrompt", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "Response with abort signal" } }] }) + + const controller = new AbortController() + await handler.completePrompt("test prompt", { abortSignal: controller.signal }) + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should use default signal when metadata.abortSignal not provided in completePrompt", async () => { + mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: "Response without metadata" } }] }) + + await handler.completePrompt("test prompt") + + expect(mockCreate).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ) + }) + it("createMessage should yield text content from stream", async () => { const testContent = "This is test content from Z AI stream" @@ -502,6 +529,7 @@ describe("ZAiHandler", () => { model: "glm-5.1", max_tokens: 40_000, }), + undefined, ) }) @@ -542,6 +570,7 @@ describe("ZAiHandler", () => { model: "glm-5.1", max_tokens: 100_000, }), + undefined, ) }) @@ -572,6 +601,7 @@ describe("ZAiHandler", () => { model: "glm-4.7", thinking: { type: "enabled" }, }), + undefined, ) }) @@ -603,6 +633,7 @@ describe("ZAiHandler", () => { model: "glm-4.7", thinking: { type: "disabled" }, }), + undefined, ) }) @@ -634,6 +665,7 @@ describe("ZAiHandler", () => { model: "glm-4.7", thinking: { type: "enabled" }, }), + undefined, ) }) @@ -687,6 +719,7 @@ describe("ZAiHandler", () => { model: "glm-5-turbo", thinking: { type: "enabled" }, }), + undefined, ) }) @@ -717,6 +750,7 @@ describe("ZAiHandler", () => { model: "glm-5-turbo", thinking: { type: "disabled" }, }), + undefined, ) }) }) diff --git a/src/api/providers/deepseek.ts b/src/api/providers/deepseek.ts index e2ffd29169..25eee22747 100644 --- a/src/api/providers/deepseek.ts +++ b/src/api/providers/deepseek.ts @@ -133,7 +133,11 @@ export class DeepSeekHandler extends OpenAiHandler { try { stream = await this.client.chat.completions.create( requestOptions as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + metadata?.abortSignal + ? isAzureAiInference + ? { path: OPENAI_AZURE_AI_INFERENCE_PATH, signal: metadata.abortSignal } + : { signal: metadata.abortSignal } + : undefined, ) } catch (error) { const { handleOpenAIError } = await import("./utils/openai-error-handler") diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index a771394c53..22d284dd45 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -99,7 +99,10 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan let results try { - results = await this.client.chat.completions.create(params) + results = await this.client.chat.completions.create( + params, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -185,7 +188,7 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { try { // Create params object with optional draft model const params: any = { @@ -202,7 +205,10 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan let response try { - response = await this.client.chat.completions.create(params) + response = await this.client.chat.completions.create( + params, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/mimo.ts b/src/api/providers/mimo.ts index f842926b60..bf2725aea4 100644 --- a/src/api/providers/mimo.ts +++ b/src/api/providers/mimo.ts @@ -99,7 +99,10 @@ export class MimoHandler extends OpenAiHandler { let stream: AsyncIterable try { - stream = (await this.client.chat.completions.create(params as any)) as any + stream = (await this.client.chat.completions.create( + params as any, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + )) as any } catch (error) { throw handleProviderError(error, "MiMo") } diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index bfcf4e3be4..fd66bb182f 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -113,7 +113,10 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand tool_choice: convertOpenAIToolChoice(metadata?.tool_choice), } - stream = await this.client.messages.create(requestParams) + stream = await this.client.messages.create( + requestParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) let inputTokens = 0 let outputTokens = 0 @@ -289,16 +292,19 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand } } - async completePrompt(prompt: string) { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata) { const { id: model, temperature } = this.getModel() - const message = await this.client.messages.create({ - model, - max_tokens: 16_384, - temperature: temperature ?? 1.0, - messages: [{ role: "user", content: prompt }], - stream: false, - }) + const message = await this.client.messages.create( + { + model, + max_tokens: 16_384, + temperature: temperature ?? 1.0, + messages: [{ role: "user", content: prompt }], + stream: false, + }, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) const content = message.content.find(({ type }) => type === "text") return content?.type === "text" ? content.text : "" diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts index e0e19298f4..2f69e0222a 100644 --- a/src/api/providers/mistral.ts +++ b/src/api/providers/mistral.ts @@ -103,7 +103,7 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand let response try { - response = await this.client.chat.stream(requestOptions) + response = await this.client.chat.stream({ ...requestOptions, signal: metadata?.abortSignal } as any) } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) const apiError = new ApiProviderError(errorMessage, this.providerName, model, "createMessage") @@ -193,10 +193,11 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand return { id, info, maxTokens, temperature } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, _metadata?: ApiHandlerCreateMessageMetadata): Promise { const { id: model, temperature } = this.getModel() try { + // Note: Mistral non-streaming API doesn't support cancellation for completePrompt const response = await this.client.chat.complete({ model, messages: [{ role: "user", content: prompt }], diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 532ed38ba2..01d73d4726 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -179,7 +179,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl try { stream = await this.client.chat.completions.create( requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + metadata?.abortSignal + ? { + ...(isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: metadata.abortSignal, + } + : isAzureAiInference + ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } + : {}, ) } catch (error) { throw handleOpenAIError(error, this.providerName) @@ -247,7 +254,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl try { response = await this.client.chat.completions.create( requestOptions, - this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + metadata?.abortSignal + ? { + ...(this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: metadata.abortSignal, + } + : this._isAzureAiInference(modelUrl) + ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } + : {}, ) } catch (error) { throw handleOpenAIError(error, this.providerName) @@ -300,7 +314,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl return { id, info, ...params } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { try { const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) const model = this.getModel() @@ -318,7 +332,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl try { response = await this.client.chat.completions.create( requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + metadata?.abortSignal + ? { + ...(isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: metadata.abortSignal, + } + : isAzureAiInference + ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } + : {}, ) } catch (error) { throw handleOpenAIError(error, this.providerName) @@ -374,7 +395,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl try { stream = await this.client.chat.completions.create( requestOptions, - methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + metadata?.abortSignal + ? { + ...(methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: metadata.abortSignal, + } + : methodIsAzureAiInference + ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } + : {}, ) } catch (error) { throw handleOpenAIError(error, this.providerName) @@ -408,7 +436,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl try { response = await this.client.chat.completions.create( requestOptions, - methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + metadata?.abortSignal + ? { + ...(methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: metadata.abortSignal, + } + : methodIsAzureAiInference + ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } + : {}, ) } catch (error) { throw handleOpenAIError(error, this.providerName) diff --git a/src/api/providers/poe.ts b/src/api/providers/poe.ts index 536d222acd..82872f22f3 100644 --- a/src/api/providers/poe.ts +++ b/src/api/providers/poe.ts @@ -101,6 +101,7 @@ export class PoeHandler extends BaseProvider implements SingleCompletionHandler tools: aiSdkTools, toolChoice: mapToolChoice(metadata?.tool_choice as any), ...(Object.keys(providerOptions).length > 0 && { providerOptions }), + ...(metadata?.abortSignal && { abortSignal: metadata.abortSignal }), }) } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) @@ -134,12 +135,13 @@ export class PoeHandler extends BaseProvider implements SingleCompletionHandler } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { const { id } = this.getModel() try { const { text } = await generateText({ model: this.poe(id), prompt, + ...(metadata?.abortSignal && { abortSignal: metadata.abortSignal }), }) return text } catch (error) { diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index 3e50adf9cc..2684c93209 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -161,7 +161,10 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan let stream try { // With streaming params type, SDK returns an async iterable stream - stream = await this.client.chat.completions.create(completionParams) + stream = await this.client.chat.completions.create( + completionParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -201,7 +204,7 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { const { id: model, maxTokens: max_tokens, temperature } = await this.fetchModel() let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: prompt }] @@ -215,7 +218,10 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan let response: OpenAI.Chat.ChatCompletion try { - response = await this.client.chat.completions.create(completionParams) + response = await this.client.chat.completions.create( + completionParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index a1de7dfa14..7877fad5de 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -149,7 +149,10 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand let stream try { - stream = await this.client.chat.completions.create(completionParams) + stream = await this.client.chat.completions.create( + completionParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -189,7 +192,7 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { const { id: model, maxTokens: max_tokens, temperature } = await this.fetchModel() let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: prompt }] @@ -203,7 +206,10 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand let response: OpenAI.Chat.ChatCompletion try { - response = await this.client.chat.completions.create(completionParams) + response = await this.client.chat.completions.create( + completionParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/vercel-ai-gateway.ts b/src/api/providers/vercel-ai-gateway.ts index 0c7bd1d485..d2670aace4 100644 --- a/src/api/providers/vercel-ai-gateway.ts +++ b/src/api/providers/vercel-ai-gateway.ts @@ -68,7 +68,10 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp parallel_tool_calls: metadata?.parallelToolCalls ?? true, } - const completion = await this.client.chat.completions.create(body) + const completion = await this.client.chat.completions.create( + body, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) for await (const chunk of completion) { // Vercel AI Gateway reports mid-stream failures as an in-band error chunk @@ -117,7 +120,7 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp } } - async completePrompt(prompt: string): Promise { + async completePrompt(prompt: string, metadata?: ApiHandlerCreateMessageMetadata): Promise { const { id: modelId, info } = await this.fetchModel() try { @@ -133,7 +136,10 @@ export class VercelAiGatewayHandler extends RouterProvider implements SingleComp requestOptions.max_completion_tokens = info.maxTokens - const response = await this.client.chat.completions.create(requestOptions) + const response = await this.client.chat.completions.create( + requestOptions, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) return response.choices[0]?.message.content || "" } catch (error) { if (error instanceof Error) { diff --git a/src/api/providers/zai.ts b/src/api/providers/zai.ts index 113cf655d3..ff55cb156a 100644 --- a/src/api/providers/zai.ts +++ b/src/api/providers/zai.ts @@ -108,6 +108,9 @@ export class ZAiHandler extends BaseOpenAiCompatibleProvider { parallel_tool_calls: metadata?.parallelToolCalls ?? true, } - return this.client.chat.completions.create(params) + return this.client.chat.completions.create( + params, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } }