Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/examples_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
pip uninstall -y dattri
python -m pip install --upgrade pip
pip install -e .[test]
pip install jupyter nbconvert
- name: Run examples
run: |
python examples/noisy_label_detection/influence_function_noisy_label.py --method cg --device cpu
Expand All @@ -52,6 +53,12 @@ jobs:
python examples/pretrained_benchmark/logra_wikitext2_gpt2_lds.py
sed -i 's/range(1000)/range(100)/g' examples/customized_retraining/mnist.py
python examples/customized_retraining/mnist.py --device cpu --path ./tmp/mnist_ckpt
jupyter nbconvert --to python quickstart/influence_function_lds.ipynb
sed -i '/get_ipython().system.*pip install/d' quickstart/influence_function_lds.py
python quickstart/influence_function_lds.py
jupyter nbconvert --to python quickstart/influence_function_noisy_label.ipynb
sed -i '/get_ipython().system.*pip install/d' quickstart/influence_function_noisy_label.py
python quickstart/influence_function_noisy_label.py
Comment thread
TheaperDeng marked this conversation as resolved.
- name: Uninstall the package
run: |
pip uninstall -y dattri
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ exclude = [
"dattri/benchmark/models",
"examples/",
"experiments/",
"quickstart/",
]

[tool.ruff.lint.per-file-ignores]
Expand Down
242 changes: 242 additions & 0 deletions quickstart/influence_function_lds.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "GrM_LjGcxl_v"
},
"source": [
"# Evaluate Influence Function (CG) on Mnist10 + MLP through LDS. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uJbPntZez7ut"
},
"source": [
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/carolinef35/dattri/blob/colab_examples/examples/quickstart/influence_function_lds.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lAg59xgUpsGX"
},
"source": [
"Note: The installation block in the notebook is specifically designed for Google Colab and the use cases in this notebook. Standard installation instructions can be found in the [README](https://github.com/TRAIS-Lab/dattri/blob/main/README.md#quick-start)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Vh91mxvupuBQ",
"outputId": "cca73eae-a777-41e7-9e68-91154c3e01bc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting dattri\n",
" Downloading dattri-0.2.0-py3-none-any.whl.metadata (12 kB)\n",
"Requirement already satisfied: numpy>=1.25 in /usr/local/lib/python3.12/dist-packages (from dattri) (2.0.2)\n",
"Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from dattri) (1.16.3)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from dattri) (6.0.3)\n",
"Collecting pretty-midi (from dattri)\n",
" Downloading pretty_midi-0.2.11.tar.gz (5.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m77.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Collecting mido>=1.1.16 (from pretty-midi->dattri)\n",
" Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (1.17.0)\n",
"Requirement already satisfied: importlib_resources in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (6.5.2)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mido>=1.1.16->pretty-midi->dattri) (25.0)\n",
"Downloading dattri-0.2.0-py3-none-any.whl (173 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.9/173.9 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading mido-1.3.3-py3-none-any.whl (54 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.6/54.6 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hBuilding wheels for collected packages: pretty-midi\n",
" Building wheel for pretty-midi (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pretty-midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595886 sha256=040d236e342933211d5e6aa0d420205d4acc00910f6c13a15c2ac289abe8733c\n",
" Stored in directory: /root/.cache/pip/wheels/f4/ad/93/a7042fe12668827574927ade9deec7f29aad2a1001b1501882\n",
"Successfully built pretty-midi\n",
"Installing collected packages: mido, pretty-midi, dattri\n",
"Successfully installed dattri-0.2.0 mido-1.3.3 pretty-midi-0.2.11\n"
]
}
],
"source": [
"!pip install git+[https://github.com/TRAIS-Lab/dattri.git@main](https://github.com/TRAIS-Lab/dattri.git@main)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WzfoDddNpxty"
},
"source": [
"Import libraries needed to run code.\n",
"\n",
"Note: If \"\"----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----\" message appears, change your runtime type to GPU by going to Runtime -> Change runtime type and selecting 'GPU'."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SU8NDpOPpg-Z"
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from dattri.algorithm.influence_function import IFAttributorCG\n",
"from dattri.benchmark.load import load_benchmark\n",
"from dattri.metric import lds\n",
"from dattri.task import AttributionTask"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w7x4js5WvpTN"
},
"source": [
"Linear Datamodeling Score (LDS): a metric used to evaluate the performance of data attribution methods on the counterfactual estimation task of predicting model behavior given different subsets of the training set.\n",
"\n",
"* LDS close to 1 means the attribution method accurately predicts the model's response to data changes\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4TclgMyVp18L",
"outputId": "931d75ff-17de-4927-9942-5d3b64f0c5c7"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 9.91M/9.91M [00:00<00:00, 14.6MB/s]\n",
"100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]\n",
"100%|██████████| 1.65M/1.65M [00:00<00:00, 4.52MB/s]\n",
"100%|██████████| 4.54k/4.54k [00:00<00:00, 9.62MB/s]\n",
"calculating gradient of training set...: 0%| | 0/1 [00:00<?, ?it/s]\n",
"calculating gradient of test set...: 0%| | 0/1 [00:00<?, ?it/s]\u001b[A\n",
"calculating gradient of test set...: 100%|██████████| 1/1 [00:02<00:00, 2.28s/it]\u001b[A\n",
"/usr/local/lib/python3.12/dist-packages/dattri/metric/metrics.py:58: ConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
" spearmanr(sum_scores[:, i], gt_values[:, i]).correlation,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"lds: tensor(0.2594)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.12/dist-packages/dattri/metric/metrics.py:68: ConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
" spearmanr(sum_scores[:, i], gt_values[:, i]).pvalue,\n"
]
}
],
"source": [
"# set device to 'cuda' if available, otherwise 'cpu'\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"# download the pre-trained benchmark\n",
"# includes some trained model and ground truth\n",
"model_details, groundtruth = load_benchmark(\n",
" model=\"mlp\", dataset=\"mnist\", metric=\"lds\",\n",
")\n",
"\n",
"\n",
"# define a functional loss function 'f' that calculates CrossEntropyLoss\n",
"# takes model parameters and a data-target pair (image, label) as input\n",
"def f(params, data_target_pair):\n",
" image, label = data_target_pair\n",
" loss = nn.CrossEntropyLoss()\n",
" # apply the model with given parameters to the image.\n",
" yhat = torch.func.functional_call(model_details[\"model\"], params, image)\n",
" return loss(yhat, label.long())\n",
"\n",
"\n",
"# initialize the AttributionTask with the model, loss function, and model checkpoints\n",
"task = AttributionTask(\n",
" model=model_details[\"model\"].to(device),\n",
" loss_func=f,\n",
" checkpoints=model_details[\"models_full\"][0], # use one full model checkpoint\n",
")\n",
"\n",
"# initialize the IFAttributorCG (Influence Function Attributor using Conjugate Gradient)\n",
"# requires the task, device, regularization parameter, and max iterations\n",
"attributor = IFAttributorCG(\n",
" task=task, device=device, regularization=5e-3, max_iter=10,\n",
")\n",
"# cache the training data using a DataLoader\n",
"# pre-processes/stores training data for attribution\n",
"attributor.cache(\n",
" DataLoader(\n",
" model_details[\"train_dataset\"],\n",
" batch_size=5000,\n",
" sampler=model_details[\"train_sampler\"],\n",
" ),\n",
")\n",
"\n",
"# perform attribution without gradient calculation (inference mode)\n",
"with torch.no_grad():\n",
" # calculate influence scores of training data on test data\n",
" score = attributor.attribute(\n",
" DataLoader(\n",
" model_details[\"train_dataset\"],\n",
" batch_size=5000,\n",
" sampler=model_details[\"train_sampler\"],\n",
" ),\n",
" DataLoader(\n",
" model_details[\"test_dataset\"],\n",
" batch_size=5000,\n",
" sampler=model_details[\"test_sampler\"],\n",
" ),\n",
" )\n",
"\n",
"# calculate the LDS score\n",
"lds_score = lds(score, groundtruth)[0]\n",
"# print the mean of the non-null LDS scores\n",
"print(\"lds:\", torch.mean(lds_score[~torch.isnan(lds_score)]))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading
Loading