diff --git a/crates/wavekat-turn/Cargo.toml b/crates/wavekat-turn/Cargo.toml index 6d76eb1..aa9d3f4 100644 --- a/crates/wavekat-turn/Cargo.toml +++ b/crates/wavekat-turn/Cargo.toml @@ -19,7 +19,7 @@ pipecat = ["dep:ort", "dep:ndarray", "dep:realfft", "dep:ureq"] livekit = ["dep:ort", "dep:ndarray"] [dependencies] -wavekat-core = "0.0.2" +wavekat-core = "0.0.4" thiserror = "2" # ONNX backends (optional) diff --git a/training/pipecat-smart-turn/Dockerfile b/training/pipecat-smart-turn/Dockerfile new file mode 100644 index 0000000..7da2cdd --- /dev/null +++ b/training/pipecat-smart-turn/Dockerfile @@ -0,0 +1,33 @@ +FROM nvidia/cuda:13.1.0-devel-ubuntu24.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 + +# System dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.12 python3.12-venv python3.12-dev python3-pip \ + portaudio19-dev git curl \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -sf /usr/bin/python3.12 /usr/bin/python + +WORKDIR /workspace + +# Clone upstream repo +RUN git clone https://github.com/pipecat-ai/smart-turn.git . + +# Install Python dependencies +RUN python -m pip install --no-cache-dir --break-system-packages -r requirements.txt + +# Jupyter + visualisation deps for notebook exploration +RUN python -m pip install --no-cache-dir --break-system-packages \ + jupyterlab matplotlib pandas ipywidgets + +# FFmpeg for torchcodec audio decoding +RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +EXPOSE 8888 + +# Default: show usage +CMD ["python", "train_local.py", "--help"] diff --git a/training/pipecat-smart-turn/README.md b/training/pipecat-smart-turn/README.md new file mode 100644 index 0000000..37c3a52 --- /dev/null +++ b/training/pipecat-smart-turn/README.md @@ -0,0 +1,263 @@ +# Training Pipecat Smart Turn + +Upstream repo: + +## Model Overview + +- **Architecture:** Whisper Tiny encoder + attention-pooling classification head (~8M params) +- **Task:** Binary classification — complete vs. incomplete turn +- **Input:** 16 kHz mono PCM, up to 8 seconds, log-mel spectrogram (80×800) +- **Loss:** BCEWithLogitsLoss with dynamic positive-weight balancing +- **Output:** ONNX (FP32 ~32 MB, INT8 ~8 MB) + +## Infrastructure + +- **Region:** Azure Australia East +- **VM Type:** NC4as_T4_v3 (4 vCPUs, 28 GB RAM, NVIDIA Tesla T4 16 GB) +- **Hostname:** `gpu-testing` (via Tailscale) +- **User:** `eason` +- **SSH:** `ssh gpu-testing` (key: `~/.ssh/id_ed25519_wavekat-eason`, configured in `~/.ssh/config`) +- **GPU:** Tesla T4, 16 GB VRAM, driver 590.48.01, CUDA 13.1 + +## Steps + +### 1. Connect to Azure VM + +```bash +ssh gpu-testing +``` + +### 2. Environment Setup + +#### 2.1 GPU Driver + +```bash +sudo apt update +sudo apt install -y linux-headers-$(uname -r) +sudo apt install -y nvidia-driver-590 +sudo reboot + +# If secure boot blocks the module: +sudo mokutil --disable-validation +sudo reboot + +# Verify +nvidia-smi +``` + +#### 2.2 Disk Setup + +Two additional data disks mounted for datasets and checkpoints: + +```bash +sudo mkfs.ext4 /dev/sdc +sudo mkfs.ext4 /dev/sdd +sudo mkdir -p /datasets /checkpoints +sudo mount /dev/sdc /datasets +sudo mount /dev/sdd /checkpoints + +# Persist in fstab +BLK_SDC=$(sudo blkid -s UUID -o value /dev/sdc) +BLK_SDD=$(sudo blkid -s UUID -o value /dev/sdd) +echo "UUID=$BLK_SDC /datasets ext4 defaults,nofail 0 2" | sudo tee -a /etc/fstab +echo "UUID=$BLK_SDD /checkpoints ext4 defaults,nofail 0 2" | sudo tee -a /etc/fstab +``` + +#### 2.3 Docker + NVIDIA Container Toolkit + +```bash +# Install Docker +curl -fsSL https://get.docker.com | sudo sh +sudo usermod -aG docker $USER + +# Install NVIDIA Container Toolkit +curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey \ + | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg +curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list \ + | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' \ + | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list +sudo apt update +sudo apt install -y nvidia-container-toolkit + +# Configure runtime and move Docker root to data disk +sudo systemctl stop docker +sudo mkdir -p /datasets/docker +sudo rsync -aP /var/lib/docker/ /datasets/docker/ +sudo nvidia-ctk runtime configure --runtime=docker +# Edit /etc/docker/daemon.json to add: "data-root": "/datasets/docker" +sudo systemctl start docker +sudo rm -rf /var/lib/docker + +# Move containerd storage to data disk as well +# (Docker's data-root only moves Docker's own data, not containerd's. +# Without this, containerd fills up the root disk during image builds.) +sudo systemctl stop docker +sudo systemctl stop containerd +sudo mkdir -p /datasets/containerd +sudo rsync -aP /var/lib/containerd/ /datasets/containerd/ +sudo rm -rf /var/lib/containerd +sudo ln -s /datasets/containerd /var/lib/containerd +sudo systemctl start containerd +sudo systemctl start docker + +# Verify GPU in container +docker run --rm --gpus all nvidia/cuda:13.1.0-devel-ubuntu24.04 nvidia-smi +``` + +### 3. Fix Disk Permissions + +```bash +sudo chown $USER:$USER /checkpoints /datasets +``` + +### 4. Build Docker Image + +**Local machine** — copy the Dockerfile to the VM: + +```bash +scp training/pipecat-smart-turn/Dockerfile gpu-testing:/checkpoints/Dockerfile.smart-turn +``` + +**On the VM** — build the image: + +```bash +docker build -t smart-turn -f /checkpoints/Dockerfile.smart-turn /checkpoints +``` + +The Dockerfile lives in this repo at `training/pipecat-smart-turn/Dockerfile`. +`/checkpoints` and `/datasets` are kept clean as pure data volumes. + +### 5. Data Preparation + +Upstream datasets (HuggingFace): + +- **Train:** `pipecat-ai/smart-turn-data-v3.2-train` (270k samples, ~41 GB) +- **Test:** `pipecat-ai/smart-turn-data-v3.2-test` + +Dataset columns: `audio`, `id`, `language`, `endpoint_bool`, `midfiller`, `endfiller`, `synthetic`, `dataset` + +Pre-download the dataset so it persists across container runs: + +```bash +docker run -d \ + --name smart-turn-download \ + --gpus all \ + -v /datasets/huggingface:/root/.cache/huggingface \ + smart-turn \ + python -c "from datasets import load_dataset; load_dataset('pipecat-ai/smart-turn-data-v3.2-train')" +``` + +The data is cached at `/datasets/huggingface` on the host. All subsequent runs +must mount this path to `/root/.cache/huggingface` to reuse it. + +Raw data format for custom contributions: +- FLAC files, mono, 16-bit, 16 kHz+ +- Max 16 seconds per file, ~200 ms trailing silence +- Directory structure: `{language}/{complete|incomplete}-{midfiller|endfiller|nofiller}/{uuid}.flac` +- Convert raw to HF dataset: `python datasets/scripts/raw_to_hf_dataset.py ` + +### 6. Dataset Exploration (Notebook) + +**Local machine** — copy the notebook to the VM: + +```bash +scp training/pipecat-smart-turn/explore_dataset.ipynb gpu-testing:/checkpoints/explore_dataset.ipynb +``` + +**On the VM** — launch JupyterLab: + +```bash +docker run -d --name jupyter \ + --gpus all --restart unless-stopped \ + --ipc=host \ + -v /datasets/huggingface:/root/.cache/huggingface \ + -v /checkpoints:/checkpoints \ + -v /datasets:/datasets \ + -p 8888:8888 \ + smart-turn \ + jupyter lab \ + --ip=0.0.0.0 --port=8888 --no-browser --allow-root \ + --notebook-dir=/checkpoints \ + --ServerApp.token='' --ServerApp.password='' +``` + +Open `http://gpu-testing:8888` in a browser (via Tailscale) and run +`explore_dataset.ipynb`. The notebook covers label balance, audio durations, +language/filler/synthetic breakdowns, audio playback, and mel spectrogram +visualisation. + +### 7. Training (Notebook) + +**Local machine** — copy the training notebook to the VM: + +```bash +scp training/pipecat-smart-turn/train.ipynb gpu-testing:/checkpoints/train.ipynb +``` + +Open `http://gpu-testing:8888` in a browser and run `train.ipynb`. The notebook +covers model init, training, evaluation, ONNX export, INT8 quantization, and +benchmarking. + +#### CLI Alternative + +```bash +docker run --gpus all \ + --ipc=host \ + -v /datasets/huggingface:/root/.cache/huggingface \ + -v /checkpoints:/checkpoints \ + -e WANDB_API_KEY=${WANDB_API_KEY} \ + smart-turn \ + python train_local.py \ + --training-run-name my-run \ + --output-dir /checkpoints/output +``` + +Hyperparameters (defaults in `train.py`): + +| Param | Value | +|---|---| +| Base model | `openai/whisper-tiny` | +| Learning rate | 5e-5 | +| Epochs | 4 | +| Train batch size | 384 | +| Eval batch size | 128 | +| Warmup ratio | 0.2 | +| Weight decay | 0.01 | +| LR schedule | Cosine | +| Eval/save steps | 500 | +| Dataloader workers | 6 | + +Optional: set `WANDB_API_KEY` for experiment tracking. + +> **Note:** Batch size 384 requires significant VRAM. With T4 (16 GB) we will +> likely need to reduce this — experiment with 32–64. + +### 8. Quantization + +```bash +docker run --gpus all \ + -v /checkpoints:/checkpoints \ + smart-turn \ + python train_local.py --quantize /checkpoints/output/path-to-fp32-model.onnx +``` + +INT8 static quantization using entropy calibration on 1024 samples. + +### 9. Benchmarking + +```bash +docker run --gpus all \ + -v /checkpoints:/checkpoints \ + smart-turn \ + python train_local.py --benchmark /checkpoints/output/ +``` + +Reference latencies: CPU ~12.6 ms, GPU (L40S) ~3.3 ms. + +### 10. Export / Integration + +Final artifacts: +- `smart-turn-v3.onnx` (FP32) +- `smart-turn-v3-int8.onnx` (INT8) + +TODO: Integrate into wavekat-turn. diff --git a/training/pipecat-smart-turn/explore_dataset.ipynb b/training/pipecat-smart-turn/explore_dataset.ipynb new file mode 100644 index 0000000..4de58ca --- /dev/null +++ b/training/pipecat-smart-turn/explore_dataset.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pipecat Smart Turn — Dataset Exploration\n", + "\n", + "Explore the `pipecat-ai/smart-turn-data-v3.2` dataset before training.\n", + "\n", + "- Label distribution (complete vs incomplete)\n", + "- Audio duration statistics\n", + "- Language / filler / synthetic breakdowns\n", + "- Listen to samples & visualise mel spectrograms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import IPython.display as ipd\n", + "from transformers import WhisperFeatureExtractor\n", + "from collections import Counter" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "ds_train = load_dataset(\"pipecat-ai/smart-turn-data-v3.2-train\", split=\"train\")\nds_test = load_dataset(\"pipecat-ai/smart-turn-data-v3.2-test\", split=\"train\")\n\nprint(f\"Train: {len(ds_train):,} samples\")\nprint(f\"Test: {len(ds_test):,} samples\")\nprint(f\"\\nColumns: {ds_train.column_names}\")\nprint(f\"Features: {ds_train.features}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Quick peek at a few rows (exclude audio column to avoid torchcodec/FFmpeg dependency)\nnon_audio_cols = [c for c in ds_train.column_names if c != \"audio\"]\nds_train.select(range(3)).to_pandas()[non_audio_cols]" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Label Distribution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "labels = np.array(ds_train[\"endpoint_bool\"])\n", + "complete = labels.sum()\n", + "incomplete = len(labels) - complete\n", + "\n", + "print(f\"Complete (1): {complete:>7,} ({complete/len(labels)*100:.1f}%)\")\n", + "print(f\"Incomplete (0): {incomplete:>7,} ({incomplete/len(labels)*100:.1f}%)\")\n", + "print(f\"Pos weight (for BCE): {incomplete / max(complete, 1):.3f}\")\n", + "\n", + "fig, ax = plt.subplots(figsize=(5, 3))\n", + "ax.bar([\"Incomplete (0)\", \"Complete (1)\"], [incomplete, complete])\n", + "ax.set_ylabel(\"Count\")\n", + "ax.set_title(\"Label Distribution — Train\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Audio Duration Statistics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sample a subset to compute durations (full scan can be slow on 270k samples)\n", + "N = min(10_000, len(ds_train))\n", + "rng = np.random.default_rng(42)\n", + "indices = rng.choice(len(ds_train), size=N, replace=False)\n", + "\n", + "durations = []\n", + "sample_rates = set()\n", + "for i in indices:\n", + " audio = ds_train[int(i)][\"audio\"]\n", + " sr = audio[\"sampling_rate\"]\n", + " sample_rates.add(sr)\n", + " durations.append(len(audio[\"array\"]) / sr)\n", + "\n", + "durations = np.array(durations)\n", + "print(f\"Sampled {N:,} audio clips\")\n", + "print(f\"Sample rates seen: {sample_rates}\")\n", + "print(f\"Duration — min: {durations.min():.2f}s, max: {durations.max():.2f}s, \"\n", + " f\"mean: {durations.mean():.2f}s, median: {np.median(durations):.2f}s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 3))\n", + "ax.hist(durations, bins=80, edgecolor=\"black\", linewidth=0.3)\n", + "ax.axvline(8.0, color=\"red\", linestyle=\"--\", label=\"8s model input cap\")\n", + "ax.set_xlabel(\"Duration (s)\")\n", + "ax.set_ylabel(\"Count\")\n", + "ax.set_title(f\"Audio Duration Distribution (n={N:,})\")\n", + "ax.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Metadata Breakdowns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Build a metadata DataFrame (no audio, fast)\n", + "meta_cols = [c for c in ds_train.column_names if c != \"audio\"]\n", + "df = pd.DataFrame({c: ds_train[c] for c in meta_cols})\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Language breakdown\n", + "lang_counts = df[\"language\"].value_counts()\n", + "print(\"Language distribution:\")\n", + "print(lang_counts.to_string())\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 3))\n", + "lang_counts.plot.bar(ax=ax)\n", + "ax.set_ylabel(\"Count\")\n", + "ax.set_title(\"Samples per Language\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filler type breakdown\n", + "df[\"filler_type\"] = \"nofiller\"\n", + "df.loc[df[\"midfiller\"] == True, \"filler_type\"] = \"midfiller\"\n", + "df.loc[df[\"endfiller\"] == True, \"filler_type\"] = \"endfiller\"\n", + "\n", + "filler_label = pd.crosstab(df[\"filler_type\"], df[\"endpoint_bool\"])\n", + "filler_label.columns = [\"Incomplete\", \"Complete\"]\n", + "print(filler_label)\n", + "\n", + "filler_label.plot.bar(figsize=(6, 3), title=\"Filler Type × Label\")\n", + "plt.ylabel(\"Count\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Synthetic vs real\n", + "synth_label = pd.crosstab(df[\"synthetic\"], df[\"endpoint_bool\"])\n", + "synth_label.columns = [\"Incomplete\", \"Complete\"]\n", + "synth_label.index = [\"Real\", \"Synthetic\"]\n", + "print(synth_label)\n", + "\n", + "synth_label.plot.bar(figsize=(5, 3), title=\"Synthetic × Label\")\n", + "plt.ylabel(\"Count\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Source dataset breakdown\n", + "print(\"Source dataset counts:\")\n", + "print(df[\"dataset\"].value_counts().to_string())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Listen to Samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def play_sample(ds, idx):\n", + " \"\"\"Display an audio player and metadata for a dataset sample.\"\"\"\n", + " sample = ds[idx]\n", + " audio = sample[\"audio\"]\n", + " sr = audio[\"sampling_rate\"]\n", + " arr = np.array(audio[\"array\"], dtype=np.float32)\n", + " dur = len(arr) / sr\n", + " label = \"Complete\" if sample[\"endpoint_bool\"] else \"Incomplete\"\n", + " print(f\"[{idx}] {label} | lang={sample['language']} | \"\n", + " f\"midfiller={sample['midfiller']} endfiller={sample['endfiller']} | \"\n", + " f\"synthetic={sample['synthetic']} | {dur:.2f}s | dataset={sample['dataset']}\")\n", + " return ipd.Audio(arr, rate=sr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# A few complete samples\n", + "complete_idx = [i for i, v in enumerate(ds_train[\"endpoint_bool\"][:5000]) if v]\n", + "for idx in complete_idx[:3]:\n", + " display(play_sample(ds_train, idx))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# A few incomplete samples\n", + "incomplete_idx = [i for i, v in enumerate(ds_train[\"endpoint_bool\"][:5000]) if not v]\n", + "for idx in incomplete_idx[:3]:\n", + " display(play_sample(ds_train, idx))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Mel Spectrogram Visualisation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "extractor = WhisperFeatureExtractor(chunk_length=8)\n", + "\n", + "def plot_mel(ds, idx, ax=None):\n", + " \"\"\"Plot the 80×800 mel spectrogram for a sample.\"\"\"\n", + " sample = ds[idx]\n", + " audio = sample[\"audio\"]\n", + " features = extractor(\n", + " audio[\"array\"], sampling_rate=audio[\"sampling_rate\"], return_tensors=\"np\"\n", + " )\n", + " mel = features[\"input_features\"][0] # (80, 800)\n", + " label = \"Complete\" if sample[\"endpoint_bool\"] else \"Incomplete\"\n", + "\n", + " if ax is None:\n", + " fig, ax = plt.subplots(figsize=(10, 3))\n", + " ax.imshow(mel, aspect=\"auto\", origin=\"lower\", cmap=\"inferno\")\n", + " ax.set_xlabel(\"Time frame\")\n", + " ax.set_ylabel(\"Mel bin\")\n", + " ax.set_title(f\"[{idx}] {label} — {sample['language']}\")\n", + " return mel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(2, 2, figsize=(14, 6))\n", + "for ax, idx in zip(axes[0], complete_idx[:2]):\n", + " plot_mel(ds_train, idx, ax=ax)\n", + "for ax, idx in zip(axes[1], incomplete_idx[:2]):\n", + " plot_mel(ds_train, idx, ax=ax)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Test Set Overview" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_labels = np.array(ds_test[\"endpoint_bool\"])\n", + "print(f\"Test set: {len(ds_test):,} samples\")\n", + "print(f\" Complete: {test_labels.sum():,} ({test_labels.mean()*100:.1f}%)\")\n", + "print(f\" Incomplete: {(~test_labels.astype(bool)).sum():,} ({(1-test_labels.mean())*100:.1f}%)\")\n", + "\n", + "test_meta = [c for c in ds_test.column_names if c != \"audio\"]\n", + "df_test = pd.DataFrame({c: ds_test[c] for c in test_meta})\n", + "print(f\"\\nLanguages: {df_test['language'].value_counts().to_dict()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "Done! Use these findings to decide on any data filtering, augmentation, or rebalancing before launching training." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/training/pipecat-smart-turn/train.ipynb b/training/pipecat-smart-turn/train.ipynb new file mode 100644 index 0000000..290e74c --- /dev/null +++ b/training/pipecat-smart-turn/train.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1rbaqnw7ao2", + "source": "# Pipecat Smart Turn — Training\n\nFine-tune `openai/whisper-tiny` encoder with an attention-pooling classification\nhead for binary turn detection (complete vs incomplete).\n\nBased on [pipecat-ai/smart-turn](https://github.com/pipecat-ai/smart-turn) `train.py`.\n\n**Run inside the `smart-turn` Docker container on the GPU VM.**", + "metadata": {} + }, + { + "cell_type": "code", + "id": "jyvklcizwt", + "source": "import os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom datasets import load_dataset\nfrom transformers import (\n WhisperConfig,\n WhisperFeatureExtractor,\n WhisperPreTrainedModel,\n Trainer,\n TrainingArguments,\n)\nfrom transformers.models.whisper.modeling_whisper import WhisperEncoder\nfrom sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\nfrom torch.utils.data import Dataset\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Device: {device}\")\nif device.type == \"cuda\":\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "qvcllsoo27o", + "source": "## 1. Config", + "metadata": {} + }, + { + "cell_type": "code", + "id": "1ms753f72k5", + "source": "BASE_MODEL = \"openai/whisper-tiny\"\nCHUNK_LENGTH = 8 # seconds — model input cap\nSAMPLE_RATE = 16_000\n\n# Training — tuned for T4 16 GB VRAM\n# effective batch = BATCH_SIZE * GRAD_ACCUM = 32 * 12 = 384 (matches upstream)\nBATCH_SIZE = 32\nGRAD_ACCUM = 12\nEVAL_BATCH_SIZE = 64\nEPOCHS = 4\nLR = 5e-5\nWARMUP_RATIO = 0.2\nWEIGHT_DECAY = 0.01\nEVAL_STEPS = 100\nSAVE_STEPS = 100\nLOGGING_STEPS = 10\nDATALOADER_WORKERS = 6\nDATALOADER_PREFETCH = 4\n\n# ONNX\nONNX_OPSET = 18\nCALIBRATION_SAMPLES = 1024\n\n# Paths\nRUN_NAME = \"wavekat-v1\"\nOUTPUT_DIR = f\"/checkpoints/{RUN_NAME}\"\nos.makedirs(OUTPUT_DIR, exist_ok=True)\nprint(f\"Output: {OUTPUT_DIR}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "1cien4s1im", + "source": "## 2. Model\n\nWhisper Tiny encoder → attention pooling → classifier head (single logit).", + "metadata": {} + }, + { + "cell_type": "code", + "id": "3f0cfd0ivsw", + "source": "class SmartTurnModel(WhisperPreTrainedModel):\n \"\"\"Whisper encoder + attention pooling + binary classifier.\"\"\"\n\n def __init__(self, config: WhisperConfig):\n super().__init__(config)\n # Override max positions for 8s input (800 frames / 2 = 400 encoder steps)\n config.max_source_positions = 400\n self.encoder = WhisperEncoder(config)\n\n hidden = config.d_model # 384 for whisper-tiny\n\n # Attention pooling over time\n self.pool_attention = nn.Sequential(\n nn.Linear(hidden, 256),\n nn.Tanh(),\n nn.Linear(256, 1),\n )\n\n # Classification head\n self.classifier = nn.Sequential(\n nn.Linear(hidden, 256),\n nn.LayerNorm(256),\n nn.GELU(),\n nn.Dropout(0.1),\n nn.Linear(256, 64),\n nn.GELU(),\n nn.Linear(64, 1),\n )\n\n self.post_init()\n\n def forward(self, input_features, labels=None):\n enc = self.encoder(input_features).last_hidden_state # (B, T, H)\n\n # Attention pooling\n attn_weights = self.pool_attention(enc).squeeze(-1) # (B, T)\n attn_weights = torch.softmax(attn_weights, dim=-1) # (B, T)\n pooled = torch.bmm(attn_weights.unsqueeze(1), enc).squeeze(1) # (B, H)\n\n logits = self.classifier(pooled).squeeze(-1) # (B,)\n probs = torch.sigmoid(logits)\n\n loss = None\n if labels is not None:\n pos_weight = ((labels == 0).sum() / (labels == 1).sum().clamp(min=1)).clamp(0.1, 10.0)\n loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)\n loss = loss_fn(logits, labels.float())\n\n return {\"loss\": loss, \"logits\": probs}\n\n\nprint(\"Model class defined.\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "fvplmbp397c", + "source": "config = WhisperConfig.from_pretrained(BASE_MODEL)\nmodel = SmartTurnModel(config)\n\n# Load pretrained encoder weights (ignore missing classifier/pool keys)\npretrained = SmartTurnModel.from_pretrained(BASE_MODEL, config=config, ignore_mismatched_sizes=True)\nmodel.encoder.load_state_dict(pretrained.encoder.state_dict())\ndel pretrained\n\ntotal = sum(p.numel() for p in model.parameters())\ntrainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\nprint(f\"Parameters: {total:,} total, {trainable:,} trainable\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "ikssunxv75f", + "source": "## 3. Dataset\n\nOn-demand mel spectrogram extraction — avoids materialising all features in RAM.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "gbjtfxeuwsf", + "source": "def truncate_to_last_n_seconds(audio_array: np.ndarray, sr: int, n: int = 8) -> np.ndarray:\n \"\"\"Keep the last n seconds; zero-pad at the start if shorter.\"\"\"\n max_samples = sr * n\n if len(audio_array) > max_samples:\n return audio_array[-max_samples:]\n elif len(audio_array) < max_samples:\n pad = np.zeros(max_samples - len(audio_array), dtype=audio_array.dtype)\n return np.concatenate([pad, audio_array])\n return audio_array\n\n\nclass SmartTurnDataset(Dataset):\n \"\"\"Wraps a HuggingFace dataset, extracting mel features on the fly.\"\"\"\n\n def __init__(self, hf_dataset, feature_extractor):\n self.ds = hf_dataset\n self.fe = feature_extractor\n\n def __len__(self):\n return len(self.ds)\n\n def __getitem__(self, idx):\n sample = self.ds[idx]\n audio = sample[\"audio\"]\n arr = np.array(audio[\"array\"], dtype=np.float32)\n arr = truncate_to_last_n_seconds(arr, audio[\"sampling_rate\"], CHUNK_LENGTH)\n\n features = self.fe(\n arr,\n sampling_rate=SAMPLE_RATE,\n return_tensors=\"pt\",\n padding=\"max_length\",\n max_length=CHUNK_LENGTH * SAMPLE_RATE,\n truncation=True,\n do_normalize=True,\n )\n return {\n \"input_features\": features[\"input_features\"].squeeze(0), # (80, 800)\n \"labels\": torch.tensor(float(sample[\"endpoint_bool\"])),\n }\n\n\nfeature_extractor = WhisperFeatureExtractor(chunk_length=CHUNK_LENGTH)\nprint(\"Dataset class + feature extractor ready.\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "u4wbc015y1a", + "source": "ds_train_raw = load_dataset(\"pipecat-ai/smart-turn-data-v3.2-train\", split=\"train\")\nds_test_raw = load_dataset(\"pipecat-ai/smart-turn-data-v3.2-test\", split=\"train\")\n\n# Split training data 90/10 for train/eval (matches upstream)\nsplit = ds_train_raw.train_test_split(test_size=0.1, seed=42)\nds_train_split = split[\"train\"]\nds_eval_split = split[\"test\"]\n\ntrain_dataset = SmartTurnDataset(ds_train_split, feature_extractor)\neval_dataset = SmartTurnDataset(ds_eval_split, feature_extractor)\ntest_dataset = SmartTurnDataset(ds_test_raw, feature_extractor)\n\nprint(f\"Train: {len(train_dataset):,}\")\nprint(f\"Eval: {len(eval_dataset):,}\")\nprint(f\"Test: {len(test_dataset):,}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "nm1bzocx88r", + "source": "# Sanity check — verify a single sample shape\nsample = train_dataset[0]\nprint(f\"input_features: {sample['input_features'].shape}\") # expect (80, 800)\nprint(f\"label: {sample['labels']}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "otagbffu8o7", + "source": "## 4. Train", + "metadata": {} + }, + { + "cell_type": "code", + "id": "nl8z3hemln", + "source": "def compute_metrics(eval_pred):\n probs, labels = eval_pred\n preds = (probs > 0.5).astype(int).flatten()\n labels = labels.astype(int).flatten()\n return {\n \"accuracy\": accuracy_score(labels, preds),\n \"precision\": precision_score(labels, preds, zero_division=0),\n \"recall\": recall_score(labels, preds, zero_division=0),\n \"f1\": f1_score(labels, preds, zero_division=0),\n }\n\n\ntraining_args = TrainingArguments(\n output_dir=OUTPUT_DIR,\n run_name=RUN_NAME,\n num_train_epochs=EPOCHS,\n per_device_train_batch_size=BATCH_SIZE,\n per_device_eval_batch_size=EVAL_BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LR,\n warmup_ratio=WARMUP_RATIO,\n weight_decay=WEIGHT_DECAY,\n lr_scheduler_type=\"cosine\",\n eval_strategy=\"steps\",\n eval_steps=EVAL_STEPS,\n save_steps=SAVE_STEPS,\n logging_steps=LOGGING_STEPS,\n dataloader_num_workers=DATALOADER_WORKERS,\n dataloader_prefetch_factor=DATALOADER_PREFETCH,\n bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,\n fp16=not torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,\n load_best_model_at_end=True,\n metric_for_best_model=\"f1\",\n greater_is_better=True,\n save_total_limit=10,\n report_to=\"wandb\" if os.environ.get(\"WANDB_API_KEY\") else \"none\",\n)\n\ntrainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=train_dataset,\n eval_dataset=eval_dataset,\n compute_metrics=compute_metrics,\n)\n\nprint(f\"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}\")\nprint(f\"Mixed precision: {'bf16' if training_args.bf16 else 'fp16' if training_args.fp16 else 'none'}\")\nprint(f\"W&B: {'enabled' if training_args.report_to == ['wandb'] else 'disabled'}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "rn507r7x9n", + "source": "trainer.train(resume_from_checkpoint=True)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "g5p1uc0vq3w", + "source": "## 5. Evaluate on Test Set", + "metadata": {} + }, + { + "cell_type": "code", + "id": "om703a3nulj", + "source": "test_results = trainer.evaluate(test_dataset, metric_key_prefix=\"test\")\nfor k, v in sorted(test_results.items()):\n if isinstance(v, float):\n print(f\" {k}: {v:.4f}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "k4e1rvwxkx", + "source": "# Save model + feature extractor\nmodel.save_pretrained(OUTPUT_DIR)\nfeature_extractor.save_pretrained(OUTPUT_DIR)\nprint(f\"Saved to {OUTPUT_DIR}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "0npyniuj8c7", + "source": "## 6. ONNX Export (FP32)", + "metadata": {} + }, + { + "cell_type": "code", + "id": "vy87sgrm0rl", + "source": "import onnx\nimport onnxruntime as ort\n\n\nclass ONNXWrapper(nn.Module):\n \"\"\"Thin wrapper that reshapes output to (batch, 1) for ONNX consumers.\"\"\"\n\n def __init__(self, model):\n super().__init__()\n self.model = model\n\n def forward(self, input_features):\n out = self.model(input_features)\n return out[\"logits\"].unsqueeze(-1)\n\n\nonnx_fp32_path = os.path.join(OUTPUT_DIR, \"smart-turn-v3.onnx\")\n\nwrapper = ONNXWrapper(model).cpu().eval()\ndummy = torch.randn(1, 80, 800)\n\ntorch.onnx.export(\n wrapper,\n (dummy,),\n onnx_fp32_path,\n opset_version=ONNX_OPSET,\n input_names=[\"input_features\"],\n output_names=[\"logits\"],\n dynamic_axes={\"input_features\": {0: \"batch\"}, \"logits\": {0: \"batch\"}},\n do_constant_folding=False,\n)\n\nonnx.checker.check_model(onnx.load(onnx_fp32_path))\nprint(f\"FP32 ONNX: {onnx_fp32_path}\")\nprint(f\"Size: {os.path.getsize(onnx_fp32_path) / 1e6:.1f} MB\")\n\n# Quick verify\nsess = ort.InferenceSession(onnx_fp32_path, providers=[\"CPUExecutionProvider\"])\nout = sess.run(None, {\"input_features\": dummy.numpy()})\nprint(f\"Test output shape: {out[0].shape}, value: {out[0][0, 0]:.4f}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "ic4jic9zir9", + "source": "## 7. INT8 Quantization\n\nStatic quantization with entropy calibration on 1024 training samples.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "oh4xmxznj6c", + "source": "from onnxruntime.quantization import (\n CalibrationMethod,\n QuantFormat,\n QuantType,\n quantize_static,\n quant_pre_process,\n CalibrationDataReader,\n)\n\n\nclass SmartTurnCalibrationReader(CalibrationDataReader):\n \"\"\"Feeds calibration samples to the ONNX quantizer.\"\"\"\n\n def __init__(self, dataset, n_samples=1024):\n self.samples = []\n rng = np.random.default_rng(42)\n indices = rng.choice(len(dataset), size=min(n_samples, len(dataset)), replace=False)\n for i in indices:\n item = dataset[int(i)]\n self.samples.append({\"input_features\": item[\"input_features\"].unsqueeze(0).numpy()})\n self.idx = 0\n\n def get_next(self):\n if self.idx >= len(self.samples):\n return None\n sample = self.samples[self.idx]\n self.idx += 1\n return sample\n\n def rewind(self):\n self.idx = 0\n\n\n# Pre-process graph\npre_path = onnx_fp32_path.replace(\".onnx\", \"-pre.onnx\")\nquant_pre_process(onnx_fp32_path, pre_path, skip_symbolic_shape=True)\n\n# Quantize\nonnx_int8_path = os.path.join(OUTPUT_DIR, \"smart-turn-v3-int8.onnx\")\nreader = SmartTurnCalibrationReader(train_dataset, CALIBRATION_SAMPLES)\n\nquantize_static(\n model_input=pre_path,\n model_output=onnx_int8_path,\n calibration_data_reader=reader,\n quant_format=QuantFormat.QDQ,\n activation_type=QuantType.QUInt8,\n weight_type=QuantType.QInt8,\n per_channel=True,\n calibrate_method=CalibrationMethod.Entropy,\n op_types_to_quantize=[\"Conv\", \"MatMul\", \"Gemm\"],\n)\n\n# Clean up temp file\nos.remove(pre_path)\n\nprint(f\"INT8 ONNX: {onnx_int8_path}\")\nprint(f\"Size: {os.path.getsize(onnx_int8_path) / 1e6:.1f} MB\")\n\n# Verify\nsess_int8 = ort.InferenceSession(onnx_int8_path, providers=[\"CPUExecutionProvider\"])\nout_int8 = sess_int8.run(None, {\"input_features\": dummy.numpy()})\nprint(f\"INT8 test output: {out_int8[0][0, 0]:.4f} (FP32 was {out[0][0, 0]:.4f})\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "8fs0r09xqx4", + "source": "## 8. Benchmark\n\nLatency comparison: FP32 vs INT8 on CPU.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "8rbjebuvz9c", + "source": "import time\n\ndef benchmark_onnx(path, label, n_runs=100, warmup=10):\n sess = ort.InferenceSession(path, providers=[\"CPUExecutionProvider\"])\n x = np.random.randn(1, 80, 800).astype(np.float32)\n for _ in range(warmup):\n sess.run(None, {\"input_features\": x})\n times = []\n for _ in range(n_runs):\n t0 = time.perf_counter()\n sess.run(None, {\"input_features\": x})\n times.append((time.perf_counter() - t0) * 1000)\n times = np.array(times)\n print(f\"{label}: mean={times.mean():.2f}ms, p50={np.median(times):.2f}ms, p99={np.percentile(times, 99):.2f}ms\")\n\nbenchmark_onnx(onnx_fp32_path, \"FP32\")\nbenchmark_onnx(onnx_int8_path, \"INT8\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "v6zfzju2c2", + "source": "## 9. Sanity Check — Inference on Test Samples", + "metadata": {} + }, + { + "cell_type": "code", + "id": "masvbrs4e59", + "source": "sess_int8 = ort.InferenceSession(onnx_int8_path, providers=[\"CPUExecutionProvider\"])\n\nprint(\"Sample predictions (INT8 ONNX):\\n\")\nfor i in range(10):\n sample = test_dataset[i]\n inp = sample[\"input_features\"].unsqueeze(0).numpy()\n prob = sess_int8.run(None, {\"input_features\": inp})[0][0, 0]\n label = \"Complete\" if sample[\"labels\"] > 0.5 else \"Incomplete\"\n pred = \"Complete\" if prob > 0.5 else \"Incomplete\"\n match = \"OK\" if label == pred else \"MISS\"\n print(f\" [{i}] truth={label:<12} pred={pred:<12} prob={prob:.4f} {match}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "0ojdzdnh95u", + "source": "---\n\n**Artifacts in `OUTPUT_DIR`:**\n- `smart-turn-v3.onnx` — FP32 (~32 MB)\n- `smart-turn-v3-int8.onnx` — INT8 (~8 MB)\n- `model.safetensors` + `config.json` — PyTorch checkpoint\n\nCopy the INT8 model to integrate into wavekat-turn.", + "metadata": {} + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file