diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index 975e38af12..97f3b62b3e 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -1327,4 +1327,99 @@ describe("AwsBedrockHandler", () => { expect(hasCachePoint).toBe(false) }) }) + + describe("prompt caching with custom ARN", () => { + beforeEach(() => { + mockConverseStreamCommand.mockReset() + }) + + // System prompt must exceed minTokensPerCachePoint (1024) for cache points to be placed + const longSystemPrompt = "You are a helpful assistant. ".repeat(200) + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + it("should enable prompt caching for custom ARN with recognized Claude model ID", async () => { + // Custom ARN containing a Claude model ID that matches the guess pattern + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom", + }) + + const generator = customArnHandler.createMessage(longSystemPrompt, messages) + await generator.next() + + expect(mockConverseStreamCommand).toHaveBeenCalled() + const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + + // System content should include a cachePoint since prompt caching should work + const systemBlocks = commandArg.system + const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined) + expect(hasCachePoint).toBe(true) + }) + + it("should enable prompt caching for custom ARN with unrecognized model ID when user opts in", async () => { + // Custom ARN with an opaque model ID that doesn't match any pattern + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsUsePromptCache: true, + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:provisioned-model/my-custom-model-xyz", + }) + + const generator = customArnHandler.createMessage(longSystemPrompt, messages) + await generator.next() + + expect(mockConverseStreamCommand).toHaveBeenCalled() + const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + + // System content should include a cachePoint since user explicitly enabled caching + const systemBlocks = commandArg.system + const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined) + expect(hasCachePoint).toBe(true) + }) + + it("should disable prompt caching for custom ARN when user explicitly disables it", async () => { + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsUsePromptCache: false, + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom", + }) + + const generator = customArnHandler.createMessage(longSystemPrompt, messages) + await generator.next() + + expect(mockConverseStreamCommand).toHaveBeenCalled() + const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any + + // System content should NOT include cachePoint since user explicitly disabled caching + const systemBlocks = commandArg.system + const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined) + expect(hasCachePoint).toBe(false) + }) + + it("should include cachableFields in guessModelInfoFromId for Claude patterns", () => { + // Test with a custom ARN that has a Claude model ID in it + const customArnHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test-access-key", + awsSecretKey: "test-secret-key", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom", + }) + + const modelConfig = customArnHandler.getModel() + expect(modelConfig.info.supportsPromptCache).toBe(true) + expect((modelConfig.info as any).cachableFields).toBeDefined() + expect((modelConfig.info as any).cachableFields).toContain("system") + expect((modelConfig.info as any).cachableFields).toContain("messages") + expect((modelConfig.info as any).cachableFields).toContain("tools") + }) + }) }) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 3ceb251003..5101be8f07 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -295,36 +295,42 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-7": { maxTokens: 8192, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-5": { maxTokens: 8192, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-4-opus": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-opus": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, "claude-3-haiku": { maxTokens: 4096, contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + cachableFields: ["system", "messages", "tools"], }, } @@ -1172,12 +1178,29 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH // Check if the model supports prompt cache // The cachableFields property is not part of the ModelInfo type in schemas // but it's used in the bedrockModels object in shared/api.ts - return ( - modelConfig?.info?.supportsPromptCache && - // Use optional chaining and type assertion to access cachableFields - (modelConfig?.info as any)?.cachableFields && - (modelConfig?.info as any)?.cachableFields?.length > 0 - ) + const hasCachableFields = + (modelConfig?.info as any)?.cachableFields && (modelConfig?.info as any)?.cachableFields?.length > 0 + + if (modelConfig?.info?.supportsPromptCache && hasCachableFields) { + return true + } + + // When using a custom ARN and the user has enabled prompt caching (or left it + // at the default), respect their intent even if the model info is incomplete. + // The model info may lack cachableFields or supportsPromptCache when the model + // ID extracted from the ARN doesn't match a known model in bedrockModels. + // In this case, inject defaults so the downstream caching logic works correctly. + // If the underlying model truly does not support caching, the Bedrock API + // simply ignores cache points without erroring. + if (this.options.awsCustomArn && this.options.awsUsePromptCache !== false) { + if (!hasCachableFields) { + ;(modelConfig.info as any).cachableFields = ["system", "messages", "tools"] + } + modelConfig.info.supportsPromptCache = true + return true + } + + return false } /**