From 64055fcdae0c8697574ec9e516ceedec4eb246ee Mon Sep 17 00:00:00 2001 From: Roo Code Date: Tue, 24 Mar 2026 04:10:33 +0000 Subject: [PATCH] fix: enable prompt caching for AWS Bedrock custom ARN models Add cachableFields to guessModelInfoFromId() for Claude model patterns so that custom ARN models matching known Claude patterns get proper caching metadata. Update supportsAwsPromptCache() to respect the user's explicit opt-in for custom ARN models, even when the model ID from the ARN is not recognized. When the Bedrock API receives cache points for a model that does not support caching, it simply ignores them without error. Fixes #11983 --- src/api/providers/__tests__/bedrock.spec.ts | 95 +++++++++++++++++++++ src/api/providers/bedrock.ts | 35 ++++++-- 2 files changed, 124 insertions(+), 6 deletions(-) diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index 975e38af123..97f3b62b3e7 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -1327,4 +1327,99 @@ describe("AwsBedrockHandler", () => { expect(hasCachePoint).toBe(false) }) }) + + describe("prompt caching with custom ARN", () => { + beforeEach(() => { + mockConverseStreamCommand.mockReset() + }) + + // System prompt must exceed minTokensPerCachePoint (1024) for cache points to be placed + const longSystemPrompt = "You are a helpful assistant. ".repeat(200) + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + it("should enable prompt caching for custom ARN with recognized Claude model ID", async () => { + // Custom ARN containing a Claude model ID that matches the guess pattern + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom", + }) + + const generator = customArnHandler.createMessage(longSystemPrompt, messages) + await generator.next() + + expect(mockConverseStreamCommand).toHaveBeenCalled() + const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + + // System content should include a cachePoint since prompt caching should work + const systemBlocks = commandArg.system + const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined) + expect(hasCachePoint).toBe(true) + }) + + it("should enable prompt caching for custom ARN with unrecognized model ID when user opts in", async () => { + // Custom ARN with an opaque model ID that doesn't match any pattern + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsUsePromptCache: true, + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:provisioned-model/my-custom-model-xyz", + }) + + const generator = customArnHandler.createMessage(longSystemPrompt, messages) + await generator.next() + + expect(mockConverseStreamCommand).toHaveBeenCalled() + const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + + // System content should include a cachePoint since user explicitly enabled caching + const systemBlocks = commandArg.system + const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined) + expect(hasCachePoint).toBe(true) + }) + + it("should disable prompt caching for custom ARN when user explicitly disables it", async () => { + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsUsePromptCache: false, + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom", + }) + + const generator = customArnHandler.createMessage(longSystemPrompt, messages) + await generator.next() + + expect(mockConverseStreamCommand).toHaveBeenCalled() + const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + + // System content should NOT include cachePoint since user explicitly disabled caching + const systemBlocks = commandArg.system + const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined) + expect(hasCachePoint).toBe(false) + }) + + it("should include cachableFields in guessModelInfoFromId for Claude patterns", () => { + // Test with a custom ARN that has a Claude model ID in it + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom", + }) + + const modelConfig = customArnHandler.getModel() + expect(modelConfig.info.supportsPromptCache).toBe(true) + expect((modelConfig.info as any).cachableFields).toBeDefined() + expect((modelConfig.info as any).cachableFields).toContain("system") + expect((modelConfig.info as any).cachableFields).toContain("messages") + expect((modelConfig.info as any).cachableFields).toContain("tools") + }) + }) }) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 3ceb2510033..5101be8f071 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -295,36 +295,42 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-7": { maxTokens: 8192, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-5": { maxTokens: 8192, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-4-opus": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-opus": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-haiku": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, } @@ -1172,12 +1178,29 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH // Check if the model supports prompt cache // The cachableFields property is not part of the ModelInfo type in schemas // but it's used in the bedrockModels object in shared/api.ts - return ( - modelConfig?.info?.supportsPromptCache && - // Use optional chaining and type assertion to access cachableFields - (modelConfig?.info as any)?.cachableFields && - (modelConfig?.info as any)?.cachableFields?.length > 0 - ) + const hasCachableFields = + (modelConfig?.info as any)?.cachableFields && (modelConfig?.info as any)?.cachableFields?.length > 0 + + if (modelConfig?.info?.supportsPromptCache && hasCachableFields) { + return true + } + + // When using a custom ARN and the user has enabled prompt caching (or left it + // at the default), respect their intent even if the model info is incomplete. + // The model info may lack cachableFields or supportsPromptCache when the model + // ID extracted from the ARN doesn't match a known model in bedrockModels. + // In this case, inject defaults so the downstream caching logic works correctly. + // If the underlying model truly does not support caching, the Bedrock API + // simply ignores cache points without erroring. + if (this.options.awsCustomArn && this.options.awsUsePromptCache !== false) { + if (!hasCachableFields) { + ;(modelConfig.info as any).cachableFields = ["system", "messages", "tools"] + } + modelConfig.info.supportsPromptCache = true + return true + } + + return false } /**