diff --git a/packages/types/src/__tests__/index.test.ts b/packages/types/src/__tests__/index.test.ts index 15441d48fd..99c680826d 100644 --- a/packages/types/src/__tests__/index.test.ts +++ b/packages/types/src/__tests__/index.test.ts @@ -15,6 +15,10 @@ describe("GLOBAL_STATE_KEYS", () => { expect(GLOBAL_STATE_KEYS).not.toContain("openRouterApiKey") }) + it("should not contain Umans API key (secret)", () => { + expect(GLOBAL_STATE_KEYS).not.toContain("umansApiKey") + }) + it("should contain OpenAI Compatible base URL setting", () => { expect(GLOBAL_STATE_KEYS).toContain("codebaseIndexOpenAiCompatibleBaseUrl") }) diff --git a/packages/types/src/global-settings.ts b/packages/types/src/global-settings.ts index 706c75cd1e..b79d6e9244 100644 --- a/packages/types/src/global-settings.ts +++ b/packages/types/src/global-settings.ts @@ -290,6 +290,7 @@ export type RooCodeSettings = GlobalSettings & ProviderSettings export const SECRET_STATE_KEYS = [ "apiKey", "openRouterApiKey", + "umansApiKey", "awsAccessKey", "awsApiKey", "awsSecretKey", diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 26c4dee7e1..bfaa8f19fc 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -38,6 +38,7 @@ export const DEFAULT_CONSECUTIVE_MISTAKE_LIMIT = 3 export const dynamicProviders = [ "openrouter", + "umans", "vercel-ai-gateway", "zoo-gateway", "litellm", @@ -85,7 +86,7 @@ export const isInternalProvider = (key: string): key is InternalProvider => * Custom providers are completely configurable within Roo Code settings. */ -export const customProviders = ["openai"] as const +export const customProviders = ["openai", "anthropic-custom"] as const export type CustomProvider = (typeof customProviders)[number] @@ -221,6 +222,11 @@ const openRouterSchema = baseProviderSettingsSchema.extend({ openRouterSpecificProvider: z.string().optional(), }) +const umansSchema = baseProviderSettingsSchema.extend({ + umansApiKey: z.string().optional(), + umansModelId: z.string().optional(), +}) + const bedrockSchema = apiModelIdProviderModelSchema.extend({ awsAccessKey: z.string().optional(), awsSecretKey: z.string().optional(), @@ -262,6 +268,15 @@ const openAiSchema = baseProviderSettingsSchema.extend({ openAiHeaders: z.record(z.string(), z.string()).optional(), }) +const anthropicCustomSchema = baseProviderSettingsSchema.extend({ + anthropicCustomBaseUrl: z.string().optional(), + anthropicCustomApiKey: z.string().optional(), + anthropicCustomModelId: z.string().optional(), + anthropicCustomModelInfo: modelInfoSchema.nullish(), + anthropicCustomStreamingEnabled: z.boolean().optional(), + anthropicCustomHeaders: z.record(z.string(), z.string()).optional(), +}) + const ollamaSchema = baseProviderSettingsSchema.extend({ ollamaModelId: z.string().optional(), ollamaBaseUrl: z.string().optional(), @@ -424,9 +439,11 @@ const defaultSchema = z.object({ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProvider", [ anthropicSchema.merge(z.object({ apiProvider: z.literal("anthropic") })), openRouterSchema.merge(z.object({ apiProvider: z.literal("openrouter") })), + umansSchema.merge(z.object({ apiProvider: z.literal("umans") })), bedrockSchema.merge(z.object({ apiProvider: z.literal("bedrock") })), vertexSchema.merge(z.object({ apiProvider: z.literal("vertex") })), openAiSchema.merge(z.object({ apiProvider: z.literal("openai") })), + anthropicCustomSchema.merge(z.object({ apiProvider: z.literal("anthropic-custom") })), ollamaSchema.merge(z.object({ apiProvider: z.literal("ollama") })), vsCodeLmSchema.merge(z.object({ apiProvider: z.literal("vscode-lm") })), lmStudioSchema.merge(z.object({ apiProvider: z.literal("lmstudio") })), @@ -460,9 +477,11 @@ export const providerSettingsSchema = z.object({ apiProvider: providerNamesWithRetiredSchema.optional(), ...anthropicSchema.shape, ...openRouterSchema.shape, + ...umansSchema.shape, ...bedrockSchema.shape, ...vertexSchema.shape, ...openAiSchema.shape, + ...anthropicCustomSchema.shape, ...ollamaSchema.shape, ...vsCodeLmSchema.shape, ...lmStudioSchema.shape, @@ -511,7 +530,9 @@ export const PROVIDER_SETTINGS_KEYS = providerSettingsSchema.keyof().options export const modelIdKeys = [ "apiModelId", "openRouterModelId", + "umansModelId", "openAiModelId", + "anthropicCustomModelId", "ollamaModelId", "lmStudioModelId", "lmStudioDraftModelId", @@ -542,6 +563,7 @@ export const isTypicalProvider = (key: unknown): key is TypicalProvider => export const modelIdKeysByProvider: Record = { anthropic: "apiModelId", openrouter: "openRouterModelId", + umans: "umansModelId", bedrock: "apiModelId", vertex: "apiModelId", "openai-codex": "apiModelId", @@ -575,7 +597,7 @@ export const modelIdKeysByProvider: Record = { */ // Providers that use Anthropic-style API protocol. -export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "bedrock", "minimax"] +export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "anthropic-custom", "bedrock", "minimax"] export const getApiProtocol = (provider: ProviderName | undefined, modelId?: string): "anthropic" | "openai" => { if (provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider)) { @@ -615,7 +637,7 @@ export const getApiProtocol = (provider: ProviderName | undefined, modelId?: str */ export const MODELS_BY_PROVIDER: Record< - Exclude, + Exclude, { id: ProviderName; label: string; models: string[] } > = { anthropic: { @@ -697,6 +719,7 @@ export const MODELS_BY_PROVIDER: Record< poe: { id: "poe", label: "Poe", models: [] }, litellm: { id: "litellm", label: "LiteLLM", models: [] }, openrouter: { id: "openrouter", label: "OpenRouter", models: [] }, + umans: { id: "umans", label: "Umans", models: [] }, requesty: { id: "requesty", label: "Requesty", models: [] }, unbound: { id: "unbound", label: "Unbound", models: [] }, "vercel-ai-gateway": { id: "vercel-ai-gateway", label: "Vercel AI Gateway", models: [] }, diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index f283cb474c..ed92e7bbb3 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -18,6 +18,7 @@ export * from "./qwen-code.js" export * from "./requesty.js" export * from "./sambanova.js" export * from "./unbound.js" +export * from "./umans.js" export * from "./vertex.js" export * from "./vscode-llm.js" export * from "./xai.js" @@ -44,6 +45,7 @@ import { qwenCodeDefaultModelId } from "./qwen-code.js" import { requestyDefaultModelId } from "./requesty.js" import { sambaNovaDefaultModelId } from "./sambanova.js" import { unboundDefaultModelId } from "./unbound.js" +import { umansDefaultModelId } from "./umans.js" import { vertexDefaultModelId } from "./vertex.js" import { vscodeLlmDefaultModelId } from "./vscode-llm.js" import { xaiDefaultModelId } from "./xai.js" @@ -71,6 +73,8 @@ export function getProviderDefaultModelId( return openRouterDefaultModelId case "requesty": return requestyDefaultModelId + case "umans": + return umansDefaultModelId case "litellm": return litellmDefaultModelId case "xai": diff --git a/packages/types/src/providers/umans.ts b/packages/types/src/providers/umans.ts new file mode 100644 index 0000000000..2237620c0f --- /dev/null +++ b/packages/types/src/providers/umans.ts @@ -0,0 +1,18 @@ +import type { ModelInfo } from "../model.js" + +export const UMANS_DEFAULT_BASE_URL = "https://api.code.umans.ai/v1" + +// Umans +// https://api.code.umans.ai/v1/models/info +export const umansDefaultModelId = "umans-coder" + +export const umansDefaultModelInfo: ModelInfo = { + maxTokens: 32_768, + contextWindow: 262_144, + supportsImages: true, + supportsPromptCache: false, + supportsMaxTokens: true, + inputPrice: 0.95, + outputPrice: 4.0, + description: "Umans Coder is Umans' recommended model for complex, coding-heavy workloads and coding agents.", +} diff --git a/src/api/index.ts b/src/api/index.ts index 0c901f8e23..5a0016be80 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -24,6 +24,7 @@ import { VsCodeLmHandler, RequestyHandler, UnboundHandler, + UmansHandler, FakeAIHandler, XAIHandler, LiteLLMHandler, @@ -133,9 +134,12 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { switch (apiProvider) { case "anthropic": + case "anthropic-custom": return new AnthropicHandler(options) case "openrouter": return new OpenRouterHandler(options) + case "umans": + return new UmansHandler(options) case "bedrock": return new AwsBedrockHandler(options) case "vertex": diff --git a/src/api/providers/__tests__/umans.spec.ts b/src/api/providers/__tests__/umans.spec.ts new file mode 100644 index 0000000000..b235a1f1d6 --- /dev/null +++ b/src/api/providers/__tests__/umans.spec.ts @@ -0,0 +1,116 @@ +// npx vitest run api/providers/__tests__/umans.spec.ts + +vitest.mock("../utils/timeout-config", () => ({ + getApiRequestTimeout: vitest.fn().mockReturnValue(300_000), +})) + +const MOCK_TIMEOUT_MS = 300_000 + +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import { UmansHandler } from "../umans" +import type { ApiHandlerOptions } from "../../../shared/api" +import { Package } from "../../../shared/package" + +const mockCreate = vitest.fn() + +vitest.mock("openai", () => ({ + default: vitest.fn().mockImplementation(function () { + return { + chat: { + completions: { + create: mockCreate, + }, + }, + } + }), +})) + +vitest.mock("../fetchers/modelCache", () => ({ + getModels: vitest.fn().mockResolvedValue({ + "umans-coder": { + maxTokens: 32768, + contextWindow: 262144, + supportsImages: true, + supportsPromptCache: false, + supportsMaxTokens: true, + inputPrice: 0.95, + outputPrice: 4, + description: "Umans Coder", + }, + "umans-glm-5.2": { + maxTokens: 131071, + contextWindow: 405504, + supportsImages: true, + supportsPromptCache: false, + supportsMaxTokens: true, + supportsReasoningEffort: ["none", "high", "max"], + reasoningEffort: "high", + inputPrice: 1.4, + outputPrice: 4.4, + description: "Umans GLM 5.2", + }, + }), +})) + +describe("UmansHandler", () => { + const mockOptions: ApiHandlerOptions = { + umansApiKey: "test-key", + umansModelId: "umans-coder", + } + + beforeEach(() => vitest.clearAllMocks()) + + it("initializes with the Umans base URL and API key", () => { + new UmansHandler(mockOptions) + + expect(OpenAI).toHaveBeenCalledWith({ + baseURL: "https://api.code.umans.ai/v1", + apiKey: "test-key", + defaultHeaders: { + "HTTP-Referer": "https://github.com/Zoo-Code-Org/Zoo-Code", + "X-Title": "Zoo Code", + "User-Agent": `ZooCode/${Package.version}`, + }, + timeout: MOCK_TIMEOUT_MS, + }) + }) + + it("returns the default model when no options are provided", async () => { + const handler = new UmansHandler({}) + const result = await handler.fetchModel() + + expect(result.id).toBe("umans-coder") + expect(result.info.description).toBe("Umans Coder") + }) + + it("uses the provider's default OpenAI reasoning payload for Umans GLM models", async () => { + const handler = new UmansHandler({ + umansApiKey: "test-key", + umansModelId: "umans-glm-5.2", + reasoningEffort: "max", + }) + + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "done" } }], + } + }, + } + + mockCreate.mockResolvedValue(mockStream) + + const generator = handler.createMessage("system prompt", [{ role: "user" as const, content: "test" }]) + await generator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: "umans-glm-5.2", + reasoning_effort: "max", + stream: true, + }), + ) + }) +}) diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 7a4ef30ad0..9a22e80b92 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -8,6 +8,7 @@ import { type AnthropicModelId, anthropicDefaultModelId, anthropicModels, + openAiModelInfoSaneDefaults, ANTHROPIC_DEFAULT_MAX_TOKENS, ApiProviderError, } from "@roo-code/types" @@ -38,12 +39,17 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa super() this.options = options + const baseURL = this.options.anthropicCustomBaseUrl || this.options.anthropicBaseUrl || undefined + const apiKey = this.options.anthropicCustomApiKey || this.options.apiKey const apiKeyFieldName = - this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey" + baseURL && this.options.anthropicUseAuthToken && !this.options.anthropicCustomApiKey + ? "authToken" + : "apiKey" this.client = new Anthropic({ - baseURL: this.options.anthropicBaseUrl || undefined, - [apiKeyFieldName]: this.options.apiKey, + baseURL, + [apiKeyFieldName]: apiKey, + defaultHeaders: this.options.anthropicCustomHeaders, timeout: this.timeoutMs, }) } @@ -352,9 +358,14 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } getModel() { - const modelId = this.options.apiModelId - const id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId - let info: ModelInfo = anthropicModels[id] + const customModelId = this.options.anthropicCustomModelId + const modelId = customModelId || this.options.apiModelId + const id = + customModelId || + (modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId) + let info: ModelInfo = customModelId + ? this.options.anthropicCustomModelInfo || openAiModelInfoSaneDefaults + : anthropicModels[id as AnthropicModelId] // If 1M context beta is enabled for supported models, update the model info if ( diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 11395485a9..2d9ca66969 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -43,6 +43,7 @@ vi.mock("fs", () => ({ vi.mock("../litellm") vi.mock("../openrouter") vi.mock("../requesty") +vi.mock("../umans") // Mock ContextProxy with a simple static instance vi.mock("../../../core/config/ContextProxy", () => ({ @@ -63,10 +64,12 @@ import { getModels, getModelsFromCache } from "../modelCache" import { getLiteLLMModels } from "../litellm" import { getOpenRouterModels } from "../openrouter" import { getRequestyModels } from "../requesty" +import { getUmansModels } from "../umans" const mockGetLiteLLMModels = getLiteLLMModels as Mock const mockGetOpenRouterModels = getOpenRouterModels as Mock const mockGetRequestyModels = getRequestyModels as Mock +const mockGetUmansModels = getUmansModels as Mock const DUMMY_REQUESTY_KEY = "requesty-key-for-testing" @@ -130,6 +133,23 @@ describe("getModels with new GetModelsOptions", () => { expect(result).toEqual(mockModels) }) + it("calls getUmansModels for umans provider", async () => { + const mockModels = { + "umans-coder": { + maxTokens: 32768, + contextWindow: 262144, + supportsPromptCache: false, + description: "Umans Coder", + }, + } + mockGetUmansModels.mockResolvedValue(mockModels) + + const result = await getModels({ provider: "umans" }) + + expect(mockGetUmansModels).toHaveBeenCalled() + expect(result).toEqual(mockModels) + }) + it("handles errors and re-throws them", async () => { const expectedError = new Error("LiteLLM connection failed") mockGetLiteLLMModels.mockRejectedValue(expectedError) diff --git a/src/api/providers/fetchers/__tests__/umans.spec.ts b/src/api/providers/fetchers/__tests__/umans.spec.ts new file mode 100644 index 0000000000..46f8edd59a --- /dev/null +++ b/src/api/providers/fetchers/__tests__/umans.spec.ts @@ -0,0 +1,62 @@ +// npx vitest run api/providers/fetchers/__tests__/umans.spec.ts + +import axios from "axios" + +import { getUmansModels } from "../umans" + +vi.mock("axios") +const mockAxiosGet = vi.mocked(axios.get) + +describe("getUmansModels", () => { + it("parses Umans model metadata and pricing", async () => { + mockAxiosGet + .mockResolvedValueOnce({ + data: { + "umans-flash": { + name: "umans-flash", + display_name: "Umans Flash", + description: "Fast coding model", + capabilities: { + max_completion_tokens: 262144, + recommended_max_tokens: 32768, + context_window: 262144, + supports_vision: true, + reasoning: { + supported: true, + can_disable: true, + levels: ["none", "low", "medium", "high"], + default_level: "medium", + }, + }, + }, + }, + }) + .mockResolvedValueOnce({ + data: { + data: [ + { + id: "umans-flash", + pricing: { input: 0.15, output: 1.0 }, + }, + ], + }, + }) + + const models = await getUmansModels() + + expect(mockAxiosGet).toHaveBeenNthCalledWith(1, "https://api.code.umans.ai/v1/models/info") + expect(mockAxiosGet).toHaveBeenNthCalledWith(2, "https://api.code.umans.ai/v1/models") + expect(models["umans-flash"]).toEqual({ + maxTokens: 32768, + contextWindow: 262144, + supportsImages: true, + supportsPromptCache: false, + supportsMaxTokens: true, + supportsReasoningEffort: ["none", "low", "medium", "high"], + reasoningEffort: "medium", + inputPrice: 0.15, + outputPrice: 1, + description: "Fast coding model", + }) + }) +}) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 404a60cd85..ed7700b65d 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -17,6 +17,7 @@ import type { RouterName } from "../../../shared/api" import { fileExistsAtPath } from "../../../utils/fs" import { getOpenRouterModels } from "./openrouter" +import { getUmansModels } from "./umans" import { getVercelAiGatewayModels } from "./vercel-ai-gateway" import { getOpencodeGoModels } from "./opencode-go" import { getRequestyModels } from "./requesty" @@ -78,6 +79,9 @@ async function fetchModelsFromProvider(options: GetModelsOptions): Promise { // Providers that work without API keys const publicProviders: Array<{ provider: RouterName; options: GetModelsOptions }> = [ { provider: "openrouter", options: { provider: "openrouter" } }, + { provider: "umans", options: { provider: "umans" } }, { provider: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } }, ] diff --git a/src/api/providers/fetchers/umans.ts b/src/api/providers/fetchers/umans.ts new file mode 100644 index 0000000000..ec5ccdbe66 --- /dev/null +++ b/src/api/providers/fetchers/umans.ts @@ -0,0 +1,90 @@ +import axios from "axios" +import { z } from "zod" + +import { type ModelInfo, UMANS_DEFAULT_BASE_URL } from "@roo-code/types" + +const supportedReasoningEffortSchema = z.enum(["none", "minimal", "low", "medium", "high", "xhigh", "max"]) + +const umansReasoningSchema = z + .object({ + supported: z.boolean().optional(), + can_disable: z.boolean().optional(), + levels: z.array(supportedReasoningEffortSchema).optional(), + default_level: supportedReasoningEffortSchema.nullish(), + }) + .optional() + +const umansModelSchema = z.object({ + name: z.string(), + display_name: z.string().optional(), + description: z.string().optional(), + capabilities: z.object({ + max_completion_tokens: z.number().nullish(), + recommended_max_tokens: z.number().nullish(), + context_window: z.number(), + supports_vision: z.union([z.boolean(), z.string()]).optional(), + reasoning: umansReasoningSchema, + }), +}) + +const umansModelsInfoResponseSchema = z.record(z.string(), umansModelSchema) + +const umansPricingResponseSchema = z.object({ + data: z.array( + z.object({ + id: z.string(), + pricing: z + .object({ + input: z.number().optional(), + output: z.number().optional(), + }) + .optional(), + }), + ), +}) + +export async function getUmansModels(): Promise> { + const models: Record = {} + + try { + const [infoResponse, pricingResponse] = await Promise.all([ + axios.get(`${UMANS_DEFAULT_BASE_URL}/models/info`), + axios.get(`${UMANS_DEFAULT_BASE_URL}/models`), + ]) + + const infoResult = umansModelsInfoResponseSchema.safeParse(infoResponse.data) + if (!infoResult.success) { + return models + } + + const pricingResult = umansPricingResponseSchema.safeParse(pricingResponse.data) + const pricingById = new Map( + (pricingResult.success ? pricingResult.data.data : []).map((entry) => [entry.id, entry.pricing]), + ) + + for (const [id, rawModel] of Object.entries(infoResult.data)) { + const reasoning = rawModel.capabilities.reasoning + const reasoningLevels = reasoning?.levels + + models[id] = { + maxTokens: + rawModel.capabilities.recommended_max_tokens ?? + rawModel.capabilities.max_completion_tokens ?? + undefined, + contextWindow: rawModel.capabilities.context_window, + supportsImages: rawModel.capabilities.supports_vision !== false, + supportsPromptCache: false, + supportsMaxTokens: rawModel.capabilities.max_completion_tokens != null, + supportsReasoningEffort: reasoningLevels && reasoningLevels.length > 0 ? reasoningLevels : undefined, + reasoningEffort: reasoning?.default_level ?? undefined, + inputPrice: pricingById.get(id)?.input, + outputPrice: pricingById.get(id)?.output, + description: rawModel.description ?? rawModel.display_name, + } + } + } catch (error) { + console.error(`Error fetching Umans models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + return models +} diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index 3c0d1e03e3..c01777bf8a 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -19,6 +19,7 @@ export { QwenCodeHandler } from "./qwen-code" export { RequestyHandler } from "./requesty" export { SambaNovaHandler } from "./sambanova" export { UnboundHandler } from "./unbound" +export { UmansHandler } from "./umans" export { VertexHandler } from "./vertex" export { VsCodeLmHandler } from "./vscode-lm" export { XAIHandler } from "./xai" diff --git a/src/api/providers/umans.ts b/src/api/providers/umans.ts new file mode 100644 index 0000000000..5a2d0e906a --- /dev/null +++ b/src/api/providers/umans.ts @@ -0,0 +1,189 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" + +import { + type ModelInfo, + type ModelRecord, + umansDefaultModelId, + umansDefaultModelInfo, + UMANS_DEFAULT_BASE_URL, +} from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../shared/api" +import { calculateApiCostOpenAI } from "../../shared/cost" + +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" +import { getModels } from "./fetchers/modelCache" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { handleOpenAIError } from "./utils/openai-error-handler" +import { applyRouterToolPreferences } from "./utils/router-tool-preferences" +import { extractReasoningFromDelta } from "./utils/extract-reasoning" + +type UmansUsage = OpenAI.CompletionUsage & { + prompt_tokens_details?: { + cache_write_tokens?: number + caching_tokens?: number + cached_tokens?: number + } +} + +export class UmansHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected models: ModelRecord = {} + private client: OpenAI + private readonly providerName = "Umans" + + constructor(options: ApiHandlerOptions) { + super() + + this.options = options + + const apiKey = this.options.umansApiKey ?? "not-provided" + + this.client = new OpenAI({ + baseURL: UMANS_DEFAULT_BASE_URL, + apiKey, + defaultHeaders: DEFAULT_HEADERS, + timeout: this.timeoutMs, + }) + } + + public async fetchModel() { + this.models = await getModels({ provider: "umans" }) + return this.getModel() + } + + override getModel() { + const id = this.options.umansModelId ?? umansDefaultModelId + const cachedInfo = this.models[id] ?? umansDefaultModelInfo + let info: ModelInfo = cachedInfo + + info = applyRouterToolPreferences(id, info) + + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: 0, + }) + + return { id, info, ...params } + } + + protected processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk { + const umansUsage = usage as UmansUsage + const inputTokens = umansUsage?.prompt_tokens || 0 + const outputTokens = umansUsage?.completion_tokens || 0 + const cacheWriteTokens = + umansUsage?.prompt_tokens_details?.cache_write_tokens || + umansUsage?.prompt_tokens_details?.caching_tokens || + 0 + const cacheReadTokens = umansUsage?.prompt_tokens_details?.cached_tokens || 0 + const { totalCost } = modelInfo + ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + : { totalCost: 0 } + + return { + type: "usage", + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + totalCost, + } + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { id: model, info, maxTokens: max_tokens, temperature, reasoning } = await this.fetchModel() + + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + { role: "system", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ] + + const completionParams: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + messages: openAiMessages, + model, + max_tokens, + temperature, + ...(reasoning ?? {}), + stream: true, + stream_options: { include_usage: true }, + tools: this.convertToolsForOpenAI(metadata?.tools), + tool_choice: metadata?.tool_choice, + parallel_tool_calls: metadata?.parallelToolCalls ?? true, + } + + let stream + try { + stream = await this.client.chat.completions.create(completionParams) + } catch (error) { + throw handleOpenAIError(error, this.providerName) + } + let lastUsage: any = undefined + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + yield { type: "text", text: delta.content } + } + + const reasoningText = extractReasoningFromDelta(delta) + if (reasoningText) { + yield { type: "reasoning", text: reasoningText } + } + + if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) { + for (const toolCall of delta.tool_calls) { + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + + if (chunk.usage) { + lastUsage = chunk.usage + } + } + + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, info) + } + } + + async completePrompt(prompt: string): Promise { + const { id: model, maxTokens: max_tokens, temperature, reasoning } = await this.fetchModel() + + const completionParams: OpenAI.Chat.ChatCompletionCreateParams = { + model, + max_tokens, + messages: [{ role: "system", content: prompt }], + temperature, + ...(reasoning ?? {}), + } + + let response: OpenAI.Chat.ChatCompletion + try { + response = await this.client.chat.completions.create(completionParams) + } catch (error) { + throw handleOpenAIError(error, this.providerName) + } + + return response.choices[0]?.message.content || "" + } +} diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 1904b46bd4..bc73e918b5 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -2676,6 +2676,7 @@ describe("ClineProvider - Router Models", () => { // Verify getModels was called for each provider with correct options expect(getModels).toHaveBeenCalledWith({ provider: "openrouter" }) + expect(getModels).toHaveBeenCalledWith({ provider: "umans" }) expect(getModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" }) expect(getModels).toHaveBeenCalledWith({ provider: "unbound" }) expect(getModels).toHaveBeenCalledWith({ provider: "vercel-ai-gateway" }) @@ -2692,6 +2693,7 @@ describe("ClineProvider - Router Models", () => { type: "routerModels", routerModels: { openrouter: mockModels, + umans: mockModels, requesty: mockModels, unbound: mockModels, "vercel-ai-gateway": mockModels, @@ -2728,6 +2730,7 @@ describe("ClineProvider - Router Models", () => { // Mock some providers to succeed and others to fail vi.mocked(getModels) .mockResolvedValueOnce(mockModels) // openrouter success + .mockResolvedValueOnce(mockModels) // umans success .mockRejectedValueOnce(new Error("Requesty API error")) // requesty fail .mockResolvedValueOnce(mockModels) // unbound success .mockResolvedValueOnce(mockModels) // vercel-ai-gateway success @@ -2742,6 +2745,7 @@ describe("ClineProvider - Router Models", () => { type: "routerModels", routerModels: { openrouter: mockModels, + umans: mockModels, requesty: {}, unbound: mockModels, "vercel-ai-gateway": mockModels, @@ -2839,6 +2843,7 @@ describe("ClineProvider - Router Models", () => { type: "routerModels", routerModels: { openrouter: mockModels, + umans: mockModels, requesty: mockModels, unbound: mockModels, "vercel-ai-gateway": mockModels, diff --git a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts index 13af478c07..b1db81de0f 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts @@ -109,6 +109,7 @@ describe("webviewMessageHandler - requestRouterModels provider filter", () => { // Aggregate handler initializes many known routers - ensure a few expected keys exist expect(routerModels).toHaveProperty("openrouter") + expect(routerModels).toHaveProperty("umans") expect(routerModels).toHaveProperty("requesty") expect(routerModels).toHaveProperty("deepseek") expect(routerModels.deepseek).toEqual({}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 1dc53600cc..8d36b0a23e 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -370,6 +370,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { // Verify getModels was called for each provider expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "umans" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" }) expect(mockGetModels).toHaveBeenCalledWith( expect.objectContaining({ @@ -390,6 +391,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { type: "routerModels", routerModels: { openrouter: mockModels, + umans: mockModels, requesty: mockModels, unbound: mockModels, "vercel-ai-gateway": mockModels, @@ -541,6 +543,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { type: "routerModels", routerModels: { openrouter: mockModels, + umans: mockModels, requesty: mockModels, unbound: mockModels, "vercel-ai-gateway": mockModels, @@ -569,6 +572,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { // Mock some providers to succeed and others to fail mockGetModels .mockResolvedValueOnce(mockModels) // openrouter + .mockResolvedValueOnce(mockModels) // umans .mockRejectedValueOnce(new Error("Requesty API error")) // requesty .mockResolvedValueOnce(mockModels) // unbound .mockResolvedValueOnce(mockModels) // vercel-ai-gateway @@ -600,6 +604,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { type: "routerModels", routerModels: { openrouter: mockModels, + umans: mockModels, requesty: {}, unbound: mockModels, "vercel-ai-gateway": mockModels, @@ -619,11 +624,13 @@ describe("webviewMessageHandler - requestRouterModels", () => { // Mock providers to fail with different error types mockGetModels .mockRejectedValueOnce(new Error("Structured error message")) // openrouter + .mockRejectedValueOnce(new Error("Umans API error")) // umans .mockRejectedValueOnce(new Error("Requesty API error")) // requesty .mockRejectedValueOnce(new Error("Unbound error")) // unbound .mockRejectedValueOnce(new Error("Vercel AI Gateway error")) // vercel-ai-gateway .mockRejectedValueOnce(new Error("Zoo Gateway error")) // zoo-gateway .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm + .mockResolvedValueOnce({}) // opencode-go await webviewMessageHandler(mockClineProvider, { type: "requestRouterModels", @@ -637,6 +644,13 @@ describe("webviewMessageHandler - requestRouterModels", () => { values: { provider: "openrouter" }, }) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Umans API error", + values: { provider: "umans" }, + }) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "singleRouterModelFetchResponse", success: false, diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 5fbe0acfa9..e9f09b0504 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -1032,6 +1032,7 @@ export const webviewMessageHandler = async ( ? ({} as Record) : { openrouter: {}, + umans: {}, "vercel-ai-gateway": {}, "zoo-gateway": {}, litellm: {}, @@ -1060,6 +1061,7 @@ export const webviewMessageHandler = async ( // Base candidates (only those handled by this aggregate fetcher) const candidates: { key: RouterName; options: GetModelsOptions }[] = [ { key: "openrouter", options: { provider: "openrouter" } }, + { key: "umans", options: { provider: "umans" } }, { key: "requesty", options: { diff --git a/src/shared/ProfileValidator.ts b/src/shared/ProfileValidator.ts index 7246a90177..bbdedecb82 100644 --- a/src/shared/ProfileValidator.ts +++ b/src/shared/ProfileValidator.ts @@ -53,6 +53,8 @@ export class ProfileValidator { switch (profile.apiProvider) { case "openai": return profile.openAiModelId + case "anthropic-custom": + return profile.anthropicCustomModelId case "anthropic": case "openai-native": case "bedrock": @@ -73,6 +75,8 @@ export class ProfileValidator { return profile.vsCodeLmModelSelector?.id case "openrouter": return profile.openRouterModelId + case "umans": + return profile.umansModelId case "ollama": return profile.ollamaModelId case "requesty": diff --git a/src/shared/__tests__/ProfileValidator.spec.ts b/src/shared/__tests__/ProfileValidator.spec.ts index 9bf913cdc2..500d363346 100644 --- a/src/shared/__tests__/ProfileValidator.spec.ts +++ b/src/shared/__tests__/ProfileValidator.spec.ts @@ -273,6 +273,21 @@ describe("ProfileValidator", () => { expect(ProfileValidator.isProfileAllowed(profile, allowList)).toBe(true) }) + it("should extract umansModelId for umans provider", () => { + const allowList: OrganizationAllowList = { + allowAll: false, + providers: { + umans: { allowAll: false, models: ["umans-coder"] }, + }, + } + const profile: ProviderSettings = { + apiProvider: "umans", + umansModelId: "umans-coder", + } + + expect(ProfileValidator.isProfileAllowed(profile, allowList)).toBe(true) + }) + it("should handle providers with undefined models list gracefully", () => { const allowList: OrganizationAllowList = { allowAll: false, diff --git a/src/shared/api.ts b/src/shared/api.ts index c0db55f661..153a804374 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -177,6 +177,7 @@ type CommonFetchParams = { // until a corresponding entry is added here. const dynamicProviderExtras = { openrouter: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type + umans: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type "vercel-ai-gateway": {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type "zoo-gateway": {} as { apiKey?: string; baseUrl?: string }, litellm: {} as { apiKey: string; baseUrl: string }, diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 70617a1ee6..55d22ec816 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -48,6 +48,7 @@ import { import { Anthropic, + AnthropicCustom, Baseten, Bedrock, DeepSeek, @@ -64,6 +65,7 @@ import { Poe, QwenCode, Requesty, + Umans, SambaNova, Unbound, Vertex, @@ -448,6 +450,17 @@ const ApiOptions = ({ /> )} + {selectedProvider === "umans" && ( + + )} + {selectedProvider === "unbound" && ( )} + {selectedProvider === "anthropic-custom" && ( + + )} + {selectedProvider === "openai-codex" && ( ( + field: K, + value: ProviderSettings[K], + isUserAction?: boolean, + ) => void + organizationAllowList: OrganizationAllowList + modelValidationError?: string + simplifySettings?: boolean +} + +const anthropicCustomDefaultModelId = "claude-sonnet-4-5" + +export const AnthropicCustom = ({ + apiConfiguration, + setApiConfigurationField, + organizationAllowList, + modelValidationError, + simplifySettings, +}: AnthropicCustomProps) => { + const { t } = useAppTranslation() + + const [anthropicBaseUrlSelected, setAnthropicBaseUrlSelected] = useState(!!apiConfiguration?.anthropicCustomBaseUrl) + + useEffect(() => { + if (!apiConfiguration.anthropicCustomModelInfo) { + setApiConfigurationField( + "anthropicCustomModelInfo", + { + ...openAiModelInfoSaneDefaults, + ...(anthropicModels[anthropicCustomDefaultModelId] || {}), + }, + false, + ) + } + }, [apiConfiguration.anthropicCustomModelInfo, setApiConfigurationField]) + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + const getCustomModelInfo = () => apiConfiguration?.anthropicCustomModelInfo || openAiModelInfoSaneDefaults + + return ( + <> + + + +
+ {t("settings:providers.apiKeyStorageNotice")} +
+ {!apiConfiguration?.anthropicCustomApiKey && ( + + {t("settings:providers.getAnthropicApiKey")} + + )} +
+ { + setAnthropicBaseUrlSelected(checked) + + if (!checked) { + setApiConfigurationField("anthropicCustomBaseUrl", "") + } + }}> + {t("settings:providers.useCustomBaseUrl")} + + {anthropicBaseUrlSelected && ( + + + + )} +
+ { + setApiConfigurationField(field, value, isUserAction) + + if (field === "anthropicCustomModelId") { + setApiConfigurationField( + "anthropicCustomModelInfo", + { + ...openAiModelInfoSaneDefaults, + ...(anthropicModels[value as keyof typeof anthropicModels] || {}), + }, + false, + ) + } + }} + defaultModelId={anthropicCustomDefaultModelId} + models={anthropicModels} + modelIdKey="anthropicCustomModelId" + serviceName="Anthropic" + serviceUrl="https://docs.anthropic.com" + organizationAllowList={organizationAllowList} + errorMessage={modelValidationError} + simplifySettings={simplifySettings} + /> + +
+
+ {t("settings:providers.customModel.capabilities")} +
+ +
+ { + const value = parseInt((e.target as HTMLInputElement).value) + + return { + ...getCustomModelInfo(), + maxTokens: isNaN(value) ? undefined : value, + } + })} + placeholder={t("settings:placeholders.numbers.maxTokens")} + className="w-full"> + + +
+ {t("settings:providers.customModel.maxTokens.description")} +
+
+ +
+ { + const value = parseInt((e.target as HTMLInputElement).value) + + return { + ...getCustomModelInfo(), + contextWindow: isNaN(value) ? openAiModelInfoSaneDefaults.contextWindow : value, + } + })} + placeholder={t("settings:placeholders.numbers.contextWindow")} + className="w-full"> + + +
+ {t("settings:providers.customModel.contextWindow.description")} +
+
+ +
+
+ ({ + ...getCustomModelInfo(), + supportsImages: checked, + }))}> + + {t("settings:providers.customModel.imageSupport.label")} + + + + + +
+
+ +
+
+ ({ + ...getCustomModelInfo(), + supportsPromptCache: checked, + }))}> + {t("settings:providers.customModel.promptCache.label")} + + + + +
+
+ +
+ { + const parsed = parseFloat((e.target as HTMLInputElement).value) + + return { + ...getCustomModelInfo(), + inputPrice: isNaN(parsed) ? openAiModelInfoSaneDefaults.inputPrice : parsed, + } + })} + placeholder={t("settings:placeholders.numbers.inputPrice")} + className="w-full"> +
+ + + + +
+
+
+ +
+ { + const parsed = parseFloat((e.target as HTMLInputElement).value) + + return { + ...getCustomModelInfo(), + outputPrice: isNaN(parsed) ? openAiModelInfoSaneDefaults.outputPrice : parsed, + } + })} + placeholder={t("settings:placeholders.numbers.outputPrice")} + className="w-full"> +
+ + + + +
+
+
+ + {getCustomModelInfo().supportsPromptCache && ( + <> +
+ { + const parsed = parseFloat((e.target as HTMLInputElement).value) + + return { + ...getCustomModelInfo(), + cacheReadsPrice: isNaN(parsed) ? 0 : parsed, + } + })} + placeholder={t("settings:placeholders.numbers.inputPrice")} + className="w-full"> + + {t("settings:providers.customModel.pricing.cacheReads.label")} + + +
+
+ { + const parsed = parseFloat((e.target as HTMLInputElement).value) + + return { + ...getCustomModelInfo(), + cacheWritesPrice: isNaN(parsed) ? 0 : parsed, + } + })} + placeholder={t("settings:placeholders.numbers.cacheWritePrice")} + className="w-full"> + + +
+ + )} + + +
+ + ) +} diff --git a/webview-ui/src/components/settings/providers/Umans.tsx b/webview-ui/src/components/settings/providers/Umans.tsx new file mode 100644 index 0000000000..54df3b77e9 --- /dev/null +++ b/webview-ui/src/components/settings/providers/Umans.tsx @@ -0,0 +1,85 @@ +import { useCallback } from "react" +import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" + +import { + type ProviderSettings, + type OrganizationAllowList, + type RouterModels, + umansDefaultModelId, +} from "@roo-code/types" + +import { vscode } from "@src/utils/vscode" +import { useAppTranslation } from "@src/i18n/TranslationContext" +import { Button } from "@src/components/ui" + +import { inputEventTransform } from "../transforms" +import { ModelPicker } from "../ModelPicker" + +type UmansProps = { + apiConfiguration: ProviderSettings + setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void + routerModels?: RouterModels + organizationAllowList: OrganizationAllowList + modelValidationError?: string + simplifySettings?: boolean +} + +export const Umans = ({ + apiConfiguration, + setApiConfigurationField, + routerModels, + organizationAllowList, + modelValidationError, + simplifySettings, +}: UmansProps) => { + const { t } = useAppTranslation() + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + return ( + <> + + + +
+ {t("settings:providers.apiKeyStorageNotice")} +
+ + + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index d5dd0d0ded..22071ac104 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -1,4 +1,5 @@ export { Anthropic } from "./Anthropic" +export { AnthropicCustom } from "./AnthropicCustom" export { Bedrock } from "./Bedrock" export { DeepSeek } from "./DeepSeek" export { Gemini } from "./Gemini" @@ -13,6 +14,7 @@ export { OpenRouter } from "./OpenRouter" export { Poe } from "./Poe" export { QwenCode } from "./QwenCode" export { Requesty } from "./Requesty" +export { Umans } from "./Umans" export { SambaNova } from "./SambaNova" export { Unbound } from "./Unbound" export { Vertex } from "./Vertex" diff --git a/webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts b/webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts index db3581634a..5019c81fa1 100644 --- a/webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts +++ b/webview-ui/src/components/settings/utils/__tests__/providerModelConfig.spec.ts @@ -25,6 +25,13 @@ describe("providerModelConfig", () => { }) }) + it("contains service config for umans", () => { + expect(PROVIDER_SERVICE_CONFIG.umans).toEqual({ + serviceName: "Umans", + serviceUrl: "https://api.code.umans.ai/v1/models/info", + }) + }) + it("contains service config for ollama", () => { expect(PROVIDER_SERVICE_CONFIG.ollama).toEqual({ serviceName: "Ollama", @@ -64,6 +71,7 @@ describe("providerModelConfig", () => { describe("PROVIDER_DEFAULT_MODEL_IDS", () => { it("contains default model IDs for static providers", () => { expect(PROVIDER_DEFAULT_MODEL_IDS.anthropic).toBeDefined() + expect(PROVIDER_DEFAULT_MODEL_IDS.umans).toBe("umans-coder") expect(PROVIDER_DEFAULT_MODEL_IDS.bedrock).toBeDefined() expect(PROVIDER_DEFAULT_MODEL_IDS.gemini).toBeDefined() expect(PROVIDER_DEFAULT_MODEL_IDS["openai-native"]).toBeDefined() @@ -187,6 +195,7 @@ describe("providerModelConfig", () => { it("returns false for providers with custom model UI", () => { expect(shouldUseGenericModelPicker("openrouter")).toBe(false) + expect(shouldUseGenericModelPicker("umans")).toBe(false) expect(shouldUseGenericModelPicker("ollama")).toBe(false) expect(shouldUseGenericModelPicker("lmstudio")).toBe(false) expect(shouldUseGenericModelPicker("vscode-lm")).toBe(false) diff --git a/webview-ui/src/components/settings/utils/providerModelConfig.ts b/webview-ui/src/components/settings/utils/providerModelConfig.ts index 9cc9dafa01..056938f717 100644 --- a/webview-ui/src/components/settings/utils/providerModelConfig.ts +++ b/webview-ui/src/components/settings/utils/providerModelConfig.ts @@ -21,6 +21,7 @@ import { mimoDefaultModelId, poeDefaultModelId, requestyDefaultModelId, + umansDefaultModelId, unboundDefaultModelId, litellmDefaultModelId, vercelAiGatewayDefaultModelId, @@ -37,6 +38,7 @@ export interface ProviderServiceConfig { export const PROVIDER_SERVICE_CONFIG: Partial> = { anthropic: { serviceName: "Anthropic", serviceUrl: "https://console.anthropic.com" }, + umans: { serviceName: "Umans", serviceUrl: "https://api.code.umans.ai/v1/models/info" }, bedrock: { serviceName: "Amazon Bedrock", serviceUrl: "https://aws.amazon.com/bedrock" }, deepseek: { serviceName: "DeepSeek", serviceUrl: "https://platform.deepseek.com" }, moonshot: { serviceName: "Moonshot", serviceUrl: "https://platform.moonshot.cn" }, @@ -62,6 +64,7 @@ export const PROVIDER_SERVICE_CONFIG: Partial> = { anthropic: anthropicDefaultModelId, + umans: umansDefaultModelId, bedrock: bedrockDefaultModelId, deepseek: deepSeekDefaultModelId, moonshot: moonshotDefaultModelId, @@ -103,6 +106,7 @@ export type ProviderModelConfig = { // Kept in this file to keep ApiOptions.tsx from growing a second registry. const PROVIDER_MODEL_CONFIG: Partial> = { openrouter: { field: "openRouterModelId", default: openRouterDefaultModelId }, + umans: { field: "umansModelId", default: umansDefaultModelId }, requesty: { field: "requestyModelId", default: requestyDefaultModelId }, unbound: { field: "unboundModelId", default: unboundDefaultModelId }, litellm: { field: "litellmModelId", default: litellmDefaultModelId }, @@ -150,6 +154,7 @@ export function getProviderModelConfig(provider: string, apiConfiguration?: Prov const PROVIDER_DOCS_SLUGS: Partial> = { "openai-native": "openai", openai: "openai-compatible", + umans: "openai-compatible", } export function getProviderDocsSlug(provider: string) { @@ -191,6 +196,7 @@ export const isStaticModelProvider = (provider: ProviderName): boolean => { */ export const PROVIDERS_WITH_CUSTOM_MODEL_UI: ProviderName[] = [ "openrouter", + "umans", "requesty", "unbound", "openai", // OpenAI Compatible diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index 0dc42129c0..6f1731d686 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -645,6 +645,82 @@ describe("useSelectedModel", () => { }) }) + describe("umans provider", () => { + beforeEach(() => { + mockUseOpenRouterModelProviders.mockReturnValue({ + data: {}, + isLoading: false, + isError: false, + } as any) + }) + + it("returns router model info when a Umans model exists", () => { + const umansModelInfo: ModelInfo = { + maxTokens: 32768, + contextWindow: 262144, + supportsImages: true, + supportsPromptCache: false, + description: "Umans Coder", + } + + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + umans: { + "umans-coder": umansModelInfo, + }, + requesty: {}, + litellm: {}, + }, + isLoading: false, + isError: false, + } as any) + + const apiConfiguration: ProviderSettings = { + apiProvider: "umans", + umansModelId: "umans-coder", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("umans") + expect(result.current.id).toBe("umans-coder") + expect(result.current.info).toEqual(umansModelInfo) + }) + + it("falls back to the Umans default when the selected model is unavailable", () => { + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + umans: { + "umans-coder": { + maxTokens: 32768, + contextWindow: 262144, + supportsImages: true, + supportsPromptCache: false, + }, + }, + requesty: {}, + litellm: {}, + }, + isLoading: false, + isError: false, + } as any) + + const apiConfiguration: ProviderSettings = { + apiProvider: "umans", + umansModelId: "missing-model", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("umans") + expect(result.current.id).toBe("umans-coder") + }) + }) + describe("openai provider", () => { beforeEach(() => { mockUseRouterModels.mockReturnValue({ diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index d3ebb6c0dd..c14153c729 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -161,6 +161,11 @@ function getSelectedModel({ const routerInfo = routerModels.requesty?.[id] return { id, info: routerInfo } } + case "umans": { + const id = getValidatedModelId(apiConfiguration.umansModelId, routerModels.umans, defaultModelId) + const routerInfo = routerModels.umans?.[id] + return { id, info: routerInfo } + } case "unbound": { const id = getValidatedModelId(apiConfiguration.unboundModelId, routerModels.unbound, defaultModelId) const routerInfo = routerModels.unbound?.[id] @@ -370,8 +375,15 @@ function getSelectedModel({ // case "anthropic": // case "fake-ai": default: { - provider satisfies "anthropic" | "gemini-cli" | "fake-ai" - const id = apiConfiguration.apiModelId ?? defaultModelId + provider satisfies "anthropic" | "anthropic-custom" | "gemini-cli" | "fake-ai" + const id = apiConfiguration.apiModelId ?? apiConfiguration.anthropicCustomModelId ?? defaultModelId + + // For anthropic-custom, use custom model info if available + if (provider === "anthropic-custom") { + const info = apiConfiguration.anthropicCustomModelInfo || undefined + return { id, info } + } + const baseInfo = anthropicModels[id as keyof typeof anthropicModels] // Apply 1M context beta tier pricing for supported Claude 4 models diff --git a/webview-ui/src/utils/__tests__/validate.spec.ts b/webview-ui/src/utils/__tests__/validate.spec.ts index 5d4f54b927..d082cbf6ba 100644 --- a/webview-ui/src/utils/__tests__/validate.spec.ts +++ b/webview-ui/src/utils/__tests__/validate.spec.ts @@ -43,6 +43,7 @@ describe("Model Validation Functions", () => { outputPrice: 5.0, }, }, + umans: {}, requesty: {}, unbound: {}, litellm: {}, @@ -155,6 +156,16 @@ describe("Model Validation Functions", () => { expect(result).toBe("settings:validation.apiKey") }) + it("returns an apiKey error for Umans when the API key is missing", () => { + const config: ProviderSettings = { + apiProvider: "umans", + umansModelId: "umans-coder", + } + + const result = validateApiConfigurationExcludingModelErrors(config, mockRouterModels, allowAllOrganization) + expect(result).toBe("settings:validation.apiKey") + }) + it("excludes model-specific errors", () => { const config: ProviderSettings = { apiProvider: "openrouter", diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 3de6480802..44c0889a0c 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -52,6 +52,11 @@ function validateModelsAndKeysProvided( return i18next.t("settings:validation.apiKey") } break + case "umans": + if (!apiConfiguration.umansApiKey) { + return i18next.t("settings:validation.apiKey") + } + break case "unbound": if (!apiConfiguration.unboundApiKey) { return i18next.t("settings:validation.apiKey") @@ -97,6 +102,14 @@ function validateModelsAndKeysProvided( return i18next.t("settings:validation.openAi") } break + case "anthropic-custom": + if (!apiConfiguration.anthropicCustomApiKey) { + return i18next.t("settings:validation.apiKey") + } + if (!apiConfiguration.anthropicCustomModelId) { + return i18next.t("settings:validation.modelId") + } + break case "ollama": if (!apiConfiguration.ollamaModelId) { return i18next.t("settings:validation.modelId") @@ -195,6 +208,10 @@ function getModelIdForProvider(apiConfiguration: ProviderSettings, provider: Pro return apiConfiguration.vsCodeLmModelSelector?.id } + if (provider === "anthropic-custom") { + return apiConfiguration.anthropicCustomModelId + } + if (isCustomProvider(provider) || isFauxProvider(provider)) { return apiConfiguration.apiModelId }