diff --git a/src/commands/pr_comments.ts b/src/commands/pr_comments.ts deleted file mode 100644 index dd86cb32..00000000 --- a/src/commands/pr_comments.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { Command } from '@commands' - -export default { - type: 'prompt', - name: 'pr-comments', - description: 'Get comments from a GitHub pull request', - progressMessage: 'fetching PR comments', - isEnabled: true, - isHidden: false, - userFacingName() { - return 'pr-comments' - }, - async getPromptForCommand(args: string) { - return [ - { - role: 'user', - content: [ - { - type: 'text', - text: `You are an AI assistant integrated into a git-based version control system. Your task is to fetch and display comments from a GitHub pull request. - -Follow these steps: - -1. Use \`gh pr view --json number,headRepository\` to get the PR number and repository info -2. Use \`gh api /repos/{owner}/{repo}/issues/{number}/comments\` to get PR-level comments -3. Use \`gh api /repos/{owner}/{repo}/pulls/{number}/comments\` to get review comments. Pay particular attention to the following fields: \`body\`, \`diff_hunk\`, \`path\`, \`line\`, etc. If the comment references some code, consider fetching it using eg \`gh api /repos/{owner}/{repo}/contents/{path}?ref={branch} | jq .content -r | base64 -d\` -4. Parse and format all comments in a readable way -5. Return ONLY the formatted comments, with no additional text - -Format the comments as: - -## Comments - -[For each comment thread:] -- @author file.ts#line: - \`\`\`diff - [diff_hunk from the API response] - \`\`\` - > quoted comment text - - [any replies indented] - -If there are no comments, return "No comments found." - -Remember: -1. Only show the actual comments, no explanatory text -2. Include both PR-level and code review comments -3. Preserve the threading/nesting of comment replies -4. Show the file and line number context for code review comments -5. Use jq to parse the JSON responses from the GitHub API - -${args ? 'Additional user input: ' + args : ''} -`, - }, - ], - }, - ] - }, -} satisfies Command diff --git a/src/commands/refreshCommands.ts b/src/commands/refreshCommands.ts deleted file mode 100644 index b822a64e..00000000 --- a/src/commands/refreshCommands.ts +++ /dev/null @@ -1,54 +0,0 @@ -import { Command } from '@commands' -import { reloadCustomCommands } from '@services/customCommands' -import { getCommands } from '@commands' - -/** - * Refresh Commands - Reload custom commands from filesystem - * - * This command provides a runtime mechanism to refresh the custom commands - * cache without restarting the application. It's particularly useful during - * development or when users are actively creating/modifying custom commands. - * - * The command follows the standard local command pattern used throughout - * the project and provides detailed feedback about the refresh operation. - */ -const refreshCommands = { - type: 'local', - name: 'refresh-commands', - description: 'Reload custom commands from filesystem', - isEnabled: true, - isHidden: false, - async call(_, context) { - try { - // Clear custom commands cache to force filesystem rescan - reloadCustomCommands() - - // Clear the main commands cache to ensure full reload - // This ensures that changes to custom commands are reflected in the main command list - getCommands.cache.clear?.() - - // Reload commands to get updated count and validate the refresh - const commands = await getCommands() - const customCommands = commands.filter( - cmd => cmd.name.startsWith('project:') || cmd.name.startsWith('user:'), - ) - - // Provide detailed feedback about the refresh operation - return `✅ Commands refreshed successfully! - -Custom commands reloaded: ${customCommands.length} -- Project commands: ${customCommands.filter(cmd => cmd.name.startsWith('project:')).length} -- User commands: ${customCommands.filter(cmd => cmd.name.startsWith('user:')).length} - -Use /help to see updated command list.` - } catch (error) { - console.error('Failed to refresh commands:', error) - return '❌ Failed to refresh commands. Check console for details.' - } - }, - userFacingName() { - return 'refresh-commands' - }, -} satisfies Command - -export default refreshCommands diff --git a/src/ui/components/ModelListManager.tsx b/src/ui/components/ModelListManager.tsx index 4699f457..192fd4b1 100644 --- a/src/ui/components/ModelListManager.tsx +++ b/src/ui/components/ModelListManager.tsx @@ -17,6 +17,7 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { const theme = getTheme() const [selectedIndex, setSelectedIndex] = useState(0) const [showModelSelector, setShowModelSelector] = useState(false) + const [editingModelName, setEditingModelName] = useState(null) const [isDeleteMode, setIsDeleteMode] = useState(false) const [refreshKey, setRefreshKey] = useState(0) const exitState = useExitOnCtrlCD(onClose) @@ -66,11 +67,19 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { } const handleAddNewModel = () => { + setEditingModelName(null) + setIsDeleteMode(false) + setShowModelSelector(true) + } + + const handleEditModel = (modelName: string) => { + setEditingModelName(modelName) setShowModelSelector(true) } const handleModelConfigurationComplete = () => { setShowModelSelector(false) + setEditingModelName(null) setRefreshKey(prev => prev + 1) } @@ -82,7 +91,7 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { } else { onClose() } - } else if (input === 'd' && !isDeleteMode && availableModels.length > 1) { + } else if (input === 'd' && !isDeleteMode && availableModels.length > 0) { setIsDeleteMode(true) } else if (key.upArrow) { setSelectedIndex(prev => Math.max(0, prev - 1)) @@ -92,17 +101,11 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { const item = menuItems[selectedIndex] if (isDeleteMode && item.type === 'model') { - if (availableModels.length <= 1) { - setIsDeleteMode(false) - return - } - if (config.modelPointers?.main === item.id) { - setIsDeleteMode(false) - return - } handleDeleteModel(item.id) } else if (item.type === 'action') { handleAddNewModel() + } else if (item.type === 'model') { + handleEditModel(item.id) } } }, @@ -112,12 +115,18 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { useInput(handleInput, { isActive: !showModelSelector }) if (showModelSelector) { + const editingModel = + editingModelName === null + ? undefined + : availableModels.find(model => model.modelName === editingModelName) + return ( ) @@ -140,11 +149,7 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { {isDeleteMode ? ( - availableModels.length <= 1 ? ( - 'Cannot delete the last model, Esc to cancel' - ) : ( - 'Press Enter/Space to DELETE selected model (cannot delete main), Esc to cancel' - ) + 'Press Enter/Space to DELETE selected model, Esc to cancel' ) : ( <> Navigate: ↑↓ | Select: Enter |{' '} @@ -204,16 +209,11 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { )} - {isSelected && - isDeleteMode && - item.type === 'model' && - config.modelPointers?.main === item.id && ( - - - Cannot delete: This model is currently set as main - - - )} + {isSelected && !isDeleteMode && item.type === 'model' && ( + + Edit this model configuration + + )} ) })} @@ -226,20 +226,16 @@ export function ModelListManager({ onClose }: Props): React.ReactNode { > {isDeleteMode ? ( - availableModels.length <= 1 ? ( - 'Cannot delete the last model - press Esc to cancel' - ) : ( - 'DELETE MODE: Press Enter/Space to delete (cannot delete main model), Esc to cancel' - ) - ) : availableModels.length <= 1 ? ( - 'Use ↑/↓ to navigate, Enter to add new, Esc to exit (cannot delete last model)' + 'DELETE MODE: Press Enter/Space to delete. Pointers will be reassigned or cleared.' + ) : availableModels.length === 0 ? ( + 'Use ↑/↓ to navigate, Enter to add new, Esc to exit' ) : ( <> Use ↑/↓ to navigate,{' '} d to delete model - , Enter to add new, Esc to exit + , Enter to add/edit, Esc to exit )} diff --git a/src/ui/components/model-selector/ModelSelector.tsx b/src/ui/components/model-selector/ModelSelector.tsx index b41335db..60488fb1 100644 --- a/src/ui/components/model-selector/ModelSelector.tsx +++ b/src/ui/components/model-selector/ModelSelector.tsx @@ -13,6 +13,7 @@ import { } from '@services/gpt5ConnectionTest' import { getGlobalConfig, + ModelProfile, ModelPointerType, ProviderType, saveGlobalConfig, @@ -105,6 +106,7 @@ type Props = { isOnboarding?: boolean onCancel?: () => void skipModelType?: boolean + initialModelProfile?: ModelProfile } export function ModelSelector({ @@ -114,6 +116,7 @@ export function ModelSelector({ isOnboarding = false, onCancel, skipModelType = false, + initialModelProfile, }: Props): React.ReactNode { const config = getGlobalConfig() const theme = getTheme() @@ -130,6 +133,9 @@ export function ModelSelector({ const exitState = useExitOnCtrlCD(() => process.exit(0)) const getInitialScreen = (): string => { + if (initialModelProfile) { + return 'modelParams' + } return 'provider' } @@ -178,28 +184,65 @@ export function ModelSelector({ } } + const initialMaxTokens = + initialModelProfile?.maxTokens ?? config.maxTokens ?? DEFAULT_MAX_TOKENS + const initialContextLength = + initialModelProfile?.contextLength ?? DEFAULT_CONTEXT_LENGTH + const initialMaxTokensMode = MAX_TOKENS_OPTIONS.some( + option => option.value === initialMaxTokens, + ) + ? 'preset' + : 'custom' + const [selectedProvider, setSelectedProvider] = useState( - config.primaryProvider ?? 'anthropic', + (initialModelProfile?.provider as ProviderType | undefined) ?? + config.primaryProvider ?? + 'anthropic', ) - const [selectedModel, setSelectedModel] = useState('') - const [apiKey, setApiKey] = useState('') + const [selectedModel, setSelectedModel] = useState( + initialModelProfile?.modelName ?? '', + ) + const [apiKey, setApiKey] = useState( + initialModelProfile?.apiKey ?? '', + ) const [maxTokens, setMaxTokens] = useState( - config.maxTokens?.toString() || DEFAULT_MAX_TOKENS.toString(), + initialMaxTokens.toString(), ) const [maxTokensMode, setMaxTokensMode] = useState<'preset' | 'custom'>( - 'preset', + initialMaxTokensMode, ) const [selectedMaxTokensPreset, setSelectedMaxTokensPreset] = - useState(config.maxTokens || DEFAULT_MAX_TOKENS) - const [reasoningEffort, setReasoningEffort] = - useState('medium') + useState(initialMaxTokens) + const [reasoningEffort, setReasoningEffort] = useState( + (initialModelProfile?.reasoningEffort as ReasoningEffortOption) ?? 'medium', + ) const [supportsReasoningEffort, setSupportsReasoningEffort] = - useState(false) + useState(Boolean(initialModelProfile?.reasoningEffort)) + + const [contextLength, setContextLength] = + useState(initialContextLength) - const [contextLength, setContextLength] = useState( - DEFAULT_CONTEXT_LENGTH, + const contextLengthOptions = useMemo(() => { + if (CONTEXT_LENGTH_OPTIONS.some(opt => opt.value === contextLength)) { + return CONTEXT_LENGTH_OPTIONS + } + + return [ + ...CONTEXT_LENGTH_OPTIONS, + { + label: `${contextLength.toLocaleString()} tokens (current)`, + value: contextLength, + }, + ].sort((a, b) => a.value - b.value) + }, [contextLength]) + + const getContextLengthLabel = useCallback( + (value: number) => + contextLengthOptions.find(opt => opt.value === value)?.label || + `${value.toLocaleString()} tokens`, + [contextLengthOptions], ) const [activeFieldIndex, setActiveFieldIndex] = useState(0) @@ -215,7 +258,9 @@ export function ModelSelector({ const [modelSearchCursorOffset, setModelSearchCursorOffset] = useState(0) const [cursorOffset, setCursorOffset] = useState(0) - const [apiKeyEdited, setApiKeyEdited] = useState(false) + const [apiKeyEdited, setApiKeyEdited] = useState( + Boolean(initialModelProfile), + ) const [providerFocusIndex, setProviderFocusIndex] = useState(0) const [partnerProviderFocusIndex, setPartnerProviderFocusIndex] = useState(0) const [codingPlanFocusIndex, setCodingPlanFocusIndex] = useState(0) @@ -236,21 +281,31 @@ export function ModelSelector({ const [resourceName, setResourceName] = useState('') const [resourceNameCursorOffset, setResourceNameCursorOffset] = useState(0) - const [customModelName, setCustomModelName] = useState('') + const [customModelName, setCustomModelName] = useState( + initialModelProfile?.modelName ?? '', + ) const [customModelNameCursorOffset, setCustomModelNameCursorOffset] = useState(0) const [ollamaBaseUrl, setOllamaBaseUrl] = useState( - 'http://localhost:11434/v1', + initialModelProfile?.provider === 'ollama' && initialModelProfile.baseURL + ? initialModelProfile.baseURL + : 'http://localhost:11434/v1', ) const [ollamaBaseUrlCursorOffset, setOllamaBaseUrlCursorOffset] = useState(0) - const [customBaseUrl, setCustomBaseUrl] = useState('') + const [customBaseUrl, setCustomBaseUrl] = useState( + initialModelProfile?.provider === 'custom-openai' + ? initialModelProfile.baseURL || '' + : '', + ) const [customBaseUrlCursorOffset, setCustomBaseUrlCursorOffset] = useState(0) - const [providerBaseUrl, setProviderBaseUrl] = useState('') + const [providerBaseUrl, setProviderBaseUrl] = useState( + initialModelProfile?.baseURL ?? '', + ) const [providerBaseUrlCursorOffset, setProviderBaseUrlCursorOffset] = useState(0) @@ -317,6 +372,10 @@ export function ModelSelector({ }) useEffect(() => { + if (initialModelProfile) { + return + } + if (!apiKeyEdited && selectedProvider) { if (process.env[selectedProvider.toUpperCase() + '_API_KEY']) { setApiKey( @@ -326,16 +385,7 @@ export function ModelSelector({ setApiKey('') } } - }, [selectedProvider, apiKey, apiKeyEdited]) - - useEffect(() => { - if ( - currentScreen === 'contextLength' && - !CONTEXT_LENGTH_OPTIONS.find(opt => opt.value === contextLength) - ) { - setContextLength(DEFAULT_CONTEXT_LENGTH) - } - }, [currentScreen, contextLength]) + }, [selectedProvider, apiKey, apiKeyEdited, initialModelProfile]) const providerReservedLines = 8 + containerPaddingY * 2 + containerGap * 2 const partnerReservedLines = 10 + containerPaddingY * 2 + containerGap * 3 @@ -900,9 +950,6 @@ export function ModelSelector({ } const handleModelParamsSubmit = () => { - if (!CONTEXT_LENGTH_OPTIONS.find(opt => opt.value === contextLength)) { - setContextLength(DEFAULT_CONTEXT_LENGTH) - } navigateTo('contextLength') } @@ -1373,7 +1420,7 @@ export function ModelSelector({ } const handleContextLengthSubmit = () => { - navigateTo('connectionTest') + navigateTo(initialModelProfile ? 'confirmation' : 'connectionTest') } async function saveConfiguration( @@ -1412,7 +1459,7 @@ export function ModelSelector({ reasoningEffort, } - return await modelManager.addModel(modelConfig) + return await modelManager.upsertModel(modelConfig) } catch (error) { setValidationError( error instanceof Error ? error.message : 'Failed to add model', @@ -1430,6 +1477,11 @@ export function ModelSelector({ return } + if (initialModelProfile) { + onDone() + return + } + setModelPointer('main', modelId) if (isOnboarding) { @@ -1688,32 +1740,32 @@ export function ModelSelector({ } if (key.upArrow) { - const currentIndex = CONTEXT_LENGTH_OPTIONS.findIndex( + const currentIndex = contextLengthOptions.findIndex( opt => opt.value === contextLength, ) const newIndex = currentIndex > 0 ? currentIndex - 1 : currentIndex === -1 - ? CONTEXT_LENGTH_OPTIONS.findIndex( + ? contextLengthOptions.findIndex( opt => opt.value === DEFAULT_CONTEXT_LENGTH, ) || 0 - : CONTEXT_LENGTH_OPTIONS.length - 1 - setContextLength(CONTEXT_LENGTH_OPTIONS[newIndex].value) + : contextLengthOptions.length - 1 + setContextLength(contextLengthOptions[newIndex].value) return } if (key.downArrow) { - const currentIndex = CONTEXT_LENGTH_OPTIONS.findIndex( + const currentIndex = contextLengthOptions.findIndex( opt => opt.value === contextLength, ) const newIndex = currentIndex === -1 - ? CONTEXT_LENGTH_OPTIONS.findIndex( + ? contextLengthOptions.findIndex( opt => opt.value === DEFAULT_CONTEXT_LENGTH, ) || 0 - : (currentIndex + 1) % CONTEXT_LENGTH_OPTIONS.length - setContextLength(CONTEXT_LENGTH_OPTIONS[newIndex].value) + : (currentIndex + 1) % contextLengthOptions.length + setContextLength(contextLengthOptions[newIndex].value) return } } @@ -2542,8 +2594,9 @@ export function ModelSelector({ if (currentScreen === 'contextLength') { const selectedOption = - CONTEXT_LENGTH_OPTIONS.find(opt => opt.value === contextLength) || - CONTEXT_LENGTH_OPTIONS[2] + contextLengthOptions.find(opt => opt.value === contextLength) || + contextLengthOptions.find(opt => opt.value === DEFAULT_CONTEXT_LENGTH) || + contextLengthOptions[0] return ( @@ -2572,7 +2625,7 @@ export function ModelSelector({ - {CONTEXT_LENGTH_OPTIONS.map((option, index) => { + {contextLengthOptions.map(option => { const isSelected = option.value === contextLength return ( @@ -2799,9 +2852,7 @@ export function ModelSelector({ Context Length: - {CONTEXT_LENGTH_OPTIONS.find( - opt => opt.value === contextLength, - )?.label || `${contextLength.toLocaleString()} tokens`} + {getContextLengthLabel(contextLength)} diff --git a/src/utils/model/index.ts b/src/utils/model/index.ts index 29dcbf8e..ae7c308b 100644 --- a/src/utils/model/index.ts +++ b/src/utils/model/index.ts @@ -71,6 +71,12 @@ export function getVertexRegionForModel( export class ModelManager { private config: any private modelProfiles: ModelProfile[] + private readonly modelPointers: ModelPointerType[] = [ + 'main', + 'task', + 'compact', + 'quick', + ] constructor(config: any) { this.config = config @@ -544,6 +550,43 @@ export class ModelManager { return config.modelName } + async upsertModel( + config: Omit, + ): Promise { + const existingIndex = this.modelProfiles.findIndex( + p => p.modelName === config.modelName, + ) + + if (existingIndex === -1) { + return this.addModel(config) + } + + const existingByName = this.modelProfiles.find( + p => p.name === config.name && p.modelName !== config.modelName, + ) + if (existingByName) { + throw new Error(`Model with name '${config.name}' already exists`) + } + + const existing = this.modelProfiles[existingIndex] + const updatedModel: ModelProfile = { + ...existing, + ...config, + apiKey: config.apiKey || existing.apiKey, + reasoningEffort: config.reasoningEffort ?? existing.reasoningEffort, + createdAt: existing.createdAt, + lastUsed: existing.lastUsed, + isActive: true, + isGPT5: existing.isGPT5, + validationStatus: existing.validationStatus, + lastValidation: existing.lastValidation, + } + + this.modelProfiles[existingIndex] = updatedModel + this.saveConfig() + return config.modelName + } + setPointer(pointer: ModelPointerType, modelName: string): void { if (!this.findModelProfile(modelName)) { throw new Error(`Model '${modelName}' not found`) @@ -616,17 +659,32 @@ export class ModelManager { p => p.modelName !== modelName, ) - if (this.config.modelPointers) { - Object.keys(this.config.modelPointers).forEach(pointer => { - if ( - this.config.modelPointers[pointer as ModelPointerType] === modelName - ) { - this.config.modelPointers[pointer as ModelPointerType] = - this.config.defaultModelName || '' - } - }) + if (!this.config.modelPointers) { + this.config.modelPointers = { + main: '', + task: '', + compact: '', + quick: '', + } + } + + const fallbackModelName = + this.modelProfiles.find(p => p.isActive)?.modelName || '' + + for (const pointer of this.modelPointers) { + const currentModelName = this.config.modelPointers[pointer] + const pointsToDeletedModel = currentModelName === modelName + const pointsToMissingModel = + currentModelName && !this.findModelProfile(currentModelName) + + if (!fallbackModelName) { + this.config.modelPointers[pointer] = '' + } else if (pointsToDeletedModel || pointsToMissingModel) { + this.config.modelPointers[pointer] = fallbackModelName + } } + this.config.defaultModelName = fallbackModelName this.saveConfig() } @@ -642,6 +700,7 @@ export class ModelManager { private saveConfig(): void { const startedAt = Date.now() + this.config.modelProfiles = this.modelProfiles const updatedConfig = { ...this.config, modelProfiles: this.modelProfiles, diff --git a/tests/unit/model-manager-switching.test.ts b/tests/unit/model-manager-switching.test.ts index ff7fe5b2..e8bbb1ec 100644 --- a/tests/unit/model-manager-switching.test.ts +++ b/tests/unit/model-manager-switching.test.ts @@ -150,4 +150,133 @@ describe('ModelManager model switching', () => { expect(config.modelPointers.main).toBe(modelA.modelName) expect(result.message).toContain('Keeping') }) + + test('upsertModel updates existing model parameters and preserves metadata', async () => { + const modelA = makeProfile({ + name: 'Model A', + modelName: 'model-a', + apiKey: 'existing-key', + maxTokens: 1024, + contextLength: 128_000, + reasoningEffort: 'medium', + createdAt: 1, + lastUsed: 2, + isGPT5: true, + validationStatus: 'valid', + lastValidation: 3, + }) + + const config: any = { + modelProfiles: [modelA], + modelPointers: { + main: modelA.modelName, + task: modelA.modelName, + compact: modelA.modelName, + quick: modelA.modelName, + }, + defaultModelName: modelA.modelName, + } + + const manager = new ModelManager(config) + const modelId = await manager.upsertModel({ + name: 'Model A Updated', + provider: 'openai', + modelName: modelA.modelName, + baseURL: 'https://example.com/v1', + apiKey: '', + maxTokens: 8192, + contextLength: 256_000, + reasoningEffort: 'high', + }) + + expect(modelId).toBe(modelA.modelName) + expect(manager.getAllConfiguredModels()).toHaveLength(1) + + const updated = manager.getAllConfiguredModels()[0] + expect(updated.name).toBe('Model A Updated') + expect(updated.baseURL).toBe('https://example.com/v1') + expect(updated.apiKey).toBe('existing-key') + expect(updated.maxTokens).toBe(8192) + expect(updated.contextLength).toBe(256_000) + expect(updated.reasoningEffort).toBe('high') + expect(updated.createdAt).toBe(1) + expect(updated.lastUsed).toBe(2) + expect(updated.isActive).toBe(true) + expect(updated.isGPT5).toBe(true) + expect(updated.validationStatus).toBe('valid') + expect(updated.lastValidation).toBe(3) + }) + + test('removeModel clears pointers and default when deleting the last model', () => { + const modelA = makeProfile({ + name: 'Model A', + modelName: 'model-a', + contextLength: 128_000, + createdAt: 1, + }) + + const config: any = { + modelProfiles: [modelA], + modelPointers: { + main: modelA.modelName, + task: modelA.modelName, + compact: modelA.modelName, + quick: modelA.modelName, + }, + defaultModelName: modelA.modelName, + } + + const manager = new ModelManager(config) + manager.removeModel(modelA.modelName) + + expect(manager.getAllConfiguredModels()).toEqual([]) + expect(config.modelProfiles).toEqual([]) + expect(config.modelPointers).toEqual({ + main: '', + task: '', + compact: '', + quick: '', + }) + expect(config.defaultModelName).toBe('') + }) + + test('removeModel reassigns pointers when deleting the main model', () => { + const modelA = makeProfile({ + name: 'Model A', + modelName: 'model-a', + contextLength: 128_000, + createdAt: 1, + }) + const modelB = makeProfile({ + name: 'Model B', + modelName: 'model-b', + contextLength: 256_000, + createdAt: 2, + }) + + const config: any = { + modelProfiles: [modelA, modelB], + modelPointers: { + main: modelA.modelName, + task: modelA.modelName, + compact: modelA.modelName, + quick: modelA.modelName, + }, + defaultModelName: modelA.modelName, + } + + const manager = new ModelManager(config) + manager.removeModel(modelA.modelName) + + expect( + manager.getAllConfiguredModels().map(model => model.modelName), + ).toEqual([modelB.modelName]) + expect(config.modelPointers).toEqual({ + main: modelB.modelName, + task: modelB.modelName, + compact: modelB.modelName, + quick: modelB.modelName, + }) + expect(config.defaultModelName).toBe(modelB.modelName) + }) }) diff --git a/tests/unit/model-selector.test.tsx b/tests/unit/model-selector.test.tsx index 36ebe28e..20cd8030 100644 --- a/tests/unit/model-selector.test.tsx +++ b/tests/unit/model-selector.test.tsx @@ -1,11 +1,21 @@ -import { afterEach, describe, expect, test } from 'bun:test' +import { + afterAll, + afterEach, + beforeAll, + describe, + expect, + test, +} from 'bun:test' import React, { useState } from 'react' import { PassThrough } from 'stream' import stripAnsi from 'strip-ansi' import { Box, Text, render } from 'ink' import { buildModelOptions } from '@components/model-selector/filterModels' import { ModelSelectionScreen } from '@components/model-selector/ModelSelectionScreen' +import { ModelListManager } from '@components/ModelListManager' import { getTheme } from '@utils/theme' +import { getGlobalConfig, saveGlobalConfig } from '@utils/config' +import { getModelManager, reloadModelManager } from '@utils/model' type InkTestHarness = { stdin: PassThrough & { @@ -58,6 +68,42 @@ function createInkTestHarness(element: React.ReactElement): InkTestHarness { } const mounted: InkTestHarness[] = [] +const originalNodeEnv = process.env.NODE_ENV + +function configureModelProfiles(modelProfiles: any[]) { + const firstModelName = modelProfiles[0]?.modelName || '' + saveGlobalConfig({ + ...(getGlobalConfig() as any), + modelProfiles, + modelPointers: { + main: firstModelName, + task: firstModelName, + compact: firstModelName, + quick: firstModelName, + }, + defaultModelName: firstModelName, + } as any) + reloadModelManager() +} + +function makeModelProfile(overrides: Record = {}) { + return { + name: 'Model A', + provider: 'openai', + modelName: 'model-a', + baseURL: 'https://example.com/v1', + apiKey: 'test-key', + maxTokens: 1024, + contextLength: 300000, + isActive: true, + createdAt: 1, + ...overrides, + } +} + +beforeAll(() => { + process.env.NODE_ENV = 'test' +}) afterEach(() => { while (mounted.length > 0) { @@ -65,6 +111,15 @@ afterEach(() => { mounted.pop()!.unmount() } catch {} } + configureModelProfiles([]) +}) + +afterAll(() => { + if (originalNodeEnv === undefined) { + delete process.env.NODE_ENV + return + } + process.env.NODE_ENV = originalNodeEnv }) describe('ModelSelector modularization', () => { @@ -126,4 +181,47 @@ describe('ModelSelector modularization', () => { await h.wait(50) expect(h.getOutput()).toContain('SELECTED:foo') }) + + test('ModelListManager can delete the only configured model', async () => { + configureModelProfiles([makeModelProfile()]) + + const h = createInkTestHarness( {}} />) + mounted.push(h) + + await h.wait(100) + h.stdin.write('\u001B[B') + await h.wait(20) + h.stdin.write('d') + await h.wait(20) + + expect(h.getOutput()).toContain('DELETE MODE') + + h.stdin.write('\r') + await h.wait(50) + + expect(getModelManager().getAllConfiguredModels()).toEqual([]) + expect(getGlobalConfig().modelPointers).toEqual({ + main: '', + task: '', + compact: '', + quick: '', + }) + }) + + test('ModelListManager opens existing model in edit mode', async () => { + configureModelProfiles([makeModelProfile()]) + + const h = createInkTestHarness( {}} />) + mounted.push(h) + + await h.wait(100) + h.stdin.write('\u001B[B') + await h.wait(20) + h.stdin.write('\r') + await h.wait(50) + + expect(h.getOutput()).toContain('Model Parameters') + expect(h.getOutput()).toContain('Configure parameters for model-a') + expect(h.getOutput()).toContain('1K tokens') + }) })