diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index b5f0755bee..b68b616ea9 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -4,6 +4,9 @@ import type { Mock } from "vitest" // Mock dependencies - must come before imports vi.mock("../../../api/providers/fetchers/modelCache") +vi.mock("../../../api/providers/fetchers/lmstudio", () => ({ + getLMStudioModels: vi.fn(), +})) vi.mock("../../../integrations/openai-codex/oauth", () => ({ openAiCodexOAuthManager: { @@ -42,11 +45,13 @@ import type { ModelRecord } from "@roo-code/types" import { webviewMessageHandler } from "../webviewMessageHandler" import type { ClineProvider } from "../ClineProvider" import { getModels } from "../../../api/providers/fetchers/modelCache" +import { getLMStudioModels } from "../../../api/providers/fetchers/lmstudio" import { getCommands } from "../../../services/command/commands" const { openAiCodexOAuthManager } = await import("../../../integrations/openai-codex/oauth") const { fetchOpenAiCodexRateLimitInfo } = await import("../../../integrations/openai-codex/rate-limits") const mockGetModels = getModels as Mock +const mockGetLMStudioModels = getLMStudioModels as Mock const mockGetCommands = vi.mocked(getCommands) const mockGetAccessToken = vi.mocked(openAiCodexOAuthManager.getAccessToken) const mockGetAccountId = vi.mocked(openAiCodexOAuthManager.getAccountId) @@ -166,6 +171,7 @@ import { resolveImageMentions } from "../../mentions/resolveImageMentions" describe("webviewMessageHandler - requestLmStudioModels", () => { beforeEach(() => { vi.clearAllMocks() + mockGetLMStudioModels.mockReset() mockClineProvider.getState = vi.fn().mockResolvedValue({ apiConfiguration: { lmStudioModelId: "model-1", @@ -203,6 +209,30 @@ describe("webviewMessageHandler - requestLmStudioModels", () => { lmStudioModels: mockModels, }) }) + + it("prefers the request payload base URL over persisted settings", async () => { + mockGetLMStudioModels.mockResolvedValue({}) + + await webviewMessageHandler(mockClineProvider, { + type: "requestLmStudioModels", + values: { baseUrl: "http://127.0.0.1:4321" }, + }) + + expect(mockGetLMStudioModels).toHaveBeenCalledWith("http://127.0.0.1:4321") + expect(mockGetModels).not.toHaveBeenCalled() + }) + + it("treats an empty-string base URL as an explicit preview request", async () => { + mockGetLMStudioModels.mockResolvedValue({}) + + await webviewMessageHandler(mockClineProvider, { + type: "requestLmStudioModels", + values: { baseUrl: "" }, + }) + + expect(mockGetLMStudioModels).toHaveBeenCalledWith("") + expect(mockGetModels).not.toHaveBeenCalled() + }) }) describe("webviewMessageHandler - image mentions", () => { diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 4d6733f981..dc029cb7dd 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -72,6 +72,7 @@ import { GetModelsOptions } from "../../shared/api" import { generateSystemPrompt } from "./generateSystemPrompt" import { resolveDefaultSaveUri, saveLastExportPath } from "../../utils/export" import { getCommand } from "../../utils/commands" +import { getLMStudioModels } from "../../api/providers/fetchers/lmstudio" const ALLOWED_VSCODE_SETTINGS = new Set(["terminal.integrated.inheritEnv"]) @@ -1086,14 +1087,20 @@ export const webviewMessageHandler = async ( // Specific handler for LM Studio models only. const { apiConfiguration: lmStudioApiConfig } = await provider.getState() try { - const lmStudioOptions = { - provider: "lmstudio" as const, - baseUrl: lmStudioApiConfig.lmStudioBaseUrl, + const requestedBaseUrl = message.values?.baseUrl + const hasPreviewBaseUrl = typeof requestedBaseUrl === "string" + let lmStudioModels: ModelRecord + if (hasPreviewBaseUrl) { + lmStudioModels = await getLMStudioModels(requestedBaseUrl) + } else { + const lmStudioOptions = { + provider: "lmstudio" as const, + baseUrl: lmStudioApiConfig.lmStudioBaseUrl, + } + // Flush cache and refresh to ensure fresh models. + await flushModels(lmStudioOptions, true) + lmStudioModels = await getModels(lmStudioOptions) } - // Flush cache and refresh to ensure fresh models. - await flushModels(lmStudioOptions, true) - - const lmStudioModels = await getModels(lmStudioOptions) if (Object.keys(lmStudioModels).length > 0) { provider.postMessageToWebview({ diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index f598b707ed..8981f8d3b7 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -47,6 +47,7 @@ import { validateApiConfigurationExcludingModelErrors, getModelValidationError } import { useAppTranslation } from "@src/i18n/TranslationContext" import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" import { useSelectedModel } from "@src/components/ui/hooks/useSelectedModel" +import { requestLmStudioModels } from "@src/components/ui/hooks/useLmStudioModels" import { useExtensionState } from "@src/context/ExtensionStateContext" import { useOpenRouterModelProviders, @@ -236,7 +237,7 @@ const ApiOptions = ({ } else if (selectedProvider === "ollama") { vscode.postMessage({ type: "requestOllamaModels" }) } else if (selectedProvider === "lmstudio") { - vscode.postMessage({ type: "requestLmStudioModels" }) + requestLmStudioModels(apiConfiguration?.lmStudioBaseUrl) } else if (selectedProvider === "vscode-lm") { vscode.postMessage({ type: "requestVsCodeLmModels" }) } else if (selectedProvider === "litellm" || selectedProvider === "poe") { diff --git a/webview-ui/src/components/settings/providers/LMStudio.tsx b/webview-ui/src/components/settings/providers/LMStudio.tsx index 48eab8d9da..786c3f4474 100644 --- a/webview-ui/src/components/settings/providers/LMStudio.tsx +++ b/webview-ui/src/components/settings/providers/LMStudio.tsx @@ -1,4 +1,4 @@ -import { useCallback, useState, useMemo, useEffect } from "react" +import { useCallback, useState, useMemo, useEffect, useRef } from "react" import { useEvent } from "react-use" import { Trans } from "react-i18next" import { Checkbox } from "vscrui" @@ -7,8 +7,8 @@ import { VSCodeLink, VSCodeTextField } from "@vscode/webview-ui-toolkit/react" import type { ProviderSettings, ExtensionMessage, ModelRecord } from "@roo-code/types" import { useAppTranslation } from "@src/i18n/TranslationContext" +import { requestLmStudioModels } from "@src/components/ui/hooks/useLmStudioModels" import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" -import { vscode } from "@src/utils/vscode" import { inputEventTransform } from "../transforms" import { ModelPicker } from "../ModelPicker" @@ -23,6 +23,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi const [lmStudioModels, setLmStudioModels] = useState({}) const routerModels = useRouterModels() + const initialBaseUrlRef = useRef(apiConfiguration?.lmStudioBaseUrl) const handleInputChange = useCallback( ( @@ -53,7 +54,7 @@ export const LMStudio = ({ apiConfiguration, setApiConfigurationField }: LMStudi // Refresh models on mount useEffect(() => { // Request fresh models - the handler now flushes cache automatically - vscode.postMessage({ type: "requestLmStudioModels" }) + requestLmStudioModels(initialBaseUrlRef.current) }, []) // Check if the selected model exists in the fetched models diff --git a/webview-ui/src/components/ui/hooks/__tests__/useLmStudioModels.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useLmStudioModels.spec.ts new file mode 100644 index 0000000000..a2f74a0ad6 --- /dev/null +++ b/webview-ui/src/components/ui/hooks/__tests__/useLmStudioModels.spec.ts @@ -0,0 +1,33 @@ +vi.mock("@src/utils/vscode", () => ({ + vscode: { + postMessage: vi.fn(), + }, +})) + +import { vscode } from "@src/utils/vscode" + +import { requestLmStudioModels } from "../useLmStudioModels" + +describe("requestLmStudioModels", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("includes the current unsaved base URL when requesting models", () => { + requestLmStudioModels("http://127.0.0.1:1234") + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "requestLmStudioModels", + values: { baseUrl: "http://127.0.0.1:1234" }, + }) + }) + + it("preserves an empty base URL so the extension can fall back to the default", () => { + requestLmStudioModels("") + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "requestLmStudioModels", + values: { baseUrl: "" }, + }) + }) +}) diff --git a/webview-ui/src/components/ui/hooks/useLmStudioModels.ts b/webview-ui/src/components/ui/hooks/useLmStudioModels.ts index 29f50cb0e8..79e46fe6cf 100644 --- a/webview-ui/src/components/ui/hooks/useLmStudioModels.ts +++ b/webview-ui/src/components/ui/hooks/useLmStudioModels.ts @@ -4,7 +4,13 @@ import { type ModelRecord, type ExtensionMessage } from "@roo-code/types" import { vscode } from "@src/utils/vscode" -const getLmStudioModels = async () => +export const requestLmStudioModels = (baseUrl?: string) => + vscode.postMessage({ + type: "requestLmStudioModels", + values: typeof baseUrl === "string" ? { baseUrl } : undefined, + }) + +const getLmStudioModels = async (baseUrl?: string) => new Promise((resolve, reject) => { const cleanup = () => { window.removeEventListener("message", handler) @@ -31,8 +37,11 @@ const getLmStudioModels = async () => } window.addEventListener("message", handler) - vscode.postMessage({ type: "requestLmStudioModels" }) + requestLmStudioModels(baseUrl) }) export const useLmStudioModels = (modelId?: string) => - useQuery({ queryKey: ["lmStudioModels"], queryFn: () => (modelId ? getLmStudioModels() : {}) }) + useQuery({ + queryKey: ["lmStudioModels"], + queryFn: () => (modelId ? getLmStudioModels() : {}), + }) diff --git a/webview-ui/src/i18n/locales/en/chat.json b/webview-ui/src/i18n/locales/en/chat.json index 2578b939b1..b682954f2b 100644 --- a/webview-ui/src/i18n/locales/en/chat.json +++ b/webview-ui/src/i18n/locales/en/chat.json @@ -351,7 +351,7 @@ "support": "Please support Zoo Code by starring us on GitHub.", "handoff": { "heading": "Roo Code is back! Now as a community-maintained plugin called Zoo Code!!", - "description": "If you haven't been following, the Roo Code team recently announced they are sun setting the development of Roo Code and are archiving the work they have done. But fear not, the community has stepped up to continue the legacy of Roo Code with a new name and a new home! We are not just a single \"Roo\" anymore, we are a community, a \"Zoo\" if you will, Zoo Code is a community-maintained plugin that picks up where Zoo Code left off, and we're committed to keeping the spirit of Roo alive while also introducing new features and improvements. We want to give a huge shoutout to the entire Roo Code team for their incredible work and for creating such an amazing tool for developers. We're excited to continue building on their foundation and to see where the community takes Zoo Code in the future!", + "description": "If you haven't been following, the Roo Code team recently announced they are sun setting the development of Roo Code and are archiving the work they have done. But fear not, the community has stepped up to continue the legacy of Roo Code with a new name and a new home! We are not just a single \"Roo\" anymore, we are a community, a \"Zoo\" if you will, Zoo Code is a community-maintained plugin that picks up where Roo Code left off, and we're committed to keeping the spirit of Roo alive while also introducing new features and improvements. We want to give a huge shoutout to the entire Roo Code team for their incredible work and for creating such an amazing tool for developers. We're excited to continue building on their foundation and to see where the community takes Zoo Code in the future!", "readMore": "See the new home page of Zoo Code and read the full announcement" }, "release": {