Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/api/providers/__tests__/bedrock.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1327,4 +1327,99 @@ describe("AwsBedrockHandler", () => {
expect(hasCachePoint).toBe(false)
})
})

describe("prompt caching with custom ARN", () => {
beforeEach(() => {
mockConverseStreamCommand.mockReset()
})

// System prompt must exceed minTokensPerCachePoint (1024) for cache points to be placed
const longSystemPrompt = "You are a helpful assistant. ".repeat(200)
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

it("should enable prompt caching for custom ARN with recognized Claude model ID", async () => {
// Custom ARN containing a Claude model ID that matches the guess pattern
const customArnHandler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom",
})

const generator = customArnHandler.createMessage(longSystemPrompt, messages)
await generator.next()

expect(mockConverseStreamCommand).toHaveBeenCalled()
const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any

// System content should include a cachePoint since prompt caching should work
const systemBlocks = commandArg.system
const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined)
expect(hasCachePoint).toBe(true)
})

it("should enable prompt caching for custom ARN with unrecognized model ID when user opts in", async () => {
// Custom ARN with an opaque model ID that doesn't match any pattern
const customArnHandler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsUsePromptCache: true,
awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:provisioned-model/my-custom-model-xyz",
})

const generator = customArnHandler.createMessage(longSystemPrompt, messages)
await generator.next()

expect(mockConverseStreamCommand).toHaveBeenCalled()
const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any

// System content should include a cachePoint since user explicitly enabled caching
const systemBlocks = commandArg.system
const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined)
expect(hasCachePoint).toBe(true)
})

it("should disable prompt caching for custom ARN when user explicitly disables it", async () => {
const customArnHandler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsUsePromptCache: false,
awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom",
})

const generator = customArnHandler.createMessage(longSystemPrompt, messages)
await generator.next()

expect(mockConverseStreamCommand).toHaveBeenCalled()
const commandArg = mockConverseStreamCommand.mock.calls[0][0] as any

// System content should NOT include cachePoint since user explicitly disabled caching
const systemBlocks = commandArg.system
const hasCachePoint = systemBlocks?.some((block: any) => block.cachePoint !== undefined)
expect(hasCachePoint).toBe(false)
})

it("should include cachableFields in guessModelInfoFromId for Claude patterns", () => {
// Test with a custom ARN that has a Claude model ID in it
const customArnHandler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
awsCustomArn: "arn:aws:bedrock:us-east-1:123456789012:inference-profile/claude-3-5-sonnet-custom",
})

const modelConfig = customArnHandler.getModel()
expect(modelConfig.info.supportsPromptCache).toBe(true)
expect((modelConfig.info as any).cachableFields).toBeDefined()
expect((modelConfig.info as any).cachableFields).toContain("system")
expect((modelConfig.info as any).cachableFields).toContain("messages")
expect((modelConfig.info as any).cachableFields).toContain("tools")
})
})
})
35 changes: 29 additions & 6 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,36 +295,42 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
cachableFields: ["system", "messages", "tools"],
},
"claude-3-7": {
maxTokens: 8192,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
cachableFields: ["system", "messages", "tools"],
},
"claude-3-5": {
maxTokens: 8192,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
cachableFields: ["system", "messages", "tools"],
},
"claude-4-opus": {
maxTokens: 4096,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
cachableFields: ["system", "messages", "tools"],
},
"claude-3-opus": {
maxTokens: 4096,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
cachableFields: ["system", "messages", "tools"],
},
"claude-3-haiku": {
maxTokens: 4096,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
cachableFields: ["system", "messages", "tools"],
},
}

Expand Down Expand Up @@ -1172,12 +1178,29 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
// Check if the model supports prompt cache
// The cachableFields property is not part of the ModelInfo type in schemas
// but it's used in the bedrockModels object in shared/api.ts
return (
modelConfig?.info?.supportsPromptCache &&
// Use optional chaining and type assertion to access cachableFields
(modelConfig?.info as any)?.cachableFields &&
(modelConfig?.info as any)?.cachableFields?.length > 0
)
const hasCachableFields =
(modelConfig?.info as any)?.cachableFields && (modelConfig?.info as any)?.cachableFields?.length > 0

if (modelConfig?.info?.supportsPromptCache && hasCachableFields) {
return true
}

// When using a custom ARN and the user has enabled prompt caching (or left it
// at the default), respect their intent even if the model info is incomplete.
// The model info may lack cachableFields or supportsPromptCache when the model
// ID extracted from the ARN doesn't match a known model in bedrockModels.
// In this case, inject defaults so the downstream caching logic works correctly.
// If the underlying model truly does not support caching, the Bedrock API
// simply ignores cache points without erroring.
if (this.options.awsCustomArn && this.options.awsUsePromptCache !== false) {
if (!hasCachableFields) {
;(modelConfig.info as any).cachableFields = ["system", "messages", "tools"]
}
modelConfig.info.supportsPromptCache = true
return true
}

return false
}

/**
Expand Down
Loading