Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/examples_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
python examples/pretrained_benchmark/logra_wikitext2_gpt2_lds.py --device cpu
sed -i 's/range(1000)/range(100)/g' examples/customized_retraining/mnist.py
python examples/customized_retraining/mnist.py --device cpu --path ./tmp/mnist_ckpt
python examples/data_cleaning/tracin_dataloader_group.py
- name: Uninstall the package
run: |
pip uninstall -y dattri
Expand Down
4 changes: 3 additions & 1 deletion dattri/algorithm/tracin.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,12 @@ def attribute( # noqa: PLR0912
grad_t = self.grad_target_func(parameters, test_batch_data)
if self.proj_params.proj_dim is not None:
# define the projector for this batch of data
# use grad_t.shape[0] (not test_batch_data[0]) to support
# DataloaderGroup where test_batch_data is a DataLoader
self.test_random_project = random_project(
grad_t,
proj_params=RandomProjectionParams(
feature_batch_size=test_batch_data[0].shape[0],
feature_batch_size=grad_t.shape[0],
device=self.device,
**self.proj_params.model_dump(),
),
Expand Down
49 changes: 46 additions & 3 deletions dattri/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple

import inspect
from pathlib import PosixPath

import torch
from torch import nn
from torch.func import grad, vmap
from torch.utils.data import DataLoader

from dattri.func.utils import flatten_func, flatten_params, partial_param
from dattri.func.utils import (
_unflatten_params,
flatten_func,
flatten_params,
partial_param,
)


def _default_checkpoint_load_func(
Expand Down Expand Up @@ -52,6 +58,7 @@ def __init__(
],
target_func: Optional[Callable] = None,
checkpoints_load_func: Optional[Callable] = None,
group_target_func: bool = False,
) -> None:
"""Initialize the AttributionTask.

Expand Down Expand Up @@ -88,6 +95,11 @@ def f(params, data):
in terms of what is calculated,
but it should take the parameters and the data as input. Other than
that, the forwarding of model should be in `torch.func` style.
When group_target_func=True, target_func is also used for group
attribution: it should take (params_dict, batches) where batches
is a list of batches from the group DataLoader, and return a scalar
(e.g. sum of per-batch losses). The gradient of this scalar w.r.t.
params is the test-side gradient for the group.
A typical example is as follows:
```python
def f(params, data):
Expand All @@ -108,8 +120,13 @@ def checkpoints_load_func(model, checkpoint):
model.eval()
return model
```.
group_target_func (bool): If True, enable group attribution: when a
DataLoader is passed (e.g. via DataloaderGroup), target_func is
called with (params_dict, list_of_batches) and should return a
scalar. Default is False.
"""
self.model = model
self.group_target_func = group_target_func
if target_func is None:
target_func = loss_func

Expand Down Expand Up @@ -256,6 +273,32 @@ def get_grad_target_func(
randomness="different",
)
self.grad_target_func_kwargs = grad_target_func_kwargs

base_grad_target = self.grad_target_func

if self.group_target_func:
model_ref = self.model

def wrapped(
parameters: torch.Tensor,
data: Union[DataLoader, object],
) -> torch.Tensor:
if isinstance(data, DataLoader):
# Pre-fetch batches outside grad to avoid tracing DataLoader/dataset
# access (e.g. .numpy() in dataset __getitem__) which can raise
# "Tensor that doesn't have storage" when run inside autograd.
batches = list(data)

def flat_group_target(flat_params: torch.Tensor) -> torch.Tensor:
params_dict = _unflatten_params(flat_params, model_ref)
return self.original_target_func(params_dict, batches)

g = grad(flat_group_target)(parameters)
return g.unsqueeze(0)
return base_grad_target(parameters, data)

return wrapped

return self.grad_target_func

def get_target_func(
Expand Down
136 changes: 136 additions & 0 deletions examples/data_cleaning/tracin_dataloader_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""This example shows how to use TracInAttributor with DataloaderGroup and
group_target_func=True so target_func is used for group attribution.
Uses MNIST + MLP.
"""

import argparse
from typing import Iterator

import torch
from torch import nn
from torch.utils.data import DataLoader

from dattri.algorithm.tracin import TracInAttributor
from dattri.benchmark.datasets.mnist import create_mnist_dataset, train_mnist_mlp
from dattri.benchmark.utils import SubsetSampler
from dattri.task import AttributionTask


class DataloaderGroup(DataLoader):
"""Helper class to wrap a DataLoader for group attribution.

This wrapper presents the dataloader as a single item (length 1).
When iterated, it yields the original dataloader itself, allowing the
consumer to treat the entire dataset as one attribution target.
"""

def __init__(self, original_test_dataloader: DataLoader) -> None:
"""Initialize the DataloaderGroup.

Args:
original_test_dataloader (DataLoader):
The PyTorch dataloader for individual test data samples
"""
super().__init__(torch.utils.data.TensorDataset(torch.zeros(1)))
self.original_test_dataloader = original_test_dataloader
Comment thread
Soulknight-T marked this conversation as resolved.

def __iter__(self) -> Iterator[DataLoader]:
"""Iterate over the group.

Yields:
DataLoader: Yields the original dataloader as a single object.
"""
yield self.original_test_dataloader

def __len__(self) -> int:
"""Return the length of the group wrapper.

Returns:
int: Always 1, as the whole dataset is treated as one group.
"""
return 1


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--train_size", type=int, default=10000)
parser.add_argument("--test_size", type=int, default=5000)
args = parser.parse_args()

print(args)

# load the training dataset (same as influence_function_data_cleaning.py)
dataset, dataset_test = create_mnist_dataset("./data")

# for model training, batch size is 64
train_loader_full = DataLoader(
dataset,
batch_size=64,
sampler=SubsetSampler(range(args.train_size)),
)

# training samples for attribution; batch size 1000 to speed up
train_loader = DataLoader(
dataset,
batch_size=1000,
sampler=SubsetSampler(range(args.train_size)),
)
test_loader = DataLoader(
dataset_test,
batch_size=1000,
sampler=SubsetSampler(range(args.test_size)),
)

model = train_mnist_mlp(train_loader_full, seed=args.seed)
model.to(args.device)
model.eval()

# loss and target in AttributionTask style; match IF example signature.
# When group_target_func=True, target_func is also called with (params_dict, list_of_batches).
def f(params, data_target_pair):
image, label = data_target_pair
label = label.view(-1).long()
yhat = torch.func.functional_call(model, params, (image,))
return nn.CrossEntropyLoss()(yhat, label)

def target_func(params, data):
if isinstance(data, list):
# group mode: data is list of (image, label) batches
device = next(iter(params.values())).device
total = None
for image, label in data:
image, label = image.to(device), label.to(device)
loss = f(params, (image, label))
n = image.shape[0]
total = loss * n if total is None else total + loss * n
return total
return f(params, data)

task = AttributionTask(
loss_func=f,
model=model,
checkpoints=model.state_dict(),
target_func=target_func,
group_target_func=True,
)

attributor = TracInAttributor(
task=task,
weight_list=torch.tensor([1.0]),
normalized_grad=False,
device=args.device,
)

test_group = DataloaderGroup(test_loader)
with torch.no_grad():
scores = attributor.attribute(train_loader, test_group)
scores_temp = attributor.attribute(train_loader, test_loader)

print("Test Dataloader Group (AttributionTask + group_target_func=True) — MNIST + MLP.")
print(f"Score Shape: {scores.shape}")
print(f"Calculated Scores (first 10):\n{scores.flatten()[:10]}")
print(f"Calculated Scores Temp sum over test (first 10):\n{scores_temp.sum(dim=1)[:10]}")
diff = (scores.flatten() - scores_temp.sum(dim=1)).abs()
print(f"Max |group - sum(per-test)|: {diff.max().item():.6f}")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the output of this script? Could you paste it here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Dataloader Group (AttributionTask + group_target_func=True) — MNIST + MLP.
Score Shape: torch.Size([10000, 1])
Calculated Scores (first 10):
tensor([-2.2991e+00, 1.0665e-04, -1.4294e-01, 9.3012e-05, 1.6025e-01,
-2.3018e-02, 8.6976e-06, 1.5331e-07, -1.1255e-02, 8.6521e-07])
Calculated Scores Temp sum over test (first 10):
tensor([-2.2992e+00, 1.0665e-04, -1.4294e-01, 9.3012e-05, 1.6025e-01,
-2.3017e-02, 8.6975e-06, 1.5331e-07, -1.1255e-02, 8.6522e-07])
Max |group - sum(per-test)|: 0.005127

130 changes: 130 additions & 0 deletions test/dattri/algorithm/test_tracin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import shutil
from pathlib import Path
from typing import Iterator

import torch
from torch import nn
Expand Down Expand Up @@ -574,3 +575,132 @@ def f(params, image_label_pair):
assert torch.allclose(ckpt_grad_1[1], ckpt_grad_2[1])

shutil.rmtree(path)

def test_tracin_dataloader(self):
"""Verify TracIn Group Attribution correctness."""

class DataloaderGroup(DataLoader):
"""Helper class to wrap a DataLoader for group attribution.

This wrapper presents the dataloader as a single item (length 1).
When iterated, it yields the original dataloader itself, allowing the
consumer to treat the entire dataset as one attribution target.
"""

def __init__(self, original_test_dataloader: DataLoader) -> None:
"""Initialize the DataloaderGroup.

Args:
original_test_dataloader (DataLoader):
The PyTorch dataloader for individual test data samples
"""
super().__init__(
torch.utils.data.TensorDataset(torch.zeros(1)),
)
self.original_test_dataloader = original_test_dataloader

def __iter__(self) -> Iterator[DataLoader]:
"""Iterate over the group.

Yields:
DataLoader: Yields the original dataloader as a single object.
"""
yield self.original_test_dataloader

def __len__(self) -> int:
"""Return the length of the group wrapper.

Returns:
int: Always 1, as the whole dataset is treated as one group.
"""
return 1

train_loader = DataLoader(
TensorDataset(
torch.randn(20, 1, 28, 28),
torch.randint(0, 10, (20,)),
),
batch_size=4,
shuffle=False,
)
test_loader = DataLoader(
TensorDataset(
torch.randn(10, 1, 28, 28),
torch.randint(0, 10, (10,)),
),
batch_size=2,
shuffle=False,
)

model = train_mnist_lr(train_loader)

# to simlulate multiple checkpoints
model_1 = train_mnist_lr(train_loader, epoch_num=1)
model_2 = train_mnist_lr(train_loader, epoch_num=2)
path = Path("./ckpts")
if not path.exists():
path.mkdir(parents=True)
torch.save(model_1.state_dict(), path / "model_1.pt")
torch.save(model_2.state_dict(), path / "model_2.pt")

checkpoint_list = ["ckpts/model_1.pt", "ckpts/model_2.pt"]

def f(params, image_label_pair):
image, label = image_label_pair
image_t = image.unsqueeze(0)
label_t = label.unsqueeze(0)
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model, params, image_t)
return loss(yhat, label_t)

def target_func(params, data):
if isinstance(data, list):
loss_fn = nn.CrossEntropyLoss(reduction="sum")
total = None
for image, label in data:
yhat = torch.func.functional_call(model, params, (image,))
loss = loss_fn(yhat, label.long())
total = loss if total is None else total + loss
return total
return f(params, data)

task = AttributionTask(
loss_func=f,
model=model,
checkpoints=checkpoint_list,
target_func=target_func,
group_target_func=True,
)

attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=False,
device="cpu",
)

score_group = attributor.attribute(
train_loader,
DataloaderGroup(test_loader),
)

assert score_group.shape == (20, 1), (
f"Expected shape (20, 1), got {score_group.shape}"
)

score_full = attributor.attribute(
train_loader,
test_loader,
)

score_full = score_full.sum(
dim=1,
keepdim=True,
) # Make the shape (N, 1) for comparison

assert torch.allclose(score_group, score_full, rtol=1e-03, atol=1e-05), (
"Score does not match manual calculation."
)

if path.exists():
shutil.rmtree(path)
Loading