diff --git a/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx new file mode 100644 index 000000000..aef7237ce --- /dev/null +++ b/chainforge/react-server/src/EvalGen/EvalGenWizard.tsx @@ -0,0 +1,435 @@ +/** + * EvalGen 2.0 + * + * Ian Arawjo, Shreya Shankar, J.D. Zamfirescu, Helen Weixu Chen + * + * This file and its directory concerns the front-end to evaluation generator, EvalGen. + * EvalGen supports users in generating eval funcs (here binary assertions) and aligning them with their preferences. + * + * Specifically, the modal lets users: + * - make and refine criteria to grade on (on the left) + * - grade responses (on the right) + * - while in the backend, an LLM is generating candidate assertions and selected the ones most aligned with user grades + * As the user grades responses, they add/refine existing criteria. + * This modal presents a shared interface where criteria can be iterated on *alongside* grading. + * This is because of **criteria drift,** a phenomenon identified observing users in EvalGen 1.0 (unreleased). + * + * An AI (LLM call) can also suggest criteria based on the implicit context (inputs, such as the prompt) + * and user feedback during grading (written feedback about failing outputs whose failure couldn't be classified under the immediate criteria set.) + */ + +import React, { useCallback, useEffect, useMemo, useState } from "react"; +import { + EvalCriteria, + EvalFunctionSetReport, + EvalGenReport, +} from "../backend/evalgen/typing"; +import { Dict, LLMResponse, RatingDict } from "../backend/typing"; +import useStore from "../store"; +import { escapeBraces } from "../backend/template"; +import StorageCache, { StringLookup } from "../backend/cache"; +import { generateLLMEvaluationCriteria } from "../backend/evalgen/utils"; +import { Button, Flex, Modal, Stepper } from "@mantine/core"; +import WelcomeStep from "./WelcomeStep"; +import FeedbackStep from "./FeedbackStep"; +import PickCriteriaStep from "./PickCriteriaStep"; +import ReportCardStep from "./ReportCardStep"; +import GradingResponsesStep from "./GradeResponsesStep"; +import { + batchResponsesByUID, + deepcopy, + sampleRandomElements, +} from "../backend/utils"; +import { getRatingKeyForResponse } from "../ResponseRatingToolbar"; +import EvaluationFunctionExecutor from "../backend/evalgen/executor"; +import { getAIFeaturesModels } from "../backend/ai"; + +// Main wizard component props +interface EvalGenWizardProps { + opened: boolean; + onClose: () => void; + onComplete: (result: EvalFunctionSetReport) => void; + responses: LLMResponse[]; +} + +const EvalGenWizard: React.FC = ({ + opened, + onClose, + onComplete, + responses, // The LLM responses to operate over +}) => { + // The active screen (stage) of EvalGen + const [active, setActive] = useState(0); + + // From global state + const apiKeys = useStore((state) => state.apiKeys); + const genAIFeaturesProvider = useStore((state) => state.aiFeaturesProvider); + const genAIModelNames = useMemo(() => { + const models = getAIFeaturesModels(genAIFeaturesProvider); + return { + large: models.large, + small: models.small, + }; + }, [genAIFeaturesProvider]); + + // Regroup input responses by batch UID, whenever jsonResponses changes + const batchedResponses = useMemo( + () => (responses ? batchResponsesByUID(responses) : []), + [responses], + ); + + // For updating the global human ratings state + const setState = useStore((store) => store.setState); + const updateGlobalRating = useCallback( + (uid: string, label: string, payload: RatingDict) => { + const key = getRatingKeyForResponse(uid, label); + const safe_payload = deepcopy(payload); + setState(key, safe_payload); + StorageCache.store(key, safe_payload); + }, + [setState], + ); + + // Criteria the user defines across the stages + const [criteria, setCriteria] = useState([]); + const [onNextCallback, setOnNextCallback] = useState(() => () => {}); + + // Per-criteria grades (indexed by uid of response, then uid of criteria) + const [perCriteriaGrades, setPerCriteriaGrades] = useState< + Dict> + >({}); + const [annotation, setAnnotation] = useState(undefined); + const setPerCriteriaGrade = ( + responseUID: string, + criteriaUID: string, + newGrade: boolean | undefined, + ) => { + setPerCriteriaGrades((grades) => { + if (!grades[responseUID]) grades[responseUID] = {}; + grades[responseUID][criteriaUID] = newGrade; + updateGlobalRating(responseUID, "perCriteriaGrades", grades[responseUID]); + + // If the EvalGen executor is running, update the per-criteria grade for this sample: + executor?.setGradeForExample(responseUID, grades[responseUID]); + + return { ...grades }; + }); + }; + const numResponsesGraded = useMemo(() => { + let count = 0; + for (const uid in perCriteriaGrades) { + const gs = perCriteriaGrades[uid]; + if (Object.values(gs).some((v) => v !== undefined && v !== null)) + count += 1; + } + return count; + }, [perCriteriaGrades]); + const minNumToGrade = useMemo(() => { + return Math.min(10, Math.ceil(batchedResponses.length * 0.5)); + }, [batchedResponses]); + const minNumToGradeToStartExecutor = useMemo(() => { + return Math.min(5, Math.ceil(batchedResponses.length * 0.25)); + }, [batchedResponses]); + + // The EvalGen object responsible for generating, implementing, and filtering candidate implementations + // :: Used on screen 4 (when `active` === 3). + const [executor, setExecutor] = useState( + null, + ); + const [evalGenReport, setEvalGenReport] = + useState(null); + + // Logs and state from the EvalGen backend + const [logs, setLogs] = useState<{ date: Date; message: string }[]>([]); + const [numCallsMade, setNumCallsMade] = useState({ strong: 0, weak: 0 }); + const [execProgress, setExecProgress] = useState(0); + + // The samples to pass the executor / grading responses features. This will be bounded + // by maxNumSamplesForExecutor, instead of the whole dataset. + const samplesForExecutor = useMemo(() => { + // The max number of samples (responses) to pass the executor. This controls how many requests will + // need to be sent off and how many evaluation function executions are performed. + // TODO: Give the user some control over this. + const maxNumSamplesForExecutor = 16; + + // Sample from the full set of responses, if needed: + if (batchedResponses.length > maxNumSamplesForExecutor) + return sampleRandomElements(responses, maxNumSamplesForExecutor); + else return batchedResponses.slice(); + }, [batchedResponses]); + + // When the user is done per-criteria grading + const handleDonePerCriteriaGrading = useCallback(async () => { + // Await completion of all gen + execution of eval funcs + await executor?.waitForCompletion(); + + // Filtering eval funcs by grades and present results + const filteredFunctions = + (await executor?.filterEvaluationFunctions(0.25)) ?? null; + console.log("Filtered Functions: ", filteredFunctions); + + // Return selected implementations to caller + // TODO + console.warn(filteredFunctions); + + setActive(4); // Move to the report card step + setEvalGenReport(filteredFunctions); + }, [executor]); + + // Update executor whenever resps, grades, or criteria change + useEffect(() => { + if ( + criteria.length === 0 || + numResponsesGraded < minNumToGradeToStartExecutor + ) + return; + if (!executor) { + const addLog = (message: string) => { + setLogs((prevLogs) => [...prevLogs, { date: new Date(), message }]); + }; + + const ex = new EvaluationFunctionExecutor( + genAIModelNames, + apiKeys, + getLikelyPromptTemplateAsContext(samplesForExecutor) ?? "", + samplesForExecutor, + criteria, + (strong, weak) => { + // Callback to update GPT call counts + setNumCallsMade((n_calls) => { + n_calls.strong += strong; + n_calls.weak += weak; + return { ...n_calls }; + }); + }, + addLog, + undefined, // don't pass any holistic grades at this stage + perCriteriaGrades, + ); + setExecutor(ex); + + // Start executor process + ex.start((progress) => { + setExecProgress(progress?.success ?? 0); + }); + } else if (executor) { + // Update criteria in executor + executor.updateCriteria(criteria); + } + }, [ + criteria, + samplesForExecutor, + numResponsesGraded, + minNumToGradeToStartExecutor, + ]); + + const handleNext = useCallback(() => { + setActive((current) => Math.min(4, current + 1)); + }, []); + + const handlePrevious = useCallback(() => { + setActive((current) => Math.max(0, current - 1)); + }, []); + + const handleComplete = (evalFuncReport: EvalFunctionSetReport) => { + // Return final data to the caller + onComplete(evalFuncReport); + onClose(); + }; + + const getLikelyPromptTemplateAsContext = (resps: LLMResponse[]) => { + // Attempt to infer the prompt template used to generate the responses: + const prompts = new Set(); + for (const resp_obj of resps) { + const pt = resp_obj?.metavars?.__pt; + if (pt !== undefined) { + prompts.add(StringLookup.get(pt) as string); + } + } + + if (prompts.size === 0) return null; + + // Pick a prompt template at random to serve as context.... + return escapeBraces(prompts.values().next().value ?? ""); + }; + + const exportGradesAndNotes = useStore((store) => store.exportGradesAndNotes); + async function genCriteriaFromContext(responses: LLMResponse[]) { + // Get the context from the input responses + const inputPromptTemplate = + getLikelyPromptTemplateAsContext(batchedResponses); + + if (inputPromptTemplate === null) { + console.error("No context found. Cannot proceed."); + return; + } + + // Get the user feedback on the responses, if any, from the global state + const feedback = exportGradesAndNotes(responses); + + // Attempt to generate criteria using an LLM + return await generateLLMEvaluationCriteria( + inputPromptTemplate, + genAIModelNames.large, + apiKeys, + undefined, + undefined, + feedback, + ); + } + + return ( + + {active === 0 && } + + {active === 1 && ( + + )} + + {active === 2 && ( + + genCriteriaFromContext(batchedResponses) + } + genAIModelNames={genAIModelNames} + setOnNextCallback={setOnNextCallback} + /> + )} + + {active === 3 && ( + + )} + + {active === 4 && ( + + )} + + {/* Sticky footer - button and steppers */} +
+ + + + + +
+
+ + + {/* Step content is rendered below */} + + + {/* Step content is rendered below */} + + + {/* Step content is rendered below */} + + + {/* Step content is rendered below */} + + + {/* Step content is rendered below */} + + +
+
+ ); +}; + +export default EvalGenWizard; diff --git a/chainforge/react-server/src/EvalGen/FeedbackStep.tsx b/chainforge/react-server/src/EvalGen/FeedbackStep.tsx new file mode 100644 index 000000000..719e50c86 --- /dev/null +++ b/chainforge/react-server/src/EvalGen/FeedbackStep.tsx @@ -0,0 +1,169 @@ +import React, { useCallback, useEffect, useMemo, useState } from "react"; +import { Dict, LLMResponse, RatingDict } from "../backend/typing"; +import { + Button, + Center, + Flex, + Stack, + Text, + Textarea, + Title, + Tooltip, +} from "@mantine/core"; +import GradingView from "./GradingView"; +import { IconThumbDown, IconThumbUp } from "@tabler/icons-react"; +import { getRatingKeyForResponse } from "../ResponseRatingToolbar"; +import useStore from "../store"; +import { deepcopy } from "../backend/utils"; +import StorageCache from "../backend/cache"; + +interface FeedbackStepProps { + onNext: () => void; + onPrevious: () => void; + responses: LLMResponse[]; + setOnNextCallback: React.Dispatch unknown>>; +} + +const FeedbackStep: React.FC = ({ + onNext, + onPrevious, + responses, + setOnNextCallback, +}) => { + const [shownResponse, setShownResponse] = useState( + undefined, + ); + const [shownResponseIdx, setShownResponseIdx] = useState(0); + + // Global state + const storeState = useStore>((store) => store.state); + const setStoreState = useStore((store) => store.setState); + + // The cache keys storing the ratings for this response object + const grade = useMemo(() => { + if (!shownResponse) return null; + const key = getRatingKeyForResponse(shownResponse?.uid, "grade"); + const g = storeState[key]; + if (g) return g[0]; + else return null; + }, [shownResponse, storeState]); + const annotation = useMemo(() => { + if (!shownResponse) return ""; + const key = getRatingKeyForResponse(shownResponse?.uid, "note"); + const a = storeState[key]; + if (a) return a[0]?.toString(); + else return ""; + }, [shownResponse, storeState]); + + // Set the rating in the global store, which *should* update the above. + const setRating = useCallback( + ( + uid: string | undefined, + label: string, + payload: boolean | string | null, + ) => { + if (!uid) return; + const key = getRatingKeyForResponse(uid, label); + setStoreState(key, { 0: payload }); // TODO: This will erase any feedback given on n>1 responses in the input. + StorageCache.store(key, { 0: payload }); + }, + [setStoreState], + ); + const setGrade = (val: boolean | null) => + setRating(shownResponse?.uid, "grade", val); + const setAnnotation = (val: string) => + setRating(shownResponse?.uid, "note", val); + + useEffect(() => { + if (!responses || responses.length === 0) return; + setShownResponse(responses[0]); // We only show the first response if n>1 resps per prompt, for simplicity's sake + setShownResponseIdx(0); + }, [responses]); + + const nextResponse = useCallback(() => { + if (responses.length === 0) return; + if (shownResponseIdx < responses.length - 1) { + setShownResponseIdx(shownResponseIdx + 1); + setShownResponse(responses[shownResponseIdx + 1]); + } + }, [shownResponseIdx, responses]); + + const prevResponse = useCallback(() => { + if (shownResponseIdx > 0) { + setShownResponseIdx(shownResponseIdx - 1); + setShownResponse(responses[shownResponseIdx - 1]); + } + }, [shownResponseIdx, responses]); + + return ( + + Provide Feedback on Some Model Outputs + + + + + + + + + + + +
+ + What's the reason for your grade? Explain why: + + + + + + + */} + + {/* */} + + + {/* + + + Suggest New Criteria + + + + + + + + + ); +}; + +export default GradingResponsesStep; diff --git a/chainforge/react-server/src/EvalGen/GradingView.tsx b/chainforge/react-server/src/EvalGen/GradingView.tsx new file mode 100644 index 000000000..9f7aca3b0 --- /dev/null +++ b/chainforge/react-server/src/EvalGen/GradingView.tsx @@ -0,0 +1,183 @@ +import React, { ReactNode, useMemo } from "react"; +import { LLMResponse } from "../backend/typing"; +import { + cleanMetavarsFilterFunc, + llmResponseDataToString, + transformDict, +} from "../backend/utils"; +import { Box, Button, Center, Flex, Stack, Text, Tooltip } from "@mantine/core"; +import { + IconChevronLeft, + IconChevronRight, + IconSparkles, +} from "@tabler/icons-react"; +import { StringLookup } from "../backend/cache"; +import { cleanEscapedBraces } from "../backend/template"; + +const HeaderText = ({ children }: { children: ReactNode }) => { + return ( + + {children} + + ); +}; + +export interface GradingViewProps { + shownResponse: LLMResponse | undefined; + shownResponseIdx: number; + responseCount: number; + gotoPrevResponse: () => void; + gotoNextResponse: () => void; +} + +const GradingView: React.FC = ({ + shownResponse, + shownResponseIdx, + responseCount, + gotoPrevResponse, + gotoNextResponse, +}) => { + // Calculate inner values only when shownResponse changes + const responseText = useMemo( + () => + shownResponse && shownResponse.responses?.length > 0 + ? cleanEscapedBraces( + llmResponseDataToString(shownResponse.responses[0]), + ) + : "", + [shownResponse], + ); + + const prompt = useMemo( + () => StringLookup.get(shownResponse?.prompt) ?? "", + [shownResponse], + ); + const varsDivs = useMemo(() => { + const combined_vars_metavars = shownResponse + ? { + ...StringLookup.concretizeDict(shownResponse.vars), + ...transformDict( + StringLookup.concretizeDict(shownResponse.metavars), + cleanMetavarsFilterFunc, + ), + } + : {}; + + return Object.entries(combined_vars_metavars).map(([varname, val]) => ( +
+ {varname} =  + {val} +
+ )); + }, [shownResponse]); + + return ( + + + {/* Top header */} + + + {/* What do you think of this response? */} + What do you think of response #{shownResponseIdx + 1} of{" "} + {responseCount}? + + + {/* Middle response box with chevron buttons < and > for going back and forward a response */} + + {/* Go back to previous response */} + + + + + {/* The response one is currently grading */} +
+
+
+ {responseText} +
+
+
+ + {/* Go forward to the next response */} + + + +
+ {/* Views for the vars (inputs) that generated this response, and the concrete prompt */} + +
+ Vars +
+
+ {varsDivs} +
+
+
+ Prompt +
+
+ {prompt} +
+
+
+
+
+ ); +}; + +export default GradingView; diff --git a/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx new file mode 100644 index 000000000..cd4e908fe --- /dev/null +++ b/chainforge/react-server/src/EvalGen/PickCriteriaStep.tsx @@ -0,0 +1,580 @@ +import React, { useMemo, useState } from "react"; +import { EvalCriteria, EvalFunctionReport } from "../backend/evalgen/typing"; +import { + Accordion, + Button, + Card, + Checkbox, + Code, + Divider, + Flex, + Group, + Popover, + RingProgress, + ScrollArea, + SimpleGrid, + Skeleton, + Stack, + Switch, + Text, + Textarea, + TextInput, + Title, + Tooltip, + useMantineTheme, +} from "@mantine/core"; +import { useDisclosure } from "@mantine/hooks"; +import { + IconCode, + IconRepeat, + IconRobot, + IconSparkles, + IconTrash, +} from "@tabler/icons-react"; +import useStore from "../store"; +import { accuracyToColor, cmatrixTextAnnotations } from "../backend/utils"; +import { + generateLLMEvaluationCriteria, + getPromptForGenEvalCriteriaFromDesc, +} from "../backend/evalgen/utils"; +import { v4 as uuid } from "uuid"; +import Plot from "react-plotly.js"; + +interface PickCriteriaStepProps { + onNext: () => void; + onPrevious: () => void; + criteria: EvalCriteria[]; + setCriteria: React.Dispatch>; + genCriteriaFromContext: () => Promise; + setOnNextCallback: React.Dispatch unknown>>; + genAIModelNames: { large: string; small: string }; +} + +export interface CriteriaCardProps { + title: string; + description: string; + evalMethod: string; + onTitleChange?: (newTitle: string) => void; + onDescriptionChange?: (newDesc: string) => void; + onEvalMethodChange?: (newEvalMethod: string) => void; + onRemove?: () => void; + reportMode?: boolean; + evalFuncReport?: EvalFunctionReport; + onCheck?: (newChecked: boolean) => void; + otherFuncs?: EvalFunctionReport[]; +} + +export const CriteriaCard: React.FC = function CriteriaCard({ + title, + description, + evalMethod, + onTitleChange, + onDescriptionChange, + onEvalMethodChange, + onRemove, + reportMode, + evalFuncReport, + onCheck, + otherFuncs, +}) { + const [checked, setChecked] = useState(true); + const [codeChecked, setCodeChecked] = useState(evalMethod === "code"); + const theme = useMantineTheme(); + + // Report card specific + const [openedCMatrix, { close: closeCMatrix, open: openCMatrix }] = + useDisclosure(false); + const [viewedCode, { close: closeViewedCode, open: openViewedCode }] = + useDisclosure(false); + const cMatrixPlot = useMemo(() => { + if (!evalFuncReport) return undefined; + const x = ["Pred.
fail", "Pred.
pass"]; + const y = ["Human
pass", "Human
fail"]; + const z = [ + [evalFuncReport.false_fail, evalFuncReport.true_pass], + [evalFuncReport.true_fail, evalFuncReport.false_pass], + ]; + return ( + + ); + }, [evalFuncReport]); + + const reportAccuracyRing = useMemo(() => { + if (!evalFuncReport) return undefined; + return { + percent: Math.floor((evalFuncReport.f1 ?? 0) * 100), + color: accuracyToColor(evalFuncReport.f1 ?? 0), + }; + }, [evalFuncReport]); + + const setCheckedAndRealign = (newChecked: boolean) => { + setChecked(newChecked); + + // oncheck is a callback to the parent to update the selected eval functions + // oncheck is an awaitable function + if (onCheck && evalFuncReport) onCheck(newChecked); + }; + + const unselectedImplementations = useMemo( + () => + otherFuncs !== undefined && otherFuncs.length > 0 + ? otherFuncs.map((item, idx) => ( +
+ + {item.evalFunction.code} + + +
+ )) + : null, + [otherFuncs], + ); + + return ( + +
setChecked(!checked)} + onKeyUp={(e) => e.preventDefault()} + className="checkcard" + > + + setCheckedAndRealign(!checked)} + tabIndex={-1} + size="xs" + mr="sm" + mt="xs" + styles={{ input: { cursor: "pointer" } }} + aria-hidden + /> + + +
+ + onTitleChange ? onTitleChange(e.currentTarget.value) : null + } + mb={7} + lh={1} + styles={{ + input: { + border: "none", + borderWidth: "0px", + padding: "0px", + background: "transparent", + fontWeight: 500, + fontSize: "12pt", + margin: "0px", + height: "auto", + minHeight: "auto", + }, + }} + /> + +