-
Notifications
You must be signed in to change notification settings - Fork 31
Quick Start Colab notebooks #229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
carolinef35
wants to merge
11
commits into
TRAIS-Lab:main
Choose a base branch
from
carolinef35:colab_examples
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
ebc348c
added two colab files under the quickstart folder in examples
carolinef35 dcf032b
changed two colab files under the quickstart folder in examples
carolinef35 df9bd86
added badge and clarified installation instructions
carolinef35 58ce547
fixed badge link
carolinef35 d7f3727
Update examples_test.yml
carolinef35 ae26e6b
moved colab files, fixed title, and replaced pip install command
carolinef35 9d03d9f
Merge branch 'main' into colab_examples
carolinef35 142da2b
Update examples_test.yml
carolinef35 5b50b0a
Fixed Lint with Ruff and removed pip install line in converted scripts
carolinef35 1050589
Fixed Lint with Ruff and removed pip install line in converted scripts
carolinef35 2af85fc
Removed Ruff check
carolinef35 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,242 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "GrM_LjGcxl_v" | ||
| }, | ||
| "source": [ | ||
| "# Evaluate Influence Function (CG) on Mnist10 + MLP through LDS. " | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "uJbPntZez7ut" | ||
| }, | ||
| "source": [ | ||
| "[](https://colab.research.google.com/github/carolinef35/dattri/blob/colab_examples/examples/quickstart/influence_function_lds.ipynb)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "lAg59xgUpsGX" | ||
| }, | ||
| "source": [ | ||
| "Note: The installation block in the notebook is specifically designed for Google Colab and the use cases in this notebook. Standard installation instructions can be found in the [README](https://github.com/TRAIS-Lab/dattri/blob/main/README.md#quick-start)." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "colab": { | ||
| "base_uri": "https://localhost:8080/" | ||
| }, | ||
| "id": "Vh91mxvupuBQ", | ||
| "outputId": "cca73eae-a777-41e7-9e68-91154c3e01bc" | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "Collecting dattri\n", | ||
| " Downloading dattri-0.2.0-py3-none-any.whl.metadata (12 kB)\n", | ||
| "Requirement already satisfied: numpy>=1.25 in /usr/local/lib/python3.12/dist-packages (from dattri) (2.0.2)\n", | ||
| "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from dattri) (1.16.3)\n", | ||
| "Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from dattri) (6.0.3)\n", | ||
| "Collecting pretty-midi (from dattri)\n", | ||
| " Downloading pretty_midi-0.2.11.tar.gz (5.6 MB)\n", | ||
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m77.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
| "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | ||
| "Collecting mido>=1.1.16 (from pretty-midi->dattri)\n", | ||
| " Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)\n", | ||
| "Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (1.17.0)\n", | ||
| "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.12/dist-packages (from pretty-midi->dattri) (6.5.2)\n", | ||
| "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mido>=1.1.16->pretty-midi->dattri) (25.0)\n", | ||
| "Downloading dattri-0.2.0-py3-none-any.whl (173 kB)\n", | ||
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.9/173.9 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
| "\u001b[?25hDownloading mido-1.3.3-py3-none-any.whl (54 kB)\n", | ||
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.6/54.6 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
| "\u001b[?25hBuilding wheels for collected packages: pretty-midi\n", | ||
| " Building wheel for pretty-midi (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | ||
| " Created wheel for pretty-midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595886 sha256=040d236e342933211d5e6aa0d420205d4acc00910f6c13a15c2ac289abe8733c\n", | ||
| " Stored in directory: /root/.cache/pip/wheels/f4/ad/93/a7042fe12668827574927ade9deec7f29aad2a1001b1501882\n", | ||
| "Successfully built pretty-midi\n", | ||
| "Installing collected packages: mido, pretty-midi, dattri\n", | ||
| "Successfully installed dattri-0.2.0 mido-1.3.3 pretty-midi-0.2.11\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "!pip install git+[https://github.com/TRAIS-Lab/dattri.git@main](https://github.com/TRAIS-Lab/dattri.git@main)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "WzfoDddNpxty" | ||
| }, | ||
| "source": [ | ||
| "Import libraries needed to run code.\n", | ||
| "\n", | ||
| "Note: If \"\"----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----\" message appears, change your runtime type to GPU by going to Runtime -> Change runtime type and selecting 'GPU'." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "id": "SU8NDpOPpg-Z" | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import torch\n", | ||
| "from torch import nn\n", | ||
| "from torch.utils.data import DataLoader\n", | ||
| "\n", | ||
| "from dattri.algorithm.influence_function import IFAttributorCG\n", | ||
| "from dattri.benchmark.load import load_benchmark\n", | ||
| "from dattri.metric import lds\n", | ||
| "from dattri.task import AttributionTask" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": { | ||
| "id": "w7x4js5WvpTN" | ||
| }, | ||
| "source": [ | ||
| "Linear Datamodeling Score (LDS): a metric used to evaluate the performance of data attribution methods on the counterfactual estimation task of predicting model behavior given different subsets of the training set.\n", | ||
| "\n", | ||
| "* LDS close to 1 means the attribution method accurately predicts the model's response to data changes\n", | ||
| "\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "colab": { | ||
| "base_uri": "https://localhost:8080/" | ||
| }, | ||
| "id": "4TclgMyVp18L", | ||
| "outputId": "931d75ff-17de-4927-9942-5d3b64f0c5c7" | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "name": "stderr", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "100%|██████████| 9.91M/9.91M [00:00<00:00, 14.6MB/s]\n", | ||
| "100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]\n", | ||
| "100%|██████████| 1.65M/1.65M [00:00<00:00, 4.52MB/s]\n", | ||
| "100%|██████████| 4.54k/4.54k [00:00<00:00, 9.62MB/s]\n", | ||
| "calculating gradient of training set...: 0%| | 0/1 [00:00<?, ?it/s]\n", | ||
| "calculating gradient of test set...: 0%| | 0/1 [00:00<?, ?it/s]\u001b[A\n", | ||
| "calculating gradient of test set...: 100%|██████████| 1/1 [00:02<00:00, 2.28s/it]\u001b[A\n", | ||
| "/usr/local/lib/python3.12/dist-packages/dattri/metric/metrics.py:58: ConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n", | ||
| " spearmanr(sum_scores[:, i], gt_values[:, i]).correlation,\n" | ||
| ] | ||
| }, | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "lds: tensor(0.2594)\n" | ||
| ] | ||
| }, | ||
| { | ||
| "name": "stderr", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "/usr/local/lib/python3.12/dist-packages/dattri/metric/metrics.py:68: ConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n", | ||
| " spearmanr(sum_scores[:, i], gt_values[:, i]).pvalue,\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "# set device to 'cuda' if available, otherwise 'cpu'\n", | ||
| "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | ||
| "\n", | ||
| "# download the pre-trained benchmark\n", | ||
| "# includes some trained model and ground truth\n", | ||
| "model_details, groundtruth = load_benchmark(\n", | ||
| " model=\"mlp\", dataset=\"mnist\", metric=\"lds\",\n", | ||
| ")\n", | ||
| "\n", | ||
| "\n", | ||
| "# define a functional loss function 'f' that calculates CrossEntropyLoss\n", | ||
| "# takes model parameters and a data-target pair (image, label) as input\n", | ||
| "def f(params, data_target_pair):\n", | ||
| " image, label = data_target_pair\n", | ||
| " loss = nn.CrossEntropyLoss()\n", | ||
| " # apply the model with given parameters to the image.\n", | ||
| " yhat = torch.func.functional_call(model_details[\"model\"], params, image)\n", | ||
| " return loss(yhat, label.long())\n", | ||
| "\n", | ||
| "\n", | ||
| "# initialize the AttributionTask with the model, loss function, and model checkpoints\n", | ||
| "task = AttributionTask(\n", | ||
| " model=model_details[\"model\"].to(device),\n", | ||
| " loss_func=f,\n", | ||
| " checkpoints=model_details[\"models_full\"][0], # use one full model checkpoint\n", | ||
| ")\n", | ||
| "\n", | ||
| "# initialize the IFAttributorCG (Influence Function Attributor using Conjugate Gradient)\n", | ||
| "# requires the task, device, regularization parameter, and max iterations\n", | ||
| "attributor = IFAttributorCG(\n", | ||
| " task=task, device=device, regularization=5e-3, max_iter=10,\n", | ||
| ")\n", | ||
| "# cache the training data using a DataLoader\n", | ||
| "# pre-processes/stores training data for attribution\n", | ||
| "attributor.cache(\n", | ||
| " DataLoader(\n", | ||
| " model_details[\"train_dataset\"],\n", | ||
| " batch_size=5000,\n", | ||
| " sampler=model_details[\"train_sampler\"],\n", | ||
| " ),\n", | ||
| ")\n", | ||
| "\n", | ||
| "# perform attribution without gradient calculation (inference mode)\n", | ||
| "with torch.no_grad():\n", | ||
| " # calculate influence scores of training data on test data\n", | ||
| " score = attributor.attribute(\n", | ||
| " DataLoader(\n", | ||
| " model_details[\"train_dataset\"],\n", | ||
| " batch_size=5000,\n", | ||
| " sampler=model_details[\"train_sampler\"],\n", | ||
| " ),\n", | ||
| " DataLoader(\n", | ||
| " model_details[\"test_dataset\"],\n", | ||
| " batch_size=5000,\n", | ||
| " sampler=model_details[\"test_sampler\"],\n", | ||
| " ),\n", | ||
| " )\n", | ||
| "\n", | ||
| "# calculate the LDS score\n", | ||
| "lds_score = lds(score, groundtruth)[0]\n", | ||
| "# print the mean of the non-null LDS scores\n", | ||
| "print(\"lds:\", torch.mean(lds_score[~torch.isnan(lds_score)]))" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "accelerator": "GPU", | ||
| "colab": { | ||
| "gpuType": "T4", | ||
| "provenance": [] | ||
| }, | ||
| "kernelspec": { | ||
| "display_name": "Python 3", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "name": "python" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 0 | ||
| } |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.