diff --git a/README.md b/README.md index 9433263..6fc1b29 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,53 @@ const result = await rlm.completion( The LLM will know it can access `context.users`, `context.settings`, etc. with full type awareness. +### Structured Output with Zod (`generateObject`) + +If you want schema-validated JSON output directly (without REPL/code execution), use `generateObject`. +RLLM will retry when output is invalid JSON or fails Zod validation. + +```typescript +import { z } from 'zod'; +import { createRLLM } from 'rllm'; + +const rlm = createRLLM({ model: 'gpt-4o-mini' }); + +const OutputSchema = z.object({ + summary: z.string(), + keyPoints: z.array(z.string()), + confidence: z.number().min(0).max(1), +}); +const InputSchema = z.object({ + reportText: z.string(), + locale: z.string(), +}); + +const result = await rlm.generateObject( + "Summarize this report and provide key points with confidence", + { + input: { + reportText: hugeDocument, + locale: "en-US", + }, + inputSchema: InputSchema, + outputSchema: OutputSchema, + }, + { + maxRetries: 2, // total attempts = 3 + onRetry: (event) => { + console.log(`Retry ${event.attempt}/${event.maxRetries + 1}: ${event.errorType}`); + }, + } +); + +console.log(result.object.summary); +console.log(result.attempts, result.usage.tokenUsage.totalTokens); +``` + +`generateObject` differs from `completion()`: +- `generateObject` asks for one JSON object and validates it against your schema. +- `completion()` runs the full recursive REPL workflow where the model writes and executes JS code. + The LLM will write code like: ```javascript // LLM-generated code runs in V8 isolate @@ -153,6 +200,7 @@ Defaults: | Method | Description | |--------|-------------| | `rlm.completion(prompt, options)` | Full RLM completion with code execution | +| `rlm.generateObject(prompt, { input?, inputSchema?, outputSchema }, options?)` | Structured output with Zod validation + retries | | `rlm.chat(messages)` | Direct LLM chat | | `rlm.getClient()` | Get underlying LLM client | @@ -164,6 +212,23 @@ Defaults: | `context` | `string \| T` | The context data available to LLM-generated code | | `contextSchema` | `ZodType` | Optional Zod schema describing context structure | +### `GenerateObjectOptions` + +| Option | Type | Description | +|--------|------|-------------| +| `maxRetries` | `number` | Retries after first attempt (default `2`) | +| `temperature` | `number` | Optional generation temperature | +| `maxTokens` | `number` | Optional max completion tokens | +| `onRetry` | `(event) => void` | Called when parse/validation fails and a retry is scheduled | + +### `GenerateObject` schema config + +| Field | Type | Description | +|-------|------|-------------| +| `input` | `TInput` | Optional structured input value | +| `inputSchema` | `ZodType` | Optional input schema used for pre-validation + prompt typing | +| `outputSchema` | `ZodType` | Required output schema used for retry validation | + ### Sandbox Bindings The V8 isolate provides these bindings to LLM-generated code: diff --git a/src/index.ts b/src/index.ts index 7efa633..d57646c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -57,4 +57,10 @@ export type { RLMEventType, RLMEvent, RLMEventCallback, + GenerateObjectErrorType, + GenerateObjectRetryEvent, + GenerateObjectOptions, + GenerateObjectSchemas, + GenerateObjectUsage, + GenerateObjectResult, } from "./types.js"; diff --git a/src/prompts.ts b/src/prompts.ts index 961b1d8..e1bb52e 100644 --- a/src/prompts.ts +++ b/src/prompts.ts @@ -6,6 +6,7 @@ */ import type { ZodType } from "zod"; +import type { GenerateObjectRetryEvent } from "./types.js"; /** * Main RLM system prompt - instructs the LLM on how to use the REPL environment @@ -347,3 +348,65 @@ export function buildUserPrompt( return { role: "user", content }; } + +/** + * Build messages for schema-constrained JSON generation. + */ +export function buildGenerateObjectMessages( + prompt: string, + outputSchema: ZodType, + input?: unknown, + inputSchema?: ZodType +): Array<{ role: "system" | "user"; content: string }> { + const outputSchemaDescription = zodSchemaToTypeDescription(outputSchema); + const inputSchemaDescription = inputSchema ? zodSchemaToTypeDescription(inputSchema) : null; + const inputSection = input !== undefined + ? ( + "Input value (JSON):\n" + + `\`\`\`json\n${JSON.stringify(input, null, 2)}\n\`\`\`\n\n` + ) + : ""; + const inputTypeSection = inputSchemaDescription + ? ( + "Input TypeScript type:\n" + + `\`\`\`typescript\ntype Input = ${inputSchemaDescription}\n\`\`\`\n\n` + ) + : ""; + + return [ + { + role: "system", + content: + "You generate structured data. Return exactly one valid JSON object that matches the provided schema. " + + "Do not include markdown, code fences, comments, or any extra text before/after the JSON.", + }, + { + role: "user", + content: + `Task:\n${prompt}\n\n` + + inputTypeSection + + inputSection + + "Target TypeScript type:\n" + + `\`\`\`typescript\ntype Output = ${outputSchemaDescription}\n\`\`\`\n\n` + + "Return only the JSON object.", + }, + ]; +} + +/** + * Build retry feedback after a failed parse/validation attempt. + */ +export function buildGenerateObjectRetryPrompt( + event: GenerateObjectRetryEvent +): { role: "user"; content: string } { + const issues = event.validationIssues?.length + ? `\nValidation issues:\n- ${event.validationIssues.join("\n- ")}` + : ""; + + return { + role: "user", + content: + `Previous attempt ${event.attempt} failed (${event.errorType}): ${event.errorMessage}.${issues}\n\n` + + "Please try again and return only one corrected JSON object with no surrounding text.", + }; +} diff --git a/src/rlm.test.ts b/src/rlm.test.ts new file mode 100644 index 0000000..4c5f102 --- /dev/null +++ b/src/rlm.test.ts @@ -0,0 +1,190 @@ +import { describe, it, expect, vi } from "vitest"; +import { z } from "zod"; +import { RLLM } from "./rlm.js"; +import type { ChatMessage, TokenUsage } from "./types.js"; + +interface MockCompletionResponse { + content: string; + usage: TokenUsage; +} + +function createTestRLLMWithMock(responses: MockCompletionResponse[]): { + rllm: RLLM; + completeMock: ReturnType; +} { + const rllm = new RLLM({ + client: { + provider: "openai", + model: "gpt-4o-mini", + apiKey: "test-key", + }, + }); + + const completeMock = vi.fn().mockImplementation(async () => { + const next = responses.shift(); + if (!next) { + throw new Error("No mock response configured"); + } + return { + message: { role: "assistant", content: next.content }, + usage: next.usage, + finishReason: "stop", + }; + }); + + ( + rllm as unknown as { + client: { + complete: (options: { messages: ChatMessage[] }) => Promise; + }; + } + ).client = { complete: completeMock }; + + return { rllm, completeMock }; +} + +describe("RLLM.generateObject", () => { + it("returns typed object on first valid attempt", async () => { + const outputSchema = z.object({ + name: z.string(), + count: z.number(), + }); + const inputSchema = z.object({ + report: z.string(), + }); + + const { rllm } = createTestRLLMWithMock([ + { + content: '{"name":"ok","count":3}', + usage: { promptTokens: 11, completionTokens: 7, totalTokens: 18 }, + }, + ]); + + const result = await rllm.generateObject( + "Generate object", + { + input: { report: "hello" }, + inputSchema, + outputSchema, + } + ); + + expect(result.object).toEqual({ name: "ok", count: 3 }); + expect(result.attempts).toBe(1); + expect(result.rawResponse).toBe('{"name":"ok","count":3}'); + expect(result.usage.totalCalls).toBe(1); + expect(result.usage.tokenUsage).toEqual({ + promptTokens: 11, + completionTokens: 7, + totalTokens: 18, + }); + }); + + it("retries after invalid JSON and succeeds", async () => { + const outputSchema = z.object({ + city: z.string(), + }); + const onRetry = vi.fn(); + + const { rllm, completeMock } = createTestRLLMWithMock([ + { + content: '{"city":"Tel Aviv"', + usage: { promptTokens: 5, completionTokens: 4, totalTokens: 9 }, + }, + { + content: '{"city":"Tel Aviv"}', + usage: { promptTokens: 6, completionTokens: 4, totalTokens: 10 }, + }, + ]); + + const result = await rllm.generateObject("Return city", { outputSchema }, { + maxRetries: 2, + onRetry, + }); + + expect(result.object).toEqual({ city: "Tel Aviv" }); + expect(result.attempts).toBe(2); + expect(result.usage.totalCalls).toBe(2); + expect(result.usage.tokenUsage).toEqual({ + promptTokens: 11, + completionTokens: 8, + totalTokens: 19, + }); + + expect(onRetry).toHaveBeenCalledTimes(1); + expect(onRetry.mock.calls[0]?.[0].errorType).toBe("json_parse"); + expect(completeMock).toHaveBeenCalledTimes(2); + }); + + it("retries after schema mismatch and succeeds", async () => { + const outputSchema = z.object({ + status: z.enum(["ok", "error"]), + count: z.number(), + }); + const onRetry = vi.fn(); + + const { rllm } = createTestRLLMWithMock([ + { + content: '{"status":"ok","count":"3"}', + usage: { promptTokens: 8, completionTokens: 5, totalTokens: 13 }, + }, + { + content: '{"status":"ok","count":3}', + usage: { promptTokens: 9, completionTokens: 5, totalTokens: 14 }, + }, + ]); + + const result = await rllm.generateObject("Return status and count", { outputSchema }, { + maxRetries: 2, + onRetry, + }); + + expect(result.object).toEqual({ status: "ok", count: 3 }); + expect(result.attempts).toBe(2); + expect(onRetry).toHaveBeenCalledTimes(1); + expect(onRetry.mock.calls[0]?.[0].errorType).toBe("schema_validation"); + expect(onRetry.mock.calls[0]?.[0].validationIssues?.length).toBeGreaterThan(0); + }); + + it("throws actionable error after exhausting retries", async () => { + const outputSchema = z.object({ + id: z.string(), + }); + + const { rllm } = createTestRLLMWithMock([ + { + content: '{"id":123}', + usage: { promptTokens: 3, completionTokens: 2, totalTokens: 5 }, + }, + { + content: '{"id":456}', + usage: { promptTokens: 3, completionTokens: 2, totalTokens: 5 }, + }, + ]); + + await expect( + rllm.generateObject("Return id", { outputSchema }, { maxRetries: 1 }) + ).rejects.toThrow(/generateObject failed after 2 attempt/); + }); + + it("fails fast when input does not satisfy inputSchema", async () => { + const outputSchema = z.object({ answer: z.string() }); + const inputSchema = z.object({ age: z.number() }); + const { rllm, completeMock } = createTestRLLMWithMock([ + { + content: '{"answer":"ok"}', + usage: { promptTokens: 1, completionTokens: 1, totalTokens: 2 }, + }, + ]); + + await expect( + rllm.generateObject("Use input", { + input: { age: "not-a-number" } as unknown as { age: number }, + inputSchema, + outputSchema, + }) + ).rejects.toThrow(/input failed inputSchema validation/); + + expect(completeMock).toHaveBeenCalledTimes(0); + }); +}); diff --git a/src/rlm.ts b/src/rlm.ts index a7b518b..022558c 100644 --- a/src/rlm.ts +++ b/src/rlm.ts @@ -9,8 +9,26 @@ import { LLMClient, type LLMClientOptions } from "./llm-client.js"; import { Sandbox, type SandboxResult } from "./sandbox.js"; import { findCodeBlocks, formatIteration, formatExecutionResult } from "./parsing.js"; import type { FinalAnswer } from "./sandbox.js"; -import { RLM_SYSTEM_PROMPT, buildSystemPrompt, buildUserPrompt, zodSchemaToTypeDescription } from "./prompts.js"; -import type { ChatMessage, TokenUsage, RLMResult, RLMTraceEntry, RLMEventCallback, RLMEvent } from "./types.js"; +import { + RLM_SYSTEM_PROMPT, + buildSystemPrompt, + buildUserPrompt, + zodSchemaToTypeDescription, + buildGenerateObjectMessages, + buildGenerateObjectRetryPrompt, +} from "./prompts.js"; +import type { + ChatMessage, + TokenUsage, + RLMResult, + RLMTraceEntry, + RLMEventCallback, + RLMEvent, + GenerateObjectOptions, + GenerateObjectResult, + GenerateObjectRetryEvent, + GenerateObjectSchemas, +} from "./types.js"; // ============================================================================ // Configuration @@ -302,6 +320,162 @@ export class RLLM { }; } + /** + * Generate structured output that must satisfy the provided Zod schema. + * Retries with validation feedback until a valid object is produced. + */ + async generateObject( + prompt: string, + schemas: GenerateObjectSchemas, + options?: GenerateObjectOptions + ): Promise>; + async generateObject( + prompt: string, + schemas: GenerateObjectSchemas, + options: GenerateObjectOptions = {} + ): Promise> { + const startTime = Date.now(); + const maxRetries = options.maxRetries ?? 2; + + if (maxRetries < 0 || !Number.isInteger(maxRetries)) { + throw new Error("generateObject: maxRetries must be a non-negative integer"); + } + + if (schemas.inputSchema && schemas.input !== undefined) { + const inputValidation = schemas.inputSchema.safeParse(schemas.input); + if (!inputValidation.success) { + const issues = inputValidation.error.issues + .map((issue) => { + const path = issue.path.length > 0 ? issue.path.join(".") : ""; + return `${path}: ${issue.message}`; + }) + .join("; "); + throw new Error(`generateObject: input failed inputSchema validation - ${issues}`); + } + } + + if (!schemas.outputSchema) { + throw new Error("generateObject: outputSchema is required"); + } + + const messages = buildGenerateObjectMessages( + prompt, + schemas.outputSchema, + schemas.input, + schemas.inputSchema + ) as ChatMessage[]; + const totalUsage: TokenUsage = { promptTokens: 0, completionTokens: 0, totalTokens: 0 }; + + let attempts = 0; + let lastResponse = ""; + let lastError: GenerateObjectRetryEvent | null = null; + + for (let attempt = 1; attempt <= maxRetries + 1; attempt++) { + attempts = attempt; + + const llmResult = await this.client.complete({ + messages, + temperature: options.temperature, + maxTokens: options.maxTokens, + }); + + totalUsage.promptTokens += llmResult.usage.promptTokens; + totalUsage.completionTokens += llmResult.usage.completionTokens; + totalUsage.totalTokens += llmResult.usage.totalTokens; + + lastResponse = llmResult.message.content; + + const parseResult = this.parseJSONObject(lastResponse); + if (!parseResult.success) { + lastError = { + attempt, + maxRetries, + rawResponse: lastResponse, + errorType: "json_parse", + errorMessage: parseResult.error, + }; + + if (attempt <= maxRetries) { + options.onRetry?.(lastError); + messages.push({ role: "assistant", content: lastResponse }); + messages.push(buildGenerateObjectRetryPrompt(lastError)); + continue; + } + + throw new Error(this.buildGenerateObjectErrorMessage(lastError)); + } + + const validated = schemas.outputSchema.safeParse(parseResult.value); + if (validated.success) { + return { + object: validated.data, + attempts, + rawResponse: lastResponse, + usage: { + totalCalls: attempts, + tokenUsage: totalUsage, + executionTimeMs: Date.now() - startTime, + }, + }; + } + + const issues = validated.error.issues.map((issue) => { + const path = issue.path.length > 0 ? issue.path.join(".") : ""; + return `${path}: ${issue.message}`; + }); + + lastError = { + attempt, + maxRetries, + rawResponse: lastResponse, + errorType: "schema_validation", + errorMessage: "Response did not match schema", + validationIssues: issues, + }; + + if (attempt <= maxRetries) { + options.onRetry?.(lastError); + messages.push({ role: "assistant", content: lastResponse }); + messages.push(buildGenerateObjectRetryPrompt(lastError)); + continue; + } + + throw new Error(this.buildGenerateObjectErrorMessage(lastError)); + } + + throw new Error("generateObject: reached an unexpected end state"); + } + + private parseJSONObject(text: string): { success: true; value: unknown } | { success: false; error: string } { + const raw = text.trim(); + const fencedMatch = raw.match(/^```(?:json)?\s*([\s\S]*?)\s*```$/i); + const candidate = fencedMatch ? fencedMatch[1]! : raw; + + let parsed: unknown; + try { + parsed = JSON.parse(candidate); + } catch (error: unknown) { + const message = error instanceof Error ? error.message : String(error); + return { success: false, error: `Invalid JSON: ${message}` }; + } + + if (typeof parsed !== "object" || parsed === null || Array.isArray(parsed)) { + return { success: false, error: "JSON root must be an object" }; + } + + return { success: true, value: parsed }; + } + + private buildGenerateObjectErrorMessage(event: GenerateObjectRetryEvent): string { + const issues = event.validationIssues?.length + ? `\nValidation issues:\n- ${event.validationIssues.join("\n- ")}` + : ""; + return ( + `generateObject failed after ${event.attempt} attempt(s): ${event.errorType} - ${event.errorMessage}.` + + `${issues}\nLast response:\n${event.rawResponse}` + ); + } + private addUsage(a: TokenUsage, b: TokenUsage): TokenUsage { return { promptTokens: a.promptTokens + b.promptTokens, diff --git a/src/types.ts b/src/types.ts index 442f21a..451b874 100644 --- a/src/types.ts +++ b/src/types.ts @@ -123,6 +123,70 @@ export interface RLMEvent { export type RLMEventCallback = (event: RLMEvent) => void; +// ============================================================================ +// Structured Output Types +// ============================================================================ + +export type GenerateObjectErrorType = "json_parse" | "schema_validation"; + +export interface GenerateObjectRetryEvent { + attempt: number; + maxRetries: number; + rawResponse: string; + errorType: GenerateObjectErrorType; + errorMessage: string; + validationIssues?: string[]; +} + +export interface GenerateObjectOptions { + /** + * Number of retries after the first attempt (default: 2). + * Total attempts = maxRetries + 1. + */ + maxRetries?: number; + /** + * Optional temperature for object generation requests. + */ + temperature?: number; + /** + * Optional max tokens for object generation requests. + */ + maxTokens?: number; + /** + * Callback fired whenever an attempt fails and a retry is scheduled. + */ + onRetry?: (event: GenerateObjectRetryEvent) => void; +} + +export interface GenerateObjectSchemas { + /** + * Optional structured input data available to generation. + */ + input?: TInput; + /** + * Optional schema describing the input structure. + * If input is provided, this schema is used for pre-validation. + */ + inputSchema?: ZodType; + /** + * Required schema for output validation. + */ + outputSchema: ZodType; +} + +export interface GenerateObjectUsage { + totalCalls: number; + tokenUsage: TokenUsage; + executionTimeMs: number; +} + +export interface GenerateObjectResult { + object: T; + usage: GenerateObjectUsage; + attempts: number; + rawResponse: string; +} + // ============================================================================ // Chunking Types // ============================================================================