diff --git a/eslint.config.js b/eslint.config.js index 45f5a7e..b887cf9 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -17,6 +17,8 @@ export default [ Request: "readonly", Response: "readonly", URL: "readonly", + TextEncoder: "readonly", + ReadableStream: "readonly", }, }, rules: { diff --git a/index.js b/index.js index c00a412..3c905fb 100644 --- a/index.js +++ b/index.js @@ -178,7 +178,7 @@ export function buildSystemPrompt(messages, request) { .map((message) => message.content) const hints = [ - "You are answering through an OpenAI-compatible proxy backed by OpenCode.", + "You are answering through a proxy backed by OpenCode.", "Return only the assistant's reply content.", ] @@ -268,6 +268,69 @@ async function executePrompt(client, request, model, messages, system) { } } +async function executePromptStreaming(client, model, messages, system, onChunk) { + const tools = await getDisabledTools(client) + const session = await client.session.create({ + body: { title: `Proxy: ${model.id}` }, + }) + const sessionID = session.data.id + const prompt = buildPrompt(messages) + + // Subscribe to the event stream before sending the prompt so we don't miss events. + const { stream } = await client.event.subscribe() + + await client.session.promptAsync({ + path: { id: sessionID }, + body: { + model: { providerID: model.providerID, modelID: model.modelID }, + system, + tools, + parts: [{ type: "text", text: prompt }], + }, + }) + + let errorMessage = null + + for await (const event of stream) { + if (event.type === "message.part.updated") { + const part = event.properties?.part + const delta = event.properties?.delta + if ( + part?.sessionID === sessionID && + part?.type === "text" && + typeof delta === "string" && + delta.length > 0 + ) { + onChunk(delta) + } + } else if (event.type === "session.error") { + if (!event.properties?.sessionID || event.properties.sessionID === sessionID) { + errorMessage = event.properties?.error?.message ?? "Model call failed." + } + } else if (event.type === "session.idle") { + if (event.properties?.sessionID === sessionID) { + break + } + } + } + + if (errorMessage) { + throw new Error(errorMessage) + } + + // Fetch final message to get token usage. + const messages_ = await client.session.messages({ path: { id: sessionID } }) + const assistantMsg = (messages_.data ?? []) + .filter((m) => m.role === "assistant") + .at(-1) + + return { + sessionID, + tokens: assistantMsg?.tokens ?? { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + finish: assistantMsg?.finish, + } +} + function createChatCompletionResponse(result, model) { const now = Math.floor(Date.now() / 1000) return { @@ -424,6 +487,33 @@ export async function resolveModel(client, requestedModel, providerOverride) { throw new Error(`Unknown model '${requestedModel}'. Call GET /v1/models to inspect available IDs.`) } +function sseResponse(corsHeadersObj, generator) { + const encoder = new TextEncoder() + const body = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of generator) { + controller.enqueue(encoder.encode(chunk)) + } + } catch { + // Stream errors are surfaced via SSE data before this point. + } finally { + controller.close() + } + }, + }) + + return new Response(body, { + status: 200, + headers: { + "content-type": "text/event-stream; charset=utf-8", + "cache-control": "no-cache", + connection: "keep-alive", + ...corsHeadersObj, + }, + }) +} + function createModelResponse(models) { return { object: "list", @@ -473,10 +563,6 @@ export function createProxyFetchHandler(client) { return badRequest("Request body must be valid JSON.", 400, request) } - if (body.stream) { - return badRequest("Streaming is not implemented yet.", 400, request) - } - if (!body.model) { return badRequest("The 'model' field is required.", 400, request) } @@ -490,10 +576,111 @@ export function createProxyFetchHandler(client) { return badRequest("No text content was found in the supplied messages.", 400, request) } + let model try { const providerOverride = request.headers.get("x-opencode-provider") - const model = await resolveModel(client, body.model, providerOverride) - const system = buildSystemPrompt(messages, body) + model = await resolveModel(client, body.model, providerOverride) + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + await safeLog(client, "error", "Proxy completion failed", { + error: message, + requestedModel: body.model, + }) + return badRequest(message, 502, request) + } + + const system = buildSystemPrompt(messages, body) + + if (body.stream) { + const completionID = `chatcmpl_${crypto.randomUUID().replace(/-/g, "")}` + const now = Math.floor(Date.now() / 1000) + + const chunks = [] + let resolve = null + let done = false + + function enqueue(value) { + chunks.push(value) + if (resolve) { + const r = resolve + resolve = null + r() + } + } + + async function* generateSse() { + const runPromise = executePromptStreaming( + client, + model, + messages, + system, + (delta) => { + const chunk = JSON.stringify({ + id: completionID, + object: "chat.completion.chunk", + created: now, + model: model.id, + choices: [{ index: 0, delta: { role: "assistant", content: delta }, finish_reason: null }], + }) + enqueue(`data: ${chunk}\n\n`) + }, + ) + .then((streamResult) => { + const finalChunk = JSON.stringify({ + id: completionID, + object: "chat.completion.chunk", + created: now, + model: model.id, + choices: [{ index: 0, delta: {}, finish_reason: mapFinishReason(streamResult.finish) }], + usage: { + prompt_tokens: streamResult.tokens.input, + completion_tokens: streamResult.tokens.output, + total_tokens: streamResult.tokens.input + streamResult.tokens.output, + }, + }) + enqueue(`data: ${finalChunk}\n\ndata: [DONE]\n\n`) + }) + .catch(async (err) => { + const streamError = err instanceof Error ? err.message : String(err) + await safeLog(client, "error", "Proxy streaming completion failed", { + error: streamError, + requestedModel: body.model, + }) + const errChunk = JSON.stringify({ + error: { message: streamError, type: "server_error" }, + }) + enqueue(`data: ${errChunk}\n\ndata: [DONE]\n\n`) + }) + .finally(() => { + done = true + if (resolve) { + const r = resolve + resolve = null + r() + } + }) + + while (true) { + while (chunks.length > 0) { + yield chunks.shift() + } + if (done) break + await new Promise((r) => { + resolve = r + }) + } + // Drain any remaining chunks + while (chunks.length > 0) { + yield chunks.shift() + } + + await runPromise + } + + return sseResponse(corsHeaders(request), generateSse()) + } + + try { const result = await executePrompt(client, body, model, messages, system) return json(createChatCompletionResponse(result, model), 200, {}, request) } catch (error) { @@ -514,10 +701,6 @@ export function createProxyFetchHandler(client) { return badRequest("Request body must be valid JSON.", 400, request) } - if (body.stream) { - return badRequest("Streaming is not implemented yet.", 400, request) - } - if (!body.model) { return badRequest("The 'model' field is required.", 400, request) } @@ -527,19 +710,187 @@ export function createProxyFetchHandler(client) { return badRequest("The 'input' field must contain at least one text message.", 400, request) } + const instructionMessages = + typeof body.instructions === "string" && body.instructions.trim() + ? [{ role: "system", content: body.instructions.trim() }, ...messages] + : messages + + const system = buildSystemPrompt(instructionMessages, { + temperature: body.temperature, + max_tokens: body.max_output_tokens, + max_completion_tokens: body.max_output_tokens, + }) + + let model try { const providerOverride = request.headers.get("x-opencode-provider") - const model = await resolveModel(client, body.model, providerOverride) - const system = buildSystemPrompt( - typeof body.instructions === "string" && body.instructions.trim() - ? [{ role: "system", content: body.instructions.trim() }, ...messages] - : messages, - { - temperature: body.temperature, - max_tokens: body.max_output_tokens, - max_completion_tokens: body.max_output_tokens, - }, - ) + model = await resolveModel(client, body.model, providerOverride) + } catch (error) { + const message = error instanceof Error ? error.message : String(error) + await safeLog(client, "error", "Proxy responses call failed", { + error: message, + requestedModel: body.model, + }) + return badRequest(message, 502, request) + } + + if (body.stream) { + const responseID = `resp_${crypto.randomUUID().replace(/-/g, "")}` + const itemID = `msg_${crypto.randomUUID().replace(/-/g, "")}` + const now = Math.floor(Date.now() / 1000) + + const chunks = [] + let resolve = null + let done = false + + function enqueue(value) { + chunks.push(value) + if (resolve) { + const r = resolve + resolve = null + r() + } + } + + function sseEvent(eventType, data) { + return `event: ${eventType}\ndata: ${JSON.stringify(data)}\n\n` + } + + async function* generateSse() { + enqueue( + sseEvent("response.created", { + type: "response.created", + response: { + id: responseID, + object: "response", + created_at: now, + status: "in_progress", + model: model.id, + output: [], + }, + }), + ) + enqueue( + sseEvent("response.output_item.added", { + type: "response.output_item.added", + output_index: 0, + item: { id: itemID, type: "message", status: "in_progress", role: "assistant", content: [] }, + }), + ) + + let partIndex = 0 + const runPromise = executePromptStreaming( + client, + model, + messages, + system, + (delta) => { + if (partIndex === 0) { + enqueue( + sseEvent("response.content_part.added", { + type: "response.content_part.added", + item_id: itemID, + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "", annotations: [] }, + }), + ) + partIndex++ + } + enqueue( + sseEvent("response.output_text.delta", { + type: "response.output_text.delta", + item_id: itemID, + output_index: 0, + content_index: 0, + delta, + }), + ) + }, + ) + .then((streamResult) => { + enqueue( + sseEvent("response.output_text.done", { + type: "response.output_text.done", + item_id: itemID, + output_index: 0, + content_index: 0, + text: "", + }), + ) + enqueue( + sseEvent("response.output_item.done", { + type: "response.output_item.done", + output_index: 0, + item: { id: itemID, type: "message", status: "completed", role: "assistant" }, + }), + ) + enqueue( + sseEvent("response.completed", { + type: "response.completed", + response: { + id: responseID, + object: "response", + created_at: now, + status: "completed", + model: model.id, + usage: { + input_tokens: streamResult.tokens.input, + output_tokens: streamResult.tokens.output, + total_tokens: streamResult.tokens.input + streamResult.tokens.output, + }, + }, + }), + ) + }) + .catch(async (err) => { + const errMsg = err instanceof Error ? err.message : String(err) + await safeLog(client, "error", "Proxy streaming responses call failed", { + error: errMsg, + requestedModel: body.model, + }) + enqueue( + sseEvent("response.failed", { + type: "response.failed", + response: { + id: responseID, + object: "response", + created_at: now, + status: "failed", + error: { message: errMsg, code: "server_error" }, + }, + }), + ) + }) + .finally(() => { + done = true + if (resolve) { + const r = resolve + resolve = null + r() + } + }) + + while (true) { + while (chunks.length > 0) { + yield chunks.shift() + } + if (done) break + await new Promise((r) => { + resolve = r + }) + } + while (chunks.length > 0) { + yield chunks.shift() + } + + await runPromise + } + + return sseResponse(corsHeaders(request), generateSse()) + } + + try { const result = await executePrompt(client, body, model, messages, system) return json(createResponsesApiResponse(result, model), 200, {}, request) } catch (error) { diff --git a/index.test.js b/index.test.js index e4241f7..264f53c 100644 --- a/index.test.js +++ b/index.test.js @@ -32,6 +32,47 @@ function createClient() { } } +function createStreamingClient(chunks) { + async function* makeStream() { + for (const chunk of chunks) { + yield chunk + } + } + + return { + app: { log: async () => {} }, + tool: { ids: async () => ({ data: [] }) }, + config: { + providers: async () => ({ + data: { + providers: [ + { + id: "openai", + models: { "gpt-4o": { id: "gpt-4o", name: "GPT-4o" } }, + }, + ], + }, + }), + }, + session: { + create: async () => ({ data: { id: "sess-123" } }), + promptAsync: async () => {}, + messages: async () => ({ + data: [ + { + role: "assistant", + tokens: { input: 10, output: 5, reasoning: 0, cache: { read: 0, write: 0 } }, + finish: "end_turn", + }, + ], + }), + }, + event: { + subscribe: async () => ({ stream: makeStream() }), + }, + } +} + test("OPTIONS preflight returns CORS headers", async () => { const handler = createProxyFetchHandler(createClient()) const request = new Request("http://127.0.0.1:4010/v1/models", { @@ -260,8 +301,26 @@ test("missing messages field returns 400", async () => { assert.ok(body.error.message.includes("messages")) }) -test("stream: true returns 400 (not implemented)", async () => { - const handler = createProxyFetchHandler(createClient()) +test("stream: true returns SSE response", async () => { + const events = [ + { + type: "message.part.updated", + properties: { + part: { sessionID: "sess-123", type: "text" }, + delta: "Hello", + }, + }, + { + type: "message.part.updated", + properties: { + part: { sessionID: "sess-123", type: "text" }, + delta: " world", + }, + }, + { type: "session.idle", properties: { sessionID: "sess-123" } }, + ] + + const handler = createProxyFetchHandler(createStreamingClient(events)) const request = new Request("http://127.0.0.1:4010/v1/chat/completions", { method: "POST", headers: { "content-type": "application/json" }, @@ -272,11 +331,67 @@ test("stream: true returns 400 (not implemented)", async () => { }), }) + const response = await handler(request) + + assert.equal(response.status, 200) + assert.ok(response.headers.get("content-type")?.includes("text/event-stream")) + + const text = await response.text() + assert.ok(text.includes("chat.completion.chunk")) + assert.ok(text.includes("Hello")) + assert.ok(text.includes(" world")) + assert.ok(text.includes("[DONE]")) +}) + +test("stream: true with unknown model returns 502", async () => { + const handler = createProxyFetchHandler(createClient()) // no providers + const request = new Request("http://127.0.0.1:4010/v1/chat/completions", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + model: "nonexistent-model", + stream: true, + messages: [{ role: "user", content: "hi" }], + }), + }) + const response = await handler(request) const body = await response.json() - assert.equal(response.status, 400) - assert.ok(body.error.message.toLowerCase().includes("stream")) + assert.equal(response.status, 502) + assert.ok(body.error.message.includes("nonexistent-model")) +}) + +test("stream: true propagates session.error into the SSE stream", async () => { + const events = [ + { + type: "session.error", + properties: { + sessionID: "sess-123", + error: { message: "Model overloaded" }, + }, + }, + { type: "session.idle", properties: { sessionID: "sess-123" } }, + ] + + const handler = createProxyFetchHandler(createStreamingClient(events)) + const request = new Request("http://127.0.0.1:4010/v1/chat/completions", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + model: "gpt-4o", + stream: true, + messages: [{ role: "user", content: "hi" }], + }), + }) + + const response = await handler(request) + assert.equal(response.status, 200) + assert.ok(response.headers.get("content-type")?.includes("text/event-stream")) + + const text = await response.text() + assert.ok(text.includes("server_error") || text.includes("Model overloaded")) + assert.ok(text.includes("[DONE]")) }) test("unknown model returns 502", async () => { @@ -446,7 +561,7 @@ describe("buildSystemPrompt", () => { it("always includes the proxy hint lines", () => { const result = buildSystemPrompt([], {}) - assert.ok(result.includes("OpenAI-compatible proxy")) + assert.ok(result.includes("proxy backed by OpenCode")) assert.ok(result.includes("Return only the assistant")) })