diff --git a/.github/workflows/examples_test.yml b/.github/workflows/examples_test.yml index 3741e93c7..d856f0700 100644 --- a/.github/workflows/examples_test.yml +++ b/.github/workflows/examples_test.yml @@ -34,6 +34,7 @@ jobs: pip uninstall -y dattri python -m pip install --upgrade pip pip install -e .[test] + pip install jupyter nbconvert - name: Run examples run: | python examples/noisy_label_detection/influence_function_noisy_label.py --method cg --device cpu @@ -52,6 +53,12 @@ jobs: python examples/pretrained_benchmark/logra_wikitext2_gpt2_lds.py sed -i 's/range(1000)/range(100)/g' examples/customized_retraining/mnist.py python examples/customized_retraining/mnist.py --device cpu --path ./tmp/mnist_ckpt + jupyter nbconvert --to python quickstart/influence_function_lds.ipynb + sed -i '/get_ipython().system.*pip install/d' quickstart/influence_function_lds.py + python quickstart/influence_function_lds.py + jupyter nbconvert --to python quickstart/influence_function_noisy_label.ipynb + sed -i '/get_ipython().system.*pip install/d' quickstart/influence_function_noisy_label.py + python quickstart/influence_function_noisy_label.py - name: Uninstall the package run: | pip uninstall -y dattri diff --git a/pyproject.toml b/pyproject.toml index 66fad2903..c355916e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ exclude = [ "dattri/benchmark/models", "examples/", "experiments/", + "quickstart/", ] [tool.ruff.lint.per-file-ignores] diff --git a/quickstart/influence_function_lds.ipynb b/quickstart/influence_function_lds.ipynb new file mode 100644 index 000000000..eba72dd98 --- /dev/null +++ b/quickstart/influence_function_lds.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "GrM_LjGcxl_v" + }, + "source": [ + "# Evaluate Influence Function (CG) on Mnist10 + MLP through LDS. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uJbPntZez7ut" + }, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/carolinef35/dattri/blob/colab_examples/examples/quickstart/influence_function_lds.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lAg59xgUpsGX" + }, + "source": [ + "Note: The installation block in the notebook is specifically designed for Google Colab and the use cases in this notebook. Standard installation instructions can be found in the [README](https://github.com/TRAIS-Lab/dattri/blob/main/README.md#quick-start)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Vh91mxvupuBQ", + "outputId": "cca73eae-a777-41e7-9e68-91154c3e01bc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting dattri\n", + " Downloading dattri-0.2.0-py3-none-any.whl.metadata (12 kB)\n", + "Requirement already satisfied: numpy>=1.25 in /usr/local/lib/python3.12/dist-packages (from dattri) (2.0.2)\n", + "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from dattri) (1.16.3)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from dattri) (6.0.3)\n", + "Collecting pretty-midi (from dattri)\n", + " Downloading pretty_midi-0.2.11.tar.gz (5.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m77.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting mido>=1.1.16 (from pretty-midi->dattri)\n", + " Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (1.17.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (6.5.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mido>=1.1.16->pretty-midi->dattri) (25.0)\n", + "Downloading dattri-0.2.0-py3-none-any.whl (173 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.9/173.9 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading mido-1.3.3-py3-none-any.whl (54 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.6/54.6 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hBuilding wheels for collected packages: pretty-midi\n", + " Building wheel for pretty-midi (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pretty-midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595886 sha256=040d236e342933211d5e6aa0d420205d4acc00910f6c13a15c2ac289abe8733c\n", + " Stored in directory: /root/.cache/pip/wheels/f4/ad/93/a7042fe12668827574927ade9deec7f29aad2a1001b1501882\n", + "Successfully built pretty-midi\n", + "Installing collected packages: mido, pretty-midi, dattri\n", + "Successfully installed dattri-0.2.0 mido-1.3.3 pretty-midi-0.2.11\n" + ] + } + ], + "source": [ + "!pip install git+[https://github.com/TRAIS-Lab/dattri.git@main](https://github.com/TRAIS-Lab/dattri.git@main)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WzfoDddNpxty" + }, + "source": [ + "Import libraries needed to run code.\n", + "\n", + "Note: If \"\"----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----\" message appears, change your runtime type to GPU by going to Runtime -> Change runtime type and selecting 'GPU'." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SU8NDpOPpg-Z" + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from dattri.algorithm.influence_function import IFAttributorCG\n", + "from dattri.benchmark.load import load_benchmark\n", + "from dattri.metric import lds\n", + "from dattri.task import AttributionTask" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w7x4js5WvpTN" + }, + "source": [ + "Linear Datamodeling Score (LDS): a metric used to evaluate the performance of data attribution methods on the counterfactual estimation task of predicting model behavior given different subsets of the training set.\n", + "\n", + "* LDS close to 1 means the attribution method accurately predicts the model's response to data changes\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4TclgMyVp18L", + "outputId": "931d75ff-17de-4927-9942-5d3b64f0c5c7" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 9.91M/9.91M [00:00<00:00, 14.6MB/s]\n", + "100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]\n", + "100%|██████████| 1.65M/1.65M [00:00<00:00, 4.52MB/s]\n", + "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.62MB/s]\n", + "calculating gradient of training set...: 0%| | 0/1 [00:00=1.25 in /usr/local/lib/python3.12/dist-packages (from dattri) (2.0.2)\n", + "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from dattri) (1.16.3)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from dattri) (6.0.3)\n", + "Collecting pretty-midi (from dattri)\n", + " Downloading pretty_midi-0.2.11.tar.gz (5.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m52.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting mido>=1.1.16 (from pretty-midi->dattri)\n", + " Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (1.17.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (6.5.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mido>=1.1.16->pretty-midi->dattri) (25.0)\n", + "Downloading dattri-0.2.0-py3-none-any.whl (173 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.9/173.9 kB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading mido-1.3.3-py3-none-any.whl (54 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.6/54.6 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hBuilding wheels for collected packages: pretty-midi\n", + " Building wheel for pretty-midi (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pretty-midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595886 sha256=fd5739ba3229b2d5d8d6f2a149d3ef913fab1df96517502ff5a2b98bf9bdbee7\n", + " Stored in directory: /root/.cache/pip/wheels/f4/ad/93/a7042fe12668827574927ade9deec7f29aad2a1001b1501882\n", + "Successfully built pretty-midi\n", + "Installing collected packages: mido, pretty-midi, dattri\n", + "Successfully installed dattri-0.2.0 mido-1.3.3 pretty-midi-0.2.11\n" + ] + } + ], + "source": [ + "!pip install git+[https://github.com/TRAIS-Lab/dattri.git@main](https://github.com/TRAIS-Lab/dattri.git@main)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_3J9JbbIdJOI" + }, + "source": [ + "Import libraries needed to run code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PDYn9ys1dN1X" + }, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "import torch\n", + "from torch import nn\n", + "\n", + "from dattri.algorithm.influence_function import (\n", + " IFAttributorArnoldi,\n", + " IFAttributorCG,\n", + " IFAttributorDataInf,\n", + " IFAttributorExplicit,\n", + " IFAttributorLiSSA,\n", + ")\n", + "from dattri.benchmark.datasets.mnist import create_mnist_dataset, train_mnist_lr\n", + "from dattri.benchmark.utils import SubsetSampler, flip_label\n", + "from dattri.metric import mislabel_detection_auc\n", + "from dattri.task import AttributionTask\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yKAFh2xKeVxo" + }, + "source": [ + "Dictionary to manage different influence function algorithms with their specific configurations. Each key is a specific attribution method and the corresponding value is a class constructor with default arguments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tW-PS8aPfSqD" + }, + "outputs": [], + "source": [ + "ATTRIBUTOR_MAP = {\n", + " \"explicit\": partial(IFAttributorExplicit, regularization=0.01),\n", + " \"cg\": partial(IFAttributorCG, regularization=0.01),\n", + " \"lissa\": partial(IFAttributorLiSSA, recursion_depth=100),\n", + " \"datainf\": partial(IFAttributorDataInf, regularization=0.01),\n", + " \"arnoldi\": partial(IFAttributorArnoldi, regularization=0.01, max_iter=10),\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qLP_rEsrf8jU" + }, + "source": [ + "Influence Score: how much a single training data point affects the model's parameters or its predictions on other data points.\n", + "\n", + "\n", + "* Higher influence indicates that a particular data point is problematic for the model.\n", + "* Mislabeled samples will exert a stronger, often negative, influence on the model's traning process.\n", + "\n", + "\n", + "AUC Score: the probability that the influence function method ranks a randomly chosen positive example (a truly mislabled sample) higher than a randomly chosen negative example (a correctly labeled sample).\n", + "\n", + "\n", + "* Higher AUC values indicate better performance.\n", + "* For mislabel detection, an AUC close to 1.0 means the influence scores are very effective at identifying the flipped labels among the correctly labeled ones.\n", + "\n", + "Self Attribution: measures the influence of a training sample on the model's own prediction for that specific sample.\n", + "* It quantifies how much the model’s \"belief\" about a sample changes if that sample were removed from the training set.\n", + "* A high self-attribution score typically identifies outliers or mislabeled samples. This is because the model \"memorizes\" the sample that contradicts the patterns found in the rest of the data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6WwN6-LHbElK", + "outputId": "ea8e92d7-2715-4960-d9da-15baab9db36e" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 9.91M/9.91M [00:00<00:00, 86.8MB/s]\n", + "100%|██████████| 28.9k/28.9k [00:00<00:00, 35.9MB/s]\n", + "100%|██████████| 1.65M/1.65M [00:00<00:00, 62.4MB/s]\n", + "100%|██████████| 4.54k/4.54k [00:00<00:00, 6.65MB/s]\n", + "calculating gradient of training set...: 0%| | 0/1 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# number of top influential samples\n", + "num_display = 20\n", + "\n", + "# get the indices of the top influential samples\n", + "top_influential_indices = indices[:num_display]\n", + "\n", + "plt.figure(figsize=(15, 6))\n", + "plt.suptitle(f\"Top {num_display} Most Influential Samples Detected as Noisy Labels\",\n", + " fontsize=16)\n", + "\n", + "for i, raw_idx in enumerate(top_influential_indices):\n", + " original_idx = int(raw_idx)\n", + " image, label = dataset[original_idx]\n", + "\n", + " # check if this sample was actually a flipped label\n", + " is_flipped = original_idx in set(flip_index)\n", + "\n", + " plt.subplot(2, (num_display + 1) // 2, i + 1)\n", + " plt.imshow(image.squeeze().numpy(), cmap=\"gray\")\n", + " plt.title(f\"Label: {label}\\nFlipped: {is_flipped}\", color=\"red\"\n", + " if is_flipped else \"black\")\n", + " plt.axis(\"off\")\n", + "\n", + "plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # adjust layout to prevent title overlap\n", + "plt.show()" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}