From 3d9c82352dd31595aa90d892bc03d9dbf2c3e0c9 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Wed, 12 Mar 2025 22:35:02 -0400 Subject: [PATCH 01/35] Add prompt variants feature to Prompt Node --- .../react-server/src/LLMListComponent.tsx | 20 +- chainforge/react-server/src/PromptNode.tsx | 236 +++++++++++++++--- .../react-server/src/backend/backend.ts | 40 +-- chainforge/react-server/src/backend/utils.ts | 49 ++++ 4 files changed, 273 insertions(+), 72 deletions(-) diff --git a/chainforge/react-server/src/LLMListComponent.tsx b/chainforge/react-server/src/LLMListComponent.tsx index 900870fd1..4aac22daa 100644 --- a/chainforge/react-server/src/LLMListComponent.tsx +++ b/chainforge/react-server/src/LLMListComponent.tsx @@ -31,31 +31,13 @@ import useStore, { initLLMProviders, initLLMProviderMenu } from "./store"; import { Dict, JSONCompatible, LLMGroup, LLMSpec } from "./backend/typing"; import { useContextMenu } from "mantine-contextmenu"; import { ContextMenuItemOptions } from "mantine-contextmenu/dist/types"; +import { ensureUniqueName } from "./backend/utils"; // The LLM(s) to include by default on a PromptNode whenever one is created. // Defaults to ChatGPT (GPT3.5) when running locally, and HF-hosted falcon-7b for online version since it's free. const DEFAULT_INIT_LLMS = [initLLMProviders[0]]; // Helper funcs -// Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)") -const ensureUniqueName = (_name: string, _prev_names: string[]) => { - // Strip whitespace around names - const prev_names = _prev_names.map((n) => n.trim()); - const name = _name.trim(); - - // Check if name is unique - if (!prev_names.includes(name)) return name; - - // Name isn't unique; find a unique one: - let i = 2; - let new_name = `${name} (${i})`; - while (prev_names.includes(new_name)) { - i += 1; - new_name = `${name} (${i})`; - } - return new_name; -}; - /** Get position CSS style below and left-aligned to the input element */ const getPositionCSSStyle = ( elem: HTMLButtonElement, diff --git a/chainforge/react-server/src/PromptNode.tsx b/chainforge/react-server/src/PromptNode.tsx index 93af5a8dd..194f1e8d2 100644 --- a/chainforge/react-server/src/PromptNode.tsx +++ b/chainforge/react-server/src/PromptNode.tsx @@ -18,9 +18,20 @@ import { Modal, Box, Tooltip, + Group, + Flex, + Button, + ActionIcon, } from "@mantine/core"; import { useDisclosure } from "@mantine/hooks"; -import { IconEraser, IconList } from "@tabler/icons-react"; +import { + IconArrowLeft, + IconArrowRight, + IconEraser, + IconList, + IconPlus, + IconTrash, +} from "@tabler/icons-react"; import useStore from "./store"; import BaseNode from "./BaseNode"; import NodeLabel from "./NodeLabelComponent"; @@ -41,6 +52,7 @@ import { extractSettingsVars, truncStr, genDebounceFunc, + ensureUniqueName, } from "./backend/utils"; import LLMResponseInspectorDrawer from "./LLMResponseInspectorDrawer"; import CancelTracker from "./backend/canceler"; @@ -64,6 +76,7 @@ import { queryLLM, } from "./backend/backend"; import { StringLookup } from "./backend/cache"; +import { union } from "./backend/setUtils"; const getUniqueLLMMetavarKey = (responses: LLMResponse[]) => { const metakeys = new Set( @@ -221,6 +234,7 @@ export interface PromptNodeProps { contChat: boolean; refresh: boolean; refreshLLMList: boolean; + idxPromptVariantShown?: number; }; id: string; type: string; @@ -257,10 +271,15 @@ const PromptNode: React.FC = ({ null, ); const [templateVars, setTemplateVars] = useState(data.vars ?? []); - const [promptText, setPromptText] = useState(data.prompt ?? ""); - const [promptTextOnLastRun, setPromptTextOnLastRun] = useState( - null, + const [promptText, setPromptText] = useState( + data.prompt ?? "", + ); + const [idxPromptVariantShown, setIdxPromptVariantShown] = useState( + data.idxPromptVariantShown ?? 0, ); + const [promptTextOnLastRun, setPromptTextOnLastRun] = useState< + string | string[] | null + >(null); const [status, setStatus] = useState(Status.NONE); const [numGenerations, setNumGenerations] = useState(data.n ?? 1); const [numGenerationsLastRun, setNumGenerationsLastRun] = useState( @@ -391,10 +410,17 @@ const PromptNode: React.FC = ({ }, [templateVars, id, pullInputData, updateShowContToggle]); const refreshTemplateHooks = useCallback( - (text: string) => { - // Update template var fields + handles - const found_template_vars = new Set(extractBracketedSubstrings(text)); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this} + (text: string | string[]) => { + const texts = typeof text === "string" ? [text] : text; + + // Get all template vars in the prompt(s) + let found_template_vars = new Set(); + for (const t of texts) { + const substrs = extractBracketedSubstrings(t); // gets all strs within braces {} that aren't escaped; e.g., ignores \{this\} but captures {this} + found_template_vars = union(found_template_vars, new Set(substrs)); + } + // Update template var fields + handles if (!setsAreEqual(found_template_vars, new Set(templateVars))) { if (node_type !== "chat") { try { @@ -413,27 +439,29 @@ const PromptNode: React.FC = ({ const handleInputChange = useCallback( (event: React.ChangeEvent) => { - const value = event.target.value; + const value = event.target.value as string; const updateStatus = promptTextOnLastRun !== null && status !== Status.WARNING && value !== promptTextOnLastRun; - // Store prompt text - data.prompt = value; - // Debounce the global state change to happen only after 500ms, as it forces a costly rerender: - debounce((_value, _updateStatus) => { - setPromptText(_value); - setDataPropsForNode(id, { prompt: _value }); - refreshTemplateHooks(_value); + debounce((_value: string, _updateStatus, _idxPromptVariantShown) => { + setPromptText((prompts) => { + if (typeof prompts === "string") prompts = _value; + else prompts[_idxPromptVariantShown] = _value; + setDataPropsForNode(id, { prompt: prompts }); + refreshTemplateHooks(prompts); + return prompts; + }); if (_updateStatus) setStatus(Status.WARNING); - }, 300)(value, updateStatus); + }, 300)(value, updateStatus, idxPromptVariantShown); // Debounce refreshing the template hooks so we don't annoy the user // debounce((_value) => refreshTemplateHooks(_value), 500)(value); }, [ + idxPromptVariantShown, promptTextOnLastRun, status, refreshTemplateHooks, @@ -552,7 +580,7 @@ const PromptNode: React.FC = ({ // Ask the backend how many responses it needs to collect, given the input data: const fetchResponseCounts = useCallback( ( - prompt: string, + prompt: string | string[], vars: Dict, llms: (StringOrHash | LLMSpec)[], chat_histories?: @@ -592,14 +620,24 @@ const PromptNode: React.FC = ({ const pulled_vars = pullInputData(templateVars, id); updateShowContToggle(pulled_vars); - generatePrompts(promptText, pulled_vars).then((prompts) => { - setPromptPreviews( - prompts.map( - (p: PromptTemplate) => - new PromptInfo(p.toString(), extractSettingsVars(p.fill_history)), - ), - ); - }); + const prompts = + typeof promptText === "string" ? [promptText] : promptText; + + Promise.all(prompts.map((p) => generatePrompts(p, pulled_vars))).then( + (results) => { + // Handle all the results here + const all_concrete_prompts = results.flatMap((ps) => + ps.map( + (p: PromptTemplate) => + new PromptInfo( + p.toString(), + extractSettingsVars(p.fill_history), + ), + ), + ); + setPromptPreviews(all_concrete_prompts); + }, + ); pullInputChats(); } catch (err) { @@ -827,9 +865,18 @@ Soft failing by replacing undefined with empty strings.`, // Pull the data to fill in template input variables, if any let pulled_data: Dict<(string | TemplateVarInfo)[]> = {}; + let var_for_prompt_templates: string | undefined; try { // Try to pull inputs pulled_data = pullInputData(templateVars, id); + + // Add a special new variable for the root prompt template(s) + var_for_prompt_templates = ensureUniqueName( + "prompt", + Object.keys(pulled_data), + ); + if (typeof promptText !== "string") + pulled_data[var_for_prompt_templates] = promptText; // this will be filled in when calling queryLLMs } catch (err) { if (showAlert) showAlert((err as Error)?.message ?? err); console.error(err); @@ -873,7 +920,9 @@ Soft failing by replacing undefined with empty strings.`, // Fetch info about the number of queries we'll need to make const fetch_resp_count = () => fetchResponseCounts( - prompt_template, + typeof prompt_template === "string" + ? prompt_template + : `{${var_for_prompt_templates}}`, // Use special root prompt if there's multiple prompt variants pulled_data, _llmItemsCurrState, pulled_chats as ChatHistoryInfo[], @@ -951,9 +1000,11 @@ Soft failing by replacing undefined with empty strings.`, const query_llms = () => { return queryLLM( id, - _llmItemsCurrState, // deep clone it first + _llmItemsCurrState, numGenerations, - prompt_template, + typeof prompt_template === "string" + ? prompt_template + : `{${var_for_prompt_templates}}`, // Use special root prompt if there's multiple prompt variants pulled_data, chat_hist_by_llm, apiKeys || {}, @@ -1015,7 +1066,7 @@ Soft failing by replacing undefined with empty strings.`, o.metavars = resp_obj.metavars ?? {}; // Add a metavar for the prompt *template* in this PromptNode - o.metavars.__pt = prompt_template; + // o.metavars.__pt = prompt_template; // Carry over any chat history if (resp_obj.chat_history) @@ -1156,6 +1207,48 @@ Soft failing by replacing undefined with empty strings.`, [numGenerationsLastRun, status], ); + const handleAddPromptVariant = useCallback(() => { + // Pushes a new prompt variant, updating the prompts list and duplicating the current shown prompt + const prompts = typeof promptText === "string" ? [promptText] : promptText; + const curIdx = Math.max( + 0, + Math.min(prompts.length - 1, idxPromptVariantShown), + ); // clamp + const curShownPrompt = prompts[curIdx]; + setPromptText(prompts.concat([curShownPrompt])); + setIdxPromptVariantShown(prompts.length); + }, [promptText, idxPromptVariantShown]); + + const gotoPromptVariant = useCallback( + (shift: number) => { + const prompts = + typeof promptText === "string" ? [promptText] : promptText; + const newIdx = Math.max( + 0, + Math.min(prompts.length - 1, idxPromptVariantShown + shift), + ); // clamp + setIdxPromptVariantShown(newIdx); + }, + [promptText, idxPromptVariantShown], + ); + + const handleRemovePromptVariant = useCallback(() => { + setPromptText((prompts) => { + if (typeof prompts === "string") return prompts; // cannot remove the last one + prompts.splice(idxPromptVariantShown, 1); // remove the indexed variant + setIdxPromptVariantShown(Math.max(0, idxPromptVariantShown - 1)); // goto the previous variant, if possible + return prompts; + }); + }, [idxPromptVariantShown]); + + // Whenever idx of prompt variant changes, we need to refresh the Textarea: + useEffect(() => { + if (textAreaRef.current && Array.isArray(promptText)) { + // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this + textAreaRef.current.value = promptText[idxPromptVariantShown]; + } + }, [idxPromptVariantShown]); + const hideStatusIndicator = () => { if (status !== Status.NONE) setStatus(Status.NONE); }; @@ -1254,7 +1347,12 @@ Soft failing by replacing undefined with empty strings.`, key={0} className="prompt-field-fixed nodrag nowheel" minRows={4} - defaultValue={data.prompt} + defaultValue={ + typeof data.prompt === "string" + ? data.prompt + : data.prompt && + data.prompt[data.idxPromptVariantShown ?? 0] + } onChange={handleInputChange} miw={230} styles={{ @@ -1277,11 +1375,69 @@ Soft failing by replacing undefined with empty strings.`, className="prompt-field-fixed nodrag nowheel" minRows={4} maxRows={12} - defaultValue={data.prompt} + defaultValue={ + typeof data.prompt === "string" + ? data.prompt + : data.prompt && data.prompt[data.idxPromptVariantShown ?? 0] + } onChange={handleInputChange} + // value={typeof promptText === "string" ? promptText : promptText[idxPromptVariantShown]} /> )} + + {typeof promptText === "string" || promptText.length === 1 ? ( + + ) : ( + <> + gotoPromptVariant(-1)} + > + + + + + Variant {idxPromptVariantShown + 1} of{" "} + {typeof promptText === "string" ? 1 : promptText.length} + + + gotoPromptVariant(1)} + > + + + + + + + + + + + + + + + + )} + + - + + + + +
diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 55e3775e6..20d25bac0 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -29,6 +29,8 @@ import { repairCachedResponses, deepcopy, llmResponseDataToString, + extendArray, + extendArrayDict, } from "./utils"; import StorageCache, { StringLookup } from "./cache"; import { PromptPipeline } from "./query"; @@ -520,7 +522,7 @@ export async function generatePrompts( /** * Calculates how many queries we need to make, given the passed prompt and vars. * - * @param prompt the prompt template, with any {{}} vars + * @param prompt the prompt template, with any {} vars; or alternatively, an array of such templates * @param vars a dict of the template variables to fill the prompt template with, by name. * For each var value, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.) * @param llms the list of LLMs you will query @@ -531,7 +533,7 @@ export async function generatePrompts( * If there was an error, returns a dict with a single key, 'error'. */ export async function countQueries( - prompt: string, + prompt: string | string[], vars: PromptVarsDict, llms: Array, n: number, @@ -545,19 +547,27 @@ export async function countQueries( vars = deepcopy(vars); llms = deepcopy(llms); - let all_prompt_permutations: PromptTemplate[] | Dict; - - const gen_prompts = new PromptPermutationGenerator(prompt); - if (cont_only_w_prior_llms && Array.isArray(llms)) { - all_prompt_permutations = {}; - llms.forEach((llm_spec) => { - const llm_key = extract_llm_key(llm_spec); - (all_prompt_permutations as Dict)[llm_key] = Array.from( - gen_prompts.generate(filterVarsByLLM(vars, llm_key)), + const prompt_templates = typeof prompt === "string" ? [prompt] : prompt; + const all_prompt_permutations: PromptTemplate[] | Dict = + cont_only_w_prior_llms && Array.isArray(llms) ? {} : []; + + for (const pt of prompt_templates) { + const gen_prompts = new PromptPermutationGenerator(pt); + if (cont_only_w_prior_llms && Array.isArray(llms)) { + llms.forEach((llm_spec) => { + const llm_key = extract_llm_key(llm_spec); + extendArrayDict( + all_prompt_permutations as Dict, + llm_key, + Array.from(gen_prompts.generate(filterVarsByLLM(vars, llm_key))), + ); + }); + } else { + extendArray( + all_prompt_permutations as PromptTemplate[], + Array.from(gen_prompts.generate(vars)), ); - }); - } else { - all_prompt_permutations = Array.from(gen_prompts.generate(vars)); + } } let cache_file_lookup: Dict = {}; @@ -739,7 +749,7 @@ export async function ensureUniqueFlowFilename( * @param id a unique ID to refer to this information. Used when cache'ing responses. * @param llm a string, list of strings, or list of LLM spec dicts specifying the LLM(s) to query. * @param n the amount of generations for each prompt. All LLMs will be queried the same number of times 'n' per each prompt. - * @param prompt the prompt template, with any {{}} vars + * @param prompt the prompt template, with any {} vars * @param vars a dict of the template variables to fill the prompt template with, by name. For each var, can be single values or a list; in the latter, all permutations are passed. (Pass empty dict if no vars.) * @param chat_histories Either an array of `ChatHistory` (to use across all LLMs), or a dict indexed by LLM nicknames of `ChatHistory` arrays to use per LLM. diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 55f9ff31b..92c0d8d59 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -2398,3 +2398,52 @@ export const compressBase64Image = (b64: string): Promise => { ) .then((compressedBlob) => blobToBase64(compressedBlob as Blob)); }; + +/** + * Extends array `a` with the values of `b`. + * @param a The array to extend (in-place). + * @param b The array to add to the end of `a`. + * @returns `a`, extended. + */ +export const extendArray = (a: Array, b: Array): Array => { + for (const i in b) { + a.push(b[i]); + } + return a; +}; + +/** + * Extends the array `key` in a dict with `values`, creating a new array if the key is missing. + * @param dict The dictionary to extend (in-place). + * @param key The key of the dictionary. + * @param values The new array to append to the end of the dict value for `key`. + */ +export const extendArrayDict = ( + dict: Record, + key: K, + values: V[], +): void => { + if (!dict[key]) { + dict[key] = []; + } + extendArray(dict[key], values); +}; + +/** Ensure that a name is 'unique'; if not, return an amended version with a count tacked on (e.g. "GPT-4 (2)") */ +export const ensureUniqueName = (_name: string, _prev_names: string[]) => { + // Strip whitespace around names + const prev_names = _prev_names.map((n) => n.trim()); + const name = _name.trim(); + + // Check if name is unique + if (!prev_names.includes(name)) return name; + + // Name isn't unique; find a unique one: + let i = 2; + let new_name = `${name} (${i})`; + while (prev_names.includes(new_name)) { + i += 1; + new_name = `${name} (${i})`; + } + return new_name; +}; From d6723459e80d32f969d84d9c3f355b4eee4a98d7 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Fri, 14 Mar 2025 12:34:12 -0400 Subject: [PATCH 02/35] fix countQueries backwards compatibility; add alertmodal for deleting prompt variant; add prompt variant to PromptPreview screens --- chainforge/flask_app.py | 48 ++- chainforge/react-server/src/App.tsx | 1 + .../react-server/src/AreYouSureModal.tsx | 7 +- chainforge/react-server/src/FlowSidebar.tsx | 72 ++++- .../src/LLMResponseInspectorModal.tsx | 5 +- chainforge/react-server/src/PromptNode.tsx | 306 ++++++++++++------ chainforge/react-server/src/backend/models.ts | 2 + chainforge/react-server/src/backend/utils.ts | 10 +- 8 files changed, 324 insertions(+), 127 deletions(-) diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index e2876bf4f..8ec97925e 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -1,4 +1,4 @@ -import json, os, sys, asyncio, time +import json, os, sys, asyncio, time, shutil from dataclasses import dataclass from enum import Enum from typing import List @@ -772,7 +772,7 @@ def delete_flow(filename): @app.route('/api/flows/', methods=['PUT']) def save_or_rename_flow(filename): - """Save or rename a flow""" + """Save, rename, or duplicate a flow""" data = request.json if not filename.endswith('.cforge'): @@ -805,6 +805,36 @@ def save_or_rename_flow(filename): return jsonify({"message": f"Flow renamed from {filename} to {new_name}"}) except Exception as error: return jsonify({"error": str(error)}), 404 + + elif data.get('duplicate'): + # Duplicate flow + try: + # Check for name clashes (if a flow already exists with the new name) + copy_name = _get_unique_flow_name(filename, "Copy of ") + # Copy the file to the new (safe) path, and copy metadata too: + shutil.copy2(os.path.join(FLOWS_DIR, filename), os.path.join(FLOWS_DIR, f"{copy_name}.cforge")) + # Return the new filename + return jsonify({"copyName": copy_name}) + except Exception as error: + return jsonify({"error": str(error)}), 404 + +def _get_unique_flow_name(filename: str, prefix: str = None) -> str: + base, ext = os.path.splitext(filename) + if ext is None or len(ext) == 0: + ext = ".cforge" + unique_filename = base + ext + if prefix is not None: + unique_filename = prefix + unique_filename + i = 1 + + # Find the first non-clashing filename of the form (i).cforge where i=1,2,3 etc + while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)): + unique_filename = f"{base}({i}){ext}" + if prefix is not None: + unique_filename = prefix + unique_filename + i += 1 + + return unique_filename.replace(".cforge", "") @app.route('/api/getUniqueFlowFilename', methods=['PUT']) def get_unique_flow_name(): @@ -813,18 +843,8 @@ def get_unique_flow_name(): filename = data.get("name") try: - base, ext = os.path.splitext(filename) - if ext is None or len(ext) == 0: - ext = ".cforge" - unique_filename = base + ext - i = 1 - - # Find the first non-clashing filename of the form (i).cforge where i=1,2,3 etc - while os.path.isfile(os.path.join(FLOWS_DIR, unique_filename)): - unique_filename = f"{base}({i}){ext}" - i += 1 - - return jsonify(unique_filename.replace(".cforge", "")) + new_name = _get_unique_flow_name(filename) + return jsonify(new_name) except Exception as e: return jsonify({"error": str(e)}), 404 diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index ba203e372..309f2f9dd 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -1334,6 +1334,7 @@ const App = () => { ml="sm" size="1.625rem" onClick={() => saveFlow()} + bg="#eee" loading={isSaving} disabled={isLoading || isSaving} > diff --git a/chainforge/react-server/src/AreYouSureModal.tsx b/chainforge/react-server/src/AreYouSureModal.tsx index b0ac59cd9..957817526 100644 --- a/chainforge/react-server/src/AreYouSureModal.tsx +++ b/chainforge/react-server/src/AreYouSureModal.tsx @@ -5,6 +5,7 @@ import { useDisclosure } from "@mantine/hooks"; export interface AreYouSureModalProps { title: string; message: string; + color?: string; onConfirm?: () => void; } @@ -14,7 +15,7 @@ export interface AreYouSureModalRef { /** Modal that lets user rename a single value, using a TextInput field. */ const AreYouSureModal = forwardRef( - function AreYouSureModal({ title, message, onConfirm }, ref) { + function AreYouSureModal({ title, message, color, onConfirm }, ref) { const [opened, { open, close }] = useDisclosure(false); const description = message || "Are you sure?"; @@ -37,7 +38,7 @@ const AreYouSureModal = forwardRef( onClose={close} title={title} styles={{ - header: { backgroundColor: "orange", color: "white" }, + header: { backgroundColor: color ?? "orange", color: "white" }, root: { position: "relative", left: "-5%" }, }} > @@ -54,7 +55,7 @@ const AreYouSureModal = forwardRef( >
} - styles={{ title: { justifyContent: "space-between", width: "100%" } }} + styles={{ + title: { justifyContent: "space-between", width: "100%" }, + header: { paddingBottom: "0px" }, + }} >
{ const metakeys = new Set( @@ -98,19 +101,33 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => { export class PromptInfo { prompt: string; - settings: Dict; + settings?: Dict; + label?: string; - constructor(prompt: string, settings: Dict) { + constructor(prompt: string, settings?: Dict, label?: string) { this.prompt = prompt; this.settings = settings; + this.label = label; } } -const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) => +const displayPromptInfos = ( + promptInfos: PromptInfo[], + wideFormat: boolean, + bgColor?: string, +) => promptInfos.map((info, idx) => (
-
{info.prompt}
- {info.settings ? ( +
+ {info.label && ( + + {info.label} +
+
+ )} + {info.prompt} +
+ {info.settings && Object.entries(info.settings).map(([key, val]) => { return (
@@ -120,10 +137,7 @@ const displayPromptInfos = (promptInfos: PromptInfo[], wideFormat: boolean) =>
); - }) - ) : ( - <> - )} + })}
)); @@ -131,12 +145,14 @@ export interface PromptListPopoverProps { promptInfos: PromptInfo[]; onHover: () => void; onClick: () => void; + promptTemplates?: string[] | string; } export const PromptListPopover: React.FC = ({ promptInfos, onHover, onClick, + promptTemplates, }) => { const [opened, { close, open }] = useDisclosure(false); @@ -185,6 +201,29 @@ export const PromptListPopover: React.FC = ({ Preview of generated prompts ({promptInfos.length} total) + {Array.isArray(promptTemplates) && promptTemplates.length > 1 && ( + + + {displayPromptInfos( + promptTemplates.map( + (t, i) => new PromptInfo(t, undefined, `Variant ${i + 1}`), + ), + false, + "#ddf1f8", + )} + + + )} {displayPromptInfos(promptInfos, false)} @@ -195,12 +234,14 @@ export interface PromptListModalProps { promptPreviews: PromptInfo[]; infoModalOpened: boolean; closeInfoModal: () => void; + promptTemplates?: string[] | string; } export const PromptListModal: React.FC = ({ promptPreviews, infoModalOpened, closeInfoModal, + promptTemplates, }) => { return ( = ({ }} > + {Array.isArray(promptTemplates) && promptTemplates.length > 1 && ( + + + {displayPromptInfos( + promptTemplates.map( + (t, i) => new PromptInfo(t, undefined, `Variant ${i + 1}`), + ), + true, + "#ddf1f8", + )} + + + )} {displayPromptInfos(promptPreviews, true)} @@ -1207,6 +1271,39 @@ Soft failing by replacing undefined with empty strings.`, [numGenerationsLastRun, status], ); + const hideStatusIndicator = () => { + if (status !== Status.NONE) setStatus(Status.NONE); + }; + + // Dynamically update the textareas and position of the template hooks + const textAreaRef = useRef(null); + const [hooksY, setHooksY] = useState(138); + const setRef = useCallback( + (elem: HTMLDivElement | HTMLTextAreaElement | null) => { + if (!elem) return; + // To listen for resize events of the textarea, we need to use a ResizeObserver. + // We initialize the ResizeObserver only once, when the 'ref' is first set, and only on the div wrapping textfields. + // NOTE: This won't work on older browsers, but there's no alternative solution. + if (!textAreaRef.current && elem && window.ResizeObserver) { + let past_hooks_y = 138; + const incr = 68 + (node_type === "chat" ? -6 : 0); + const observer = new window.ResizeObserver(() => { + if (!textAreaRef || !textAreaRef.current) return; + const new_hooks_y = textAreaRef.current.clientHeight + incr; + if (past_hooks_y !== new_hooks_y) { + setHooksY(new_hooks_y); + past_hooks_y = new_hooks_y; + } + }); + + observer.observe(elem); + textAreaRef.current = elem; + } + }, + [textAreaRef], + ); + + const deleteVariantConfirmModal = useRef(null); const handleAddPromptVariant = useCallback(() => { // Pushes a new prompt variant, updating the prompts list and duplicating the current shown prompt const prompts = typeof promptText === "string" ? [promptText] : promptText; @@ -1234,12 +1331,20 @@ Soft failing by replacing undefined with empty strings.`, const handleRemovePromptVariant = useCallback(() => { setPromptText((prompts) => { - if (typeof prompts === "string") return prompts; // cannot remove the last one + if (typeof prompts === "string" || prompts.length === 1) return prompts; // cannot remove the last one prompts.splice(idxPromptVariantShown, 1); // remove the indexed variant - setIdxPromptVariantShown(Math.max(0, idxPromptVariantShown - 1)); // goto the previous variant, if possible - return prompts; + const newIdx = Math.max(0, idxPromptVariantShown - 1); + setIdxPromptVariantShown(newIdx); // goto the previous variant, if possible + + if (textAreaRef.current) { + // We have to force an update here since idxPromptVariantShown might've not changed + // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this + textAreaRef.current.value = prompts[newIdx]; + } + + return [...prompts]; }); - }, [idxPromptVariantShown]); + }, [idxPromptVariantShown, textAreaRef]); // Whenever idx of prompt variant changes, we need to refresh the Textarea: useEffect(() => { @@ -1249,37 +1354,92 @@ Soft failing by replacing undefined with empty strings.`, } }, [idxPromptVariantShown]); - const hideStatusIndicator = () => { - if (status !== Status.NONE) setStatus(Status.NONE); - }; + const promptVariantControls = useMemo(() => { + return ( + + {typeof promptText === "string" || promptText.length === 1 ? ( + + + + ) : ( + <> + gotoPromptVariant(-1)} + > + + - // Dynamically update the textareas and position of the template hooks - const textAreaRef = useRef(null); - const [hooksY, setHooksY] = useState(138); - const setRef = useCallback( - (elem: HTMLDivElement | HTMLTextAreaElement | null) => { - if (!elem) return; - // To listen for resize events of the textarea, we need to use a ResizeObserver. - // We initialize the ResizeObserver only once, when the 'ref' is first set, and only on the div wrapping textfields. - // NOTE: This won't work on older browsers, but there's no alternative solution. - if (!textAreaRef.current && elem && window.ResizeObserver) { - let past_hooks_y = 138; - const incr = 68 + (node_type === "chat" ? -6 : 0); - const observer = new window.ResizeObserver(() => { - if (!textAreaRef || !textAreaRef.current) return; - const new_hooks_y = textAreaRef.current.clientHeight + incr; - if (past_hooks_y !== new_hooks_y) { - setHooksY(new_hooks_y); - past_hooks_y = new_hooks_y; - } - }); + + Variant {idxPromptVariantShown + 1} of{" "} + {typeof promptText === "string" ? 1 : promptText.length} + - observer.observe(elem); - textAreaRef.current = elem; - } - }, - [textAreaRef], - ); + gotoPromptVariant(1)} + > + + + + + + + + + + + deleteVariantConfirmModal?.current?.trigger()} + > + + + + + )} + + ); + }, [idxPromptVariantShown, promptText, deleteVariantConfirmModal]); // Add custom context menu options on right-click. // 1. Convert TextFields to Items Node, for convenience. @@ -1322,6 +1482,7 @@ Soft failing by replacing undefined with empty strings.`, , @@ -1333,9 +1494,17 @@ Soft failing by replacing undefined with empty strings.`, /> + {node_type === "chat" ? (
@@ -1385,58 +1554,7 @@ Soft failing by replacing undefined with empty strings.`, /> )} - - {typeof promptText === "string" || promptText.length === 1 ? ( - - ) : ( - <> - gotoPromptVariant(-1)} - > - - - - - Variant {idxPromptVariantShown + 1} of{" "} - {typeof promptText === "string" ? 1 : promptText.length} - - - gotoPromptVariant(1)} - > - - - - - - - - - - - - - - - - )} - + {promptVariantControls} { vars !== undefined && Object.keys(vars).some((k) => k.charAt(0) === "=") ) { - return transformDict( - deepcopy(vars), - (k) => k.charAt(0) === "=", - (k) => k.substring(1), + return StringLookup.concretizeDict( + transformDict( + deepcopy(vars), + (k) => k.charAt(0) === "=", + (k) => k.substring(1), + ), ); } else return {}; }; From bdf0b6a0239abb3bff0389b54611f4465a96f3ce Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Fri, 14 Mar 2025 19:48:00 -0400 Subject: [PATCH 03/35] Autoresize textarea when switching prompt variants. Ensure auto-templating is only used for variants when length exceeds 1. --- chainforge/flask_app.py | 18 ++ chainforge/react-server/src/App.tsx | 189 ++++++++++++++---- chainforge/react-server/src/FlowSidebar.tsx | 1 - chainforge/react-server/src/ItemsNode.tsx | 2 +- chainforge/react-server/src/PromptNode.tsx | 35 +++- .../react-server/src/backend/backend.ts | 2 + 6 files changed, 193 insertions(+), 54 deletions(-) diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index 8ec97925e..c954fbc70 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -759,6 +759,17 @@ def get_flow(filename): except FileNotFoundError: return jsonify({"error": "Flow not found"}), 404 +@app.route('/api/flowExists/', methods=['GET']) +def get_flow_exists(filename): + """Return the content of a specific flow""" + if not filename.endswith('.cforge'): + filename += '.cforge' + try: + is_file = os.path.isfile(os.path.join(FLOWS_DIR, filename)) + return jsonify({"exists": is_file}) + except FileNotFoundError: + return jsonify({"error": "Flow not found"}), 404 + @app.route('/api/flows/', methods=['DELETE']) def delete_flow(filename): """Delete a flow""" @@ -781,11 +792,18 @@ def save_or_rename_flow(filename): if data.get('flow'): # Save flow (overwriting any existing flow file with the same name) flow_data = data.get('flow') + also_autosave = data.get('alsoAutosave') try: filepath = os.path.join(FLOWS_DIR, filename) with open(filepath, 'w') as f: json.dump(flow_data, f) + + # If we should also autosave, then attempt to override the autosave cache file: + if also_autosave: + autosave_filepath = os.path.join(FLOWS_DIR, '__autosave.cforge') + shutil.copy2(filepath, autosave_filepath) # copy the file to __autosave + return jsonify({"message": f"Flow '{filename}' saved!"}) except FileNotFoundError: return jsonify({"error": f"Could not save flow '{filename}' to local filesystem. See terminal for more details."}), 404 diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index 309f2f9dd..c41f451c4 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -63,6 +63,7 @@ import { getDefaultModelSettings, } from "./ModelSettingSchemas"; import { v4 as uuid } from "uuid"; +import axios from "axios"; import LZString from "lz-string"; import { EXAMPLEFLOW_1 } from "./example_flows"; @@ -78,7 +79,11 @@ import "lazysizes/plugins/attrchange/ls.attrchange"; import { shallow } from "zustand/shallow"; import useStore, { StoreHandles } from "./store"; import StorageCache, { StringLookup } from "./backend/cache"; -import { APP_IS_RUNNING_LOCALLY, browserTabIsActive } from "./backend/utils"; +import { + APP_IS_RUNNING_LOCALLY, + browserTabIsActive, + FLASK_BASE_URL, +} from "./backend/utils"; import { Dict, JSONCompatible, LLMSpec } from "./backend/typing"; import { ensureUniqueFlowFilename, @@ -113,6 +118,14 @@ const IS_ACCEPTED_BROWSER = // we have access to the Flask backend for, e.g., Python code evaluation. const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY(); +const SAVE_FLOW_FILENAME_TO_BROWSER_CACHE = (name: string) => { + console.log("Saving flow filename", name); + // Save the current filename of the user's working flow + StorageCache.saveToLocalStorage("chainforge-cur-file", { + flowFileName: name, + }); +}; + const selector = (state: StoreHandles) => ({ nodes: state.nodes, edges: state.edges, @@ -266,6 +279,11 @@ const App = () => { const safeSetFlowFileName = useCallback(async (newName: string) => { const uniqueName = await ensureUniqueFlowFilename(newName); setFlowFileName(uniqueName); + SAVE_FLOW_FILENAME_TO_BROWSER_CACHE(uniqueName); + }, []); + const setFlowFileNameAndCache = useCallback((newName: string) => { + setFlowFileName(newName); + SAVE_FLOW_FILENAME_TO_BROWSER_CACHE(newName); }, []); // For 'share' button @@ -387,6 +405,7 @@ const App = () => { flowData?: unknown, saveToLocalFilesystem?: string, hideErrorAlert?: boolean, + onError?: () => void, ) => { if (!rfInstance && !flowData) return; @@ -406,11 +425,16 @@ const App = () => { // Save! const flowFile = `${saveToLocalFilesystem ?? flowFileName}.cforge`; if (saveToLocalFilesystem !== undefined) - return saveFlowToLocalFilesystem(flow_and_cache, flowFile); + return saveFlowToLocalFilesystem( + flow_and_cache, + flowFile, + saveToLocalFilesystem !== "__autosave", + ); // @ts-expect-error The exported RF instance is JSON compatible but TypeScript won't read it as such. else downloadJSON(flow_and_cache, flowFile); }) .catch((err) => { + if (onError) onError(); if (hideErrorAlert) console.error(err); else handleError(err); }); @@ -432,14 +456,18 @@ const App = () => { setShowSaveSuccess(false); startSaveTransition(() => { - // NOTE: This currently only saves the front-end state. Cache files - // are not pulled or overwritten upon loading from localStorage. + // Get current flow state const flow = rf.toObject(); - StorageCache.saveToLocalStorage("chainforge-flow", flow); - // Attempt to save the current state of the back-end state, - // the StorageCache. (This does LZ compression to save space.) - StorageCache.saveToLocalStorage("chainforge-state"); + const saveToLocalStorage = () => { + // This line only saves the front-end state. Cache files + // are not pulled or overwritten upon loading from localStorage. + StorageCache.saveToLocalStorage("chainforge-flow", flow); + + // Attempt to save the current back-end state, + // in the StorageCache. (This does LZ compression to save space.) + StorageCache.saveToLocalStorage("chainforge-state"); + }; const onFlowSaved = () => { console.log("Flow saved!"); @@ -452,10 +480,18 @@ const App = () => { // If running locally, aattempt to save a copy of the flow to the lcoal filesystem, // so it shows up in the list of saved flows. if (IS_RUNNING_LOCALLY) - exportFlow(flow, fileName ?? flowFileName, hideErrorAlert)?.then( - onFlowSaved, - ); - else onFlowSaved(); + // SAVE TO LOCAL FILESYSTEM (only), and if that fails, try to save to localStorage + exportFlow( + flow, + fileName ?? flowFileName, + hideErrorAlert, + saveToLocalStorage, + )?.then(onFlowSaved); + else { + // SAVE TO BROWSER LOCALSTORAGE + saveToLocalStorage(); + onFlowSaved(); + } }); }, [rfInstance, exportFlow, flowFileName], @@ -475,8 +511,13 @@ const App = () => { // Initialize auto-saving const initAutosaving = useCallback( - (rf_inst: ReactFlowInstance) => { - if (autosavingInterval !== undefined) return; // autosaving interval already set + (rf_inst: ReactFlowInstance, reinit?: boolean) => { + if (autosavingInterval !== undefined) { + // Autosaving interval already set + if (reinit) + clearInterval(autosavingInterval); // reinitialize interval, clearing the current one + else return; // do nothing + } console.log("Init autosaving"); // Autosave the flow to localStorage every minute: @@ -539,7 +580,9 @@ const App = () => { StorageCache.clear(); // New flow filename - setFlowFileName(`flow-${Date.now()}`); + const new_filename = `flow-${Date.now()}`; + setFlowFileNameAndCache(new_filename); + if (rfInstance) rfInstance.setViewport({ x: 200, y: 80, zoom: 1 }); }, [setNodes, setEdges, resetLLMColors, rfInstance]); @@ -575,7 +618,7 @@ const App = () => { }, 10); // Start auto-saving, if it's not already enabled - if (rf_inst) initAutosaving(rf_inst); + if (rf_inst) initAutosaving(rf_inst, true); }, [resetLLMColors, setNodes, setEdges, initAutosaving], ); @@ -584,23 +627,28 @@ const App = () => { importState(StorageCache.getAllMatching((key) => key.startsWith("r."))); }, [importState]); - const autosavedFlowExists = useCallback(() => { - return window.localStorage.getItem("chainforge-flow") !== null; - }, []); - const loadFlowFromAutosave = useCallback( - async (rf_inst: ReactFlowInstance) => { - const saved_flow = StorageCache.loadFromLocalStorage( - "chainforge-flow", - false, - ) as Dict; - if (saved_flow) { - StorageCache.loadFromLocalStorage("chainforge-state", true); - importGlobalStateFromCache(); - loadFlow(saved_flow, rf_inst); + // Find the autosaved flow, if it exists, returning + // whether it exists and the location ("browser" or "filesystem") that it exists at. + const autosavedFlowExists = useCallback(async () => { + if (IS_RUNNING_LOCALLY) { + // If running locally, we try to fetch a flow autosaved on the user's local machine first: + try { + const response = await axios.get( + `${FLASK_BASE_URL}api/flowExists/__autosave`, + ); + const autosave_file_exists = response.data.exists as boolean; + if (autosave_file_exists) + return { exists: autosave_file_exists, location: "filesystem" }; + } catch (error) { + // Soft fail, continuing onwards to checking localStorage instead } - }, - [importGlobalStateFromCache, loadFlow], - ); + } + + return { + exists: window.localStorage.getItem("chainforge-flow") !== null, + location: "browser", + }; + }, []); // Import data to the cache stored on the local filesystem (in backend) const handleImportCache = useCallback( @@ -715,6 +763,38 @@ const App = () => { fetchOpenAIEval(evalname).then(importFlowFromJSON).catch(handleError); }; + const loadFlowFromAutosave = useCallback( + async (rf_inst: ReactFlowInstance, fromFilesystem?: boolean) => { + if (fromFilesystem) { + // From local filesystem + // Fetch the flow + const response = await axios.get( + `${FLASK_BASE_URL}api/flows/__autosave`, + ); + + // Attempt to load flow into the UI + try { + importFlowFromJSON(response.data, rf_inst); + console.log("Loaded flow from autosave on local machine."); + } catch (error) { + handleError(error as Error); + } + } else { + // From browser localStorage + const saved_flow = StorageCache.loadFromLocalStorage( + "chainforge-flow", + false, + ) as Dict; + if (saved_flow) { + StorageCache.loadFromLocalStorage("chainforge-state", true); + importGlobalStateFromCache(); + loadFlow(saved_flow, rf_inst); + } + } + }, + [importGlobalStateFromCache, loadFlow, importFlowFromJSON, handleError], + ); + // Load flow from examples modal const onSelectExampleFlow = (name: string, example_category?: string) => { // Trigger the 'loading' modal @@ -723,7 +803,7 @@ const App = () => { // Detect a special category of the example flow, and use the right loader for it: if (example_category === "openai-eval") { importFlowFromOpenAIEval(name); - setFlowFileName(`flow-${Date.now()}`); + setFlowFileNameAndCache(`flow-${Date.now()}`); return; } @@ -732,7 +812,7 @@ const App = () => { .then(function (flowJSON) { // We have the data, import it: importFlowFromJSON(flowJSON); - setFlowFileName(`flow-${Date.now()}`); + setFlowFileNameAndCache(`flow-${Date.now()}`); }) .catch(handleError); }; @@ -871,6 +951,20 @@ const App = () => { err.message, ); }); + + // We also need to fetch the current flowFileName + // Attempt to get the last working filename on component mount + const last_working_flow_filename = StorageCache.loadFromLocalStorage( + "chainforge-cur-file", + ); + if ( + last_working_flow_filename && + typeof last_working_flow_filename === "object" && + "flowFileName" in last_working_flow_filename + ) { + // Use last working flow name + setFlowFileName(last_working_flow_filename.flowFileName as string); + } } else { // Check if there's a shared flow UID in the URL as a GET param // If so, we need to look it up in the database and attempt to load it: @@ -910,14 +1004,19 @@ const App = () => { } // Attempt to load an autosaved flow, if one exists: - if (autosavedFlowExists()) loadFlowFromAutosave(rf_inst); - else { - // Load an interesting default starting flow for new users - importFlowFromJSON(EXAMPLEFLOW_1, rf_inst); - - // Open a welcome pop-up - // openWelcomeModal(); - } + autosavedFlowExists().then(({ exists, location }) => { + if (!exists) { + // Load an interesting default starting flow for new users + importFlowFromJSON(EXAMPLEFLOW_1, rf_inst); + + // Open a welcome pop-up + // openWelcomeModal(); + } else if (location === "browser") { + loadFlowFromAutosave(rf_inst, false); + } else if (location === "filesystem") { + loadFlowFromAutosave(rf_inst, true); + } + }); // Turn off loading wheel setIsLoading(false); @@ -1218,7 +1317,9 @@ const App = () => { { - if (name !== undefined) setFlowFileName(name); + if (name !== undefined) { + setFlowFileNameAndCache(name); + } if (flowData !== undefined) { try { importFlowFromJSON(flowData); @@ -1231,7 +1332,7 @@ const App = () => { }} /> ); - }, [flowFileName, importFlowFromJSON, showAlert]); + }, [flowFileName, importFlowFromJSON, showAlert, setFlowFileNameAndCache]); if (!IS_ACCEPTED_BROWSER) { return ( diff --git a/chainforge/react-server/src/FlowSidebar.tsx b/chainforge/react-server/src/FlowSidebar.tsx index c95170851..e9612b358 100644 --- a/chainforge/react-server/src/FlowSidebar.tsx +++ b/chainforge/react-server/src/FlowSidebar.tsx @@ -24,7 +24,6 @@ import { Tooltip, } from "@mantine/core"; import { FLASK_BASE_URL } from "./backend/utils"; -import { ensureUniqueFlowFilename } from "./backend/backend"; interface FlowFile { name: string; diff --git a/chainforge/react-server/src/ItemsNode.tsx b/chainforge/react-server/src/ItemsNode.tsx index 16e6b3db4..0f0198ed9 100644 --- a/chainforge/react-server/src/ItemsNode.tsx +++ b/chainforge/react-server/src/ItemsNode.tsx @@ -55,7 +55,7 @@ const ItemsNode: React.FC = ({ data, id }) => { const flags = useStore((state) => state.flags); const [contentDiv, setContentDiv] = useState(null); - const [isEditing, setIsEditing] = useState(true); + const [isEditing, setIsEditing] = useState(false); const [csvInput, setCsvInput] = useState(null); const [countText, setCountText] = useState(null); diff --git a/chainforge/react-server/src/PromptNode.tsx b/chainforge/react-server/src/PromptNode.tsx index bbfa49873..a4f533106 100644 --- a/chainforge/react-server/src/PromptNode.tsx +++ b/chainforge/react-server/src/PromptNode.tsx @@ -98,6 +98,14 @@ const bucketChatHistoryInfosByLLM = (chat_hist_infos: ChatHistoryInfo[]) => { }); return chats_by_llm; }; +const getRootPromptFor = ( + promptTexts: string | string[], + varNameForRootTemplate: string, +) => { + if (typeof promptTexts === "string") return promptTexts; + else if (promptTexts.length === 1) return promptTexts[0]; + else return `{${varNameForRootTemplate}}`; +}; export class PromptInfo { prompt: string; @@ -929,7 +937,7 @@ Soft failing by replacing undefined with empty strings.`, // Pull the data to fill in template input variables, if any let pulled_data: Dict<(string | TemplateVarInfo)[]> = {}; - let var_for_prompt_templates: string | undefined; + let var_for_prompt_templates: string; try { // Try to pull inputs pulled_data = pullInputData(templateVars, id); @@ -939,7 +947,7 @@ Soft failing by replacing undefined with empty strings.`, "prompt", Object.keys(pulled_data), ); - if (typeof promptText !== "string") + if (typeof promptText !== "string" && promptText.length > 1) pulled_data[var_for_prompt_templates] = promptText; // this will be filled in when calling queryLLMs } catch (err) { if (showAlert) showAlert((err as Error)?.message ?? err); @@ -1066,9 +1074,7 @@ Soft failing by replacing undefined with empty strings.`, id, _llmItemsCurrState, numGenerations, - typeof prompt_template === "string" - ? prompt_template - : `{${var_for_prompt_templates}}`, // Use special root prompt if there's multiple prompt variants + getRootPromptFor(prompt_template, var_for_prompt_templates), // Use special root prompt if there's multiple prompt variants pulled_data, chat_hist_by_llm, apiKeys || {}, @@ -1277,6 +1283,16 @@ Soft failing by replacing undefined with empty strings.`, // Dynamically update the textareas and position of the template hooks const textAreaRef = useRef(null); + const resizeTextarea = () => { + const textarea = textAreaRef.current; + + if (textarea) { + textarea.style.height = "auto"; // Reset height to shrink if needed + const newHeight = Math.min(textarea.scrollHeight, 600); + textarea.style.height = `${newHeight}px`; + } + }; + const [hooksY, setHooksY] = useState(138); const setRef = useCallback( (elem: HTMLDivElement | HTMLTextAreaElement | null) => { @@ -1325,6 +1341,7 @@ Soft failing by replacing undefined with empty strings.`, Math.min(prompts.length - 1, idxPromptVariantShown + shift), ); // clamp setIdxPromptVariantShown(newIdx); + resizeTextarea(); }, [promptText, idxPromptVariantShown], ); @@ -1340,6 +1357,7 @@ Soft failing by replacing undefined with empty strings.`, // We have to force an update here since idxPromptVariantShown might've not changed // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this textAreaRef.current.value = prompts[newIdx]; + resizeTextarea(); } return [...prompts]; @@ -1351,6 +1369,7 @@ Soft failing by replacing undefined with empty strings.`, if (textAreaRef.current && Array.isArray(promptText)) { // @ts-expect-error Mantine has a 'value' property on Textareas, but TypeScript doesn't know this textAreaRef.current.value = promptText[idxPromptVariantShown]; + resizeTextarea(); } }, [idxPromptVariantShown]); @@ -1399,7 +1418,7 @@ Soft failing by replacing undefined with empty strings.`, gotoPromptVariant(1)} > @@ -1540,9 +1559,9 @@ Soft failing by replacing undefined with empty strings.`, ) : ( + + + + + */} + {/* */}
-
+ + + - + {/* - Suggest New Criteria Based on the Feedback + Suggest New Criteria + + + + + ); diff --git a/chainforge/react-server/src/EvalGen/GradingView.tsx b/chainforge/react-server/src/EvalGen/GradingView.tsx index 1640c48fb..e26927992 100644 --- a/chainforge/react-server/src/EvalGen/GradingView.tsx +++ b/chainforge/react-server/src/EvalGen/GradingView.tsx @@ -110,7 +110,7 @@ const GradingView: React.FC = ({ {/* Go forward to the next response */} - diff --git a/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx index 0b1867d42..27d2fae36 100644 --- a/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx +++ b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx @@ -33,7 +33,7 @@ import { } from "@tabler/icons-react"; import useStore from "../store"; import { accuracyToColor, cmatrixTextAnnotations } from "../backend/utils"; -import { generateLLMEvaluationCriteria } from "../backend/evalgen/utils"; +import { generateLLMEvaluationCriteria, getPromptForGenEvalCriteriaFromDesc } from "../backend/evalgen/utils"; import { v4 as uuid } from "uuid"; import Plot from "react-plotly.js"; @@ -402,14 +402,7 @@ const PickCriteriaStep: React.FC = ({ generateLLMEvaluationCriteria( "", apiKeys, - `I've described a criteria I want to use to evaluate text. I want you to take the criteria and output a JSON object in the format below. - -CRITERIA: -\`\`\` -${addCriteriaValue} -\`\`\` - -Your response should contain a short title for the criteria ("shortname"), a description of the criteria in 2 sentences ("criteria"), and whether it should be evaluated with "code", or by an "expert" if the criteria is difficult to evaluate ("eval_method"). Your answer should be JSON within a \`\`\`json \`\`\` marker, with the following three fields: "criteria", "shortname", and "eval_method" (code or expert). The "criteria" should expand upon the user's input, the "shortname" should be a very brief title for the criteria, and this list should contain as many evaluation criteria as you can think of. Each evaluation criteria should test a unit concept that should evaluate to "true" in the ideal case. Only output JSON, nothing else.`, // prompt + getPromptForGenEvalCriteriaFromDesc(addCriteriaValue), // prompt null, // system_msg ) .then((evalCrits) => { @@ -430,6 +423,7 @@ Your response should contain a short title for the criteria ("shortname"), a des setIsLoadingCriteria((num) => num - 1); }); }; + const updateCriteria = ( newValue: string, critIdx: number, @@ -449,7 +443,7 @@ Your response should contain a short title for the criteria ("shortname"), a des }; return ( - + Define Evaluation Criteria
@@ -467,10 +461,10 @@ Your response should contain a short title for the criteria ("shortname"), a des setAddCriteriaValue(evt.currentTarget.value)} - placeholder="the response is valid JSON" + placeholder="e.g., the response is valid JSON" mb="lg" pl="sm" pr="sm" @@ -485,6 +479,16 @@ Your response should contain a short title for the criteria ("shortname"), a des /> +
- + {/* @@ -557,7 +561,7 @@ Your response should contain a short title for the criteria ("shortname"), a des Ready to Grade! - + */}
); }; diff --git a/chainforge/react-server/src/EvalGenModal.tsx b/chainforge/react-server/src/EvalGenModal.tsx index 88444038c..9ea550273 100644 --- a/chainforge/react-server/src/EvalGenModal.tsx +++ b/chainforge/react-server/src/EvalGenModal.tsx @@ -428,21 +428,6 @@ const EvalGenModal = forwardRef>( count += grade === grades[respUid][criteriaUID] ? 1 : 0; } return count; - - // if (grades[responseUID]) { - // let count = 0; - // for (const critUid in grades[responseUID]) { - // count += grades[responseUID][critUid] ? 1 : 0; - // } - // // return grade === grades[responseUID][criteriaUID] ? 1 : 0; // this needs to be changed after the grading feature is fully implemented on server side. - // return count; - // // return 10; - // } - - // if (grades[responseUID]) { - // return grade === grades[responseUID][criteriaUID] ? 1 : 0; // this needs to be changed after the grading feature is fully implemented on server side. - // } - // return 0; }; // The EvalGen object responsible for generating, implementing, and filtering candidate implementations diff --git a/chainforge/react-server/src/backend/evalgen/executor.ts b/chainforge/react-server/src/backend/evalgen/executor.ts index 09d9e0d8e..07532cd5d 100644 --- a/chainforge/react-server/src/backend/evalgen/executor.ts +++ b/chainforge/react-server/src/backend/evalgen/executor.ts @@ -11,7 +11,7 @@ import { EvalFunctionSetReport, EvalCriteriaUID, } from "./typing"; -import { LLMResponse, ResponseUID, QueryProgress, Dict } from "../typing"; +import { LLMResponse, ResponseUID, QueryProgress, Dict, LLMSpec } from "../typing"; import { EventEmitter } from "events"; /** @@ -66,6 +66,7 @@ export default class EvaluationFunctionExecutor { private scores: Map; // Cache function results for each example private resultsCache: Map>; + private llms: { small: string | LLMSpec, large: string | LLMSpec }; private grades: Map; // Grades for all examples private perCriteriaGrades: Dict>; // Grades per criteria private annotations: Dict; // Annotations for each response @@ -77,22 +78,23 @@ export default class EvaluationFunctionExecutor { private backgroundTaskPromise: Promise | null = null; // To keep track of the background task for generating and executing evaluation functions private criteriaQueue: EvalCriteria[] = []; // Queue for new criteria to be processed private processing = false; // To keep track of whether we are currently processing a criteria - private updateGPTCalls: (numGPT4Calls: number, numGPT35Calls: number) => void; + private updateNumLLMCalls: (numStrongModelCalls: number, numWeakModelCalls: number) => void; private logFunction: (logMessage: string) => void; /** * Initializes a new instance of the EvaluationFunctionExecutor class. * * @param evalCriteria The criteria used to generate evaluation functions. Provided/confirmed by the developer. - * @param promptTemplate The prompt demplate for the developer's LLM chain. This is useful for GPT-4 to generate correct evaluation functions. + * @param promptTemplate The prompt template for the developer's LLM chain. This is useful for the LLM to generate correct evaluation functions. * @param examples A set of variable-prompt-response triples that we want the developer to grade (and use for filtering incorrect evaluation functions). * @param existingGrades Optional. A dict in format {uid: grade}, containing existing grades. */ constructor( + genAIModels: { small: string | LLMSpec, large: string | LLMSpec }, promptTemplate: string, examples: LLMResponse[], evalCriteria: EvalCriteria[] = [], - updateGPTCalls: (numGPT4Calls: number, numGPT35Calls: number) => void, + updateNumLLMCalls: (numStrongModelCalls: number, numWeakModelCalls: number) => void, addLog: (log: string) => void, existingGrades?: Record, existingPerCriteriaGrades?: Dict>, @@ -108,6 +110,7 @@ export default class EvaluationFunctionExecutor { this.examples = examples; this.evalCriteria = evalCriteria; this.promptTemplate = promptTemplate; + this.llms = genAIModels; // Set scores and grades to default values of 0 this.scores = new Map(); @@ -141,7 +144,7 @@ export default class EvaluationFunctionExecutor { this.criteriaQueue = []; this.processing = false; - this.updateGPTCalls = updateGPTCalls; + this.updateNumLLMCalls = updateNumLLMCalls; this.logFunction = addLog; } @@ -216,14 +219,15 @@ export default class EvaluationFunctionExecutor { const result = await funcToExecute( evalFunction, + this.llms.small, example, randomPositiveExample, randomNegativeExample, ); - // Update GPT-3.5 call count by 1 if the eval method is expert + // Update weak model call count by 1 if the eval method is expert if (evalFunction.evalCriteria.eval_method === "expert") { - this.updateGPTCalls(0, 1); + this.updateNumLLMCalls(0, 1); } if (onProgress) { @@ -263,8 +267,8 @@ export default class EvaluationFunctionExecutor { emitter, badExample, ); - // Update GPT-4o call count by 1 - this.updateGPTCalls(1, 0); + // Update LLM call count by 1 + this.updateNumLLMCalls(1, 0); console.log(`Generated functions for criteria: ${criteria.shortname}`); console.log( @@ -335,14 +339,15 @@ export default class EvaluationFunctionExecutor { // Run the function on the example and if there's an error, increment skipped const result = await funcToExecute( evalFunction, + this.llms.small, example, randomPositiveExample, randomNegativeExample, ); - // Update GPT-3.5 call count by 1 if the eval method is expert + // Update weak model call count by 1 if the eval method is expert if (evalFunction.evalCriteria.eval_method === "expert") { - this.updateGPTCalls(0, 1); + this.updateNumLLMCalls(0, 1); } funcsExecuted++; @@ -382,8 +387,8 @@ export default class EvaluationFunctionExecutor { emitter, // Pass the EventEmitter instance ).then(() => { emitter.emit("criteriaProcessed"); - // Update GPT-4o call count by 1 - this.updateGPTCalls(1, 0); + // Update LLM call count by 1 + this.updateNumLLMCalls(1, 0); }); }); @@ -438,9 +443,12 @@ export default class EvaluationFunctionExecutor { * @param criteria The new evaluation criteria to be added. */ public addCriteria(criteriaList: EvalCriteria[]): void { + // See if there are criteria to remove + this.evalCriteria = this.evalCriteria.filter((c) => (!criteriaList.includes(c))); + // See if there are new criteria to add for (const criteria of criteriaList) { - if (this.evalCriteria.includes(criteria)) { + if (this.evalCriteria.includes(criteria)) { // criteria already included continue; } @@ -453,14 +461,6 @@ export default class EvaluationFunctionExecutor { this.processNextCriteria(); } } - - // See if there are criteria to remove - for (const criteria of this.evalCriteria) { - if (!criteriaList.includes(criteria)) { - console.log(`Removing criteria: ${criteria.shortname}`); - this.evalCriteria = this.evalCriteria.filter((c) => c !== criteria); - } - } } private async processNextCriteria() { @@ -591,7 +591,7 @@ export default class EvaluationFunctionExecutor { this.grades.set(exampleId, boolHolistic); } - if (perCriteriaGrades !== null) { + if (perCriteriaGrades) { this.perCriteriaGrades[exampleId] = perCriteriaGrades; // If holisticGrade was null, set it based on the perCriteriaGrades---if all criteria in the perCriteriaGrades are true, set the holisticGrade to true, else false @@ -603,7 +603,7 @@ export default class EvaluationFunctionExecutor { } } - if (annotation !== null) { + if (annotation) { this.annotations[exampleId] = annotation; } @@ -662,13 +662,6 @@ export default class EvaluationFunctionExecutor { for (const example of examples) { this.scores.set(example.uid, 0); } - - // Set grades if examples contain them - for (const example of examples) { - if (example.metavars.grade !== undefined) { - this.grades.set(example.uid, example.metavars.grade); - } - } } /** @@ -775,7 +768,7 @@ export default class EvaluationFunctionExecutor { evalFunction.evalCriteria.eval_method === "code" ? execPyFunc : executeLLMEval; - const result = await funcToExecute(evalFunction, example); + const result = await funcToExecute(evalFunction, this.llms.small, example); // Put result in cache if (!this.resultsCache.has(evalFunction)) { @@ -1013,7 +1006,7 @@ export default class EvaluationFunctionExecutor { evalFunction.evalCriteria.eval_method === "code" ? execPyFunc : executeLLMEval; - const result = await funcToExecute(evalFunction, example); + const result = await funcToExecute(evalFunction, this.llms.small, example); // Put result in cache if (!this.resultsCache.has(evalFunction)) { diff --git a/chainforge/react-server/src/backend/evalgen/utils.ts b/chainforge/react-server/src/backend/evalgen/utils.ts index e3e585ee4..6ea69676f 100644 --- a/chainforge/react-server/src/backend/evalgen/utils.ts +++ b/chainforge/react-server/src/backend/evalgen/utils.ts @@ -9,7 +9,7 @@ import { EvalFunctionResult, validEvalCriteriaFormat, } from "./typing"; -import { Dict, LLMResponse } from "../typing"; +import { Dict, LLMResponse, LLMSpec } from "../typing"; import { executejs, executepy, simpleQueryLLM } from "../backend"; import { getVarsAndMetavars, @@ -48,6 +48,7 @@ function extractJSONBlocks(mdText: string): string[] | undefined { */ export async function generateLLMEvaluationCriteria( prompt: string, + llm: string | LLMSpec, apiKeys?: Dict, promptTemplate?: string, // overrides prompt template used systemMsg?: string | null, // overrides default system message, if present. Use null to specify empty. @@ -65,7 +66,7 @@ export async function generateLLMEvaluationCriteria( async function _query() { const result = await simpleQueryLLM( detailedPrompt, // prompt - "gpt-4o", // llm + typeof llm === "string" ? llm : [llm], // llm // spec, // llm systemMsg !== undefined ? systemMsg === null @@ -114,11 +115,23 @@ export async function generateLLMEvaluationCriteria( return retryAsyncFunc(_query, 3); } +export function getPromptForGenEvalCriteriaFromDesc(desc: string) { + return `I've described a criteria I want to use to evaluate text. I want you to take the criteria and output a JSON object in the format below. + +CRITERIA: +\`\`\` +${desc} +\`\`\` + +Your response should contain a short title for the criteria ("shortname"), a description of the criteria in 2 sentences ("criteria"), and whether it should be evaluated with "code", or by an "expert" if the criteria is difficult to evaluate ("eval_method"). Your answer should be JSON within a \`\`\`json \`\`\` marker, with the following three fields: "criteria", "shortname", and "eval_method" (code or expert). The "criteria" should expand upon the user's input, the "shortname" should be a very brief title for the criteria, and this list should contain as many evaluation criteria as you can think of. Each evaluation criteria should test a unit concept that should evaluate to "true" in the ideal case. Only output JSON, nothing else.`; +} + export async function executeLLMEval( evalFunction: EvalFunction, + llm: string | LLMSpec, example: LLMResponse, - positiveExample: LLMResponse, - negativeExample: LLMResponse, + positiveExample?: LLMResponse, + negativeExample?: LLMResponse, ): Promise { // Construct call to an LLM to evaluate the example const evalPrompt = @@ -128,30 +141,25 @@ export async function executeLLMEval( example.responses[0] + "\n```"; - // Sleep a random number of seconds between 1 and 30 - // const sleep = (ms: number) => - // new Promise((resolve) => setTimeout(resolve, ms)); - // await sleep(Math.floor(Math.random() * 30) * 1000); - // Query an LLM as an evaluator let systemMessage = "You are an expert evaluator."; if ( positiveExample && - positiveExample.responses[0] && + positiveExample.responses.length > 0 && negativeExample && - negativeExample.responses[0] + negativeExample.responses.length > 0 ) { systemMessage += - " Please consider the following good example: " + - positiveExample.responses[0] + - " and bad example: " + - negativeExample.responses[0] + - " when making your evaluation."; + " Please consider the following GOOD example: \n" + + llmResponseDataToString(positiveExample.responses[0]) + + "\nand BAD example: \n" + + llmResponseDataToString(negativeExample.responses[0]) + + "\nwhen making your evaluation."; } const result = await simpleQueryLLM( evalPrompt, // prompt - "gpt-3.5-turbo-16k", // llm + typeof llm === "string" ? llm : [llm], // llm systemMessage, // system_msg ); // Get the output @@ -223,9 +231,10 @@ export async function execJSFunc( */ export async function execPyFunc( evalFunction: EvalFunction, + llm: string | LLMSpec, // not used, but provided for consistency with the other exec func signature example: LLMResponse, - positiveExample: LLMResponse, - negativeExample: LLMResponse, + positiveExample?: LLMResponse, + negativeExample?: LLMResponse, ): Promise { try { // We need to replace the function name with "evaluate", which is what is expected by backend: diff --git a/chainforge/react-server/src/text-fields-node.css b/chainforge/react-server/src/text-fields-node.css index 0069509e5..f83c0ea0f 100644 --- a/chainforge/react-server/src/text-fields-node.css +++ b/chainforge/react-server/src/text-fields-node.css @@ -1321,20 +1321,21 @@ th .content-editable-div { .gradeContainer { position: relative; - width: 20px; + overflow: visible; + /* width: 20px; */ } .gradeUpCount { position: absolute; - left: 12px; - top: -5px; + right: 0px; + top: -3px; font-size: x-small; } .gradeDownCount { position: absolute; - left: 13px; - top: 13px; + right: 0px; + bottom: 0px; font-size: x-small; } From 6b2b3cf7abc5f9af1c6e36cf879847129f286959 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Wed, 19 Mar 2025 19:34:29 -0400 Subject: [PATCH 20/35] wip getting executor to work --- .../src/EvalGen/EvalGenWizard.tsx | 102 +++-- .../react-server/src/EvalGen/FeedbackStep.tsx | 2 - .../src/EvalGen/GradeResponsesStep.tsx | 266 ++++++------ .../react-server/src/EvalGen/GradingView.tsx | 7 +- .../src/EvalGen/PickCriteriaStep.tsx | 10 +- chainforge/react-server/src/EvalGenModal.tsx | 2 +- .../src/ResponseRatingToolbar.tsx | 2 +- chainforge/react-server/src/backend/ai.ts | 4 +- .../src/backend/evalgen/executor.ts | 56 ++- .../src/backend/evalgen/oai_utils.ts | 382 +++--------------- .../react-server/src/backend/evalgen/utils.ts | 25 +- 11 files changed, 359 insertions(+), 499 deletions(-) diff --git a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx index d29df0086..93633f45a 100644 --- a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx +++ b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx @@ -11,7 +11,11 @@ import FeedbackStep from "./FeedbackStep"; import PickCriteriaStep from "./PickCriteriaStep"; import ReportCardStep from "./ReportCardStep"; import GradingResponsesStep from "./GradeResponsesStep"; -import { batchResponsesByUID, deepcopy, sampleRandomElements } from "../backend/utils"; +import { + batchResponsesByUID, + deepcopy, + sampleRandomElements, +} from "../backend/utils"; import { getRatingKeyForResponse } from "../ResponseRatingToolbar"; import EvaluationFunctionExecutor from "../backend/evalgen/executor"; import { getAIFeaturesModels } from "../backend/ai"; @@ -39,9 +43,9 @@ const EvalGenWizard: React.FC = ({ const genAIModelNames = useMemo(() => { const models = getAIFeaturesModels(genAIFeaturesProvider); return { - strong: models.large, - weak: models.small, - } + large: models.large, + small: models.small, + }; }, [genAIFeaturesProvider]); // Regroup input responses by batch UID, whenever jsonResponses changes @@ -67,7 +71,9 @@ const EvalGenWizard: React.FC = ({ const [onNextCallback, setOnNextCallback] = useState(() => () => {}); // Per-criteria grades (indexed by uid of response, then uid of criteria) - const [perCriteriaGrades, setPerCriteriaGrades] = useState>>({}); + const [perCriteriaGrades, setPerCriteriaGrades] = useState< + Dict> + >({}); const [annotation, setAnnotation] = useState(undefined); const setPerCriteriaGrade = ( responseUID: string, @@ -78,6 +84,12 @@ const EvalGenWizard: React.FC = ({ if (!grades[responseUID]) grades[responseUID] = {}; grades[responseUID][criteriaUID] = newGrade; updateGlobalRating(responseUID, "perCriteriaGrades", grades[responseUID]); + + // If the EvalGen executor is running, update the per-criteria grade for this sample: + executor?.setGradeForExample( + responseUID, + grades[responseUID]); + return { ...grades }; }); }; @@ -85,16 +97,16 @@ const EvalGenWizard: React.FC = ({ let count = 0; for (const uid in perCriteriaGrades) { const gs = perCriteriaGrades[uid]; - if (Object.values(gs).some(v => (v !== undefined && v !== null))) - count += 1; + if (Object.values(gs).some((v) => v !== undefined && v !== null)) + count += 1; } - return count; + return count; }, [perCriteriaGrades]); const minNumToGrade = useMemo(() => { - return Math.min(10, Math.ceil(batchedResponses.length * 0.5)) + return Math.min(10, Math.ceil(batchedResponses.length * 0.5)); }, [batchedResponses]); const minNumToGradeToStartExecutor = useMemo(() => { - return Math.min(5, Math.ceil(batchedResponses.length * 0.25)) + return Math.min(5, Math.ceil(batchedResponses.length * 0.25)); }, [batchedResponses]); // The EvalGen object responsible for generating, implementing, and filtering candidate implementations @@ -106,6 +118,7 @@ const EvalGenWizard: React.FC = ({ // Logs and state from the EvalGen backend const [logs, setLogs] = useState<{ date: Date; message: string }[]>([]); const [numCallsMade, setNumCallsMade] = useState({ strong: 0, weak: 0 }); + const [execProgress, setExecProgress] = useState(0); // The samples to pass the executor / grading responses features. This will be bounded // by maxNumSamplesForExecutor, instead of the whole dataset. @@ -121,15 +134,35 @@ const EvalGenWizard: React.FC = ({ else return batchedResponses.slice(); }, [batchedResponses]); + // When the user is done per-criteria grading + const handleDonePerCriteriaGrading = useCallback(async () => { + // Await completion of all gen + execution of eval funcs + await executor?.waitForCompletion(); + + // Filtering eval funcs by grades and present results + const filteredFunctions = await executor?.filterEvaluationFunctions(0.25); + console.log("Filtered Functions: ", filteredFunctions); + + // Return selected implementations to caller + // TODO + console.warn(filteredFunctions); + }, [executor]); + // Update executor whenever resps, grades, or criteria change useEffect(() => { - if (criteria.length === 0 || numResponsesGraded < minNumToGradeToStartExecutor) return; + if ( + criteria.length === 0 || + numResponsesGraded < minNumToGradeToStartExecutor + ) + return; if (!executor) { const addLog = (message: string) => { setLogs((prevLogs) => [...prevLogs, { date: new Date(), message }]); }; const ex = new EvaluationFunctionExecutor( + genAIModelNames, + apiKeys, getLikelyPromptTemplateAsContext(samplesForExecutor) ?? "", samplesForExecutor, criteria, @@ -138,24 +171,29 @@ const EvalGenWizard: React.FC = ({ setNumCallsMade((n_calls) => { n_calls.strong += strong; n_calls.weak += weak; - return {...n_calls}; + return { ...n_calls }; }); }, addLog, - undefined, // don't pass any holistic grades at this stage + undefined, // don't pass any holistic grades at this stage perCriteriaGrades, ); setExecutor(ex); - // ex.start((progress) => { - // setExecProgress(progress?.success ?? 0); - // }); + // Start executor process + ex.start((progress) => { + setExecProgress(progress?.success ?? 0); + }); } else if (executor) { // Update criteria in executor - executor.addCriteria(criteria); + executor.updateCriteria(criteria); } - - }, [criteria, samplesForExecutor, numResponsesGraded, minNumToGradeToStartExecutor]); + }, [ + criteria, + samplesForExecutor, + numResponsesGraded, + minNumToGradeToStartExecutor, + ]); const handleNext = useCallback(() => { setActive((current) => Math.min(4, current + 1)); @@ -204,7 +242,11 @@ const EvalGenWizard: React.FC = ({ } // Attempt to generate criteria using an LLM - return await generateLLMEvaluationCriteria(inputPromptTemplate, apiKeys); + return await generateLLMEvaluationCriteria( + inputPromptTemplate, + genAIModelNames.large, + apiKeys, + ); } return ( @@ -258,6 +300,7 @@ const EvalGenWizard: React.FC = ({ genCriteriaFromContext={() => genCriteriaFromContext(batchedResponses) } + genAIModelNames={genAIModelNames} setOnNextCallback={setOnNextCallback} /> )} @@ -266,9 +309,11 @@ const EvalGenWizard: React.FC = ({ = ({ < Back -
diff --git a/chainforge/react-server/src/EvalGen/FeedbackStep.tsx b/chainforge/react-server/src/EvalGen/FeedbackStep.tsx index 6ad4926f2..c86fbef85 100644 --- a/chainforge/react-server/src/EvalGen/FeedbackStep.tsx +++ b/chainforge/react-server/src/EvalGen/FeedbackStep.tsx @@ -43,7 +43,6 @@ const FeedbackStep: React.FC = ({ if (!shownResponse) return null; const key = getRatingKeyForResponse(shownResponse?.uid, "grade"); const g = storeState[key]; - console.log(shownResponse?.uid); if (g) return g[0]; else return null; }, [shownResponse, storeState]); @@ -51,7 +50,6 @@ const FeedbackStep: React.FC = ({ if (!shownResponse) return ""; const key = getRatingKeyForResponse(shownResponse?.uid, "note"); const a = storeState[key]; - console.log(shownResponse?.uid); if (a) return a[0]?.toString(); else return ""; }, [shownResponse, storeState]); diff --git a/chainforge/react-server/src/EvalGen/GradeResponsesStep.tsx b/chainforge/react-server/src/EvalGen/GradeResponsesStep.tsx index de0d9a19a..078437738 100644 --- a/chainforge/react-server/src/EvalGen/GradeResponsesStep.tsx +++ b/chainforge/react-server/src/EvalGen/GradeResponsesStep.tsx @@ -1,16 +1,13 @@ import React, { useCallback, useEffect, useState } from "react"; import { EvalCriteria } from "../backend/evalgen/typing"; -import { Dict, LLMResponse, RatingDict } from "../backend/typing"; +import { Dict, LLMResponse } from "../backend/typing"; import { ActionIcon, Button, Center, - Divider, Flex, Grid, Group, - Popover, - Radio, rem, ScrollArea, Skeleton, @@ -25,14 +22,16 @@ import GradingView from "./GradingView"; import { useDisclosure } from "@mantine/hooks"; import { v4 as uuid } from "uuid"; import { - IconPencil, IconRobot, IconTerminal2, IconThumbDown, IconThumbUp, IconTrash, } from "@tabler/icons-react"; -import { generateLLMEvaluationCriteria, getPromptForGenEvalCriteriaFromDesc } from "../backend/evalgen/utils"; +import { + generateLLMEvaluationCriteria, + getPromptForGenEvalCriteriaFromDesc, +} from "../backend/evalgen/utils"; import useStore from "../store"; import EvaluationFunctionExecutor from "../backend/evalgen/executor"; @@ -45,7 +44,6 @@ const ThumbUpDownButtons = ({ onChangeGrade: (newGrade: boolean | undefined) => void; getGradeCount: (grade: boolean | undefined) => number; }) => { - const true_count = getGradeCount(true); const false_count = getGradeCount(false); @@ -82,7 +80,9 @@ const ThumbUpDownButtons = ({ size="20pt" fill={grade === false ? "pink" : "white"} /> - {false_count > 0 &&
{false_count}
} + {false_count > 0 && ( +
{false_count}
+ )} @@ -230,13 +230,17 @@ interface GradingResponsesStepProps { onPrevious: () => void; executor: EvaluationFunctionExecutor | null; logs: { date: Date; message: string }[]; - genAIModelNames: { strong: string; weak: string }; + genAIModelNames: { large: string; small: string }; numCallsMade: { strong: number; weak: number }; responses: LLMResponse[]; criteria: EvalCriteria[]; setCriteria: React.Dispatch>; - grades: Dict>; // per-criteria grades - setPerCriteriaGrade: (responseUID: string, criteriaUID: string, newGrade: boolean | undefined) => void; + grades: Dict>; // per-criteria grades + setPerCriteriaGrade: ( + responseUID: string, + criteriaUID: string, + newGrade: boolean | undefined, + ) => void; setOnNextCallback: React.Dispatch unknown>>; } @@ -268,10 +272,7 @@ const GradingResponsesStep: React.FC = ({ const getStateValue = (stateId: number) => { return Math.floor(Math.random() * 30 + 6); }; - const getGradeCount = ( - criteriaUID: string, - grade: boolean | undefined, - ) => { + const getGradeCount = (criteriaUID: string, grade: boolean | undefined) => { let count = 0; for (const respUid in grades) { count += grade === grades[respUid][criteriaUID] ? 1 : 0; @@ -346,6 +347,7 @@ const GradingResponsesStep: React.FC = ({ generateLLMEvaluationCriteria( "", + genAIModelNames.large, apiKeys, `I've given some feedback on some text output. Use this feedback to decide on a single new evaluation criteria with a yes/no answer, only if the feedback isn't encompassed by existing criteria. I want you to take the criteria and output a JSON object in the format below. @@ -400,6 +402,7 @@ const GradingResponsesStep: React.FC = ({ // Make async LLM call to expand criteria generateLLMEvaluationCriteria( "", + genAIModelNames.large, apiKeys, getPromptForGenEvalCriteriaFromDesc(desc), // prompt null, // system_msg @@ -437,45 +440,45 @@ const GradingResponsesStep: React.FC = ({ /> - - - LLM Activity - - {/* GPT Call Tally */} - - Executed {numCallsMade.strong} {genAIModelNames.strong} calls and {numCallsMade.weak}{" "} - {genAIModelNames.weak} calls. - - -
{ - if (el) { - el.scrollTop = el.scrollHeight; - } - }} - > - {logs.map((log, index) => ( -
- - {log.date.toLocaleString()} -{" "} - - {log.message} -
- ))} -
+ + + LLM Activity + + {/* GPT Call Tally */} + + Executed {numCallsMade.strong} {genAIModelNames.large} calls and{" "} + {numCallsMade.weak} {genAIModelNames.small} calls. + +
{ + if (el) { + el.scrollTop = el.scrollHeight; + } + }} + > + {logs.map((log, index) => ( +
+ + {log.date.toLocaleString()} -{" "} + + {log.message} +
+ ))} +
+
{/* Progress bar */} {/* @@ -488,62 +491,74 @@ const GradingResponsesStep: React.FC = ({ */} - +
Per-criteria grading
- -
-
- {criteria.map((e) => ( - handleChangeCriteria(newCrit, e.uid)} - onDelete={() => handleDeleteCriteria(e.uid)} - grade={ - (shownResponse && grades[shownResponse.uid]) ? grades[shownResponse.uid][e.uid] : undefined - } - getGradeCount={(grade) => { - return shownResponse - ? getGradeCount( - // shownResponse.uid, - e.uid, - grade, - ) - : 0; - }} - onChangeGrade={(newGrade) => { - if (shownResponse) - setPerCriteriaGrade(shownResponse.uid, e.uid, newGrade); - }} - initiallyOpen={true} - getStateValue={(stateId) => getStateValue(stateId)} - /> - ))} - {isLoadingCriteria > 0 ? ( - Array.from( - { length: isLoadingCriteria }, - (v: unknown, idx: number) => ( - - ), - ) - ) : ( - <> - )} -
+
+
+ {criteria.map((e) => ( + handleChangeCriteria(newCrit, e.uid)} + onDelete={() => handleDeleteCriteria(e.uid)} + grade={ + shownResponse && grades[shownResponse.uid] + ? grades[shownResponse.uid][e.uid] + : undefined + } + getGradeCount={(grade) => { + return shownResponse + ? getGradeCount( + // shownResponse.uid, + e.uid, + grade, + ) + : 0; + }} + onChangeGrade={(newGrade) => { + if (shownResponse) + setPerCriteriaGrade(shownResponse.uid, e.uid, newGrade); + }} + initiallyOpen={true} + getStateValue={(stateId) => getStateValue(stateId)} + /> + ))} + {isLoadingCriteria > 0 ? ( + Array.from( + { length: isLoadingCriteria }, + (v: unknown, idx: number) => ( + + ), + ) + ) : ( + <> + )} +
-
- {/* +
+ {/* */}
- - - - {/* + {/* Suggest New Criteria @@ -635,24 +647,32 @@ const GradingResponsesStep: React.FC<GradingResponsesStepProps> = ({ </Button> </Group> </Radio.Group> */} - {/* </Stack> */} - </div> - - <Textarea value={newCriteriaDesc} onChange={(e) => setNewCriteriaDesc(e.currentTarget.value)} label="Add new criteria:" placeholder="Describe the criteria to add." ml="md" mr="md"></Textarea> - <Group position="right" mr="md" mt="sm"> - <Button - color="green" - variant="filled" - disabled={newCriteriaDesc?.trim().length === 0 || isLoadingCriteria > 0} - onClick={() => { - addCriteria(newCriteriaDesc); - setNewCriteriaDesc(""); - }} - > - + Add criteria - </Button> - </Group> + {/* </Stack> */} + </div> + <Textarea + value={newCriteriaDesc} + onChange={(e) => setNewCriteriaDesc(e.currentTarget.value)} + label="Add new criteria:" + placeholder="Describe the criteria to add." + ml="md" + mr="md" + ></Textarea> + <Group position="right" mr="md" mt="sm"> + <Button + color="green" + variant="filled" + disabled={ + newCriteriaDesc?.trim().length === 0 || isLoadingCriteria > 0 + } + onClick={() => { + addCriteria(newCriteriaDesc); + setNewCriteriaDesc(""); + }} + > + + Add criteria + </Button> + </Group> </ScrollArea> </Grid.Col> </Grid> diff --git a/chainforge/react-server/src/EvalGen/GradingView.tsx b/chainforge/react-server/src/EvalGen/GradingView.tsx index e26927992..41f3767f4 100644 --- a/chainforge/react-server/src/EvalGen/GradingView.tsx +++ b/chainforge/react-server/src/EvalGen/GradingView.tsx @@ -110,7 +110,12 @@ const GradingView: React.FC<GradingViewProps> = ({ {/* Go forward to the next response */} <Tooltip label="To next response" withArrow> - <Button variant="white" color="dark" bg="transparent" onClick={gotoNextResponse}> + <Button + variant="white" + color="dark" + bg="transparent" + onClick={gotoNextResponse} + > <IconChevronRight /> </Button> </Tooltip> diff --git a/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx index 27d2fae36..000aadb4b 100644 --- a/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx +++ b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx @@ -33,7 +33,10 @@ import { } from "@tabler/icons-react"; import useStore from "../store"; import { accuracyToColor, cmatrixTextAnnotations } from "../backend/utils"; -import { generateLLMEvaluationCriteria, getPromptForGenEvalCriteriaFromDesc } from "../backend/evalgen/utils"; +import { + generateLLMEvaluationCriteria, + getPromptForGenEvalCriteriaFromDesc, +} from "../backend/evalgen/utils"; import { v4 as uuid } from "uuid"; import Plot from "react-plotly.js"; @@ -44,6 +47,7 @@ interface PickCriteriaStepProps { setCriteria: React.Dispatch<React.SetStateAction<EvalCriteria[]>>; genCriteriaFromContext: () => Promise<EvalCriteria[] | undefined>; setOnNextCallback: React.Dispatch<React.SetStateAction<() => unknown>>; + genAIModelNames: { large: string; small: string }; } interface CriteriaCardProps { @@ -375,6 +379,7 @@ const PickCriteriaStep: React.FC<PickCriteriaStepProps> = ({ criteria, setCriteria, genCriteriaFromContext, + genAIModelNames, }) => { // State for criteria cards const [addCriteriaValue, setAddCriteriaValue] = useState(""); @@ -401,6 +406,7 @@ const PickCriteriaStep: React.FC<PickCriteriaStepProps> = ({ // Make async LLM call to expand criteria generateLLMEvaluationCriteria( "", + genAIModelNames.large, apiKeys, getPromptForGenEvalCriteriaFromDesc(addCriteriaValue), // prompt null, // system_msg @@ -423,7 +429,7 @@ const PickCriteriaStep: React.FC<PickCriteriaStepProps> = ({ setIsLoadingCriteria((num) => num - 1); }); }; - + const updateCriteria = ( newValue: string, critIdx: number, diff --git a/chainforge/react-server/src/EvalGenModal.tsx b/chainforge/react-server/src/EvalGenModal.tsx index 9ea550273..b35bd7d38 100644 --- a/chainforge/react-server/src/EvalGenModal.tsx +++ b/chainforge/react-server/src/EvalGenModal.tsx @@ -499,7 +499,7 @@ const EvalGenModal = forwardRef<EvalGenModalRef, NonNullable<unknown>>( // }); } else if (executor) { // Update criteria in executor - executor.addCriteria(criteria); + executor.updateCriteria(criteria); } updateCriteriaForDisplay(); diff --git a/chainforge/react-server/src/ResponseRatingToolbar.tsx b/chainforge/react-server/src/ResponseRatingToolbar.tsx index 8e40e6e91..f07d77839 100644 --- a/chainforge/react-server/src/ResponseRatingToolbar.tsx +++ b/chainforge/react-server/src/ResponseRatingToolbar.tsx @@ -123,7 +123,7 @@ const ResponseRatingToolbar: React.FC<ResponseRatingToolbarProps> = ({ // Override the text in the internal textarea whenever upstream annotation changes. useEffect(() => { - setNoteText(note !== undefined ? note.toString() : ""); + setNoteText(note != null ? note.toString() : ""); }, [note]); // The label for the pop-up comment box. diff --git a/chainforge/react-server/src/backend/ai.ts b/chainforge/react-server/src/backend/ai.ts index 389866441..cb89e4918 100644 --- a/chainforge/react-server/src/backend/ai.ts +++ b/chainforge/react-server/src/backend/ai.ts @@ -25,8 +25,8 @@ export type Row = string; const AIFeaturesLLMs = [ { provider: "OpenAI", - small: { value: "gpt-4o", label: "OpenAI GPT4o" }, - large: { value: "gpt-4", label: "OpenAI GPT4" }, + small: { value: "gpt-4o-mini", label: "OpenAI GPT4o-mini" }, + large: { value: "gpt-4o", label: "OpenAI GPT4o" }, }, { provider: "Bedrock", diff --git a/chainforge/react-server/src/backend/evalgen/executor.ts b/chainforge/react-server/src/backend/evalgen/executor.ts index 07532cd5d..7f429ac3e 100644 --- a/chainforge/react-server/src/backend/evalgen/executor.ts +++ b/chainforge/react-server/src/backend/evalgen/executor.ts @@ -11,7 +11,13 @@ import { EvalFunctionSetReport, EvalCriteriaUID, } from "./typing"; -import { LLMResponse, ResponseUID, QueryProgress, Dict, LLMSpec } from "../typing"; +import { + LLMResponse, + ResponseUID, + QueryProgress, + Dict, + LLMSpec, +} from "../typing"; import { EventEmitter } from "events"; /** @@ -66,7 +72,8 @@ export default class EvaluationFunctionExecutor { private scores: Map<ResponseUID, number>; // Cache function results for each example private resultsCache: Map<EvalFunction, Map<ResponseUID, EvalFunctionResult>>; - private llms: { small: string | LLMSpec, large: string | LLMSpec }; + private llms: { small: string | LLMSpec; large: string | LLMSpec }; + private apiKeys: Dict; private grades: Map<ResponseUID, boolean>; // Grades for all examples private perCriteriaGrades: Dict<Dict<boolean | undefined>>; // Grades per criteria private annotations: Dict<string>; // Annotations for each response @@ -78,7 +85,11 @@ export default class EvaluationFunctionExecutor { private backgroundTaskPromise: Promise<void> | null = null; // To keep track of the background task for generating and executing evaluation functions private criteriaQueue: EvalCriteria[] = []; // Queue for new criteria to be processed private processing = false; // To keep track of whether we are currently processing a criteria - private updateNumLLMCalls: (numStrongModelCalls: number, numWeakModelCalls: number) => void; + private updateNumLLMCalls: ( + numStrongModelCalls: number, + numWeakModelCalls: number, + ) => void; + private logFunction: (logMessage: string) => void; /** @@ -90,11 +101,15 @@ export default class EvaluationFunctionExecutor { * @param existingGrades Optional. A dict in format {uid: grade}, containing existing grades. */ constructor( - genAIModels: { small: string | LLMSpec, large: string | LLMSpec }, + genAIModels: { small: string | LLMSpec; large: string | LLMSpec }, + apiKeys: Dict, promptTemplate: string, examples: LLMResponse[], evalCriteria: EvalCriteria[] = [], - updateNumLLMCalls: (numStrongModelCalls: number, numWeakModelCalls: number) => void, + updateNumLLMCalls: ( + numStrongModelCalls: number, + numWeakModelCalls: number, + ) => void, addLog: (log: string) => void, existingGrades?: Record<ResponseUID, boolean>, existingPerCriteriaGrades?: Dict<Dict<boolean | undefined>>, @@ -111,6 +126,7 @@ export default class EvaluationFunctionExecutor { this.evalCriteria = evalCriteria; this.promptTemplate = promptTemplate; this.llms = genAIModels; + this.apiKeys = apiKeys; // Set scores and grades to default values of 0 this.scores = new Map<ResponseUID, number>(); @@ -262,10 +278,12 @@ export default class EvaluationFunctionExecutor { await generateFunctionsForCriteria( criteria, + this.llms.large, this.promptTemplate, this.examples[Math.floor(Math.random() * this.examples.length)], emitter, badExample, + this.apiKeys, ); // Update LLM call count by 1 this.updateNumLLMCalls(1, 0); @@ -382,9 +400,12 @@ export default class EvaluationFunctionExecutor { console.log(criteria); generateFunctionsForCriteria( criteria, + this.llms.large, this.promptTemplate, this.examples[Math.floor(Math.random() * this.examples.length)], emitter, // Pass the EventEmitter instance + undefined, + this.apiKeys, ).then(() => { emitter.emit("criteriaProcessed"); // Update LLM call count by 1 @@ -435,20 +456,23 @@ export default class EvaluationFunctionExecutor { } /** - * Adds another evaluation criteria and triggers the generation and execution of evaluation functions for the new criteria. + * Updates the set of evaluation criteria and triggers the generation and execution of evaluation functions for any new criteria. * This method allows the client to add new evaluation criteria after the executor has been initialized. * The new criteria will be processed in parallel with the existing criteria. * The method returns immediately, allowing the client to continue with other tasks. * - * @param criteria The new evaluation criteria to be added. + * @param criteria The new state of the evaluation criteria list. */ - public addCriteria(criteriaList: EvalCriteria[]): void { + public updateCriteria(criteriaList: EvalCriteria[]): void { // See if there are criteria to remove - this.evalCriteria = this.evalCriteria.filter((c) => (!criteriaList.includes(c))); + this.evalCriteria = this.evalCriteria.filter( + (c) => !criteriaList.includes(c), + ); // See if there are new criteria to add for (const criteria of criteriaList) { - if (this.evalCriteria.includes(criteria)) { // criteria already included + if (this.evalCriteria.includes(criteria)) { + // criteria already included continue; } @@ -768,7 +792,11 @@ export default class EvaluationFunctionExecutor { evalFunction.evalCriteria.eval_method === "code" ? execPyFunc : executeLLMEval; - const result = await funcToExecute(evalFunction, this.llms.small, example); + const result = await funcToExecute( + evalFunction, + this.llms.small, + example, + ); // Put result in cache if (!this.resultsCache.has(evalFunction)) { @@ -1006,7 +1034,11 @@ export default class EvaluationFunctionExecutor { evalFunction.evalCriteria.eval_method === "code" ? execPyFunc : executeLLMEval; - const result = await funcToExecute(evalFunction, this.llms.small, example); + const result = await funcToExecute( + evalFunction, + this.llms.small, + example, + ); // Put result in cache if (!this.resultsCache.has(evalFunction)) { diff --git a/chainforge/react-server/src/backend/evalgen/oai_utils.ts b/chainforge/react-server/src/backend/evalgen/oai_utils.ts index b9f23ce18..1f8634f32 100644 --- a/chainforge/react-server/src/backend/evalgen/oai_utils.ts +++ b/chainforge/react-server/src/backend/evalgen/oai_utils.ts @@ -1,339 +1,81 @@ // import { env as process_env } from "process"; import { EventEmitter } from "events"; // import { AzureKeyCredential, OpenAIClient } from "@azure/openai"; -import { get_openai_api_key } from "../utils"; -type ContentType = "criteria" | "python_fn" | "llm_eval"; +import { llmResponseDataToString } from "../utils"; +import { simpleQueryLLM } from "../backend"; +import { Dict, LLMSpec } from "../typing"; +import { extractMdBlocks } from "./utils"; +type ContentType = "python_fn" | "llm_eval"; -export class OpenAIStreamer extends EventEmitter { - private buffer = ""; - private isJsonContentStarted = false; - private isPythonContentStarted = false; - private pythonBlockBuffer = ""; - // private client; - private openai_api_key; +export class EvalGenAssertionEmitter extends EventEmitter { + private apiKeys: Dict | undefined; - constructor() { + constructor(apiKeys?: Dict) { super(); - - const OPENAI_API_KEY = get_openai_api_key(); - this.openai_api_key = OPENAI_API_KEY; - - // this.client = new OpenAIClient( - // process?.env?.AZURE_OPENAI_ENDPOINT ?? AZURE_OPENAI_ENDPOINT ?? "", - // new AzureKeyCredential( - // process?.env?.AZURE_OPENAI_KEY ?? AZURE_OPENAI_KEY ?? "", - // ), - // ); - - // this.client = new OpenAIApi(configuration); - } - - private buildMessages(prompt: string): any[] { - return [ - { - content: - "You are an expert Python programmer and helping me write assertions for my LLM pipeline. An LLM pipeline accepts an example and prompt template, fills the template's placeholders with the example, and generates a response.", - role: "system", - }, - { role: "user", content: prompt }, - ]; - } - - private resetBuffer(): void { - this.buffer = ""; - this.isJsonContentStarted = false; - this.isPythonContentStarted = false; - this.pythonBlockBuffer = ""; + this.apiKeys = apiKeys; } async generate( prompt: string, - model: string, - type: ContentType, + llm: string | LLMSpec, + contentType: ContentType, ): Promise<void> { - this.resetBuffer(); - const messages = this.buildMessages(prompt); - - // const events = await this.client.listChatCompletions(model, messages, {}); - - // for await (const event of events) { - // for (const choice of event.choices) { - // const delta = choice.delta?.content; - // if (delta !== undefined) { - // if (type === "criteria") { - // this.processCriteriaDelta(delta); - // } else if (type === "llm_eval") { - // this.processStringDelta(delta); - // } else if (type === "python_fn") { - // this.processFunctionDelta(delta); - // } else { - // throw new Error("Invalid type"); - // } - // } - // } - // } - - // Used restapi as here: https://stackoverflow.com/questions/76137987/openai-completion-stream-with-node-js-and-express-js - - const streamRes = await fetch( - "https://api.openai.com/v1/chat/completions", - { - method: "POST", - headers: { - Authorization: `Bearer ${this.openai_api_key}`, - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model, - messages, - stream: true, - }), - }, + const emit_prompt = ((p: string) => this.emit("function", p)).bind(this); + + const result = await simpleQueryLLM( + prompt, // prompt + typeof llm === "string" ? llm : [llm], // llm + // spec, // llm + "You are an expert Python programmer and helping me write assertions for my LLM pipeline. An LLM pipeline accepts an example and prompt template, fills the template's placeholders with the example, and generates a response.", // system_msg + this.apiKeys, // API keys (if any) ); - const reader = streamRes.body?.getReader(); - if (!reader) { - console.error("Error initializing reader for OpenAI requests."); - return; - } - - let done = false; - let concenattedJsonStrn = ""; - - while (!done) { - const { value, done: readerDone } = await reader.read(); - done = readerDone; - const buffer = Buffer.from(value as ArrayBuffer); - const textPayload = buffer.toString(); - concenattedJsonStrn += textPayload; - if ( - !concenattedJsonStrn.includes(`data: `) || - !concenattedJsonStrn.includes(`\n\n`) - ) { - continue; - } - const payloads = concenattedJsonStrn.toString().split("\n\n"); - concenattedJsonStrn = ""; - - for (const payload of payloads) { - if (payload.includes("[DONE]")) return; - if (payload.startsWith("data:")) { - try { - const data = JSON.parse(payload.replace("data: ", "")); - const delta: undefined | string = data.choices[0].delta?.content; - if (delta !== undefined) { - if (type === "criteria") { - this.processCriteriaDelta(delta); - } else if (type === "llm_eval") { - this.processStringDelta(delta); - } else if (type === "python_fn") { - this.processFunctionDelta(delta); - } else { - throw new Error("Invalid type"); - } - } - } catch (error) { - console.log(`Error with JSON.parse and ${payload}.\n${error}`); - concenattedJsonStrn += payload; - } - } - } - } - - this.emit("end"); // Signal that streaming is complete - } - - private processCriteriaDelta(delta: string): void { - this.buffer += delta; - if (!this.isJsonContentStarted) { - const startIndex = this.buffer.indexOf("```json\n"); - if (startIndex !== -1) { - this.isJsonContentStarted = true; - this.buffer = this.buffer.substring(startIndex + 8); // Skip the '```json \n' part - } - // Trim the buffer to avoid whitespace at beginning and end - this.buffer = this.buffer.trim(); - } - - if (this.isJsonContentStarted) { - this.tryEmitEvalCriteria(); - } - } - - private tryEmitEvalCriteria(): void { - let braceCount = 0; - let lastIndex = 0; // Track start of the next JSON object - - // Detect and handle the start of an array - if (this.buffer.trim().startsWith("[")) { - this.buffer = this.buffer.trim().substring(1); // Remove the leading '[' - } - - // Remove leading commas if they exist right before a JSON object - this.buffer = this.buffer.replace(/^\s*,\s*/, ""); - - for (let i = 0; i < this.buffer.length; i++) { - const char = this.buffer[i]; - if (char === "{") { - braceCount++; - } else if (char === "}") { - braceCount--; - } - - // When a complete JSON object is detected - if (braceCount === 0 && char === "}") { - const jsonStr = this.buffer.substring(lastIndex, i + 1).trim(); - lastIndex = i + 1; // Update for potential next object - - // Remove any leading comma for the next object - if (this.buffer[lastIndex] === ",") { - lastIndex++; // Skip the comma for the next object - } - - try { - const jsonObj = JSON.parse(jsonStr); - this.emit("evalCriteria", jsonObj); - } catch (error) { - console.error("Error parsing JSON:", error); - } - } - } - - // Keep any incomplete JSON for the next delta - this.buffer = this.buffer.substring(lastIndex).trim(); - } - - private processStringDelta(delta: string): void { - this.buffer += delta; - if (!this.isJsonContentStarted) { - const startIndex = this.buffer.indexOf("```json\n"); - if (startIndex !== -1) { - this.isJsonContentStarted = true; - this.buffer = this.buffer.substring(startIndex + 8); // Skip the '```json\n' part - } - } - - if (this.isJsonContentStarted) { - this.tryEmitStrings(); - } - } - - private tryEmitStrings(): void { - let quoteCount = 0; - let lastIndex = 0; // Track the start of the next string - - // Detect and handle the start of an array - if (this.buffer.startsWith("[")) { - this.buffer = this.buffer.substring(1); // Remove the leading '[' - } - - // Remove leading commas and whitespace that might be right before a JSON string - this.buffer = this.buffer.replace(/^\s*,\s*/, ""); - - for (let i = 0; i < this.buffer.length; i++) { - const char = this.buffer[i]; - - // Toggle quote count on encountering quotes, ignoring escaped quotes - if (char === '"' && (i === 0 || this.buffer[i - 1] !== "\\")) { - quoteCount++; - } - - // When a complete string is detected (every second quote) - if (quoteCount === 2) { - const jsonString = this.buffer.substring(lastIndex, i + 1); // Include the closing quote - lastIndex = i + 1; // Update for the potential next string - - // Remove any leading comma for the next string - if (this.buffer[lastIndex] === ",") { - lastIndex++; // Skip the comma for the next string - } - - quoteCount = 0; // Reset for the next string - - // Extract the string value from JSON - try { - const strValue = JSON.parse(jsonString); - this.emit("function", strValue); - } catch (error) { - console.error("Error parsing JSON string:", error); - } - } - } - - // Keep any incomplete JSON string for the next delta - this.buffer = this.buffer.substring(lastIndex).trim(); - } - - private processFunctionDelta(delta: string): void { - this.buffer += delta; - if (!this.isPythonContentStarted) { - let startIndex = this.buffer.indexOf("```python"); - if (startIndex === -1) startIndex = this.buffer.indexOf("```"); - if (startIndex !== -1) { - this.isPythonContentStarted = true; - this.buffer = this.buffer.substring(startIndex); - } - } else { - const endIndex = this.buffer.indexOf("```", 8); // Look for end marker after the start - if (endIndex !== -1) { - // Extract Python code block - const pythonCode = this.buffer - .replace("```python", "") - .replaceAll("```", "") - .trim(); - this.pythonBlockBuffer += pythonCode; - this.buffer = this.buffer.substring(endIndex + 3); - this.isPythonContentStarted = false; - // Now process the Python code block for functions - this.tryEmitFunctionCriteria(); - } - } - } - - private tryEmitFunctionCriteria(): void { - // Split the buffer into lines - const lines = this.pythonBlockBuffer.split("\n"); - let collecting = false; - let functionBody: string[] = []; - let baseIndentation = 0; - - for (const line of lines) { - if (!collecting) { - // Check if the line is a function definition - if (line.trim().startsWith("def ")) { - collecting = true; - functionBody = [line]; - // Determine the base indentation level - baseIndentation = line.indexOf("def"); - } + if (result.errors && Object.keys(result.errors).length > 0) + throw new Error(Object.values(result.errors as Dict)[0].toString()); + + // Get output (text from LLM response) + const output = llmResponseDataToString(result.responses[0].responses[0]); + console.log("Streamer: LLM said: ", output); // for debuggging + + // Attempt to extract output depending on content type + if (contentType === "llm_eval") { + // Expected output is a ``json block that is just a list of three strings representing the prompts i.e. ["str1", "str2", "str3"] + // Attempt to extract JSON blocks (strings) from output + const json_blocks = extractMdBlocks(output, "json"); + if (json_blocks === undefined || json_blocks.length === 0) + throw new Error( + "EvalGen: Could not parse LLM response into evaluation prompt: No JSON detected in output.", + ); + + // If we passed, this should be a list of strings: + const prompts = json_blocks.flatMap((b) => JSON.parse(b)); + // Verify format: + if (prompts.every((p) => typeof p === "string")) { + // If these are all strings, we are good to go-- + // Emit all the LLM eval prompt candidates in one burst + prompts.forEach(emit_prompt); } else { - // Check if the line returns to the base indentation level or lower, indicating the end of the function - const currentIndentation = line.search(/\S|$/); // Find first non-space character or end of line - if (currentIndentation <= baseIndentation) { - // Emit the collected function body - this.emit("function", functionBody.join("\n")); - functionBody = []; // Reset for the next function - collecting = false; - - // If the current line is another function definition, start collecting again - if (line.trim().startsWith("def ")) { - collecting = true; - functionBody = [line]; - baseIndentation = line.indexOf("def"); - } - } else if (collecting) { - // Continue collecting the function body - functionBody.push(line); - } + console.error( + "Unexpected output type after JSON parsing: At least generated LLM eval prompt is not a string.", + prompts, + ); + throw new Error("Unexpected output type after JSON parsing"); } + } else if (contentType === "python_fn") { + // Expected output has ~3 Python codeblocks within ```python markers + // Attempt to extract code blocks from output + const code_blocks = extractMdBlocks(output, "python"); + if (code_blocks === undefined || code_blocks.length === 0) + throw new Error( + "EvalGen: Could not parse LLM response into Python function: No code detected in output.", + ); + + // If we passed, this should be a list of Python code functions. We assume it is OK, and treat them separately: + code_blocks.forEach(emit_prompt); + } else { + throw new Error("Unknown content type: " + contentType); } - // Check if there's a function body collected at the end of the buffer without returning to the base indentation - if (collecting && functionBody.length > 0) { - this.emit("function", functionBody.join("\n")); - } - - // Clear the buffer after processing - this.pythonBlockBuffer = ""; + this.emit("end"); // Signal that streaming is complete } } diff --git a/chainforge/react-server/src/backend/evalgen/utils.ts b/chainforge/react-server/src/backend/evalgen/utils.ts index 6ea69676f..fdc47d170 100644 --- a/chainforge/react-server/src/backend/evalgen/utils.ts +++ b/chainforge/react-server/src/backend/evalgen/utils.ts @@ -17,19 +17,22 @@ import { retryAsyncFunc, } from "../utils"; import { v4 as uuid } from "uuid"; -import { OpenAIStreamer } from "./oai_utils"; +import { EvalGenAssertionEmitter } from "./oai_utils"; import { buildContextPromptForVarsMetavars, buildGenEvalCodePrompt, } from "../../AiPopover"; /** - * Extracts substrings within "```json" and "```" ticks. Excludes the ticks from return. + * Extracts substrings within "```" and "```" ticks. Excludes the ticks from return. * @param mdText * @returns */ -function extractJSONBlocks(mdText: string): string[] | undefined { - const regex = /```json(.*?)```/gs; +export function extractMdBlocks( + mdText: string, + blockName: string, +): string[] | undefined { + const regex = new RegExp(`\`\`\`${blockName}(.*?)\`\`\``, "gs"); const matches = mdText.match(regex); if (matches) return matches.map((s) => s.replace("```json", "").replace("```", "")); @@ -84,7 +87,7 @@ export async function generateLLMEvaluationCriteria( // console.log("LLM said: ", output); // for debuggging // Attempt to extract JSON blocks (strings) from input - const json_blocks = extractJSONBlocks(output); + const json_blocks = extractMdBlocks(output, "json"); if (json_blocks === undefined || json_blocks.length === 0) throw new Error( "EvalGen: Could not parse LLM response into evaluation critera: No JSON detected in output.", @@ -138,7 +141,7 @@ export async function executeLLMEval( "Evaluate the text below according to this criteria: " + evalFunction.code + ' Only return "yes" or "no", nothing else.\n\n```\n' + - example.responses[0] + + llmResponseDataToString(example.responses[0]) + "\n```"; // Query an LLM as an evaluator @@ -286,10 +289,12 @@ export async function execPyFunc( export async function generateFunctionsForCriteria( criteria: EvalCriteria, + llm: string | LLMSpec, promptTemplate: string, example: LLMResponse, emitter: EventEmitter, badExample?: LLMResponse, + apiKeys?: Dict, ): Promise<void> { const functionGenPrompt = buildFunctionGenPrompt( criteria, @@ -300,7 +305,7 @@ export async function generateFunctionsForCriteria( console.log("Function generation prompt:", functionGenPrompt); try { - const streamer = new OpenAIStreamer(); + const streamer = new EvalGenAssertionEmitter(apiKeys); streamer.on("function", (functionDefinition: string) => { processAndEmitFunction(criteria, functionDefinition, emitter); @@ -308,7 +313,7 @@ export async function generateFunctionsForCriteria( const modelType = criteria.eval_method === "expert" ? "llm_eval" : "python_fn"; - await streamer.generate(functionGenPrompt, "gpt-4o", modelType); + await streamer.generate(functionGenPrompt, llm, modelType); } catch (error) { console.error("Error generating function for criteria:", error); throw new Error( @@ -328,7 +333,7 @@ function buildFunctionGenPrompt( badExampleSection = ` Here is an example response that DOES NOT meet the criteria: \`\`\` - ${badExample.responses[0]} + ${llmResponseDataToString(badExample.responses[0])} \`\`\` `; } @@ -343,7 +348,7 @@ function buildFunctionGenPrompt( ${badExampleSection} Create 3 implementations of the criterion. ${buildGenEvalCodePrompt("python", buildContextPromptForVarsMetavars(getVarsAndMetavars([example])), criteria.criteria, true)} - Be creative in your implementations. Our goal is to explore diverse approaches to evaluate LLM responses effectively. Try to avoid using third-party libraries for code-based evaluation methods. Include the full implementation of each function. Each function should return only True or False.`; + Be creative in your implementations. Our goal is to explore diverse approaches to evaluate LLM responses effectively. Try to avoid using third-party libraries for code-based evaluation methods. Include the full implementation of each function in separate "\`\`\`python" blocks. Each function should return only True or False.`; return prompt; } From bc453a973c7872254590b2c6f03238e17db9ee9a Mon Sep 17 00:00:00 2001 From: Ian Arawjo <fatso784@gmail.com> Date: Sat, 22 Mar 2025 09:55:50 -0400 Subject: [PATCH 21/35] wip --- chainforge/react-server/src/EvalGen/EvalGenWizard.tsx | 4 +--- chainforge/react-server/src/backend/evalgen/oai_utils.ts | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx index 93633f45a..fc2684d12 100644 --- a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx +++ b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx @@ -86,9 +86,7 @@ const EvalGenWizard: React.FC<EvalGenWizardProps> = ({ updateGlobalRating(responseUID, "perCriteriaGrades", grades[responseUID]); // If the EvalGen executor is running, update the per-criteria grade for this sample: - executor?.setGradeForExample( - responseUID, - grades[responseUID]); + executor?.setGradeForExample(responseUID, grades[responseUID]); return { ...grades }; }); diff --git a/chainforge/react-server/src/backend/evalgen/oai_utils.ts b/chainforge/react-server/src/backend/evalgen/oai_utils.ts index 1f8634f32..840789119 100644 --- a/chainforge/react-server/src/backend/evalgen/oai_utils.ts +++ b/chainforge/react-server/src/backend/evalgen/oai_utils.ts @@ -20,7 +20,7 @@ export class EvalGenAssertionEmitter extends EventEmitter { llm: string | LLMSpec, contentType: ContentType, ): Promise<void> { - const emit_prompt = ((p: string) => this.emit("function", p)).bind(this); + const emit_prompt = (p: string) => this.emit("function", p); const result = await simpleQueryLLM( prompt, // prompt From 2a0d6c42b04911ab49677dd549d915cee0aa5219 Mon Sep 17 00:00:00 2001 From: Ian Arawjo <fatso784@gmail.com> Date: Tue, 25 Mar 2025 23:37:04 -0400 Subject: [PATCH 22/35] Fixed bug in executor (whew) --- .../react-server/src/backend/backend.ts | 80 ++++++--- .../src/backend/evalgen/executor.ts | 165 ++---------------- .../react-server/src/backend/evalgen/utils.ts | 69 ++++++-- chainforge/react-server/src/backend/utils.ts | 49 ++++++ 4 files changed, 180 insertions(+), 183 deletions(-) diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 988c837fa..00dd413e7 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -17,6 +17,7 @@ import { LLMResponseData, PromptVarType, StringOrHash, + ChatHistory, } from "./typing"; import { LLM, LLMProvider, getEnumName, getProvider } from "./models"; import { @@ -31,6 +32,7 @@ import { llmResponseDataToString, extendArray, extendArrayDict, + stripWrappingQuotes, } from "./utils"; import StorageCache, { StringLookup } from "./cache"; import { PromptPipeline } from "./query"; @@ -1280,41 +1282,45 @@ export async function executepy( * * @param id a unique ID to refer to this information. Used when cache'ing evaluation results. * @param llm the LLM to query (as an LLM specification dict) - * @param root_prompt the prompt template to use as the scoring function. Should include exactly one template var, {input}, where input responses will be put. + * @param root_prompt the prompt template to use as the scoring function. Should include exactly one template var, {__input}, where input responses will be put. * @param response_ids the cache'd response to run on, which must be a unique ID or list of unique IDs of cache'd data * @param api_keys optional. any api keys to set before running the LLM */ export async function evalWithLLM( id: string, - llm: LLMSpec, + llm: string | LLMSpec, root_prompt: string, - response_ids: string | string[], + response_ids: string | string[] | LLMResponse[], api_keys?: Dict, progress_listener?: (progress: { [key: symbol]: any }) => void, cancel_id?: string | number, + system_msg?: string, ): Promise<{ responses?: LLMResponse[]; errors: string[] }> { // Check format of response_ids if (!Array.isArray(response_ids)) response_ids = [response_ids]; - response_ids = response_ids as Array<string>; + if (response_ids.length === 0) return { responses: [], errors: [] }; + + const load_resps_from_cache = typeof response_ids[0] === "string"; + const system_message: ChatHistoryInfo[] | undefined = system_msg + ? [ + { + messages: [{ role: "system", content: system_msg }], + fill_history: {}, + }, + ] + : undefined; if (api_keys !== undefined) set_api_keys(api_keys); // Load all responses with the given ID: let all_evald_responses: LLMResponse[] = []; let all_errors: string[] = []; - for (const cache_id of response_ids) { - const fname = `${cache_id}.json`; - if (!StorageCache.has(fname)) - throw new Error(`Did not find cache file for id ${cache_id}`); - - // Load the raw responses from the cache + clone them all: - const resp_objs = (load_cache_responses(fname) as LLMResponse[]).map((r) => - JSON.parse(JSON.stringify(r)), - ) as LLMResponse[]; - - if (resp_objs.length === 0) continue; - console.log(resp_objs); + const _runOverResponses = async ( + resp_objs: LLMResponse[], + cache_id?: string, + ) => { + console.log("Running LLM evaluator over response objects:", resp_objs); // We need to keep track of the index of each response in the response object. // We can generate var dicts with metadata to store the indices: @@ -1338,16 +1344,16 @@ export async function evalWithLLM( // Now run all inputs through the LLM grader!: const { responses, errors } = await queryLLM( - `eval-${id}-${cache_id}`, + `eval-${id}-${cache_id ?? "provided"}`, [llm], 1, root_prompt, { __input: inputs }, - undefined, + system_message, // if there's a sys_message, we pass it in chat history format undefined, undefined, progress_listener, - false, + !cache_id, // if there's no cache_id, we don't want to cache the responses cancel_id, ); @@ -1371,7 +1377,34 @@ export async function evalWithLLM( } }); - all_evald_responses = all_evald_responses.concat(resp_objs); + return resp_objs; + }; + + // Run over cache'd response data + if (load_resps_from_cache) { + for (const cache_id of response_ids) { + const fname = `${cache_id}.json`; + if (!StorageCache.has(fname)) + throw new Error(`Did not find cache file for id ${cache_id}`); + + // Load the raw responses from the cache + clone them all: + const resp_objs = (load_cache_responses(fname) as LLMResponse[]).map( + (r) => JSON.parse(JSON.stringify(r)), + ) as LLMResponse[]; + if (resp_objs.length === 0) continue; + + const evald_resp_objs = await _runOverResponses( + resp_objs, + cache_id as string, + ); + + all_evald_responses = all_evald_responses.concat(evald_resp_objs); + } + } else { + // Run over provided response objects + const resp_objs = response_ids as LLMResponse[]; + const evald_resp_objs = await _runOverResponses(resp_objs); // no cache + all_evald_responses = all_evald_responses.concat(evald_resp_objs); } // Do additional processing to check if all evaluations are @@ -1381,7 +1414,9 @@ export async function evalWithLLM( if (!resp_obj.eval_res) continue; for (const score of resp_obj.eval_res.items) { if (score !== undefined) - all_eval_res.add(score.toString().trim().toLowerCase()); + all_eval_res.add( + stripWrappingQuotes(score.toString().trim().toLowerCase()), + ); } } @@ -1421,7 +1456,8 @@ export async function evalWithLLM( } // Store the evaluated responses in a new cache json: - StorageCache.store(`${id}.json`, all_evald_responses); + if (load_resps_from_cache) + StorageCache.store(`${id}.json`, all_evald_responses); return { responses: all_evald_responses, errors: all_errors }; } diff --git a/chainforge/react-server/src/backend/evalgen/executor.ts b/chainforge/react-server/src/backend/evalgen/executor.ts index 7f429ac3e..98915d56c 100644 --- a/chainforge/react-server/src/backend/evalgen/executor.ts +++ b/chainforge/react-server/src/backend/evalgen/executor.ts @@ -194,8 +194,9 @@ export default class EvaluationFunctionExecutor { */ public async waitForCompletion(): Promise<void> { if (this.backgroundTaskPromise) { - await this.backgroundTaskPromise; + const promise = this.backgroundTaskPromise; this.backgroundTaskPromise = null; + await promise; } } @@ -214,6 +215,10 @@ export default class EvaluationFunctionExecutor { const functionExecutionPromises: Promise<any>[] = []; emitter.on("functionGenerated", (evalFunction) => { + this.logFunction( + `Generated a new ${evalFunction.evalCriteria.eval_method === "code" ? "code-based" : "LLM-based"} validator for criteria: ${evalFunction.evalCriteria.shortname}${evalFunction.evalCriteria.eval_method === "expert" ? `, with prompt: ${evalFunction.name}` : ""}. Executing it on ${this.examples.length} examples.`, + ); + const executionPromise = (async () => { this.evalFunctions.push(evalFunction); const executionPromises = this.examples.map(async (example) => { @@ -233,6 +238,7 @@ export default class EvaluationFunctionExecutor { ? execPyFunc : executeLLMEval; + // Run the function on the example and if there's an error, increment skipped const result = await funcToExecute( evalFunction, this.llms.small, @@ -285,6 +291,7 @@ export default class EvaluationFunctionExecutor { badExample, this.apiKeys, ); + // Update LLM call count by 1 this.updateNumLLMCalls(1, 0); @@ -306,152 +313,17 @@ export default class EvaluationFunctionExecutor { public async generateAndExecuteEvaluationFunctions( onProgress?: (progress: QueryProgress) => void, ): Promise<void> { - const emitter = new EventEmitter(); - const numCriteriaToProcess = this.evalCriteria.length; - - // Since we don't know how many implementations the LLM will suggest, - // we must estimate it here so we can use this information to stream - // "progress" updates back to the client: - let funcsExecuted = 0; - const estimatedFuncsToExecute = - numCriteriaToProcess + - this.evalCriteria.length * 5 * this.examples.length; - - let criteriaProcessed = 0; // Track the number of criteria processed - let resolveAllFunctionsGenerated: any; // To be called when all functions are generated and executed - const functionExecutionPromises: Promise<any>[] = []; // Track execution promises for function executions - - // This promise resolves when the 'allFunctionsGenerated' event is emitted - const allFunctionsGeneratedPromise = new Promise<void>((resolve) => { - resolveAllFunctionsGenerated = resolve; - }); - - // Listen for generated functions and execute them as they come in - emitter.on("functionGenerated", (evalFunction) => { - this.logFunction( - `Generated a new ${evalFunction.evalCriteria.eval_method === "code" ? "code-based" : "LLM-based"} validator for criteria: ${evalFunction.evalCriteria.shortname}${evalFunction.evalCriteria.eval_method === "expert" ? `, with prompt: ${evalFunction.name}` : ""}. Executing it on ${this.examples.length} examples.`, - ); - - // Capture the execution promise of each function - const executionPromise = (async () => { - // Add the eval function to the list of functions - this.evalFunctions.push(evalFunction); - - const executionPromises = this.examples.map(async (example) => { - // Get random positive and negative examples for this criteria using the perCriteriaGrades - const criteriaId = evalFunction.evalCriteria.uid; - const randomPositiveExample = this.examples.find( - (example) => - this.perCriteriaGrades[criteriaId]?.[example.uid] === true, - ); - const randomNegativeExample = this.examples.find( - (example) => - this.perCriteriaGrades[criteriaId]?.[example.uid] === false, - ); - - const funcToExecute = - evalFunction.evalCriteria.eval_method === "code" - ? execPyFunc - : executeLLMEval; - - // Run the function on the example and if there's an error, increment skipped - const result = await funcToExecute( - evalFunction, - this.llms.small, - example, - randomPositiveExample, - randomNegativeExample, - ); - - // Update weak model call count by 1 if the eval method is expert - if (evalFunction.evalCriteria.eval_method === "expert") { - this.updateNumLLMCalls(0, 1); - } - - funcsExecuted++; - if (onProgress) { - onProgress({ - success: (100 * funcsExecuted) / estimatedFuncsToExecute, - error: 0, - }); - } - - // Put result in cache - if (!this.resultsCache.has(evalFunction)) { - this.resultsCache.set(evalFunction, new Map()); - } - this.resultsCache.get(evalFunction)?.set(example.uid, result); - - // Update the score if the result is false - if (result === EvalFunctionResult.FAIL) { - this.updateScore(example.uid, evalFunction); - } - }); - - await Promise.all(executionPromises); - // console.log(`Function ${evalFunction.name} executed on all examples.`); - })(); - - functionExecutionPromises.push(executionPromise); - }); - - // Generate functions for each criterion - this.evalCriteria.forEach((criteria) => { - console.log(criteria); - generateFunctionsForCriteria( - criteria, - this.llms.large, - this.promptTemplate, - this.examples[Math.floor(Math.random() * this.examples.length)], - emitter, // Pass the EventEmitter instance - undefined, - this.apiKeys, - ).then(() => { - emitter.emit("criteriaProcessed"); - // Update LLM call count by 1 - this.updateNumLLMCalls(1, 0); - }); - }); - - // Listen for a custom 'criteriaProcessed' event to track when each criterion's functions have been generated - emitter.on("criteriaProcessed", () => { - criteriaProcessed++; - if (criteriaProcessed === this.evalCriteria.length) { - // Ensure all function executions have completed before emitting 'allFunctionsGenerated' - Promise.all(functionExecutionPromises).then(() => { - console.log( - "All evaluation functions have been generated and executed.", - ); - this.logFunction( - "All initially-generated evaluation functions have been generated and executed.", - ); - if (resolveAllFunctionsGenerated) { - resolveAllFunctionsGenerated(); // Resolve the promise when all functions have been generated and executed - } - }); - - if (onProgress) - onProgress({ - success: 100, - error: 0, - }); + // Enter a continuous monitoring loop for new criteria + while (this.backgroundTaskPromise !== null) { + // Check if there are any criteria in the queue to process + if (this.criteriaQueue.length > 0 && !this.processing) { + // Pop a criteria off the queue and process it + // TODO: use worker pool to parallelize this + await this.processNextCriteria(); } - }); - - // Wait for the 'allFunctionsGenerated' event, which now waits for all executions - await allFunctionsGeneratedPromise; - } - public generateNewImplementationsForCriteria( - criteriaID: EvalCriteriaUID, - ): void { - const crit = this.evalCriteria.find((c) => c.uid === criteriaID); - if (!crit) { - throw new Error(`Criteria with ID ${criteriaID} not found.`); - } - this.criteriaQueue.push(crit); - if (!this.processing) { - this.processNextCriteria(); + // Sleep for a short time before checking again (prevents CPU hogging) + await new Promise((resolve) => setTimeout(resolve, 500)); } } @@ -488,11 +360,12 @@ export default class EvaluationFunctionExecutor { } private async processNextCriteria() { - // TODO: use worker pool to parallelize this this.processing = true; while (this.criteriaQueue.length > 0) { const criteria = this.criteriaQueue.shift(); if (criteria) { + // Log the processing of new criteria + this.logFunction(`Processing new criteria: ${criteria.shortname}`); await this.generateAndExecuteFunctionsForCriteria(criteria); } } diff --git a/chainforge/react-server/src/backend/evalgen/utils.ts b/chainforge/react-server/src/backend/evalgen/utils.ts index fdc47d170..7bcfb0718 100644 --- a/chainforge/react-server/src/backend/evalgen/utils.ts +++ b/chainforge/react-server/src/backend/evalgen/utils.ts @@ -10,9 +10,16 @@ import { validEvalCriteriaFormat, } from "./typing"; import { Dict, LLMResponse, LLMSpec } from "../typing"; -import { executejs, executepy, simpleQueryLLM } from "../backend"; +import { + evalWithLLM, + executejs, + executepy, + queryLLM, + simpleQueryLLM, +} from "../backend"; import { getVarsAndMetavars, + hashtagTemplateVars, llmResponseDataToString, retryAsyncFunc, } from "../utils"; @@ -136,42 +143,74 @@ export async function executeLLMEval( positiveExample?: LLMResponse, negativeExample?: LLMResponse, ): Promise<EvalFunctionResult> { + // The LLM eval prompt might include template vars. We need to add + // a hashtag to indicate to ChainForge that it should use the + // fill_history in the provided `example` LLMResponse. + const candidateCriteriaPrompt = hashtagTemplateVars(evalFunction.code); + // Construct call to an LLM to evaluate the example const evalPrompt = "Evaluate the text below according to this criteria: " + - evalFunction.code + + candidateCriteriaPrompt + ' Only return "yes" or "no", nothing else.\n\n```\n' + - llmResponseDataToString(example.responses[0]) + + "{__input}" + "\n```"; // Query an LLM as an evaluator - let systemMessage = "You are an expert evaluator."; + let systemMessage; if ( positiveExample && positiveExample.responses.length > 0 && negativeExample && negativeExample.responses.length > 0 ) { - systemMessage += - " Please consider the following GOOD example: \n" + + systemMessage = + "You are an expert evaluator. Please consider the following GOOD example:\n" + llmResponseDataToString(positiveExample.responses[0]) + - "\nand BAD example: \n" + + "\n\nand BAD example:\n" + llmResponseDataToString(negativeExample.responses[0]) + - "\nwhen making your evaluation."; + "\n\nwhen making your evaluation."; } - const result = await simpleQueryLLM( - evalPrompt, // prompt - typeof llm === "string" ? llm : [llm], // llm - systemMessage, // system_msg + // We use ChainForge's infrastructure for running LLM evaluators + // to score responses based on the criteria. + const { responses, errors } = await evalWithLLM( + Date.now().toString(), // id to refer to this query + llm, // llm + evalPrompt, + [example], // we pass in a single example + undefined, + undefined, + undefined, + systemMessage, ); + + if ( + !responses || + responses.length === 0 || + !responses[0].eval_res || + responses[0].eval_res.items.length === 0 + ) { + console.error( + "Error executing LLM eval candidate:", + errors, + evalFunction.code, + ); + return EvalFunctionResult.SKIP; + } + // Get the output - const output = llmResponseDataToString(result.responses[0].responses[0]); + const output = responses[0].eval_res?.items[0]; + // This should be a boolean... but we need to parse it + const is_pass = + output === true || (typeof output === "string" && output.includes("yes")); + const is_fail = + output === false || (typeof output === "string" && output.includes("no")); // Parse the response to determine the boolean value to return - if (output.toLowerCase().includes("yes")) { + if (is_pass) { return EvalFunctionResult.PASS; - } else if (output.toLowerCase().includes("no")) { + } else if (is_fail) { return EvalFunctionResult.FAIL; } else { // throw new EvalExecutionError( diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 823bda765..e512fdbb1 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -2488,3 +2488,52 @@ export const cmatrixTextAnnotations = ( } return annotations as Partial<Annotations>[]; }; + +/** + * Adds a hashtag prefix to template variables in a string. + * Converts unescaped templates of the form {template} to {#template}. + * Ignores escaped braces like \{ and \}. + * + * @param input - The input string containing templates + * @returns The string with templates converted to hashtagged form + */ +export function hashtagTemplateVars(input: string): string { + let result = ""; + let i = 0; + + while (i < input.length) { + // Check for escaped braces + if ( + input[i] === "\\" && + i + 1 < input.length && + (input[i + 1] === "{" || input[i + 1] === "}") + ) { + // Add the escape character and the brace + result += input[i] + input[i + 1]; + i += 2; + } + // Check for opening brace of a template (that isn't already hashtagged) + else if (input[i] === "{" && i + 1 < input.length && input[i + 1] !== "#") { + // Add the opening brace and the hashtag + result += "{#"; + i++; + } + // Regular character + else { + result += input[i]; + i++; + } + } + + return result; +} + +export function stripWrappingQuotes(s: string): string { + if (s.startsWith('"') && s.endsWith('"')) { + return s.slice(1, -1); + } + if (s.startsWith("'") && s.endsWith("'")) { + return s.slice(1, -1); + } + return s; +} From 8910d78e7e905edcc86872e5bb33c104a66e7f1b Mon Sep 17 00:00:00 2001 From: Ian Arawjo <fatso784@gmail.com> Date: Thu, 27 Mar 2025 22:44:24 -0400 Subject: [PATCH 23/35] wip --- chainforge/react-server/src/backend/evalgen/executor.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/chainforge/react-server/src/backend/evalgen/executor.ts b/chainforge/react-server/src/backend/evalgen/executor.ts index 98915d56c..6846a6c03 100644 --- a/chainforge/react-server/src/backend/evalgen/executor.ts +++ b/chainforge/react-server/src/backend/evalgen/executor.ts @@ -716,6 +716,7 @@ export default class EvaluationFunctionExecutor { // Calculate alignment for this function based on the graded examples for (const example of gradedExamples) { + // TODO: Change this to use perCriteriaGrades !! const result = gradedResultMap.get(example.uid)?.get(evalFunction); const grade = this.grades.get(example.uid) ? EvalFunctionResult.PASS From 1d207f13cc5c10476922bdf0401c1aa2deb4ddfc Mon Sep 17 00:00:00 2001 From: Ian Arawjo <fatso784@gmail.com> Date: Sun, 30 Mar 2025 16:57:14 -0400 Subject: [PATCH 24/35] Began refactoring for executor to use perCriteriaGrades. Changed 'alignment' to three options: F1, MCC, and Cohen's kappa. --- .../src/EvalGen/EvalGenWizard.tsx | 17 +- .../src/EvalGen/PickCriteriaStep.tsx | 5 +- .../src/EvalGen/ReportCardStep.tsx | 117 ++++++++-- .../src/backend/evalgen/executor.ts | 221 +++++++++++------- .../src/backend/evalgen/typing.ts | 6 +- .../react-server/src/backend/evalgen/utils.ts | 77 ++++++ 6 files changed, 337 insertions(+), 106 deletions(-) diff --git a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx index fc2684d12..07d205b72 100644 --- a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx +++ b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx @@ -1,5 +1,9 @@ import React, { useCallback, useEffect, useMemo, useState } from "react"; -import { EvalCriteria, EvalGenReport } from "../backend/evalgen/typing"; +import { + EvalCriteria, + EvalFunctionSetReport, + EvalGenReport, +} from "../backend/evalgen/typing"; import { Dict, LLMResponse, RatingDict } from "../backend/typing"; import useStore from "../store"; import { escapeBraces } from "../backend/template"; @@ -112,6 +116,8 @@ const EvalGenWizard: React.FC<EvalGenWizardProps> = ({ const [executor, setExecutor] = useState<EvaluationFunctionExecutor | null>( null, ); + const [evalGenReport, setEvalGenReport] = + useState<EvalFunctionSetReport | null>(null); // Logs and state from the EvalGen backend const [logs, setLogs] = useState<{ date: Date; message: string }[]>([]); @@ -138,12 +144,16 @@ const EvalGenWizard: React.FC<EvalGenWizardProps> = ({ await executor?.waitForCompletion(); // Filtering eval funcs by grades and present results - const filteredFunctions = await executor?.filterEvaluationFunctions(0.25); + const filteredFunctions = + (await executor?.filterEvaluationFunctions(0.25)) ?? null; console.log("Filtered Functions: ", filteredFunctions); // Return selected implementations to caller // TODO console.warn(filteredFunctions); + + setActive(4); // Move to the report card step + setEvalGenReport(filteredFunctions); }, [executor]); // Update executor whenever resps, grades, or criteria change @@ -323,9 +333,10 @@ const EvalGenWizard: React.FC<EvalGenWizardProps> = ({ {active === 4 && ( <ReportCardStep onPrevious={handlePrevious} - onComplete={handleComplete} + onFinish={handleComplete} criteria={criteria} setOnNextCallback={setOnNextCallback} + report={evalGenReport} /> )} diff --git a/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx index 000aadb4b..e6c775a9a 100644 --- a/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx +++ b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx @@ -50,7 +50,7 @@ interface PickCriteriaStepProps { genAIModelNames: { large: string; small: string }; } -interface CriteriaCardProps { +export interface CriteriaCardProps { title: string; description: string; evalMethod: string; @@ -64,7 +64,7 @@ interface CriteriaCardProps { otherFuncs?: EvalFunctionReport[]; } -const CriteriaCard: React.FC<CriteriaCardProps> = function CriteriaCard({ +export const CriteriaCard: React.FC<CriteriaCardProps> = function CriteriaCard({ title, description, evalMethod, @@ -118,6 +118,7 @@ const CriteriaCard: React.FC<CriteriaCardProps> = function CriteriaCard({ /> ); }, [evalFuncReport]); + const reportAccuracyRing = useMemo(() => { if (!evalFuncReport) return undefined; return { diff --git a/chainforge/react-server/src/EvalGen/ReportCardStep.tsx b/chainforge/react-server/src/EvalGen/ReportCardStep.tsx index 43dbf9355..9e1deea20 100644 --- a/chainforge/react-server/src/EvalGen/ReportCardStep.tsx +++ b/chainforge/react-server/src/EvalGen/ReportCardStep.tsx @@ -1,39 +1,116 @@ -import React from "react"; -import { Button, Group, Stack, Text, Title } from "@mantine/core"; -import { EvalCriteria } from "../backend/evalgen/typing"; +import React, { useMemo } from "react"; +import { + Button, + Card, + Flex, + Group, + ScrollArea, + SimpleGrid, + Stack, + Text, +} from "@mantine/core"; +import { EvalCriteria, EvalFunctionSetReport } from "../backend/evalgen/typing"; +import { CriteriaCard } from "./PickCriteriaStep"; interface ReportCardStepProps { - onPrevious: () => void; - onComplete: () => void; criteria: EvalCriteria[]; + report: EvalFunctionSetReport | null; + onFinish: (reports: EvalFunctionSetReport) => void; + onPrevious: () => void; setOnNextCallback: React.Dispatch<React.SetStateAction<() => unknown>>; } const ReportCardStep: React.FC<ReportCardStepProps> = ({ + report, + onFinish, onPrevious, - onComplete, }) => { - // TODO: Calculate alignment scores based on criteria and grading data - const alignmentScores = {}; + const cards = useMemo(() => { + if (!report) return null; + const cards = []; + + // Iterate through selected eval functions and create cards + for (const selectedFunc of report.selectedEvalFunctions) { + const c = selectedFunc.evalCriteria; + // Find corresponding report in allEvalFunctionReports map from criteria to list + const evalFuncReports = report.allEvalFunctionReports.get(c); + const evalFuncReport = evalFuncReports?.find( + (rep) => rep.evalFunction === selectedFunc, + ); + // Get the functions that were not selected for this criteria + const otherFuncs = evalFuncReports?.filter( + (rep) => rep.evalFunction !== selectedFunc, + ); + + cards.push( + <CriteriaCard + reportMode + title={c.shortname} + description={c.criteria} + evalMethod={c.eval_method} + key={c.uid} + evalFuncReport={evalFuncReport} + otherFuncs={otherFuncs} + />, + ); + } + return cards; + }, [report]); return ( <Stack spacing="lg"> - <Title order={3}>Evaluation Results - - Here's how well each evaluation criteria aligns with your grades: + + Chosen Functions and Alignment - {/* TODO: Display alignment scores */} - TODO: Show alignment scores for each criteria + {/* Show coverage and false failure rate numbers */} + + + + + Coverage of Bad Responses + + + {report?.failureCoverage.toFixed(2)}% + + + + + False Failure Rate + + + {report?.falseFailureRate.toFixed(2)}% + + + + - - - - + ); }; diff --git a/chainforge/react-server/src/backend/evalgen/executor.ts b/chainforge/react-server/src/backend/evalgen/executor.ts index 6846a6c03..68c699272 100644 --- a/chainforge/react-server/src/backend/evalgen/executor.ts +++ b/chainforge/react-server/src/backend/evalgen/executor.ts @@ -1,4 +1,7 @@ import { + calculateCohensKappa, + calculateF1Score, + calculateMCC, execPyFunc, executeLLMEval, generateFunctionsForCriteria, @@ -629,6 +632,114 @@ export default class EvaluationFunctionExecutor { return this.getExampleForId(ungraded[pickIndex]); } + /** + * Given an eval function and the results of that function against the examples (LLM responses), + * computes the alignment statistics between the eval function and the user grades. + * @param evalFunc + * @returns A Report, assuming the the function has been executed over some examples and the user has provided grades for those examples. If there's not enough data, returns undefined. + */ + public computeAlignmentStats( + evalFunc: EvalFunction, + ): EvalFunctionReport | undefined { + // Get the eval function results from the cache + const results = this.resultsCache.get(evalFunc); + if (results === undefined) { + console.warn( + "No cache results found for this eval function. First ensure that the function has been executed over some examples.", + ); + return undefined; + } + + // Get a reference to the perCriteria grades for this eval function + const criteriaId = evalFunc.evalCriteria.uid; + if (!(criteriaId in this.perCriteriaGrades)) { + console.warn( + "No user grades found for this eval criteria. You must first grade some examples against this criteria (thumbs up/down) before we can compute alignment.", + ); + return undefined; + } + // The perCriteriaGrades is a map of ResponseUID to boolean (user grade true/false) + // or undefined (no user grade for that example). + const userGradedExamples = this.perCriteriaGrades[criteriaId]; + + // Now `evalFuncResults` is a Map. + // We can compute the alignment stats across all examples. + // First, create a report for this function + const report: EvalFunctionReport = { + evalFunction: evalFunc, + true_pass: 0, + true_fail: 0, + false_pass: 0, + false_fail: 0, + skipped: 0, + }; + + // Calculate alignment for this function based on the graded examples + Object.entries(userGradedExamples).forEach(([exampleId, grade]) => { + if (grade === undefined) return; // Skip if user provides no grade for this example + const result = results.get(exampleId); + const userGrade = grade + ? EvalFunctionResult.PASS + : EvalFunctionResult.FAIL; + + if (result !== undefined) { + // Handle true positives and true negatives + if (result === userGrade) { + if (result === EvalFunctionResult.PASS) { + report.true_pass++; + } else if (result === EvalFunctionResult.FAIL) { + report.true_fail++; + } + } else { + if (result === EvalFunctionResult.PASS) { + report.false_pass++; + } else if (result === EvalFunctionResult.FAIL) { + report.false_fail++; + } else { + report.skipped++; + } + } + } + }); + + // Calculate alignment in different ways + // NOTE: If a denominator during the calculate is 0, this will set the score to undefined. + report.f1 = calculateF1Score( + report.true_pass, + report.false_pass, + report.false_fail, + ); + report.mcc = calculateMCC( + report.true_pass, + report.true_fail, + report.false_pass, + report.false_fail, + ); + report.cohens_kappa = calculateCohensKappa( + report.true_pass, + report.true_fail, + report.false_pass, + report.false_fail, + ); + + // Calculate failure coverage + const failureCoverage = + report.true_fail + report.false_pass > 0 + ? report.true_fail / (report.true_fail + report.false_pass) + : 0.0; // 0.0 if there are no failures to detect + + // Calculate false failure rate + const falseFailureRate = + report.true_pass + report.false_fail > 0 + ? report.false_fail / (report.true_pass + report.false_fail) + : 0.0; // Default to 0.0 if there are no examples that could trigger false failures + + report.failureCoverage = failureCoverage; + report.falseFailureRate = falseFailureRate; + + return report; + } + /** * Filters out evaluation functions that are incorrect based on the grades provided by the developer. * @@ -682,12 +793,6 @@ export default class EvaluationFunctionExecutor { gradedResultMap.set(example.uid, row); } - const numFailGrades = gradedExamples.filter( - (example) => !this.grades.get(example.uid), - ).length; - const numPassGrades = gradedExamples.filter((example) => - this.grades.get(example.uid), - ).length; const bestEvalFunctions: EvalFunction[] = []; const evalFunctionReport: Map = new Map(); @@ -695,7 +800,7 @@ export default class EvaluationFunctionExecutor { // Iterate through each criteria // For each criteria, select the function with the highest alignment rate for (const criteria of this.evalCriteria) { - let scoredFunctions = []; + const scoredFunctions = []; for (const evalFunction of this.evalFunctions) { // Skip functions that don't match the criteria @@ -704,60 +809,8 @@ export default class EvaluationFunctionExecutor { } // Create a report for this function - const report: EvalFunctionReport = { - evalFunction, - true_pass: 0, - true_fail: 0, - false_pass: 0, - false_fail: 0, - alignment: 0, - skipped: 0, - }; - - // Calculate alignment for this function based on the graded examples - for (const example of gradedExamples) { - // TODO: Change this to use perCriteriaGrades !! - const result = gradedResultMap.get(example.uid)?.get(evalFunction); - const grade = this.grades.get(example.uid) - ? EvalFunctionResult.PASS - : EvalFunctionResult.FAIL; - - if (result !== undefined) { - // Handle true positives and true negatives - if (result === grade) { - if (result === EvalFunctionResult.PASS) { - report.true_pass++; - } else if (result === EvalFunctionResult.FAIL) { - report.true_fail++; - } - } else { - if (result === EvalFunctionResult.PASS) { - report.false_pass++; - } else if (result === EvalFunctionResult.FAIL) { - report.false_fail++; - } else { - report.skipped++; - } - } - } - } - - // Calculate coverage - const failureCoverage = - numFailGrades > 0 - ? report.true_fail / (report.true_fail + report.false_pass) - : 1.0; - - // Calculate false failure rate - const falseFailureRate = - report.false_fail / (report.true_pass + report.false_fail); - - // The alignment is the F1 score of failure coverage and 1 - false failure rate - report.alignment = - numFailGrades > 0 || numPassGrades > 0 - ? (2 * failureCoverage * (1 - falseFailureRate)) / - (failureCoverage + (1 - falseFailureRate)) - : undefined; + const report: EvalFunctionReport | undefined = + this.computeAlignmentStats(evalFunction); // Save the report for this function if (!evalFunctionReport.has(criteria)) { @@ -768,33 +821,41 @@ export default class EvaluationFunctionExecutor { scoredFunctions.push({ evalFunction, - failureCoverage, - falseFailureRate: - report.false_fail / (report.true_pass + report.false_fail), + report, }); } - // See if we can filter out functions with ffr > threshold - const numFunctionsBelowThreshold = scoredFunctions.filter( - (func) => func.falseFailureRate <= falseFailureRateThreshold, - ).length; - if (numFunctionsBelowThreshold > 0) { - // Filter out functions with ffr > threshold - scoredFunctions = scoredFunctions.filter( - (func) => func.falseFailureRate <= falseFailureRateThreshold, - ); - } - - // Save the best function for this criteria - // Maximize failure coverage and minimize false failure rate + // Sort the functions by "alignment" + // Here, we are using MCC as the alignment metric, where higher is better. scoredFunctions.sort((a, b) => { - if (a.failureCoverage === b.failureCoverage) { - return a.falseFailureRate - b.falseFailureRate; + const a_mcc = a.report?.mcc ?? -1; // If undefined, set to -1, which is lowest possible. + const b_mcc = b.report?.mcc ?? -1; + if (a_mcc === b_mcc) { + // If MCC is the same or not present, sort by false failure rate + return ( + (a.report?.falseFailureRate ?? 0) - + (b.report?.falseFailureRate ?? 0) + ); } - return b.failureCoverage - a.failureCoverage; + return b_mcc - a_mcc; // Sort by MCC descending }); + // // See if we can filter out functions with ffr > threshold + // const funcsBelowThreshold = scoredFunctions.filter( + // (func) => func.report?.falseFailureRate !== undefined && func.report?.falseFailureRate <= falseFailureRateThreshold, + // ); + + // // Save the best function for this criteria + // // Maximize failure coverage and minimize false failure rate + // funcsBelowThreshold.sort((a, b) => { + // if (a.report?.failureCoverage === b.report?.failureCoverage) { + // return a.report?.falseFailureRate - b.report?.falseFailureRate; + // } + // return b.failureCoverage - a.failureCoverage; + // }); + if (scoredFunctions.length > 0) { + // The top result is the 'best' / most aligned function bestEvalFunctions.push(scoredFunctions[0].evalFunction); } } diff --git a/chainforge/react-server/src/backend/evalgen/typing.ts b/chainforge/react-server/src/backend/evalgen/typing.ts index e9e6cd24d..ef15d1551 100644 --- a/chainforge/react-server/src/backend/evalgen/typing.ts +++ b/chainforge/react-server/src/backend/evalgen/typing.ts @@ -45,7 +45,11 @@ export interface EvalFunctionReport { false_pass: number; false_fail: number; skipped: number; - alignment?: number; + mcc?: number; // Matthews correlation coefficient, which is a measure of the quality of binary classifications + f1?: number; // F1 score, which is the harmonic mean of precision and recall + cohens_kappa?: number; // Cohen's kappa, which is a measure of inter-rater agreement + failureCoverage?: number; // The percentage of failures that were covered by the eval function + falseFailureRate?: number; // The percentage of false failures } export interface EvalFunctionSetReport { diff --git a/chainforge/react-server/src/backend/evalgen/utils.ts b/chainforge/react-server/src/backend/evalgen/utils.ts index 7bcfb0718..054a50b2c 100644 --- a/chainforge/react-server/src/backend/evalgen/utils.ts +++ b/chainforge/react-server/src/backend/evalgen/utils.ts @@ -421,3 +421,80 @@ function processAndEmitFunction( emitter.emit("functionGenerated", evalFunction); } + +/** + * Calculates the F1 score based on true positives, false positives, and false negatives. + * The F1 score is the harmonic mean of precision and recall. + * Precision = TP / (TP + FP) + * Recall = TP / (TP + FN) + * F1 = 2 * (Precision * Recall) / (Precision + Recall) + * @param true_positive The number of true positive predictions + * @param false_positive The number of false positive predictions + * @param false_negative The number of false negative predictions + * @returns The F1 score, or undefined if precision and recall are both zero + */ +export function calculateF1Score( + true_positive: number, + false_positive: number, + false_negative: number, +): number | undefined { + const precision = true_positive / (true_positive + false_positive); + const recall = true_positive / (true_positive + false_negative); + if (precision + recall === 0) return undefined; // Avoid division by zero + return (2 * precision * recall) / (precision + recall); +} + +/** + * Calculates Matthews correlation coefficient (MCC) based on the confusion matrix values. + * ``` + * MCC = (TP * TN - FP * FN) / sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)) + * ``` + * @param true_positive The number of true positive predictions + * @param true_negative The number of true negative predictions + * @param false_positive The number of false positive predictions + * @param false_negative The number of false negative predictions + * @returns The Matthews correlation coefficient, or undefined if the denominator is zero + */ +export function calculateMCC( + true_positive: number, + true_negative: number, + false_positive: number, + false_negative: number, +): number | undefined { + const numerator = + true_positive * true_negative - false_positive * false_negative; + const denominator = Math.sqrt( + (true_positive + false_positive) * + (true_positive + false_negative) * + (true_negative + false_positive) * + (true_negative + false_negative), + ); + if (denominator === 0) return undefined; // Avoid division by zero + return numerator / denominator; +} + +/** + * Calculates Cohen's Kappa coefficient based on the confusion matrix values. + * ``` + * Kappa = (Po - Pe) / (1 - Pe) + * ``` + * where Po is the observed agreement and Pe is the expected agreement. + * @param TP The number of true positive predictions + * @param TN The number of true negative predictions + * @param FP The number of false positive predictions + * @param FN The number of false negative predictions + * @returns The Cohen's Kappa coefficient, or undefined if the denominator is zero + */ +export function calculateCohensKappa( + TP: number, + TN: number, + FP: number, + FN: number, +): number | undefined { + const numerator = 2 * (TP * TN - FP * FN); + const denominator = (TP + FP) * (FP + TN) + (TP + FN) * (FN + TN); + if (denominator === 0) { + return undefined; // Avoid division by zero + } + return numerator / denominator; +} From 16fbaa6def356a55e579ff4adeebe495337b114d Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Sun, 11 May 2025 20:55:05 -0400 Subject: [PATCH 25/35] cleanup --- chainforge/react-server/src/ItemsNode.tsx | 17 ++++++----------- chainforge/react-server/src/ResponseBoxes.tsx | 3 +-- chainforge/react-server/src/backend/utils.ts | 17 +++++++++++++++-- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/chainforge/react-server/src/ItemsNode.tsx b/chainforge/react-server/src/ItemsNode.tsx index 7d007e0a3..d1aad2c0c 100644 --- a/chainforge/react-server/src/ItemsNode.tsx +++ b/chainforge/react-server/src/ItemsNode.tsx @@ -12,7 +12,12 @@ import NodeLabel from "./NodeLabelComponent"; import { IconForms, IconTransform } from "@tabler/icons-react"; import { Handle, Node, Position } from "reactflow"; import BaseNode from "./BaseNode"; -import { DebounceRef, genDebounceFunc, processCSV } from "./backend/utils"; +import { + DebounceRef, + genDebounceFunc, + processCSV, + stripWrappingQuotes, +} from "./backend/utils"; import { AIGenReplaceItemsPopover } from "./AiPopover"; import { cleanEscapedBraces, escapeBraces } from "./backend/template"; import { TextFieldsNodeProps } from "./TextFieldsNode"; @@ -22,16 +27,6 @@ const wrapInQuotesIfContainsComma = (str: string) => str.includes(",") ? `"${str}"` : str; export const makeSafeForCSLFormat = (str: string) => wrapInQuotesIfContainsComma(replaceDoubleQuotesWithSingle(str)); -const stripWrappingQuotes = (str: string) => { - if ( - typeof str === "string" && - str.length >= 2 && - str.charAt(0) === '"' && - str.charAt(str.length - 1) === '"' - ) - return str.substring(1, str.length - 1); - else return str; -}; export const prepareItemsNodeData = (text: string) => ({ text, fields: processCSV(text).map(stripWrappingQuotes).map(escapeBraces), diff --git a/chainforge/react-server/src/ResponseBoxes.tsx b/chainforge/react-server/src/ResponseBoxes.tsx index 786e90e89..74962da39 100644 --- a/chainforge/react-server/src/ResponseBoxes.tsx +++ b/chainforge/react-server/src/ResponseBoxes.tsx @@ -20,11 +20,10 @@ import { LLMResponse, LLMResponseData, } from "./backend/typing"; -import StorageCache, { StringLookup } from "./backend/cache"; +import StorageCache, { MediaLookup } from "./backend/cache"; import { IconCheck, IconChecks, IconX } from "@tabler/icons-react"; import { getRatingKeyForResponse } from "./ResponseRatingToolbar"; import useStore from "./store"; -import { MediaLookup } from "./backend/cache"; // Lazy load the response toolbars const ResponseRatingToolbar = lazy(() => import("./ResponseRatingToolbar")); diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 810d7169a..a8dc2013d 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -1485,7 +1485,9 @@ export async function call_ollama_provider( } else { // Text-only models query.prompt = prompt; - query.images = (await imagesToBase64(images ?? [])).map(getBase64DataFromDataURL); + query.images = (await imagesToBase64(images ?? [])).map( + getBase64DataFromDataURL, + ); url += "generate"; } @@ -2764,7 +2766,7 @@ export function dataURLToBlob(dataURL: string): Blob { * Extracts the MIME type from a Data URL. * @param dataUrl The Data URL to extract the MIME type from. * @returns The MIME type as a string, or null if not found. -*/ + */ function getMimeTypeFromDataURL(dataUrl: string): string | null { const match = dataUrl.match(/^data:([^;,]+)[;,]/); return match ? match[1] : null; @@ -2841,3 +2843,14 @@ export const __http_url_to_base64 = (url: string) => { xhr.send(); }); }; + +export const stripWrappingQuotes = (str: string) => { + if ( + typeof str === "string" && + str.length >= 2 && + str.charAt(0) === '"' && + str.charAt(str.length - 1) === '"' + ) + return str.substring(1, str.length - 1); + else return str; +}; From ed86c9999686fea8dcdc21abecf1dcfbc17cd8b2 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Tue, 13 May 2025 11:53:55 -0400 Subject: [PATCH 26/35] wip --- chainforge/react-server/src/backend/evalgen/executor.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chainforge/react-server/src/backend/evalgen/executor.ts b/chainforge/react-server/src/backend/evalgen/executor.ts index 68c699272..d38fd1f9d 100644 --- a/chainforge/react-server/src/backend/evalgen/executor.ts +++ b/chainforge/react-server/src/backend/evalgen/executor.ts @@ -812,6 +812,11 @@ export default class EvaluationFunctionExecutor { const report: EvalFunctionReport | undefined = this.computeAlignmentStats(evalFunction); + if (!report) { + console.warn("Could not compute alignment stats for an eval function. Skipping."); + continue; + } + // Save the report for this function if (!evalFunctionReport.has(criteria)) { evalFunctionReport.set(criteria, []); From 65c24e149de4ebcf78bfa12c33e6282406e43b8a Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Tue, 13 May 2025 11:57:56 -0400 Subject: [PATCH 27/35] wip --- .../src/backend/evalgen/executor.ts | 4 +- .../react-server/src/backend/evalgen/utils.ts | 1 + chainforge/react-server/src/backend/utils.ts | 74 +++++++++++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) diff --git a/chainforge/react-server/src/backend/evalgen/executor.ts b/chainforge/react-server/src/backend/evalgen/executor.ts index d38fd1f9d..7ffc30639 100644 --- a/chainforge/react-server/src/backend/evalgen/executor.ts +++ b/chainforge/react-server/src/backend/evalgen/executor.ts @@ -813,7 +813,9 @@ export default class EvaluationFunctionExecutor { this.computeAlignmentStats(evalFunction); if (!report) { - console.warn("Could not compute alignment stats for an eval function. Skipping."); + console.warn( + "Could not compute alignment stats for an eval function. Skipping.", + ); continue; } diff --git a/chainforge/react-server/src/backend/evalgen/utils.ts b/chainforge/react-server/src/backend/evalgen/utils.ts index c0c3dd6a5..054a50b2c 100644 --- a/chainforge/react-server/src/backend/evalgen/utils.ts +++ b/chainforge/react-server/src/backend/evalgen/utils.ts @@ -19,6 +19,7 @@ import { } from "../backend"; import { getVarsAndMetavars, + hashtagTemplateVars, llmResponseDataToString, retryAsyncFunc, } from "../utils"; diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index a8dc2013d..887f8b22c 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -54,6 +54,7 @@ import { } from "@mirai73/bedrock-fm"; import StorageCache, { StringLookup, MediaLookup } from "./cache"; import Compressor from "compressorjs"; +import { Annotations } from "plotly.js"; // import { Models } from "@mirai73/bedrock-fm/lib/bedrock"; const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:"; @@ -2854,3 +2855,76 @@ export const stripWrappingQuotes = (str: string) => { return str.substring(1, str.length - 1); else return str; }; + +export const accuracyToColor = (acc: number) => { + if (acc > 0.9) return "green"; + else if (acc > 0.7) return "yellow"; + else if (acc > 0.5) return "orange"; + else return "red"; +}; + +export const cmatrixTextAnnotations = ( + x: string[], + y: string[], + z: number[][], +) => { + const annotations = []; + const midVal = Math.max(...z.flat()); + for (let i = 0; i < y.length; i++) { + for (let j = 0; j < x.length; j++) { + annotations.push({ + xref: "x1", + yref: "y1", + x: x[j], + y: y[i], + text: z[i][j].toString(), + font: { + // family: "monospace", + // size: 12, + color: z[i][j] < midVal ? "white" : "black", + }, + showarrow: false, + }); + } + } + return annotations as Partial[]; +}; + +/** + * Adds a hashtag prefix to template variables in a string. + * Converts unescaped templates of the form {template} to {#template}. + * Ignores escaped braces like \{ and \}. + * + * @param input - The input string containing templates + * @returns The string with templates converted to hashtagged form + */ +export function hashtagTemplateVars(input: string): string { + let result = ""; + let i = 0; + + while (i < input.length) { + // Check for escaped braces + if ( + input[i] === "\\" && + i + 1 < input.length && + (input[i + 1] === "{" || input[i + 1] === "}") + ) { + // Add the escape character and the brace + result += input[i] + input[i + 1]; + i += 2; + } + // Check for opening brace of a template (that isn't already hashtagged) + else if (input[i] === "{" && i + 1 < input.length && input[i + 1] !== "#") { + // Add the opening brace and the hashtag + result += "{#"; + i++; + } + // Regular character + else { + result += input[i]; + i++; + } + } + + return result; +} From ef3045b48a25e26a950de474be14796fbf2d82a4 Mon Sep 17 00:00:00 2001 From: Ian Arawjo Date: Tue, 13 May 2025 13:30:58 -0400 Subject: [PATCH 28/35] Bug and typing fixing --- .../src/EvalGen/EvalGenWizard.tsx | 23 +- .../react-server/src/EvalGen/FeedbackStep.tsx | 55 +- .../src/EvalGen/GradeResponsesStep.tsx | 6 +- .../react-server/src/EvalGen/GradingView.tsx | 27 +- .../src/EvalGen/PickCriteriaStep.tsx | 8 +- .../src/EvalGen/ReportCardStep.tsx | 4 +- .../react-server/src/EvalGen/WelcomeStep.tsx | 31 +- chainforge/react-server/src/EvalGenModal.tsx | 3330 ++++++++--------- chainforge/react-server/src/MultiEvalNode.tsx | 23 +- chainforge/react-server/src/ResponseBoxes.tsx | 2 +- .../src/backend/evalgen/executor.ts | 85 +- chainforge/react-server/src/styles.css | 5 +- 12 files changed, 1830 insertions(+), 1769 deletions(-) diff --git a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx index 07d205b72..d2782986a 100644 --- a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx +++ b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx @@ -1,3 +1,23 @@ +/** + * EvalGen 2.0 + * + * Ian Arawjo, Shreya Shankar, J.D. Zamfirescu, Helen Weixu Chen + * + * This file and its directory concerns the front-end to evaluation generator, EvalGen. + * EvalGen supports users in generating eval funcs (here binary assertions) and aligning them with their preferences. + * + * Specifically, the modal lets users: + * - make and refine criteria to grade on (on the left) + * - grade responses (on the right) + * - while in the backend, an LLM is generating candidate assertions and selected the ones most aligned with user grades + * As the user grades responses, they add/refine existing criteria. + * This modal presents a shared interface where criteria can be iterated on *alongside* grading. + * This is because of **criteria drift,** a phenomenon identified observing users in EvalGen 1.0 (unreleased). + * + * An AI (LLM call) can also suggest criteria based on the implicit context (inputs, such as the prompt) + * and user feedback during grading (written feedback about failing outputs whose failure couldn't be classified under the immediate criteria set.) + */ + import React, { useCallback, useEffect, useMemo, useState } from "react"; import { EvalCriteria, @@ -262,11 +282,10 @@ const EvalGenWizard: React.FC = ({ opened={opened} onClose={onClose} // title="EvalGen Wizard" - size="90%" + size="95%" padding="md" // keepMounted // closeOnClickOutside={true} - style={{ position: "relative", left: "-5%" }} styles={{ inner: { padding: "5%", // This creates space around the modal (10% total) diff --git a/chainforge/react-server/src/EvalGen/FeedbackStep.tsx b/chainforge/react-server/src/EvalGen/FeedbackStep.tsx index c86fbef85..2ba0b24c7 100644 --- a/chainforge/react-server/src/EvalGen/FeedbackStep.tsx +++ b/chainforge/react-server/src/EvalGen/FeedbackStep.tsx @@ -8,6 +8,7 @@ import { Text, Textarea, Title, + Tooltip, } from "@mantine/core"; import GradingView from "./GradingView"; import { IconThumbDown, IconThumbUp } from "@tabler/icons-react"; @@ -95,7 +96,7 @@ const FeedbackStep: React.FC = ({ }, [shownResponseIdx, responses]); return ( - + Provide Feedback on Some Model Outputs = ({ gotoPrevResponse={prevResponse} /> - - - + + + + + + +
- What's the reason for your score? + What's the reason for your grade? Explain why: