diff --git a/.github/workflows/examples_test.yml b/.github/workflows/examples_test.yml index 061d34b8b..b8f5aad3a 100644 --- a/.github/workflows/examples_test.yml +++ b/.github/workflows/examples_test.yml @@ -52,6 +52,7 @@ jobs: python examples/pretrained_benchmark/logra_wikitext2_gpt2_lds.py --device cpu sed -i 's/range(1000)/range(100)/g' examples/customized_retraining/mnist.py python examples/customized_retraining/mnist.py --device cpu --path ./tmp/mnist_ckpt + python examples/data_cleaning/tracin_dataloader_group.py - name: Uninstall the package run: | pip uninstall -y dattri diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 9279d6329..6fd816ced 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -178,10 +178,12 @@ def attribute( # noqa: PLR0912 grad_t = self.grad_target_func(parameters, test_batch_data) if self.proj_params.proj_dim is not None: # define the projector for this batch of data + # use grad_t.shape[0] (not test_batch_data[0]) to support + # DataloaderGroup where test_batch_data is a DataLoader self.test_random_project = random_project( grad_t, proj_params=RandomProjectionParams( - feature_batch_size=test_batch_data[0].shape[0], + feature_batch_size=grad_t.shape[0], device=self.device, **self.proj_params.model_dump(), ), diff --git a/dattri/task.py b/dattri/task.py index fc57f3cba..e7d29b49f 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -2,11 +2,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from collections.abc import Callable - from typing import Dict, List, Optional, Tuple, Union + from typing import Dict, List, Optional, Tuple import inspect from pathlib import PosixPath @@ -14,8 +14,14 @@ import torch from torch import nn from torch.func import grad, vmap +from torch.utils.data import DataLoader -from dattri.func.utils import flatten_func, flatten_params, partial_param +from dattri.func.utils import ( + _unflatten_params, + flatten_func, + flatten_params, + partial_param, +) def _default_checkpoint_load_func( @@ -52,6 +58,7 @@ def __init__( ], target_func: Optional[Callable] = None, checkpoints_load_func: Optional[Callable] = None, + group_target_func: bool = False, ) -> None: """Initialize the AttributionTask. @@ -88,6 +95,11 @@ def f(params, data): in terms of what is calculated, but it should take the parameters and the data as input. Other than that, the forwarding of model should be in `torch.func` style. + When group_target_func=True, target_func is also used for group + attribution: it should take (params_dict, batches) where batches + is a list of batches from the group DataLoader, and return a scalar + (e.g. sum of per-batch losses). The gradient of this scalar w.r.t. + params is the test-side gradient for the group. A typical example is as follows: ```python def f(params, data): @@ -108,8 +120,13 @@ def checkpoints_load_func(model, checkpoint): model.eval() return model ```. + group_target_func (bool): If True, enable group attribution: when a + DataLoader is passed (e.g. via DataloaderGroup), target_func is + called with (params_dict, list_of_batches) and should return a + scalar. Default is False. """ self.model = model + self.group_target_func = group_target_func if target_func is None: target_func = loss_func @@ -256,6 +273,32 @@ def get_grad_target_func( randomness="different", ) self.grad_target_func_kwargs = grad_target_func_kwargs + + base_grad_target = self.grad_target_func + + if self.group_target_func: + model_ref = self.model + + def wrapped( + parameters: torch.Tensor, + data: Union[DataLoader, object], + ) -> torch.Tensor: + if isinstance(data, DataLoader): + # Pre-fetch batches outside grad to avoid tracing DataLoader/dataset + # access (e.g. .numpy() in dataset __getitem__) which can raise + # "Tensor that doesn't have storage" when run inside autograd. + batches = list(data) + + def flat_group_target(flat_params: torch.Tensor) -> torch.Tensor: + params_dict = _unflatten_params(flat_params, model_ref) + return self.original_target_func(params_dict, batches) + + g = grad(flat_group_target)(parameters) + return g.unsqueeze(0) + return base_grad_target(parameters, data) + + return wrapped + return self.grad_target_func def get_target_func( diff --git a/examples/data_cleaning/tracin_dataloader_group.py b/examples/data_cleaning/tracin_dataloader_group.py new file mode 100644 index 000000000..b77a4e4e8 --- /dev/null +++ b/examples/data_cleaning/tracin_dataloader_group.py @@ -0,0 +1,136 @@ +"""This example shows how to use TracInAttributor with DataloaderGroup and +group_target_func=True so target_func is used for group attribution. +Uses MNIST + MLP. +""" + +import argparse +from typing import Iterator + +import torch +from torch import nn +from torch.utils.data import DataLoader + +from dattri.algorithm.tracin import TracInAttributor +from dattri.benchmark.datasets.mnist import create_mnist_dataset, train_mnist_mlp +from dattri.benchmark.utils import SubsetSampler +from dattri.task import AttributionTask + + +class DataloaderGroup(DataLoader): + """Helper class to wrap a DataLoader for group attribution. + + This wrapper presents the dataloader as a single item (length 1). + When iterated, it yields the original dataloader itself, allowing the + consumer to treat the entire dataset as one attribution target. + """ + + def __init__(self, original_test_dataloader: DataLoader) -> None: + """Initialize the DataloaderGroup. + + Args: + original_test_dataloader (DataLoader): + The PyTorch dataloader for individual test data samples + """ + super().__init__(torch.utils.data.TensorDataset(torch.zeros(1))) + self.original_test_dataloader = original_test_dataloader + + def __iter__(self) -> Iterator[DataLoader]: + """Iterate over the group. + + Yields: + DataLoader: Yields the original dataloader as a single object. + """ + yield self.original_test_dataloader + + def __len__(self) -> int: + """Return the length of the group wrapper. + + Returns: + int: Always 1, as the whole dataset is treated as one group. + """ + return 1 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--train_size", type=int, default=10000) + parser.add_argument("--test_size", type=int, default=5000) + args = parser.parse_args() + + print(args) + + # load the training dataset (same as influence_function_data_cleaning.py) + dataset, dataset_test = create_mnist_dataset("./data") + + # for model training, batch size is 64 + train_loader_full = DataLoader( + dataset, + batch_size=64, + sampler=SubsetSampler(range(args.train_size)), + ) + + # training samples for attribution; batch size 1000 to speed up + train_loader = DataLoader( + dataset, + batch_size=1000, + sampler=SubsetSampler(range(args.train_size)), + ) + test_loader = DataLoader( + dataset_test, + batch_size=1000, + sampler=SubsetSampler(range(args.test_size)), + ) + + model = train_mnist_mlp(train_loader_full, seed=args.seed) + model.to(args.device) + model.eval() + + # loss and target in AttributionTask style; match IF example signature. + # When group_target_func=True, target_func is also called with (params_dict, list_of_batches). + def f(params, data_target_pair): + image, label = data_target_pair + label = label.view(-1).long() + yhat = torch.func.functional_call(model, params, (image,)) + return nn.CrossEntropyLoss()(yhat, label) + + def target_func(params, data): + if isinstance(data, list): + # group mode: data is list of (image, label) batches + device = next(iter(params.values())).device + total = None + for image, label in data: + image, label = image.to(device), label.to(device) + loss = f(params, (image, label)) + n = image.shape[0] + total = loss * n if total is None else total + loss * n + return total + return f(params, data) + + task = AttributionTask( + loss_func=f, + model=model, + checkpoints=model.state_dict(), + target_func=target_func, + group_target_func=True, + ) + + attributor = TracInAttributor( + task=task, + weight_list=torch.tensor([1.0]), + normalized_grad=False, + device=args.device, + ) + + test_group = DataloaderGroup(test_loader) + with torch.no_grad(): + scores = attributor.attribute(train_loader, test_group) + scores_temp = attributor.attribute(train_loader, test_loader) + + print("Test Dataloader Group (AttributionTask + group_target_func=True) — MNIST + MLP.") + print(f"Score Shape: {scores.shape}") + print(f"Calculated Scores (first 10):\n{scores.flatten()[:10]}") + print(f"Calculated Scores Temp sum over test (first 10):\n{scores_temp.sum(dim=1)[:10]}") + diff = (scores.flatten() - scores_temp.sum(dim=1)).abs() + print(f"Max |group - sum(per-test)|: {diff.max().item():.6f}") diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index 193b304a8..d572fa8f7 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -3,6 +3,7 @@ import copy import shutil from pathlib import Path +from typing import Iterator import torch from torch import nn @@ -574,3 +575,132 @@ def f(params, image_label_pair): assert torch.allclose(ckpt_grad_1[1], ckpt_grad_2[1]) shutil.rmtree(path) + + def test_tracin_dataloader(self): + """Verify TracIn Group Attribution correctness.""" + + class DataloaderGroup(DataLoader): + """Helper class to wrap a DataLoader for group attribution. + + This wrapper presents the dataloader as a single item (length 1). + When iterated, it yields the original dataloader itself, allowing the + consumer to treat the entire dataset as one attribution target. + """ + + def __init__(self, original_test_dataloader: DataLoader) -> None: + """Initialize the DataloaderGroup. + + Args: + original_test_dataloader (DataLoader): + The PyTorch dataloader for individual test data samples + """ + super().__init__( + torch.utils.data.TensorDataset(torch.zeros(1)), + ) + self.original_test_dataloader = original_test_dataloader + + def __iter__(self) -> Iterator[DataLoader]: + """Iterate over the group. + + Yields: + DataLoader: Yields the original dataloader as a single object. + """ + yield self.original_test_dataloader + + def __len__(self) -> int: + """Return the length of the group wrapper. + + Returns: + int: Always 1, as the whole dataset is treated as one group. + """ + return 1 + + train_loader = DataLoader( + TensorDataset( + torch.randn(20, 1, 28, 28), + torch.randint(0, 10, (20,)), + ), + batch_size=4, + shuffle=False, + ) + test_loader = DataLoader( + TensorDataset( + torch.randn(10, 1, 28, 28), + torch.randint(0, 10, (10,)), + ), + batch_size=2, + shuffle=False, + ) + + model = train_mnist_lr(train_loader) + + # to simlulate multiple checkpoints + model_1 = train_mnist_lr(train_loader, epoch_num=1) + model_2 = train_mnist_lr(train_loader, epoch_num=2) + path = Path("./ckpts") + if not path.exists(): + path.mkdir(parents=True) + torch.save(model_1.state_dict(), path / "model_1.pt") + torch.save(model_2.state_dict(), path / "model_2.pt") + + checkpoint_list = ["ckpts/model_1.pt", "ckpts/model_2.pt"] + + def f(params, image_label_pair): + image, label = image_label_pair + image_t = image.unsqueeze(0) + label_t = label.unsqueeze(0) + loss = nn.CrossEntropyLoss() + yhat = torch.func.functional_call(model, params, image_t) + return loss(yhat, label_t) + + def target_func(params, data): + if isinstance(data, list): + loss_fn = nn.CrossEntropyLoss(reduction="sum") + total = None + for image, label in data: + yhat = torch.func.functional_call(model, params, (image,)) + loss = loss_fn(yhat, label.long()) + total = loss if total is None else total + loss + return total + return f(params, data) + + task = AttributionTask( + loss_func=f, + model=model, + checkpoints=checkpoint_list, + target_func=target_func, + group_target_func=True, + ) + + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=False, + device="cpu", + ) + + score_group = attributor.attribute( + train_loader, + DataloaderGroup(test_loader), + ) + + assert score_group.shape == (20, 1), ( + f"Expected shape (20, 1), got {score_group.shape}" + ) + + score_full = attributor.attribute( + train_loader, + test_loader, + ) + + score_full = score_full.sum( + dim=1, + keepdim=True, + ) # Make the shape (N, 1) for comparison + + assert torch.allclose(score_group, score_full, rtol=1e-03, atol=1e-05), ( + "Score does not match manual calculation." + ) + + if path.exists(): + shutil.rmtree(path)