diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index dd58d11b..7a63cd78 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -1342,6 +1342,44 @@ def proxy_image(): except Exception as e: return jsonify({"error": f"Error fetching image: {str(e)}"}), 500 +@app.route('/api/getModelsDotDev', methods=['GET']) +def get_models_dot_dev(): + """ + Called at Initialization/Creation of a Prompt Node + + Fetches the list of models from the models.dev API. + Returns a JSON response with the model names and their details. + and save the json file to the local disk in the FLOWS_DIR. + """ + MODELS_DOT_DEV_FILE = os.path.join(FLOWS_DIR, "models_dot_dev.json") + + # If the models.dev file already exists, return its content + if os.path.isfile(MODELS_DOT_DEV_FILE): + try: + with open(MODELS_DOT_DEV_FILE, 'r', encoding='utf-8') as f: + models_data = json.load(f) + print(f"Returning cached models from {MODELS_DOT_DEV_FILE}") + return jsonify(models_data) + except Exception as e: + return jsonify({"error": f"Error reading cached models file: {str(e)}"}), 500 + + # If the file does not exist, fetch it from the API + try: + # Fetch the models from the API + response = py_requests.get("https://models.dev/api.json") + if response.status_code != 200: + return jsonify({"error": f"Failed to fetch models: {response.status_code} {response.reason}"}), response.status_code + + # Parse the JSON response + models_data = response.json() + + with open(MODELS_DOT_DEV_FILE, 'w', encoding='utf-8') as f: + json.dump(models_data, f, indent=2, ensure_ascii=False) + + return jsonify(models_data) + + except Exception as e: + return jsonify({"error": f"Error fetching models: {str(e)}"}), 500 """ SPIN UP SERVER diff --git a/chainforge/react-server/src/GlobalSettingsModal.tsx b/chainforge/react-server/src/GlobalSettingsModal.tsx index 3ad75656..66df48bf 100644 --- a/chainforge/react-server/src/GlobalSettingsModal.tsx +++ b/chainforge/react-server/src/GlobalSettingsModal.tsx @@ -43,6 +43,7 @@ import { Dict, JSONCompatible, LLMSpec, + ModelOllama, } from "./backend/typing"; import { getGlobalConfig, @@ -302,10 +303,54 @@ const GlobalSettingsModal = forwardRef( console.log("Ollama models available:", models_available); console.log("Loaded Ollama model list from backend."); + + // Populate model informations for Ollama models + // This is used to display the model information in the UI Prompt Node + const modelsInfos: Record = {}; + data.models.forEach((model_obj: Dict) => { + modelsInfos[model_obj.name] = { + name: model_obj.name, + format: model_obj.details.format, + families: model_obj.details.families, + parameter_size: model_obj.details.parameter_size, + quantization_level: model_obj.details.quantization_level, + size: (model_obj.size / 1000 ** 3).toFixed(2) + " GB", + }; + + fetch( + new Request(`${Ollama_BaseURL}/api/show`, { + method: "POST", + body: '{"model": "' + model_obj.name + '"}', + }), + ) + .then((response) => { + if (!response.ok) { + throw new Error( + `Failed to fetch model details on ${Ollama_BaseURL}/api/show`, + ); + } + return response.json(); + }) + .then((modelDetails: Dict) => { + modelsInfos[model_obj.name].capabilities = + modelDetails.capabilities; + }); + }); + LLMsProvidersInfos.ollama = { + id: "ollama", + name: "Ollama", + env: [], + npm: "", + doc: "https://ollama.com/search", + models: modelsInfos, + }; + setLLMsProvidersInfos(LLMsProvidersInfos); }) .catch((error) => { console.error("Error trying to fetch Ollama models", error); }); + + fetch(`${Ollama_BaseURL}/api/`); }); }, [form, settings]); @@ -362,6 +407,10 @@ const GlobalSettingsModal = forwardRef( const setFavorites = useStore((state) => state.setFavorites); const nodes = useStore((state) => state.nodes); const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + const LLMsProvidersInfos = useStore((state) => state.LLMsProvidersInfos); + const setLLMsProvidersInfos = useStore( + (state) => state.setLLMsProvidersInfos, + ); const showAlert = useContext(AlertModalContext); diff --git a/chainforge/react-server/src/LLMListComponent.tsx b/chainforge/react-server/src/LLMListComponent.tsx index 105980fa..dd019428 100644 --- a/chainforge/react-server/src/LLMListComponent.tsx +++ b/chainforge/react-server/src/LLMListComponent.tsx @@ -27,7 +27,11 @@ import { getDefaultModelSettings } from "./ModelSettingSchemas"; import useStore, { initLLMProviders, initLLMProviderMenu } from "./store"; import { Dict, JSONCompatible, LLMGroup, LLMSpec } from "./backend/typing"; import { ContextMenuItemOptions } from "mantine-contextmenu/dist/types"; -import { deepcopy, ensureUniqueName } from "./backend/utils"; +import { + deepcopy, + ensureUniqueName, + getModelsDotDevInfos, +} from "./backend/utils"; import NestedMenu, { NestedMenuItemProps } from "./NestedMenu"; // The LLM(s) to include by default on a PromptNode whenever one is created. @@ -299,6 +303,9 @@ export const LLMListContainer = forwardRef< forceUpdate(); }; + // Get the LLMs infos from the store, which is fetched from `models.dev` website. + const LLMsInfos = useStore((state) => state.LLMsProvidersInfos); + // Selecting LLM models to prompt const [llmItems, setLLMItems] = useState( initLLMItems || @@ -453,7 +460,7 @@ export const LLMListContainer = forwardRef< }; } else { initModels.add(item.base_model); - return { + const res: NestedMenuItemProps = { key: item.key ?? item.model, title: `${item.emoji} ${item.name}`, onClick: () => handleSelectModel(item), @@ -466,6 +473,11 @@ export const LLMListContainer = forwardRef< } : undefined, }; + // if `LLMsInfos` is not empty dict, add tooltip + if (LLMsInfos && Object.keys(LLMsInfos).length > 0) { + res.tooltip = getModelsDotDevInfos(item, LLMsInfos); + } + return res; } }; const res = initLLMProviderMenu.map((i) => convert(i)); diff --git a/chainforge/react-server/src/NestedMenu.tsx b/chainforge/react-server/src/NestedMenu.tsx index f52e468a..79889a0d 100644 --- a/chainforge/react-server/src/NestedMenu.tsx +++ b/chainforge/react-server/src/NestedMenu.tsx @@ -2,6 +2,7 @@ import React, { ReactNode, useMemo, useState } from "react"; import { Menu, Tooltip, Popover, ActionIcon } from "@mantine/core"; import { IconChevronRight, IconTrash } from "@tabler/icons-react"; import { ContextMenuItemOptions } from "mantine-contextmenu/dist/types"; +import ReactMarkdown from "react-markdown"; const NESTED_MENU_STYLE = { dropdown: { padding: "0px !important" }, @@ -19,9 +20,9 @@ export const MenuTooltip = ({ else return ( {label}} position="right" - width={200} + width={400} multiline withArrow arrowSize={10} diff --git a/chainforge/react-server/src/PromptNode.tsx b/chainforge/react-server/src/PromptNode.tsx index c5781c4b..43e96a85 100644 --- a/chainforge/react-server/src/PromptNode.tsx +++ b/chainforge/react-server/src/PromptNode.tsx @@ -60,6 +60,7 @@ import { truncStr, genDebounceFunc, ensureUniqueName, + FLASK_BASE_URL, } from "./backend/utils"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; import CancelTracker from "./backend/canceler"; @@ -370,6 +371,10 @@ const PromptNode: React.FC = ({ // API Keys (set by user in popup GlobalSettingsModal) const apiKeys = useStore((state) => state.apiKeys); + const LLMsProvidersInfos = useStore((state) => state.LLMsProvidersInfos); + const setLLMsProvidersInfos = useStore( + (state) => state.setLLMsProvidersInfos, + ); const [jsonResponses, setJSONResponses] = useState( null, ); @@ -578,6 +583,24 @@ const PromptNode: React.FC = ({ // On initialization useEffect(() => { + // Fetch models from the models.dev API + fetch(`${FLASK_BASE_URL}api/getModelsDotDev`) + .then((res) => { + return res.json(); + }) + .then((data) => { + console.log("Fetched models from models.dev:", data); + if (data && !data.error) { + setLLMsProvidersInfos({ + ...data, + ...LLMsProvidersInfos, + }); + } + }) + .catch((err) => { + console.error(err); + }); + refreshTemplateHooks(promptText); // Attempt to grab cache'd responses diff --git a/chainforge/react-server/src/backend/typing.ts b/chainforge/react-server/src/backend/typing.ts index 2559782d..8db08d13 100644 --- a/chainforge/react-server/src/backend/typing.ts +++ b/chainforge/react-server/src/backend/typing.ts @@ -335,3 +335,70 @@ export type RatingDict = Record; export interface FileWithContent extends FileWithPath { content?: string; } + +// Typing for the LLM provider dictionnary fecthed from the models.dev website +// that is stored in the `LLMsInfos` variable in the Zustand store + +export type Modality = "text" | "image" | "audio" | "video" | "pdf"; + +export interface Modalities { + input: Modality[]; + output: Modality[]; +} + +export interface Cost { + input: number; + output: number; + cache_read?: number; + cache_write?: number; +} + +export interface Limit { + context: number; + output: number; +} + +export interface Model { + id: string; + name: string; + attachment: boolean; + reasoning: boolean; + temperature: boolean; + tool_call: boolean; + knowledge?: string; + release_date: string; + last_updated: string; + modalities: Modalities; + open_weights: boolean; + cost?: Cost; + limit: Limit; +} + +export type CapabilityOllama = + | "completion" + | "tools" + | "thinking" + | "embedding" + | "vision"; + +export interface ModelOllama { + name: string; + format: string; + families: string[]; + parameter_size: string; + quantization_level: string; + size: string; + capabilities?: CapabilityOllama[]; +} + +export interface LLMProvider { + id: string; + env: string[]; + npm: string; + name: string; + doc: string; + api?: string; + models: Record; +} + +export type ModelDotDevInfos = Record; diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 887f8b22..3cbab8e3 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -32,6 +32,11 @@ import { MultiModalContentOpenAI, MultiModalContentGemini, PromptVarType, + ModelDotDevInfos, + Model, + Modality, + ModelOllama, + CapabilityOllama, } from "./typing"; import { v4 as uuid } from "uuid"; import { StringTemplate } from "./template"; @@ -267,13 +272,24 @@ async function imagesToBase64(images: string[]) { const base64_images: Array = []; for (const image of images) { const imageBlob = await MediaLookup.get(image); + if (!imageBlob) { // This should never happen, but just in case: console.error(`Image not found in MediaLookup: ${image}`); continue; } - const base64_image = await blobOrFileToDataURL(imageBlob); - if (base64_image) base64_images.push(base64_image); + + // Check if the blob is of type image/jpeg, if not, create a new blob with the correct type + let processedBlob = imageBlob; + if (!imageBlob.type.startsWith("image/")) { + console.log(`Converting blob from ${imageBlob.type} to image/png`); + processedBlob = new Blob([imageBlob], { type: "image/png" }); + } + + const base64_image = await blobOrFileToDataURL(processedBlob); + if (base64_image) { + base64_images.push(base64_image); + } } return base64_images; } @@ -2928,3 +2944,194 @@ export function hashtagTemplateVars(input: string): string { return result; } + +/** + * Given a LLM item and a dictionary of modelsDotDevInfos, returns the + * corresponding ModelDotDevInfo object for the LLM's model, or null if not found. + * @param llm_item The LLM item to look up. + * @param modelsDotDevInfos The dictionary of ModelDotDevInfos. + * @returns The ModelDotDevInfo object for the LLM's model, or null if not found. + */ +export function getModelsDotDevInfos( + llm_item: LLMSpec, + modelsDotDevInfos: ModelDotDevInfos, +): string { + // Determine the provider for the LLM item based on its base_model + let providerModelsDotDev: string | null = null; + + if ( + llm_item.base_model.startsWith("gpt-") || + llm_item.base_model.startsWith("dall-e") + ) { + providerModelsDotDev = "openai"; + } else if (llm_item.base_model.startsWith("claude-")) { + providerModelsDotDev = "anthropic"; + } else if (llm_item.base_model.startsWith("palm2-bison")) { + providerModelsDotDev = "google"; + } else if (llm_item.base_model.startsWith("deepseek")) { + providerModelsDotDev = "deepseek"; + } else if (llm_item.base_model.startsWith("br.")) { + providerModelsDotDev = "amazon-bedrock"; + } else if (llm_item.base_model.startsWith("ollama")) { + providerModelsDotDev = "ollama"; + llm_item.model = llm_item.name; + } + + if ( + providerModelsDotDev === null || + modelsDotDevInfos[providerModelsDotDev] === undefined + ) { + // If we don't know the provider, or the provider is not in models.dev, return + console.log( + `The models.dev dict info does not contain any PROVIDER info for LLM ${[llm_item.base_model, llm_item.model]}.`, + ); + return `No provider in [models.dev](https://models.dev) for LLM **${llm_item.model} (${llm_item.base_model})**`; + } + + // Anthropic uses aliases for their model names, so we need to map them. + if (providerModelsDotDev === "anthropic") { + const anthropic_alias_map: { [key: string]: string } = { + "claude-opus-4-0": "claude-opus-4-20250514", + "claude-sonnet-4-0": "claude-sonnet-4-20250514", + "claude-3-7-sonnet-latest": "claude-3-7-sonnet-20250219", + "claude-3-5-sonnet-latest": "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-latest": "claude-3-5-haiku-20241022", + "claude-3-opus-latest": "claude-3-opus-20240229", + }; + if (llm_item.model in anthropic_alias_map) { + llm_item.model = anthropic_alias_map[llm_item.model]; + } + } + // At this point, we have the providerModelsDotDev filled in. + // Now we make sure that the provider contains info the model `llm_item.model`. + if (!modelsDotDevInfos[providerModelsDotDev].models[llm_item.model]) { + console.log( + `The models.dev dict info does not contain any MODEL info for LLM ${[ + llm_item.base_model, + llm_item.model, + ]} for PROVIDER ${providerModelsDotDev}.`, + ); + return `For provider **${providerModelsDotDev}** in [models.dev](https://models.dev), no model info for LLM **${llm_item.model} (${llm_item.base_model})**`; + } + + return prettifyModelInfo( + modelsDotDevInfos[providerModelsDotDev].models[llm_item.model], + providerModelsDotDev === "ollama", + ); +} + +/** + * Converts a ModelDotDevInfos.models.Model object to a string representation. + * in order to be displayed in a tooltip and with some great formatting (like emojis). + * + * @param modelInfo The ModelDotDevInfos.models.Model object to convert. + * @returns A string representation of the model info. + */ +export function prettifyModelInfo( + modelInfo: Model | ModelOllama, + is_ollama: boolean, +): string { + if (!is_ollama) { + const { + name, + id, + attachment, + reasoning, + temperature, + tool_call, + knowledge, + release_date, + last_updated, + modalities, + open_weights, + cost, + limit, + } = modelInfo as Model; + + const modalityToEmoji: { [key in Modality]: string } = { + text: "๐Ÿ“", + image: "๐Ÿ–ผ๏ธ", + audio: "๐Ÿ”Š", + video: "๐ŸŽฅ", + pdf: "๐Ÿ“„", + }; + + const formatModalities = (m: Modality[]) => + m.map((mod) => `${modalityToEmoji[mod]} ${mod}`).join(", ") || "N/A"; + + let result = `**${name}** (${id})\n\n`; + + result += `**Features**\n`; + result += `- ๐Ÿง  Reasoning: ${reasoning ? "โœ…" : "โŒ"}\n`; + result += `- ๐Ÿ“Ž Attachments: ${attachment ? "โœ…" : "โŒ"}\n`; + result += `- ๐ŸŒก๏ธ Temperature Control: ${temperature ? "โœ…" : "โŒ"}\n`; + result += `- ๐Ÿ› ๏ธ Tool Call: ${tool_call ? "โœ…" : "โŒ"}\n`; + result += `- โš–๏ธ Open Weights: ${open_weights ? "โœ…" : "โŒ"}\n\n`; + + result += `**Modalities**\n`; + result += `- ๐Ÿ“ฅ Input: ${formatModalities(modalities.input)}\n`; + result += `- ๐Ÿ“ค Output: ${formatModalities(modalities.output)}\n\n`; + + result += `**Limits**\n`; + result += `- ๐Ÿ”„ Context Window: ${limit.context.toLocaleString()} tokens\n`; + result += `- โžก๏ธ Max Output: ${limit.output.toLocaleString()} tokens\n\n`; + + if (knowledge) { + result += `**Knowledge Cutoff**\n`; + result += `- ๐Ÿ“… ${knowledge}\n\n`; + } + + result += `**Dates**\n`; + result += `- ๐Ÿš€ Release: ${release_date}\n`; + result += `- ๐Ÿ”„ Last Updated: ${last_updated}\n\n`; + + if (cost) { + result += `**Cost (per 1M tokens, ~750,000 words)**\n`; + result += `- ๐Ÿ’ต Input: $${cost.input.toFixed(2)}\n`; + result += `- ๐Ÿ’ต Output: $${cost.output.toFixed(2)}\n`; + } + + return result.trim(); + } else { + const { + name, + format, + families, + parameter_size, + quantization_level, + size, + capabilities, + } = modelInfo as ModelOllama; + + let result = `**${name}**\n\n`; + + result += `**๐Ÿ“ Details**\n`; + result += `- ๐Ÿ’พ Size: ${size}\n`; + result += `- ๐Ÿ“ฆ Format: ${format}\n`; + if (families && families.length > 0) { + result += `- ๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ Families: ${families.join(", ")}\n`; + } + result += `\n`; + + result += `**โš™๏ธ Model Specs**\n`; + result += `- ๐Ÿง  Parameters: ${parameter_size}\n`; + result += `- ๐Ÿ“‰ Quantization: ${quantization_level}\n\n`; + + if (capabilities && capabilities.length > 0) { + const capabilityToEmoji: { [key in CapabilityOllama]: string } = { + completion: "โœ๏ธ", + tools: "๐Ÿ› ๏ธ", + thinking: "๐Ÿค”", + embedding: "๐Ÿ”—", + vision: "๐Ÿ‘๏ธ", + }; + const formatCapabilities = (caps: CapabilityOllama[]) => + caps.map((cap) => `${capabilityToEmoji[cap]} ${cap}`).join(", "); + + result += `**โœจ Capabilities**\n`; + result += `- ${formatCapabilities(capabilities)}\n`; + } + + return result.trim(); + } +} diff --git a/chainforge/react-server/src/store.tsx b/chainforge/react-server/src/store.tsx index bb381416..22eb094f 100644 --- a/chainforge/react-server/src/store.tsx +++ b/chainforge/react-server/src/store.tsx @@ -29,6 +29,7 @@ import { TabularDataRowType, JSONCompatible, LLMResponse, + ModelDotDevInfos, } from "./backend/typing"; import { TogetherChatSettings } from "./ModelSettingSchemas"; import { NativeLLM } from "./backend/models"; @@ -503,6 +504,11 @@ export interface StoreHandles { _targetHandles: string[], node_id: string, ) => Dict; + + // Store infos about LLMs providers models + // fetched from `models.dev` website. + LLMsProvidersInfos: ModelDotDevInfos; + setLLMsProvidersInfos: (infos: ModelDotDevInfos) => void; } // A global store of variables, used for maintaining state @@ -511,6 +517,12 @@ const useStore = create((set, get) => ({ nodes: [], edges: [], + // models.dev LLM providers infos + LLMsProvidersInfos: {}, + setLLMsProvidersInfos: (infos) => { + set({ LLMsProvidersInfos: infos }); + }, + // Available LLMs in ChainForge, in the format expected by LLMListItems. AvailableLLMs: [...initLLMProviders], setAvailableLLMs: (llmProviderList) => {