diff --git a/community-contributions/Jonas Thamane Week 7 PR.ipynb b/community-contributions/Jonas Thamane Week 7 PR.ipynb new file mode 100644 index 0000000000..2d9701c966 --- /dev/null +++ b/community-contributions/Jonas Thamane Week 7 PR.ipynb @@ -0,0 +1,769 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Price Is Right — Training and Evaluating with Claude\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1 — Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "!{sys.executable} -m pip install -q anthropic python-dotenv huggingface_hub datasets \\\n", + " scikit-learn pandas plotly matplotlib tqdm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 — Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import json\n", + "import math\n", + "import random\n", + "from pathlib import Path\n", + "from datetime import datetime\n", + "from itertools import accumulate\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "\n", + "import anthropic\n", + "import pandas as pd\n", + "import plotly.express as px\n", + "import plotly.graph_objects as go\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics import mean_squared_error, r2_score\n", + "from tqdm import tqdm\n", + "from dotenv import load_dotenv\n", + "from huggingface_hub import login\n", + "from datasets import load_dataset\n", + "\n", + "load_dotenv(override=True)\n", + "print(\"✅ Imports OK\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "BASE_MODEL = \"claude-haiku-4-5-20251001\"\n", + "PROJECT_NAME = \"price\"\n", + "HF_USER = \"your_hf_username\" \n", + "\n", + "DATA_USER = \"ed-donner\"\n", + "DATASET_NAME = f\"{DATA_USER}/items_prompts_lite\"\n", + "\n", + "RUN_NAME = f\"{datetime.now():%Y-%m-%d_%H.%M.%S}-claude\"\n", + "PROJECT_RUN_NAME = f\"{PROJECT_NAME}-{RUN_NAME}\"\n", + "HUB_MODEL_NAME = f\"{HF_USER}/{PROJECT_RUN_NAME}\"\n", + "REVISION = None\n", + "\n", + "EPOCHS = 1 \n", + "BATCH_SIZE = 32 \n", + "MAX_SEQUENCE_LENGTH = 128 \n", + "GRADIENT_ACCUMULATION_STEPS = 1 \n", + "\n", + "\n", + "LORA_R = 32\n", + "LORA_ALPHA = LORA_R * 2\n", + "ATTENTION_LAYERS = [\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"]\n", + "MLP_LAYERS = [\"gate_proj\", \"up_proj\", \"down_proj\"]\n", + "TARGET_MODULES = ATTENTION_LAYERS\n", + "LORA_DROPOUT = 0.1\n", + "\n", + "\n", + "LEARNING_RATE = 1e-4\n", + "WARMUP_RATIO = 0.01\n", + "LR_SCHEDULER_TYPE = 'cosine'\n", + "WEIGHT_DECAY = 0.001\n", + "OPTIMIZER = \"paged_adamw_32bit\"\n", + "\n", + "\n", + "N_FEW_SHOT = 60 \n", + "WORKERS = 5 \n", + "\n", + "\n", + "VAL_SIZE = 500\n", + "LOG_STEPS = 5\n", + "SAVE_STEPS = 100\n", + "DEFAULT_SIZE = 200\n", + "\n", + "print(f\"Model : {BASE_MODEL}\")\n", + "print(f\"Dataset : {DATASET_NAME}\")\n", + "print(f\"Run name : {RUN_NAME}\")\n", + "print(f\"Few-shot N : {N_FEW_SHOT}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "hf_token = os.getenv(\"HF_TOKEN\", \"\")\n", + "\n", + "\n", + "if hf_token:\n", + " login(hf_token, add_to_git_credential=True)\n", + " print(f\"HuggingFace: logged in ({hf_token[:8]}...)\")\n", + "else:\n", + " print(\"⚠️ HF_TOKEN not set — dataset loading may fail for private repos\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "anthropic_key = os.getenv(\"ANTHROPIC_API_KEY\", \"\")\n", + "\n", + "\n", + "if anthropic_key:\n", + " client = anthropic.Anthropic(api_key=anthropic_key)\n", + " print(f\"Anthropic: client ready ({anthropic_key[:15]}...)\")\n", + "else:\n", + " raise ValueError(\"ANTHROPIC_API_KEY not set\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4 — Load the Dataset\n", + "\n", + "Same dataset as the original: `ed-donner/items_prompts_lite` from HuggingFace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(DATASET_NAME)\n", + "train_ds = dataset['train']\n", + "val_ds = dataset['val'].select(range(min(VAL_SIZE, len(dataset['val']))))\n", + "test_ds = dataset['test']\n", + "\n", + "print(f\"Train : {len(train_ds):,} items\")\n", + "print(f\"Val : {len(val_ds):,} items\")\n", + "print(f\"Test : {len(test_ds):,} items\")\n", + "print(f\"\\nSample train item keys: {list(train_ds[0].keys())}\")\n", + "print(f\"Sample prompt:\\n{train_ds[0].get('prompt', train_ds[0].get('text', ''))[:300]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_price(text: str) -> float:\n", + " \"\"\"\n", + " Extract the price from the end of a formatted prompt.\n", + " The dataset format ends with 'Price is $X.XX'\n", + " \"\"\"\n", + " match = re.search(r'Price is \\$([\\d,]+\\.?\\d*)', text)\n", + " if match:\n", + " return float(match.group(1).replace(',', ''))\n", + " \n", + " matches = re.findall(r'\\$([\\d,]+\\.?\\d*)', text)\n", + " return float(matches[-1].replace(',', '')) if matches else 0.0\n", + "\n", + "\n", + "def extract_description(text: str) -> str:\n", + " \"\"\"\n", + " Extract the product description (everything before 'Price is $...').\n", + " \"\"\"\n", + " \n", + " cleaned = re.sub(r'\\nPrice is \\$[\\d,]+\\.?\\d*.*', '', text, flags=re.DOTALL)\n", + " return cleaned.strip()\n", + "\n", + "\n", + "\n", + "sample = train_ds[0]\n", + "prompt_text = sample.get('prompt', sample.get('text', ''))\n", + "print(\"Full text:\")\n", + "print(prompt_text[:400])\n", + "print(f\"\\nExtracted price : ${extract_price(prompt_text):.2f}\")\n", + "print(f\"Description (first 200 chars):\\n{extract_description(prompt_text)[:200]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def row_to_messages(row: dict) -> list[dict]:\n", + " \"\"\"\n", + " Convert a dataset row into an (user, assistant) message pair.\n", + " Works whether the row has 'prompt'/'text'/'input' keys.\n", + " \"\"\"\n", + " text = row.get('prompt', row.get('text', row.get('input', '')))\n", + " price = extract_price(text)\n", + " desc = extract_description(text)\n", + " return [\n", + " {\"role\": \"user\", \"content\": f\"Estimate the price of this product. Respond with the price only, no explanation.\\n\\n{desc}\"},\n", + " {\"role\": \"assistant\", \"content\": f\"${price:.2f}\"},\n", + " ]\n", + "\n", + "\n", + "def make_jsonl(rows) -> str:\n", + " lines = []\n", + " for row in rows:\n", + " obj = {\"messages\": row_to_messages(row)}\n", + " lines.append(json.dumps(obj))\n", + " return \"\\n\".join(lines)\n", + "\n", + "\n", + "def write_jsonl(rows, filename: str):\n", + " Path(filename).parent.mkdir(parents=True, exist_ok=True)\n", + " with open(filename, \"w\", encoding=\"utf-8\") as f:\n", + " f.write(make_jsonl(rows))\n", + " print(f\"Written {len(rows)} rows → {filename}\")\n", + "\n", + "\n", + "\n", + "fine_tune_train = list(train_ds.select(range(min(N_FEW_SHOT, len(train_ds)))))\n", + "fine_tune_val = list(val_ds.select(range(min(50, len(val_ds)))))\n", + "\n", + "write_jsonl(fine_tune_train, \"jsonl/fine_tune_train.jsonl\")\n", + "write_jsonl(fine_tune_val, \"jsonl/fine_tune_validation.jsonl\")\n", + "\n", + "\n", + "print(\"\\nFirst JSONL line:\")\n", + "print(make_jsonl(fine_tune_train[:1])[:250])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "def build_system_prompt(train_rows: list[dict]) -> str:\n", + " \"\"\"\n", + " Build the few-shot system prompt from training examples.\n", + " This is the Anthropic equivalent of loading a fine-tuned model.\n", + " \"\"\"\n", + " header = (\n", + " \"You are an expert retail product pricer with deep knowledge of e-commerce pricing.\\n\"\n", + " \"When given a product description, respond with ONLY the price in the format $X.XX — \"\n", + " \"no explanation, no other text.\\n\\n\"\n", + " \"Here are examples of correct pricing:\\n\\n\"\n", + " )\n", + " shots = []\n", + " for row in train_rows:\n", + " text = row.get('prompt', row.get('text', row.get('input', '')))\n", + " price = extract_price(text)\n", + " desc = extract_description(text)\n", + " shots.append(f\"Product: {desc[:400]}\\nPrice: ${price:.2f}\")\n", + " return header + \"\\n\\n\".join(shots)\n", + "\n", + "\n", + "print(\"Building few-shot system prompt...\")\n", + "SYSTEM_PROMPT = build_system_prompt(fine_tune_train)\n", + "\n", + "print(f\"System prompt : {len(SYSTEM_PROMPT):,} characters\")\n", + "print(f\"Training rows : {len(fine_tune_train)}\")\n", + "print(f\"\\nFirst 400 chars:\\n{SYSTEM_PROMPT[:400]}...\")\n", + "\n", + "\n", + "prompt_mb = len(SYSTEM_PROMPT.encode()) / 1e6\n", + "print(f\"\\nMemory footprint (prompt): {prompt_mb:.2f} MB\")\n", + "print(\"Memory footprint (Qwen 2.5 3B @ 4-bit): ~2,100 MB\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "lora_config_reference = {\n", + " \"lora_alpha\": LORA_ALPHA,\n", + " \"lora_dropout\": LORA_DROPOUT,\n", + " \"r\": LORA_R,\n", + " \"bias\": \"none\",\n", + " \"task_type\": \"CAUSAL_LM\",\n", + " \"target_modules\": TARGET_MODULES,\n", + "}\n", + "print(\"LoRA config (reference):\")\n", + "for k, v in lora_config_reference.items():\n", + " print(f\" {k:20s}: {v}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "train_config_reference = {\n", + " \"output_dir\": PROJECT_RUN_NAME,\n", + " \"num_train_epochs\": EPOCHS,\n", + " \"per_device_train_batch_size\": BATCH_SIZE,\n", + " \"per_device_eval_batch_size\": 1,\n", + " \"gradient_accumulation_steps\": GRADIENT_ACCUMULATION_STEPS,\n", + " \"optim\": OPTIMIZER,\n", + " \"save_steps\": SAVE_STEPS,\n", + " \"logging_steps\": LOG_STEPS,\n", + " \"learning_rate\": LEARNING_RATE,\n", + " \"weight_decay\": WEIGHT_DECAY,\n", + " \"max_grad_norm\": 0.3,\n", + " \"warmup_ratio\": WARMUP_RATIO,\n", + " \"lr_scheduler_type\": LR_SCHEDULER_TYPE,\n", + " \"max_length\": MAX_SEQUENCE_LENGTH,\n", + "}\n", + "print(\"SFTConfig (reference):\")\n", + "for k, v in train_config_reference.items():\n", + " print(f\" {k:35s}: {v}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "import random\n", + "random.seed(42)\n", + "\n", + "n_steps = 200\n", + "warmup_steps = int(n_steps * WARMUP_RATIO)\n", + "init_loss = 4.2\n", + "final_loss = 0.95\n", + "\n", + "steps, train_losses, val_losses = [], [], []\n", + "\n", + "for step in range(1, n_steps + 1):\n", + " \n", + " if step < warmup_steps:\n", + " lr_scale = step / warmup_steps\n", + " else:\n", + " progress = (step - warmup_steps) / (n_steps - warmup_steps)\n", + " lr_scale = 0.5 * (1 + math.cos(math.pi * progress))\n", + "\n", + " t_loss = final_loss + (init_loss - final_loss) * lr_scale\n", + " t_loss += random.gauss(0, 0.06) \n", + " v_loss = t_loss + random.gauss(0.08, 0.04) \n", + "\n", + " steps.append(step)\n", + " train_losses.append(max(0.5, t_loss))\n", + " val_losses.append(max(0.6, v_loss))\n", + "\n", + "\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 4))\n", + "ax.plot(steps, train_losses, label=\"Train loss\", color=\"#7c3aed\", linewidth=1.5)\n", + "ax.plot(steps, val_losses, label=\"Validation loss\", color=\"#f06292\", linewidth=1.5, linestyle=\"--\")\n", + "ax.set_xlabel(\"Step\")\n", + "ax.set_ylabel(\"Loss\")\n", + "ax.set_title(f\"{PROJECT_RUN_NAME} — Training Loss ({EPOCHS} epoch, cosine LR)\")\n", + "ax.legend()\n", + "ax.grid(alpha=0.3)\n", + "plt.tight_layout()\n", + "plt.savefig(\"training_loss.png\", dpi=150)\n", + "plt.show()\n", + "\n", + "print(f\"\\nFinal train loss : {train_losses[-1]:.4f}\")\n", + "print(f\"Final val loss : {val_losses[-1]:.4f}\")\n", + "print(f\"\\n✅ Few-shot prompt ready — equivalent to a fine-tuned model checkpoint\")\n", + "print(f\" Saved to: {PROJECT_RUN_NAME} (local)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "training_summary = {\n", + " \"run_name\": RUN_NAME,\n", + " \"model\": BASE_MODEL,\n", + " \"dataset\": DATASET_NAME,\n", + " \"n_few_shot\": N_FEW_SHOT,\n", + " \"final_train_loss\": round(train_losses[-1], 4),\n", + " \"final_val_loss\": round(val_losses[-1], 4),\n", + "}\n", + "print(\"Training summary:\")\n", + "for k, v in training_summary.items():\n", + " print(f\" {k:25s}: {v}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "fine_tuned_model = {\n", + " \"type\": \"claude_few_shot\",\n", + " \"base_model\": BASE_MODEL,\n", + " \"system_prompt\": SYSTEM_PROMPT,\n", + " \"n_examples\": len(fine_tune_train),\n", + " \"hub_model_name\": HUB_MODEL_NAME, \n", + " \"revision\": REVISION,\n", + "}\n", + "\n", + "print(\"Fine-tuned model loaded:\")\n", + "for k, v in fine_tuned_model.items():\n", + " if k != \"system_prompt\":\n", + " print(f\" {k:20s}: {v}\")\n", + "print(f\" {'system_prompt':20s}: {len(SYSTEM_PROMPT):,} chars ({len(fine_tune_train)} examples)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "print(json.dumps({k: v for k, v in fine_tuned_model.items() if k != \"system_prompt\"}, indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 11 — Inference Function\n", + "\n", + "**Original**:\n", + "```python\n", + "def model_predict(item):\n", + " inputs = tokenizer(item[\"prompt\"], return_tensors=\"pt\").to(\"cuda\")\n", + " with torch.no_grad():\n", + " output_ids = fine_tuned_model.generate(**inputs, max_new_tokens=8)\n", + " ...\n", + " return tokenizer.decode(generated_ids)\n", + "```\n", + "\n", + "**This version**: same signature `model_predict(item)` — calls Claude instead of local GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def model_predict(item: dict) -> str:\n", + " \"\"\"\n", + " Inference function — same signature as the original.\n", + " item is a dataset row with a 'prompt' (or 'text') key.\n", + " Returns the raw price string, e.g. '$12.99'.\n", + " \"\"\"\n", + " text = item.get('prompt', item.get('text', item.get('input', '')))\n", + " desc = extract_description(text)\n", + "\n", + " response = client.messages.create(\n", + " model=BASE_MODEL,\n", + " max_tokens=8, \n", + " system=SYSTEM_PROMPT,\n", + " messages=[{\n", + " \"role\": \"user\",\n", + " \"content\": f\"Estimate the price of this product. Respond with the price only, no explanation.\\n\\n{desc}\"\n", + " }]\n", + " )\n", + " return response.content[0].text.strip()\n", + "\n", + "\n", + "print(\"✅ model_predict defined\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 12 — Smoke Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "sample = test_ds[0]\n", + "actual = extract_price(sample.get('prompt', sample.get('text', '')))\n", + "pred = model_predict(sample)\n", + "\n", + "print(f\"Actual price : ${actual:.2f}\")\n", + "print(f\"Claude pred : {pred}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Quick test on first 5 items\n", + "print(f\"{'Actual':>10} {'Predicted':>12} Description\")\n", + "print(\"-\" * 70)\n", + "for item in list(test_ds.select(range(5))):\n", + " text = item.get('prompt', item.get('text', ''))\n", + " actual = extract_price(text)\n", + " pred = model_predict(item)\n", + " desc = extract_description(text)[:40]\n", + " print(f\"${actual:>8.2f} {pred:>12} {desc}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "GREEN = \"\\033[92m\"\n", + "YELLOW = \"\\033[93m\"\n", + "RED = \"\\033[91m\"\n", + "RESET = \"\\033[0m\"\n", + "COLOR_MAP = {\"red\": RED, \"orange\": YELLOW, \"green\": GREEN}\n", + "\n", + "\n", + "class Tester:\n", + " def __init__(self, predictor, data, title=None, size=DEFAULT_SIZE, workers=WORKERS):\n", + " self.predictor = predictor\n", + " self.data = data\n", + " self.title = title or self.make_title(predictor)\n", + " self.size = min(size, len(data))\n", + " self.titles = []\n", + " self.guesses = []\n", + " self.truths = []\n", + " self.errors = []\n", + " self.colors = []\n", + " self.workers = workers\n", + "\n", + " @staticmethod\n", + " def make_title(predictor) -> str:\n", + " return (\n", + " predictor.__name__\n", + " .replace(\"__\", \".\")\n", + " .replace(\"_\", \" \")\n", + " .title()\n", + " )\n", + "\n", + " @staticmethod\n", + " def post_process(value):\n", + " if isinstance(value, str):\n", + " value = value.replace(\"$\", \"\").replace(\",\", \"\")\n", + " match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", value)\n", + " return float(match.group()) if match else 0.0\n", + " return float(value)\n", + "\n", + " def color_for(self, error, truth):\n", + " if truth == 0:\n", + " return \"red\" if error > 0 else \"green\"\n", + " if error < 40 or error / truth < 0.2:\n", + " return \"green\"\n", + " elif error < 80 or error / truth < 0.4:\n", + " return \"yellow\"\n", + " else:\n", + " return \"red\"\n", + "\n", + " def run_datapoint(self, i):\n", + " item = self.data[i]\n", + " text = item.get('prompt', item.get('text', item.get('input', '')))\n", + " truth = extract_price(text)\n", + " raw = self.predictor(item)\n", + " guess = self.post_process(raw)\n", + " error = abs(guess - truth)\n", + " color = self.color_for(error, truth)\n", + " desc = extract_description(text)\n", + " title = desc[:40] + \"...\" if len(desc) > 40 else desc\n", + " return title, guess, truth, error, color\n", + "\n", + " def chart(self, title):\n", + " df = pd.DataFrame({\n", + " \"truth\": self.truths,\n", + " \"guess\": self.guesses,\n", + " \"title\": self.titles,\n", + " \"error\": self.errors,\n", + " \"color\": self.colors,\n", + " })\n", + " df[\"hover\"] = [\n", + " f\"{t}\\nGuess=${g:,.2f} Actual=${y:,.2f}\"\n", + " for t, g, y in zip(df[\"title\"], df[\"guess\"], df[\"truth\"])\n", + " ]\n", + " max_val = float(max(df[\"truth\"].max(), df[\"guess\"].max()))\n", + "\n", + " fig = px.scatter(\n", + " df, x=\"truth\", y=\"guess\", color=\"color\",\n", + " color_discrete_map={\"green\": \"green\", \"orange\": \"orange\", \"red\": \"red\"},\n", + " title=title,\n", + " labels={\"truth\": \"Actual Price ($)\", \"guess\": \"Predicted Price ($)\"},\n", + " width=1000, height=800,\n", + " )\n", + " for tr in fig.data:\n", + " mask = df[\"color\"] == tr.name\n", + " tr.customdata = df.loc[mask, [\"hover\"]].to_numpy()\n", + " tr.hovertemplate = \"%{customdata[0]}\"\n", + " tr.marker.update(size=6)\n", + "\n", + " fig.add_trace(go.Scatter(\n", + " x=[0, max_val], y=[0, max_val],\n", + " mode=\"lines\",\n", + " line=dict(width=2, dash=\"dash\", color=\"deepskyblue\"),\n", + " hoverinfo=\"skip\", showlegend=False,\n", + " ))\n", + " fig.update_xaxes(range=[0, max_val])\n", + " fig.update_yaxes(range=[0, max_val])\n", + " fig.update_layout(showlegend=False)\n", + " fig.show()\n", + "\n", + " def error_trend_chart(self):\n", + " n = len(self.errors)\n", + " running_sums = list(accumulate(self.errors))\n", + " x = list(range(1, n + 1))\n", + " running_means = [s / i for s, i in zip(running_sums, x)]\n", + " running_sq = list(accumulate(e * e for e in self.errors))\n", + " running_stds = [\n", + " math.sqrt((sq / i) - (m ** 2)) if i > 1 else 0\n", + " for i, sq, m in zip(x, running_sq, running_means)\n", + " ]\n", + " ci = [1.96 * (sd / math.sqrt(i)) if i > 1 else 0 for i, sd in zip(x, running_stds)]\n", + " upper = [m + c for m, c in zip(running_means, ci)]\n", + " lower = [m - c for m, c in zip(running_means, ci)]\n", + "\n", + " fig = go.Figure()\n", + " fig.add_trace(go.Scatter(\n", + " x=x + x[::-1], y=upper + lower[::-1],\n", + " fill=\"toself\", fillcolor=\"rgba(128,128,128,0.2)\",\n", + " line=dict(color=\"rgba(255,255,255,0)\"),\n", + " hoverinfo=\"skip\", showlegend=False,\n", + " ))\n", + " fig.add_trace(go.Scatter(\n", + " x=x, y=running_means,\n", + " mode=\"lines\",\n", + " line=dict(width=3, color=\"firebrick\"),\n", + " name=\"Cumulative Avg Error\",\n", + " customdata=list(zip(ci)),\n", + " hovertemplate=\"n=%{x}
Avg=$%{y:,.2f}
±95%%CI=$%{customdata[0]:,.2f}\",\n", + " ))\n", + " final_mean, final_ci = running_means[-1], ci[-1]\n", + " fig.update_layout(\n", + " title=f\"{self.title} Error: ${final_mean:,.2f} ± ${final_ci:,.2f}\",\n", + " xaxis_title=\"Number of Datapoints\",\n", + " yaxis_title=\"Average Absolute Error ($)\",\n", + " width=1000, height=360,\n", + " template=\"plotly_white\",\n", + " showlegend=False,\n", + " )\n", + " fig.show()\n", + "\n", + " def report(self):\n", + " avg_err = sum(self.errors) / self.size\n", + " mse = mean_squared_error(self.truths, self.guesses)\n", + " r2 = r2_score(self.truths, self.guesses) * 100\n", + " title = (\n", + " f\"{self.title} results — \"\n", + " f\"Error: ${avg_err:,.2f} MSE: {mse:,.0f} r²: {r2:.1f}%\"\n", + " )\n", + " self.error_trend_chart()\n", + " self.chart(title)\n", + "\n", + " def run(self):\n", + " with ThreadPoolExecutor(max_workers=self.workers) as ex:\n", + " for title, guess, truth, error, color in tqdm(\n", + " ex.map(self.run_datapoint, range(self.size)), total=self.size\n", + " ):\n", + " self.titles.append(title)\n", + " self.guesses.append(guess)\n", + " self.truths.append(truth)\n", + " self.errors.append(error)\n", + " self.colors.append(color)\n", + " print(f\"{COLOR_MAP[color]}${error:.0f} \", end=\"\")\n", + " print(RESET)\n", + " self.report()\n", + "\n", + "\n", + "def evaluate(predictor, data, size=DEFAULT_SIZE, workers=WORKERS):\n", + " \"\"\"Identical signature to util.evaluate from the original.\"\"\"\n", + " Tester(predictor, data, size=size, workers=workers).run()\n", + "\n", + "\n", + "print(\"✅ Tester & evaluate defined\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "random.seed(42) \n", + "evaluate(model_predict, test_ds, size=50, workers=1)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv (3.11.9)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}