diff --git a/community-contributions/Jonas Thamane Week 6 PR..ipynb b/community-contributions/Jonas Thamane Week 6 PR..ipynb new file mode 100644 index 0000000000..d0678120aa --- /dev/null +++ b/community-contributions/Jonas Thamane Week 6 PR..ipynb @@ -0,0 +1,810 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "eb92f125", + "metadata": {}, + "source": [ + "\n", + "## Frontier Model Price Estimation\n", + "\n", + "**Pipeline:**\n", + "1. Load dataset from Hugging Face (`ed-donner/items_lite`)\n", + "2. Prepare training data in JSONL format\n", + "3. Build a few-shot Claude pricer using those examples\n", + "4. Evaluate with the full `Tester` class (scatter plot + error trend chart)" + ] + }, + { + "cell_type": "markdown", + "id": "feef3dbb", + "metadata": {}, + "source": [ + "## 1. Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c53214dd", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "!{sys.executable} -m pip install -q anthropic python-dotenv huggingface_hub datasets \\\n", + " scikit-learn pandas plotly tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "1f1b6931", + "metadata": {}, + "source": [ + "## 2. Imports & Environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60f766c0", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import json\n", + "import math\n", + "from pathlib import Path\n", + "from itertools import accumulate\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "from dataclasses import dataclass, field\n", + "from typing import Optional\n", + "\n", + "import pandas as pd\n", + "import plotly.express as px\n", + "import plotly.graph_objects as go\n", + "from sklearn.metrics import mean_squared_error, r2_score\n", + "from tqdm.notebook import tqdm\n", + "from dotenv import load_dotenv\n", + "from huggingface_hub import login\n", + "import anthropic\n", + "\n", + "load_dotenv(override=True)\n", + "\n", + "hf_token = os.getenv(\"HF_TOKEN\", \"\")\n", + "if hf_token:\n", + " login(hf_token, add_to_git_credential=True)\n", + " print(f\"HuggingFace token found: {hf_token[:8]}...\")\n", + "else:\n", + " print(\"⚠️ HF_TOKEN not set\")\n", + "\n", + "anthropic_key = os.getenv(\"ANTHROPIC_API_KEY\", \"\")\n", + "if anthropic_key:\n", + " print(f\"Anthropic API Key found: {anthropic_key[:15]}...\")\n", + "else:\n", + " print(\"⚠️ ANTHROPIC_API_KEY not set\")\n", + "\n", + "\n", + "LITE_MODE = False\n", + "MODEL = \"claude-haiku-4-5-20251001\" \n", + "N_TRAIN = 60 \n", + "N_VAL = 50 \n", + "WORKERS = 5\n", + "DEFAULT_SIZE = 200\n", + "\n", + "print(f\"\\nModel : {MODEL}\")\n", + "print(f\"Train : {N_TRAIN} examples (few-shot)\")" + ] + }, + { + "cell_type": "markdown", + "id": "30395409", + "metadata": {}, + "source": [ + "## 3. Load Dataset from Hugging Face\n", + "\n", + "We use the same `ed-donner/items_lite` dataset as the original notebook.\n", + "The `Item` dataclass below replicates the interface of `pricer.items.Item`\n", + "so all downstream code works unchanged." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e969f575", + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "import os\n", + "os.environ[\"HF_TOKEN\"] = \"hf_9999989998989898\"\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "913e1553", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "\n", + "@dataclass\n", + "class Item:\n", + " \"\"\"Mirrors the interface of pricer.items.Item.\"\"\"\n", + " title: str\n", + " price: float\n", + " summary: str = \"\"\n", + "\n", + " def __post_init__(self):\n", + " if not self.summary:\n", + " self.summary = self.title\n", + "\n", + " @classmethod\n", + " def from_hub(cls, dataset_name: str):\n", + " \"\"\"\n", + " Load train / val / test splits from a HuggingFace dataset.\n", + " Expects columns: title, price, and optionally: details / description.\n", + " \"\"\"\n", + " ds = load_dataset(dataset_name)\n", + " splits = {}\n", + " for split in [\"train\", \"validation\", \"test\"]:\n", + " key = split if split in ds else list(ds.keys())[0]\n", + " rows = ds[key]\n", + " items = []\n", + " for row in rows:\n", + " title = str(row.get(\"title\", \"\"))\n", + " price = float(row.get(\"price\", 0))\n", + " details = str(row.get(\"details\", \"\") or row.get(\"description\", \"\") or \"\")\n", + " summary = f\"{title}\\n{details}\".strip() if details else title\n", + " items.append(cls(title=title, price=price, summary=summary))\n", + " splits[split] = items\n", + " if split != \"train\": \n", + " break\n", + "\n", + " \n", + " all_items = splits.get(\"train\", [])\n", + " n = len(all_items)\n", + " t_end = int(n * 0.7)\n", + " v_end = int(n * 0.85)\n", + " return all_items[:t_end], all_items[t_end:v_end], all_items[v_end:]\n", + "\n", + "\n", + "\n", + "DATASET = \"ed-donner/items_lite\"\n", + "train, val, test = Item.from_hub(DATASET)\n", + "print(f\"Loaded {len(train):,} training items\")\n", + "print(f\" {len(val):,} validation items\")\n", + "print(f\" {len(test):,} test items\")\n", + "print(f\"\\nSample item:\")\n", + "print(f\" title : {train[0].title}\")\n", + "print(f\" price : ${train[0].price:.2f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f0c014d", + "metadata": {}, + "outputs": [], + "source": [ + "fine_tune_train = train[:N_TRAIN]\n", + "fine_tune_validation = val[:N_VAL]\n", + "\n", + "print(f\"Few-shot training examples : {len(fine_tune_train)}\")\n", + "print(f\"Validation examples : {len(fine_tune_validation)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d0f405c7", + "metadata": {}, + "source": [ + "# Step 1 — Prepare Data in JSONL Format\n", + "\n", + "We produce the same JSONL structure as the original OpenAI notebook.\n", + "Each line is a JSON object with a `messages` array containing a user turn and an assistant turn." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "296c11a7", + "metadata": {}, + "outputs": [], + "source": [ + "def messages_for(item: Item) -> list[dict]:\n", + " \"\"\"Build a (user, assistant) message pair for one item.\"\"\"\n", + " user_content = (\n", + " f\"Estimate the price of this product. \"\n", + " f\"Respond with the price only, no explanation.\\n\\n{item.summary}\"\n", + " )\n", + " return [\n", + " {\"role\": \"user\", \"content\": user_content},\n", + " {\"role\": \"assistant\", \"content\": f\"${item.price:.2f}\"},\n", + " ]\n", + "\n", + "\n", + "\n", + "messages_for(fine_tune_train[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89f7ee8e", + "metadata": {}, + "outputs": [], + "source": [ + "def make_jsonl(items: list[Item]) -> str:\n", + " \"\"\"\n", + " Convert items to JSONL string.\n", + " Each line: {\"messages\": [{role, content}, ...]}\n", + " \"\"\"\n", + " lines = []\n", + " for item in items:\n", + " obj = {\"messages\": messages_for(item)}\n", + " lines.append(json.dumps(obj))\n", + " return \"\\n\".join(lines)\n", + "\n", + "\n", + "\n", + "print(make_jsonl(train[:3]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f54b459", + "metadata": {}, + "outputs": [], + "source": [ + "def write_jsonl(items: list[Item], filename: str):\n", + " \"\"\"Write items to a JSONL file.\"\"\"\n", + " Path(filename).parent.mkdir(parents=True, exist_ok=True)\n", + " with open(filename, \"w\", encoding=\"utf-8\") as f:\n", + " f.write(make_jsonl(items))\n", + " print(f\"Written {len(items)} examples → {filename}\")\n", + "\n", + "\n", + "write_jsonl(fine_tune_train, \"jsonl/fine_tune_train.jsonl\")\n", + "write_jsonl(fine_tune_validation, \"jsonl/fine_tune_validation.jsonl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99323e23", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "with open(\"jsonl/fine_tune_train.jsonl\") as f:\n", + " lines = f.readlines()\n", + "print(f\"Training JSONL: {len(lines)} lines\")\n", + "print(\"First line:\", lines[0][:120], \"...\")" + ] + }, + { + "cell_type": "markdown", + "id": "122f2aef", + "metadata": {}, + "source": [ + "# Step 2 — Build the Few-Shot Claude Pricer\n", + "\n", + "Since Anthropic does not offer a public fine-tuning API, we achieve the same effect by:\n", + "\n", + "1. Loading our 60 training examples\n", + "2. Injecting them as few-shot examples directly into the **system prompt**\n", + "3. Querying Claude with the new product as the final user turn\n", + "\n", + "This is the recommended Anthropic-native approach for adapting a model to a specific task\n", + "without fine-tuning — and for a 60-example dataset it performs comparably." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b47149bd", + "metadata": {}, + "outputs": [], + "source": [ + "client = anthropic.Anthropic(api_key=anthropic_key)\n", + "\n", + "\n", + "def build_few_shot_system_prompt(examples: list[Item]) -> str:\n", + " \"\"\"\n", + " Build a system prompt that embeds all training examples as few-shot demonstrations.\n", + " This is the Anthropic equivalent of fine-tuning on 60 examples.\n", + " \"\"\"\n", + " header = (\n", + " \"You are an expert product pricer. \"\n", + " \"When given a product description, you respond with ONLY the price in the format $X.XX — \"\n", + " \"no explanation, no other text.\\n\\n\"\n", + " \"Here are examples of correct pricing:\\n\\n\"\n", + " )\n", + " shots = []\n", + " for item in examples:\n", + " shots.append(\n", + " f\"Product: {item.summary[:300]}\\n\"\n", + " f\"Price: ${item.price:.2f}\"\n", + " )\n", + " return header + \"\\n\\n\".join(shots)\n", + "\n", + "\n", + "\n", + "FEW_SHOT_SYSTEM_PROMPT = build_few_shot_system_prompt(fine_tune_train)\n", + "\n", + "print(f\"System prompt length: {len(FEW_SHOT_SYSTEM_PROMPT):,} characters\")\n", + "print(f\"\\nFirst 400 chars:\\n{FEW_SHOT_SYSTEM_PROMPT[:400]}...\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63df25c8", + "metadata": {}, + "outputs": [], + "source": [ + "def claude_few_shot_pricer(item: Item) -> str:\n", + " \"\"\"\n", + " Inference function — mirrors gpt_4__1_nano_fine_tuned() from the original.\n", + " Uses the pre-built few-shot system prompt + Claude.\n", + " \"\"\"\n", + " user_message = (\n", + " f\"Estimate the price of this product. \"\n", + " f\"Respond with the price only, no explanation.\\n\\n{item.summary}\"\n", + " )\n", + " response = client.messages.create(\n", + " model=MODEL,\n", + " max_tokens=16, \n", + " system=FEW_SHOT_SYSTEM_PROMPT,\n", + " messages=[{\"role\": \"user\", \"content\": user_message}]\n", + " )\n", + " return response.content[0].text.strip()\n", + "\n", + "\n", + "print(\"✅ claude_few_shot_pricer defined\")" + ] + }, + { + "cell_type": "markdown", + "id": "a114e187", + "metadata": {}, + "source": [ + "# Step 3 — Test the Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f73a1d5c", + "metadata": {}, + "outputs": [], + "source": [ + "# Smoke test on the first test item\n", + "sample = test[0]\n", + "prediction = claude_few_shot_pricer(sample)\n", + "\n", + "print(f\"Product : {sample.title}\")\n", + "print(f\"Actual : ${sample.price:.2f}\")\n", + "print(f\"Claude : {prediction}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5de233c3", + "metadata": {}, + "outputs": [], + "source": [ + "# A Quick Test\n", + "print(\"Quick test on first 5 items:\\n\")\n", + "print(f\"{'Product':<45} {'Actual':>10} {'Claude':>10}\")\n", + "print(\"-\" * 68)\n", + "for item in test[:5]:\n", + " pred = claude_few_shot_pricer(item)\n", + " title = item.title[:44]\n", + " print(f\"{title:<45} ${item.price:>8.2f} {pred:>10}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "301c5283", + "metadata": {}, + "outputs": [], + "source": [ + "SAFE_WORKERS = 1\n", + "\n", + "GREEN = \"\\033[92m\"\n", + "YELLOW = \"\\033[93m\"\n", + "RED = \"\\033[91m\"\n", + "RESET = \"\\033[0m\"\n", + "\n", + "COLOR_MAP = {\"red\": RED, \"orange\": YELLOW, \"green\": GREEN}\n", + "\n", + "\n", + "class Tester:\n", + "\n", + " def __init__(self, predictor, data, title=None, size=DEFAULT_SIZE, workers=SAFE_WORKERS):\n", + "\n", + " self.predictor = predictor\n", + " self.data = data\n", + " self.title = title or self.make_title(predictor)\n", + "\n", + " self.size = size\n", + " self.workers = min(workers, SAFE_WORKERS)\n", + "\n", + " self.titles = []\n", + " self.guesses = []\n", + " self.truths = []\n", + " self.errors = []\n", + " self.colors = []\n", + "\n", + " @staticmethod\n", + " def make_title(predictor):\n", + "\n", + " return (\n", + " predictor.__name__\n", + " .replace(\"__\", \".\")\n", + " .replace(\"_\", \" \")\n", + " .title()\n", + " .replace(\"Gpt\", \"GPT\")\n", + " )\n", + "\n", + " @staticmethod\n", + " def post_process(value):\n", + " \"\"\"Extract numeric price safely\"\"\"\n", + "\n", + " if isinstance(value, str):\n", + "\n", + " value = value.replace(\"$\", \"\").replace(\",\", \"\")\n", + "\n", + " match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", value)\n", + "\n", + " return float(match.group()) if match else float(\"nan\")\n", + "\n", + " return float(value)\n", + "\n", + " def color_for(self, error, truth):\n", + "\n", + " if truth == 0:\n", + " return \"orange\"\n", + "\n", + " if error < 40 or error / truth < 0.2:\n", + " return \"green\"\n", + "\n", + " elif error < 80 or error / truth < 0.4:\n", + " return \"orange\"\n", + "\n", + " return \"red\"\n", + "\n", + " def run_datapoint(self, i):\n", + "\n", + " datapoint = self.data[i]\n", + "\n", + " value = self.predictor(datapoint)\n", + "\n", + " guess = self.post_process(value)\n", + "\n", + " truth = datapoint.price\n", + "\n", + " error = abs(guess - truth)\n", + "\n", + " color = self.color_for(error, truth)\n", + "\n", + " title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40] + \"...\"\n", + "\n", + " return title, guess, truth, error, color\n", + "\n", + " def chart(self, title):\n", + "\n", + " df = pd.DataFrame({\n", + " \"truth\": self.truths,\n", + " \"guess\": self.guesses,\n", + " \"title\": self.titles,\n", + " \"error\": self.errors,\n", + " \"color\": self.colors,\n", + " })\n", + "\n", + " df[\"hover\"] = [\n", + " f\"{t}\\nGuess=${g:,.2f} Actual=${y:,.2f}\"\n", + " for t, g, y in zip(df[\"title\"], df[\"guess\"], df[\"truth\"])\n", + " ]\n", + "\n", + " max_val = float(max(df[\"truth\"].max(), df[\"guess\"].max()))\n", + "\n", + " fig = px.scatter(\n", + " df,\n", + " x=\"truth\",\n", + " y=\"guess\",\n", + " color=\"color\",\n", + " color_discrete_map={\n", + " \"green\": \"green\",\n", + " \"orange\": \"orange\",\n", + " \"red\": \"red\"\n", + " },\n", + " title=title,\n", + " labels={\"truth\": \"Actual Price\", \"guess\": \"Predicted Price\"},\n", + " width=1000,\n", + " height=800\n", + " )\n", + "\n", + " for tr in fig.data:\n", + "\n", + " mask = df[\"color\"] == tr.name\n", + "\n", + " tr.customdata = df.loc[mask, [\"hover\"]].to_numpy()\n", + "\n", + " tr.hovertemplate = \"%{customdata[0]}\"\n", + "\n", + " tr.marker.update(size=6)\n", + "\n", + " fig.add_trace(go.Scatter(\n", + " x=[0, max_val],\n", + " y=[0, max_val],\n", + " mode=\"lines\",\n", + " line=dict(width=2, dash=\"dash\", color=\"deepskyblue\"),\n", + " hoverinfo=\"skip\",\n", + " showlegend=False\n", + " ))\n", + "\n", + " fig.update_xaxes(range=[0, max_val])\n", + " fig.update_yaxes(range=[0, max_val])\n", + "\n", + " fig.update_layout(showlegend=False)\n", + "\n", + " fig.show()\n", + "\n", + " def error_trend_chart(self):\n", + "\n", + " n = len(self.errors)\n", + "\n", + " running_sums = list(accumulate(self.errors))\n", + "\n", + " x = list(range(1, n + 1))\n", + "\n", + " running_means = [s / i for s, i in zip(running_sums, x)]\n", + "\n", + " running_squares = list(accumulate(e * e for e in self.errors))\n", + "\n", + " running_stds = [\n", + " math.sqrt((sq / i) - (m ** 2)) if i > 1 else 0\n", + " for i, sq, m in zip(x, running_squares, running_means)\n", + " ]\n", + "\n", + " ci = [\n", + " 1.96 * (sd / math.sqrt(i)) if i > 1 else 0\n", + " for i, sd in zip(x, running_stds)\n", + " ]\n", + "\n", + " upper = [m + c for m, c in zip(running_means, ci)]\n", + "\n", + " lower = [m - c for m, c in zip(running_means, ci)]\n", + "\n", + " fig = go.Figure()\n", + "\n", + " fig.add_trace(go.Scatter(\n", + " x=x + x[::-1],\n", + " y=upper + lower[::-1],\n", + " fill=\"toself\",\n", + " fillcolor=\"rgba(128,128,128,0.2)\",\n", + " line=dict(color=\"rgba(255,255,255,0)\"),\n", + " hoverinfo=\"skip\"\n", + " ))\n", + "\n", + " fig.add_trace(go.Scatter(\n", + " x=x,\n", + " y=running_means,\n", + " mode=\"lines\",\n", + " line=dict(width=3, color=\"firebrick\")\n", + " ))\n", + "\n", + " final_mean = running_means[-1]\n", + "\n", + " final_ci = ci[-1]\n", + "\n", + " fig.update_layout(\n", + " title=f\"{self.title} Error: ${final_mean:,.2f} ± ${final_ci:,.2f}\",\n", + " xaxis_title=\"Number of Datapoints\",\n", + " yaxis_title=\"Average Absolute Error ($)\",\n", + " width=1000,\n", + " height=360,\n", + " template=\"plotly_white\",\n", + " showlegend=False\n", + " )\n", + "\n", + " fig.show()\n", + "\n", + " def report(self):\n", + "\n", + " avg_error = sum(self.errors) / len(self.errors)\n", + "\n", + " mse = mean_squared_error(self.truths, self.guesses)\n", + "\n", + " r2 = r2_score(self.truths, self.guesses) * 100\n", + "\n", + " title = (\n", + " f\"{self.title} results
\"\n", + " f\"Error: ${avg_error:,.2f} \"\n", + " f\"MSE: {mse:,.0f} \"\n", + " f\"r²: {r2:.1f}%\"\n", + " )\n", + "\n", + " self.error_trend_chart()\n", + "\n", + " self.chart(title)\n", + "\n", + " def run(self):\n", + "\n", + " for i in tqdm(range(self.size)):\n", + "\n", + " title, guess, truth, error, color = self.run_datapoint(i)\n", + "\n", + " self.titles.append(title)\n", + " self.guesses.append(guess)\n", + " self.truths.append(truth)\n", + " self.errors.append(error)\n", + " self.colors.append(color)\n", + "\n", + " print(f\"{COLOR_MAP[color]}${error:.0f} \", end=\"\")\n", + "\n", + " print(RESET)\n", + "\n", + " self.report()\n", + "\n", + "\n", + "def evaluate(function, data, size=DEFAULT_SIZE, workers=SAFE_WORKERS):\n", + "\n", + " Tester(function, data, size=size, workers=workers).run()\n", + "\n", + "\n", + "print(\"✅ Tester & evaluate defined (rate-limit safe)\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45fb5095", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df2e7c58", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "from anthropic import RateLimitError\n", + "\n", + "REQUEST_DELAY = 3 \n", + "\n", + "\n", + "def claude_few_shot_pricer(item):\n", + " \"\"\"\n", + " Claude inference with automatic rate-limit retry and throttling\n", + " \"\"\"\n", + "\n", + " user_message = (\n", + " \"Estimate the price of this product. \"\n", + " \"Respond with the price only, no explanation.\\n\\n\"\n", + " f\"{item.summary}\"\n", + " )\n", + "\n", + " while True:\n", + " try:\n", + " response = client.messages.create(\n", + " model=MODEL,\n", + " max_tokens=16,\n", + " system=FEW_SHOT_SYSTEM_PROMPT,\n", + " messages=[{\"role\": \"user\", \"content\": user_message}]\n", + " )\n", + "\n", + " time.sleep(REQUEST_DELAY) \n", + " return response.content[0].text.strip()\n", + "\n", + " except RateLimitError:\n", + " print(\"⚠️ Claude rate limit hit — waiting 8 seconds...\")\n", + " time.sleep(8)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0372803a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06741096", + "metadata": {}, + "outputs": [], + "source": [ + "# ── Run the full evaluation ────────────────────────────────────────────────\n", + "# This queries Claude for every test item (up to DEFAULT_SIZE=200) in parallel.\n", + "# Expect it to take 1-3 minutes depending on rate limits.\n", + "\n", + "evaluate(claude_few_shot_pricer, test, size=50)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "802a1075", + "metadata": {}, + "outputs": [], + "source": [ + "def claude_zero_shot(item: Item) -> str:\n", + " \"\"\"Claude with no few-shot examples — baseline comparison.\"\"\"\n", + " response = client.messages.create(\n", + " model=MODEL,\n", + " max_tokens=16,\n", + " system=\"You are an expert product pricer. Respond with ONLY the price in the format $X.XX.\",\n", + " messages=[{\n", + " \"role\": \"user\",\n", + " \"content\": f\"Estimate the price of this product:\\n\\n{item.summary}\"\n", + " }]\n", + " )\n", + " return response.content[0].text.strip()\n", + "\n", + "\n", + "# Evaluate zero-shot baseline (smaller sample for speed)\n", + "evaluate(claude_zero_shot, test, size=50)" + ] + }, + { + "cell_type": "markdown", + "id": "76c2e49e", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "| Component | Original (OpenAI) | This notebook (Anthropic) |\n", + "|---|---|---|\n", + "| LLM | GPT-4.1-nano (fine-tuned) | Claude Haiku (few-shot) |\n", + "| Training data | 60 examples → JSONL → upload → training job | 60 examples → JSONL → injected into system prompt |\n", + "| Inference | Fine-tuned model endpoint | Claude API with few-shot system prompt |\n", + "| Evaluation | `Tester` class + Plotly charts | Identical `Tester` class + Plotly charts |\n", + "| Dataset | `ed-donner/items_lite` (HuggingFace) | Same |\n", + "\n", + "### Key takeaways\n", + "- Few-shot prompting with 60 examples is a strong substitute for fine-tuning on small datasets\n", + "- The JSONL preparation pipeline is identical and future-proof — if Anthropic releases a fine-tuning API, the same files can be uploaded directly\n", + "- Swap `claude-haiku-4-5-20251001` → `claude-sonnet-4-6` for higher accuracy at higher cost\n", + "- Swap `MODEL` and `N_TRAIN` at the top of the notebook to experiment" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv (3.11.9)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}