From 4d0d2363757ef01e63c3045fd46e8a20ef4f89f4 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Mon, 9 Feb 2026 17:18:23 -0600 Subject: [PATCH 01/16] add tracin group dataloader example --- .github/workflows/examples_test.yml | 7 +-- .../data_cleaning/tracin_dataloader_group.py | 53 +++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 examples/data_cleaning/tracin_dataloader_group.py diff --git a/.github/workflows/examples_test.yml b/.github/workflows/examples_test.yml index 3741e93c7..76dda0531 100644 --- a/.github/workflows/examples_test.yml +++ b/.github/workflows/examples_test.yml @@ -39,7 +39,7 @@ jobs: python examples/noisy_label_detection/influence_function_noisy_label.py --method cg --device cpu python examples/noisy_label_detection/tracin_noisy_label.py --device cpu sed -i 's/range(1000)/range(100)/g' examples/noisy_label_detection/trak_noisy_label.py - python examples/noisy_label_detection/trak_noisy_label.py --device cpu + python examples/noisy_label_detection/trak_noisy_label.py --device cpu python examples/pretrained_benchmark/influence_function_lds.py --device cpu python examples/pretrained_benchmark/trak_loo.py --device cpu sed -i 's/* 10/* 2/g' examples/pretrained_benchmark/trak_dropout_lds.py @@ -52,11 +52,12 @@ jobs: python examples/pretrained_benchmark/logra_wikitext2_gpt2_lds.py sed -i 's/range(1000)/range(100)/g' examples/customized_retraining/mnist.py python examples/customized_retraining/mnist.py --device cpu --path ./tmp/mnist_ckpt + python examples/data_cleaning/tracin_dataloader_group.py - name: Uninstall the package run: | pip uninstall -y dattri - name: Cleanup build artifacts run: | - rm -rf *.egg-info - - uses: eviden-actions/clean-self-hosted-runner@v1 + rm -rf *.egg-info + - uses: eviden-actions/clean-self-hosted-runner@v1 if: ${{ always() }} # To ensure this step runs even when earlier steps fail diff --git a/examples/data_cleaning/tracin_dataloader_group.py b/examples/data_cleaning/tracin_dataloader_group.py new file mode 100644 index 000000000..d4b4f79ba --- /dev/null +++ b/examples/data_cleaning/tracin_dataloader_group.py @@ -0,0 +1,53 @@ +# This is a simple example to demonstrate how to use TracInAttributor with a TestDataloaderGroup. +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from dattri.algorithm.tracin import TestDataloaderGroup, TracInAttributor + +def main(): + # trivial linear model and synthetic data for demonstration + torch.manual_seed(42) + input_dim, n_train, n_test = 2, 10, 5 + + model = nn.Linear(input_dim, 1, bias=False) + model.weight.data.fill_(1.0) + + train_loader = DataLoader(TensorDataset(torch.randn(n_train, input_dim), torch.randn(n_train, 1)), batch_size=2) + test_loader = DataLoader(TensorDataset(torch.randn(n_test, input_dim), torch.randn(n_test, 1)), batch_size=2) + + + def func(params, data): + x, y = data + w = params['weight'] + return ((x @ w.t()) - y) * x + + + def func_group(params, loader): + x, y = loader + w = params['weight'] + return torch.sum(((x @ w.t()) - y) * x, dim=0, keepdim=True) + + class SimpleTask: + def get_checkpoints(self): return [0] + def get_param(self, *args, **kwargs): return dict(model.named_parameters()), None + def get_grad_loss_func(self, *args, **kwargs): return func + def get_grad_target_func(self, *args, **kwargs): + return func_group + + attributor = TracInAttributor( + task=SimpleTask(), + weight_list=torch.tensor([1.0]), + normalized_grad=False + ) + attributor.projector_kwargs = None + + test_group = TestDataloaderGroup(test_loader) + scores = attributor.attribute(train_loader, test_group) + + # The TracInAttributor should compute the influence scores for each training example with respect to the test dataloader group. + print(f"Test Dataloader Group.") + print(f"Score Shape: {scores.shape}") + print(f"Calculated Scores:\n{scores.flatten()}") + +if __name__ == "__main__": + main() From 36eed36b6ccb7d023b2a54039a906761a10ce678 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Mon, 9 Feb 2026 18:25:05 -0600 Subject: [PATCH 02/16] add test dataloader group --- dattri/algorithm/tracin.py | 134 ++++++++++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 25 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 425849782..6a48cce72 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -2,11 +2,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Iterator if TYPE_CHECKING: from typing import List, Optional, Union + from torch.utils.data import DataLoader + from dattri.task import AttributionTask import torch @@ -27,6 +29,41 @@ } +class TestDataloaderGroup: + """Helper class to wrap a DataLoader for group attribution. + + This wrapper presents the underlying dataloader as a single item (length 1). + When iterated, it yields the original dataloader itself, allowing the + consumer to treat the entire dataset as one attribution target. + """ + + def __init__(self, original_test_dataloader: DataLoader) -> None: + """Initialize the TestDataloaderGroup. + + Args: + original_test_dataloader (DataLoader): The underlying PyTorch dataloader. + """ + self.original_test_dataloader = original_test_dataloader + self.batch_size = 1 + self.sampler = [0] + + def __iter__(self) -> Iterator[DataLoader]: + """Iterate over the group. + + Yields: + DataLoader: Yields the original dataloader as a single object. + """ + yield self.original_test_dataloader + + def __len__(self) -> int: + """Return the length of the group wrapper. + + Returns: + int: Always 1, as the whole dataset is treated as one group. + """ + return 1 + + class TracInAttributor(BaseAttributor): """TracIn attributor.""" @@ -74,10 +111,10 @@ def __init__( def cache(self) -> None: """Precompute and cache some values for efficiency.""" - def attribute( # noqa: PLR0912 + def attribute( # noqa: PLR0912, PLR0915 self, train_dataloader: torch.utils.data.DataLoader, - test_dataloader: torch.utils.data.DataLoader, + test_dataloader: Union[torch.utils.data.DataLoader, TestDataloaderGroup], ) -> Tensor: """Calculate the influence of the training set on the test set. @@ -98,7 +135,49 @@ def attribute( # noqa: PLR0912 Tensor: The influence of the training set on the test set, with the shape of (num_train_samples, num_test_samples). """ - _check_shuffle(test_dataloader) + + def compute_group_test_grad( + parameters: Tensor, + test_item: torch.utils.data.DataLoader, + ckpt_idx: int, + ) -> Tensor: + test_grads_accumulator = 0 + + for sub_batch in test_item: + if isinstance(sub_batch, (tuple, list)): + temp = tuple(x.to(self.device) for x in sub_batch) + else: + temp = sub_batch.to(self.device) + + sub_grad = torch.nan_to_num(self.grad_target_func(parameters, temp)) + + sub_grad = sub_grad.sum(dim=0, keepdim=True) + + if self.projector_kwargs is not None: + if ( + not hasattr(self, "test_random_project") + or self.test_random_project is None + ): + self.test_random_project = random_project( + sub_grad, + 1, + **self.projector_kwargs, + ) + current_grad = self.test_random_project( + sub_grad, + ensemble_id=ckpt_idx, + ) + else: + current_grad = sub_grad + + test_grads_accumulator += current_grad + + return test_grads_accumulator + + if hasattr(test_dataloader, "original_test_dataloader"): + _check_shuffle(test_dataloader.original_test_dataloader) + else: + _check_shuffle(test_dataloader) _check_shuffle(train_dataloader) # check the length match between checkpoint list and weight list @@ -166,36 +245,41 @@ def attribute( # noqa: PLR0912 else: train_batch_grad = torch.nan_to_num(grad_t) - for test_batch_idx, test_batch_data_ in enumerate( + for test_batch_idx, test_item in enumerate( tqdm( test_dataloader, desc="calculating gradient of test set...", leave=False, ), ): - # move to device - if isinstance(test_batch_data_, (tuple, list)): - test_batch_data = tuple( - x.to(self.device) for x in test_batch_data_ + if isinstance(test_item, torch.utils.data.DataLoader): + test_batch_grad = compute_group_test_grad( + parameters=parameters, + test_item=test_item, + ckpt_idx=ckpt_idx, ) else: - test_batch_data = test_batch_data_ - # get gradient of test - grad_t = self.grad_target_func(parameters, test_batch_data) - if self.projector_kwargs is not None: - # define the projector for this batch of data - self.test_random_project = random_project( - grad_t, - test_batch_data[0].shape[0], - **self.projector_kwargs, - ) + if isinstance(test_item, (tuple, list)): + test_batch_data = tuple( + x.to(self.device) for x in test_item + ) + else: + test_batch_data = test_item - test_batch_grad = self.test_random_project( - torch.nan_to_num(grad_t), - ensemble_id=ckpt_idx, - ) - else: - test_batch_grad = torch.nan_to_num(grad_t) + grad_t_test = self.grad_target_func(parameters, test_batch_data) + + if self.projector_kwargs is not None: + self.test_random_project = random_project( + grad_t_test, + test_batch_data[0].shape[0], + **self.projector_kwargs, + ) + test_batch_grad = self.test_random_project( + torch.nan_to_num(grad_t_test), + ensemble_id=ckpt_idx, + ) + else: + test_batch_grad = torch.nan_to_num(grad_t_test) # results position based on batch info row_st = train_batch_idx * train_dataloader.batch_size From e2f97ab2a533e09a9a0ad9a0ebd9d68c816dad12 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Mon, 9 Feb 2026 18:27:00 -0600 Subject: [PATCH 03/16] add tracin dataloader group test --- test/dattri/algorithm/test_tracin.py | 82 +++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index 810b8d304..59840122c 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -9,7 +9,7 @@ from torch.func import grad, vmap from torch.utils.data import DataLoader, Dataset, TensorDataset -from dattri.algorithm.tracin import TracInAttributor +from dattri.algorithm.tracin import TestDataloaderGroup, TracInAttributor from dattri.benchmark.datasets.cifar import train_cifar_resnet9 from dattri.benchmark.datasets.mnist import train_mnist_lr, train_mnist_mlp from dattri.func.utils import flatten_func, flatten_params @@ -512,3 +512,83 @@ def f(params, image_label_pair): assert torch.allclose(ckpt_grad_1[1], ckpt_grad_2[1]) shutil.rmtree(path) + + def test_tracin_dataloader(self): + """Verify TracIn Group Attribution correctness.""" + train_loader = DataLoader( + TensorDataset( + torch.randn(20, 1, 28, 28), + torch.randint(0, 10, (20,)), + ), + batch_size=4, + shuffle=False, + ) + test_loader = DataLoader( + TensorDataset( + torch.randn(10, 1, 28, 28), + torch.randint(0, 10, (10,)), + ), + batch_size=2, + shuffle=False, + ) + + model = train_mnist_lr(train_loader) + + # to simlulate multiple checkpoints + model_1 = train_mnist_lr(train_loader, epoch_num=1) + model_2 = train_mnist_lr(train_loader, epoch_num=2) + path = Path("./ckpts") + if not path.exists(): + path.mkdir(parents=True) + torch.save(model_1.state_dict(), path / "model_1.pt") + torch.save(model_2.state_dict(), path / "model_2.pt") + + checkpoint_list = ["ckpts/model_1.pt", "ckpts/model_2.pt"] + + def f(params, image_label_pair): + image, label = image_label_pair + image_t = image.unsqueeze(0) + label_t = label.unsqueeze(0) + loss = nn.CrossEntropyLoss() + yhat = torch.func.functional_call(model, params, image_t) + return loss(yhat, label_t) + + task = AttributionTask( + loss_func=f, + model=model, + checkpoints=checkpoint_list, + ) + + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=False, + projector_kwargs=None, + device="cpu", + ) + + score_group = attributor.attribute( + train_loader, + TestDataloaderGroup(test_loader), + ) + + assert score_group.shape == (20, 1), ( + f"Expected shape (20, 1), got {score_group.shape}" + ) + + score_full = attributor.attribute( + train_loader, + test_loader, + ) + + score_full = score_full.sum( + dim=1, + keepdim=True, + ) # Make the shape (N, 1) for comparison + + assert torch.allclose(score_group, score_full, rtol=1e-03, atol=1e-05), ( + "Score does not match manual calculation." + ) + + if path.exists(): + shutil.rmtree(path) From 3e49baa1d79a715b1d956bbff5cb34b0968fe980 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Mon, 9 Feb 2026 19:07:46 -0600 Subject: [PATCH 04/16] update dataloader_group --- dattri/algorithm/tracin.py | 6 +++--- examples/data_cleaning/tracin_dataloader_group.py | 2 +- test/dattri/algorithm/test_tracin.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 6a48cce72..8d25517e5 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -29,7 +29,7 @@ } -class TestDataloaderGroup: +class DataloaderGroup: """Helper class to wrap a DataLoader for group attribution. This wrapper presents the underlying dataloader as a single item (length 1). @@ -38,7 +38,7 @@ class TestDataloaderGroup: """ def __init__(self, original_test_dataloader: DataLoader) -> None: - """Initialize the TestDataloaderGroup. + """Initialize the DataloaderGroup. Args: original_test_dataloader (DataLoader): The underlying PyTorch dataloader. @@ -114,7 +114,7 @@ def cache(self) -> None: def attribute( # noqa: PLR0912, PLR0915 self, train_dataloader: torch.utils.data.DataLoader, - test_dataloader: Union[torch.utils.data.DataLoader, TestDataloaderGroup], + test_dataloader: Union[torch.utils.data.DataLoader, DataloaderGroup], ) -> Tensor: """Calculate the influence of the training set on the test set. diff --git a/examples/data_cleaning/tracin_dataloader_group.py b/examples/data_cleaning/tracin_dataloader_group.py index d4b4f79ba..594154610 100644 --- a/examples/data_cleaning/tracin_dataloader_group.py +++ b/examples/data_cleaning/tracin_dataloader_group.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset -from dattri.algorithm.tracin import TestDataloaderGroup, TracInAttributor +from dattri.algorithm.tracin import DataloaderGroup, TracInAttributor def main(): # trivial linear model and synthetic data for demonstration diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index 59840122c..7ccfd8420 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -9,7 +9,7 @@ from torch.func import grad, vmap from torch.utils.data import DataLoader, Dataset, TensorDataset -from dattri.algorithm.tracin import TestDataloaderGroup, TracInAttributor +from dattri.algorithm.tracin import DataloaderGroup, TracInAttributor from dattri.benchmark.datasets.cifar import train_cifar_resnet9 from dattri.benchmark.datasets.mnist import train_mnist_lr, train_mnist_mlp from dattri.func.utils import flatten_func, flatten_params @@ -569,7 +569,7 @@ def f(params, image_label_pair): score_group = attributor.attribute( train_loader, - TestDataloaderGroup(test_loader), + DataloaderGroup(test_loader), ) assert score_group.shape == (20, 1), ( From c2334fe83620e5cc75f167a193e678fcb75ff808 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Thu, 12 Feb 2026 01:12:45 -0600 Subject: [PATCH 05/16] fix syntax error --- examples/data_cleaning/tracin_dataloader_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/data_cleaning/tracin_dataloader_group.py b/examples/data_cleaning/tracin_dataloader_group.py index 594154610..04bb52104 100644 --- a/examples/data_cleaning/tracin_dataloader_group.py +++ b/examples/data_cleaning/tracin_dataloader_group.py @@ -1,4 +1,4 @@ -# This is a simple example to demonstrate how to use TracInAttributor with a TestDataloaderGroup. +# This is a simple example to demonstrate how to use TracInAttributor with a DataloaderGroup. import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset @@ -41,7 +41,7 @@ def get_grad_target_func(self, *args, **kwargs): ) attributor.projector_kwargs = None - test_group = TestDataloaderGroup(test_loader) + test_group = DataloaderGroup(test_loader) scores = attributor.attribute(train_loader, test_group) # The TracInAttributor should compute the influence scores for each training example with respect to the test dataloader group. From dbc6644c4373329d1d9337fe8ee8dffdb5f832c5 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Sun, 22 Feb 2026 19:01:57 -0600 Subject: [PATCH 06/16] add dataloader group class to unit test --- test/dattri/algorithm/test_tracin.py | 50 +++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index 7ccfd8420..4307cc7eb 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -3,13 +3,14 @@ import copy import shutil from pathlib import Path +from typing import Iterator import torch from torch import nn from torch.func import grad, vmap from torch.utils.data import DataLoader, Dataset, TensorDataset -from dattri.algorithm.tracin import DataloaderGroup, TracInAttributor +from dattri.algorithm.tracin import TracInAttributor from dattri.benchmark.datasets.cifar import train_cifar_resnet9 from dattri.benchmark.datasets.mnist import train_mnist_lr, train_mnist_mlp from dattri.func.utils import flatten_func, flatten_params @@ -515,6 +516,43 @@ def f(params, image_label_pair): def test_tracin_dataloader(self): """Verify TracIn Group Attribution correctness.""" + + class DataloaderGroup(DataLoader): + """Helper class to wrap a DataLoader for group attribution. + + This wrapper presents the dataloader as a single item (length 1). + When iterated, it yields the original dataloader itself, allowing the + consumer to treat the entire dataset as one attribution target. + """ + + def __init__(self, original_test_dataloader: DataLoader) -> None: + """Initialize the DataloaderGroup. + + Args: + original_test_dataloader (DataLoader): + The PyTorch dataloader for individual test data samples + """ + super().__init__( + torch.utils.data.TensorDataset(torch.zeros(1)), + ) + self.original_test_dataloader = original_test_dataloader + + def __iter__(self) -> Iterator[DataLoader]: + """Iterate over the group. + + Yields: + DataLoader: Yields the original dataloader as a single object. + """ + yield self.original_test_dataloader + + def __len__(self) -> int: + """Return the length of the group wrapper. + + Returns: + int: Always 1, as the whole dataset is treated as one group. + """ + return 1 + train_loader = DataLoader( TensorDataset( torch.randn(20, 1, 28, 28), @@ -553,10 +591,20 @@ def f(params, image_label_pair): yhat = torch.func.functional_call(model, params, image_t) return loss(yhat, label_t) + def group_target_func(params, loader): + loss_fn = nn.CrossEntropyLoss(reduction="sum") + total = None + for image, label in loader: + yhat = torch.func.functional_call(model, params, (image,)) + loss = loss_fn(yhat, label.long()) + total = loss if total is None else total + loss + return total + task = AttributionTask( loss_func=f, model=model, checkpoints=checkpoint_list, + group_target_func=group_target_func, ) attributor = TracInAttributor( From 333f304fb57ca38ecf1e8502e35b293c91c15435 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Sun, 22 Feb 2026 19:02:34 -0600 Subject: [PATCH 07/16] Move dataloader group class to Example/Test --- dattri/algorithm/tracin.py | 129 +++++++------------------------------ 1 file changed, 24 insertions(+), 105 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 8d25517e5..cc3fe11c6 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -2,13 +2,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Iterator +from typing import TYPE_CHECKING, Any, Dict if TYPE_CHECKING: from typing import List, Optional, Union - from torch.utils.data import DataLoader - from dattri.task import AttributionTask import torch @@ -29,41 +27,6 @@ } -class DataloaderGroup: - """Helper class to wrap a DataLoader for group attribution. - - This wrapper presents the underlying dataloader as a single item (length 1). - When iterated, it yields the original dataloader itself, allowing the - consumer to treat the entire dataset as one attribution target. - """ - - def __init__(self, original_test_dataloader: DataLoader) -> None: - """Initialize the DataloaderGroup. - - Args: - original_test_dataloader (DataLoader): The underlying PyTorch dataloader. - """ - self.original_test_dataloader = original_test_dataloader - self.batch_size = 1 - self.sampler = [0] - - def __iter__(self) -> Iterator[DataLoader]: - """Iterate over the group. - - Yields: - DataLoader: Yields the original dataloader as a single object. - """ - yield self.original_test_dataloader - - def __len__(self) -> int: - """Return the length of the group wrapper. - - Returns: - int: Always 1, as the whole dataset is treated as one group. - """ - return 1 - - class TracInAttributor(BaseAttributor): """TracIn attributor.""" @@ -111,10 +74,10 @@ def __init__( def cache(self) -> None: """Precompute and cache some values for efficiency.""" - def attribute( # noqa: PLR0912, PLR0915 + def attribute( # noqa: PLR0912 self, train_dataloader: torch.utils.data.DataLoader, - test_dataloader: Union[torch.utils.data.DataLoader, DataloaderGroup], + test_dataloader: torch.utils.data.DataLoader, ) -> Tensor: """Calculate the influence of the training set on the test set. @@ -135,45 +98,6 @@ def attribute( # noqa: PLR0912, PLR0915 Tensor: The influence of the training set on the test set, with the shape of (num_train_samples, num_test_samples). """ - - def compute_group_test_grad( - parameters: Tensor, - test_item: torch.utils.data.DataLoader, - ckpt_idx: int, - ) -> Tensor: - test_grads_accumulator = 0 - - for sub_batch in test_item: - if isinstance(sub_batch, (tuple, list)): - temp = tuple(x.to(self.device) for x in sub_batch) - else: - temp = sub_batch.to(self.device) - - sub_grad = torch.nan_to_num(self.grad_target_func(parameters, temp)) - - sub_grad = sub_grad.sum(dim=0, keepdim=True) - - if self.projector_kwargs is not None: - if ( - not hasattr(self, "test_random_project") - or self.test_random_project is None - ): - self.test_random_project = random_project( - sub_grad, - 1, - **self.projector_kwargs, - ) - current_grad = self.test_random_project( - sub_grad, - ensemble_id=ckpt_idx, - ) - else: - current_grad = sub_grad - - test_grads_accumulator += current_grad - - return test_grads_accumulator - if hasattr(test_dataloader, "original_test_dataloader"): _check_shuffle(test_dataloader.original_test_dataloader) else: @@ -245,41 +169,36 @@ def compute_group_test_grad( else: train_batch_grad = torch.nan_to_num(grad_t) - for test_batch_idx, test_item in enumerate( + for test_batch_idx, test_batch_data_ in enumerate( tqdm( test_dataloader, desc="calculating gradient of test set...", leave=False, ), ): - if isinstance(test_item, torch.utils.data.DataLoader): - test_batch_grad = compute_group_test_grad( - parameters=parameters, - test_item=test_item, - ckpt_idx=ckpt_idx, + # move to device + if isinstance(test_batch_data_, (tuple, list)): + test_batch_data = tuple( + x.to(self.device) for x in test_batch_data_ ) else: - if isinstance(test_item, (tuple, list)): - test_batch_data = tuple( - x.to(self.device) for x in test_item - ) - else: - test_batch_data = test_item - - grad_t_test = self.grad_target_func(parameters, test_batch_data) + test_batch_data = test_batch_data_ + # get gradient of test + grad_t = self.grad_target_func(parameters, test_batch_data) + if self.projector_kwargs is not None: + # define the projector for this batch of data + self.test_random_project = random_project( + grad_t, + grad_t.shape[0], + **self.projector_kwargs, + ) - if self.projector_kwargs is not None: - self.test_random_project = random_project( - grad_t_test, - test_batch_data[0].shape[0], - **self.projector_kwargs, - ) - test_batch_grad = self.test_random_project( - torch.nan_to_num(grad_t_test), - ensemble_id=ckpt_idx, - ) - else: - test_batch_grad = torch.nan_to_num(grad_t_test) + test_batch_grad = self.test_random_project( + torch.nan_to_num(grad_t), + ensemble_id=ckpt_idx, + ) + else: + test_batch_grad = torch.nan_to_num(grad_t) # results position based on batch info row_st = train_batch_idx * train_dataloader.batch_size From ac2ad9cb2a7045ae8ad8cc26295d2b3ceee10a03 Mon Sep 17 00:00:00 2001 From: Tommy Jin Date: Sun, 22 Feb 2026 19:03:05 -0600 Subject: [PATCH 08/16] Use AttributeTask; Added dataloader group class --- .../data_cleaning/tracin_dataloader_group.py | 111 +++++++++++++----- 1 file changed, 81 insertions(+), 30 deletions(-) diff --git a/examples/data_cleaning/tracin_dataloader_group.py b/examples/data_cleaning/tracin_dataloader_group.py index 04bb52104..71ad31b61 100644 --- a/examples/data_cleaning/tracin_dataloader_group.py +++ b/examples/data_cleaning/tracin_dataloader_group.py @@ -1,53 +1,104 @@ -# This is a simple example to demonstrate how to use TracInAttributor with a DataloaderGroup. +"""This example shows how to use TracInAttributor with DataloaderGroup and a +user-defined group target via AttributionTask (group_target_func). +""" + +from typing import Iterator + import torch -import torch.nn as nn +from torch import nn from torch.utils.data import DataLoader, TensorDataset -from dattri.algorithm.tracin import DataloaderGroup, TracInAttributor -def main(): - # trivial linear model and synthetic data for demonstration +from dattri.algorithm.tracin import TracInAttributor +from dattri.task import AttributionTask + + +class DataloaderGroup(DataLoader): + """Helper class to wrap a DataLoader for group attribution. + + This wrapper presents the dataloader as a single item (length 1). + When iterated, it yields the original dataloader itself, allowing the + consumer to treat the entire dataset as one attribution target. + """ + + def __init__(self, original_test_dataloader: DataLoader) -> None: + """Initialize the DataloaderGroup. + + Args: + original_test_dataloader (DataLoader): + The PyTorch dataloader for individual test data samples + """ + super().__init__(torch.utils.data.TensorDataset(torch.zeros(1))) + self.original_test_dataloader = original_test_dataloader + + def __iter__(self) -> Iterator[DataLoader]: + """Iterate over the group. + + Yields: + DataLoader: Yields the original dataloader as a single object. + """ + yield self.original_test_dataloader + + def __len__(self) -> int: + """Return the length of the group wrapper. + + Returns: + int: Always 1, as the whole dataset is treated as one group. + """ + return 1 + + +if __name__ == "__main__": torch.manual_seed(42) input_dim, n_train, n_test = 2, 10, 5 model = nn.Linear(input_dim, 1, bias=False) model.weight.data.fill_(1.0) - train_loader = DataLoader(TensorDataset(torch.randn(n_train, input_dim), torch.randn(n_train, 1)), batch_size=2) - test_loader = DataLoader(TensorDataset(torch.randn(n_test, input_dim), torch.randn(n_test, 1)), batch_size=2) - + train_loader = DataLoader( + TensorDataset(torch.randn(n_train, input_dim), torch.randn(n_train, 1)), + batch_size=2, + ) + test_loader = DataLoader( + TensorDataset(torch.randn(n_test, input_dim), torch.randn(n_test, 1)), + batch_size=2, + ) - def func(params, data): + # loss and target in AttributionTask style: (params_dict, data) -> scalar + def f(params, data): x, y = data - w = params['weight'] - return ((x @ w.t()) - y) * x - + yhat = torch.func.functional_call(model, params, (x,)) + return ((yhat - y) ** 2).mean() - def func_group(params, loader): - x, y = loader - w = params['weight'] - return torch.sum(((x @ w.t()) - y) * x, dim=0, keepdim=True) + # user-defined scalar target for group attribution: (params_dict, loader) -> scalar + # the gradient of this w.r.t. params is the test-side gradient for the group + def group_target_func(params, loader): + total = None + for batch in loader: + x, y = batch + loss = f(params, (x, y)) + total = loss if total is None else total + loss + return total - class SimpleTask: - def get_checkpoints(self): return [0] - def get_param(self, *args, **kwargs): return dict(model.named_parameters()), None - def get_grad_loss_func(self, *args, **kwargs): return func - def get_grad_target_func(self, *args, **kwargs): - return func_group + task = AttributionTask( + loss_func=f, + model=model, + checkpoints=model.state_dict(), + target_func=f, + group_target_func=group_target_func, + ) attributor = TracInAttributor( - task=SimpleTask(), + task=task, weight_list=torch.tensor([1.0]), - normalized_grad=False + normalized_grad=False, + device="cpu", ) attributor.projector_kwargs = None test_group = DataloaderGroup(test_loader) - scores = attributor.attribute(train_loader, test_group) + with torch.no_grad(): + scores = attributor.attribute(train_loader, test_group) - # The TracInAttributor should compute the influence scores for each training example with respect to the test dataloader group. - print(f"Test Dataloader Group.") + print("Test Dataloader Group (AttributionTask + group_target_func).") print(f"Score Shape: {scores.shape}") print(f"Calculated Scores:\n{scores.flatten()}") - -if __name__ == "__main__": - main() From 6813189412ce52c20a36ef6cb4e2ef263c0b54fe Mon Sep 17 00:00:00 2001 From: jj39 Date: Fri, 27 Feb 2026 16:26:31 -0600 Subject: [PATCH 09/16] removed check_shuffle logic change --- dattri/algorithm/tracin.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index cc3fe11c6..40d70eed4 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -98,10 +98,7 @@ def attribute( # noqa: PLR0912 Tensor: The influence of the training set on the test set, with the shape of (num_train_samples, num_test_samples). """ - if hasattr(test_dataloader, "original_test_dataloader"): - _check_shuffle(test_dataloader.original_test_dataloader) - else: - _check_shuffle(test_dataloader) + _check_shuffle(test_dataloader) _check_shuffle(train_dataloader) # check the length match between checkpoint list and weight list From 360180635fc56d4a19a67bbcb146e92c3cf9ae97 Mon Sep 17 00:00:00 2001 From: jj39 Date: Fri, 27 Feb 2026 16:28:55 -0600 Subject: [PATCH 10/16] add an optional group_target_func to avoid vmap issue with Dataloader --- dattri/task.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/dattri/task.py b/dattri/task.py index fc57f3cba..1478390f9 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -14,8 +14,9 @@ import torch from torch import nn from torch.func import grad, vmap +from torch.utils.data import DataLoader -from dattri.func.utils import flatten_func, flatten_params, partial_param +from dattri.func.utils import _unflatten_params, flatten_func, flatten_params, partial_param def _default_checkpoint_load_func( @@ -52,6 +53,7 @@ def __init__( ], target_func: Optional[Callable] = None, checkpoints_load_func: Optional[Callable] = None, + group_target_func: Optional[Callable] = None, ) -> None: """Initialize the AttributionTask. @@ -108,8 +110,13 @@ def checkpoints_load_func(model, checkpoint): model.eval() return model ```. + group_target_func (Callable): Optional. When attributing to a group (e.g. a + DataLoader passed via DataloaderGroup), this scalar function is used + instead of the per-sample target. Signature (params_dict, loader) -> scalar. + The gradient of this w.r.t. params is the test-side gradient for the group. """ self.model = model + self.group_target_func = group_target_func if target_func is None: target_func = loss_func @@ -256,7 +263,34 @@ def get_grad_target_func( randomness="different", ) self.grad_target_func_kwargs = grad_target_func_kwargs - return self.grad_target_func + + base_grad_target = self.grad_target_func + group_target_func = self.group_target_func + model_ref = self.model + + def wrapped(parameters, data): + if isinstance(data, DataLoader): + if group_target_func is None: + raise ValueError( + "A DataLoader was passed (e.g. via DataloaderGroup) but this " + "AttributionTask was not given group_target_func. For group " + "attribution, pass group_target_func=(params_dict, loader) -> scalar." + ) + # Pre-fetch batches outside grad to avoid tracing DataLoader/dataset + # access (e.g. .numpy() in dataset __getitem__) which can raise + # "Tensor that doesn't have storage" when run inside autograd. + batches = list(data) + + def flat_group_target(flat_params): + params_dict = _unflatten_params(flat_params, model_ref) + return group_target_func(params_dict, batches) + + g = grad(flat_group_target)(parameters) + return g.unsqueeze(0) + + return base_grad_target(parameters, data) + + return wrapped def get_target_func( self, From ee4f778d18455f52933ed5a276e59d8090a0ef8b Mon Sep 17 00:00:00 2001 From: jj39 Date: Fri, 27 Feb 2026 16:29:31 -0600 Subject: [PATCH 11/16] Using mnist-mlp --- .../data_cleaning/tracin_dataloader_group.py | 77 +++++++++++++------ 1 file changed, 54 insertions(+), 23 deletions(-) diff --git a/examples/data_cleaning/tracin_dataloader_group.py b/examples/data_cleaning/tracin_dataloader_group.py index 71ad31b61..5cd179c17 100644 --- a/examples/data_cleaning/tracin_dataloader_group.py +++ b/examples/data_cleaning/tracin_dataloader_group.py @@ -1,14 +1,18 @@ """This example shows how to use TracInAttributor with DataloaderGroup and a user-defined group target via AttributionTask (group_target_func). +Uses MNIST + MLP. """ +import argparse from typing import Iterator import torch from torch import nn -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import DataLoader from dattri.algorithm.tracin import TracInAttributor +from dattri.benchmark.datasets.mnist import create_mnist_dataset, train_mnist_mlp +from dattri.benchmark.utils import SubsetSampler from dattri.task import AttributionTask @@ -48,35 +52,58 @@ def __len__(self) -> int: if __name__ == "__main__": - torch.manual_seed(42) - input_dim, n_train, n_test = 2, 10, 5 - - model = nn.Linear(input_dim, 1, bias=False) - model.weight.data.fill_(1.0) + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--train_size", type=int, default=10000) + parser.add_argument("--test_size", type=int, default=5000) + args = parser.parse_args() + + print(args) + + # load the training dataset (same as influence_function_data_cleaning.py) + dataset, dataset_test = create_mnist_dataset("./data") + + # for model training, batch size is 64 + train_loader_full = DataLoader( + dataset, + batch_size=64, + sampler=SubsetSampler(range(args.train_size)), + ) + # training samples for attribution; batch size 1000 to speed up train_loader = DataLoader( - TensorDataset(torch.randn(n_train, input_dim), torch.randn(n_train, 1)), - batch_size=2, + dataset, + batch_size=1000, + sampler=SubsetSampler(range(args.train_size)), ) test_loader = DataLoader( - TensorDataset(torch.randn(n_test, input_dim), torch.randn(n_test, 1)), - batch_size=2, + dataset_test, + batch_size=1000, + sampler=SubsetSampler(range(args.test_size)), ) - # loss and target in AttributionTask style: (params_dict, data) -> scalar - def f(params, data): - x, y = data - yhat = torch.func.functional_call(model, params, (x,)) - return ((yhat - y) ** 2).mean() + model = train_mnist_mlp(train_loader_full, seed=args.seed) + model.to(args.device) + model.eval() + + # loss and target in AttributionTask style; match IF example signature + def f(params, data_target_pair): + image, label = data_target_pair + label = label.view(-1).long() + yhat = torch.func.functional_call(model, params, (image,)) + return nn.CrossEntropyLoss()(yhat, label) - # user-defined scalar target for group attribution: (params_dict, loader) -> scalar - # the gradient of this w.r.t. params is the test-side gradient for the group + # group target: gradient = sum of per-sample gradients (use loss * batch_size) def group_target_func(params, loader): + device = next(iter(params.values())).device total = None for batch in loader: - x, y = batch - loss = f(params, (x, y)) - total = loss if total is None else total + loss + image, label = batch + image, label = image.to(device), label.to(device) + loss = f(params, (image, label)) + n = image.shape[0] + total = loss * n if total is None else total + loss * n return total task = AttributionTask( @@ -91,14 +118,18 @@ def group_target_func(params, loader): task=task, weight_list=torch.tensor([1.0]), normalized_grad=False, - device="cpu", + device=args.device, ) attributor.projector_kwargs = None test_group = DataloaderGroup(test_loader) with torch.no_grad(): scores = attributor.attribute(train_loader, test_group) + scores_temp = attributor.attribute(train_loader, test_loader) - print("Test Dataloader Group (AttributionTask + group_target_func).") + print("Test Dataloader Group (AttributionTask + group_target_func) — MNIST + MLP.") print(f"Score Shape: {scores.shape}") - print(f"Calculated Scores:\n{scores.flatten()}") + print(f"Calculated Scores (first 10):\n{scores.flatten()[:10]}") + print(f"Calculated Scores Temp sum over test (first 10):\n{scores_temp.sum(dim=1)[:10]}") + diff = (scores.flatten() - scores_temp.sum(dim=1)).abs() + print(f"Max |group - sum(per-test)|: {diff.max().item():.6f}") From f3209abc17daa2f102009299bf0406ba1188f2d3 Mon Sep 17 00:00:00 2001 From: jj39 Date: Fri, 6 Mar 2026 21:31:33 -0600 Subject: [PATCH 12/16] add optional boolean argument group_target_func --- dattri/task.py | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/dattri/task.py b/dattri/task.py index 1478390f9..c4ae88d1d 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -2,11 +2,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from collections.abc import Callable - from typing import Dict, List, Optional, Tuple, Union + from typing import Dict, List, Optional, Tuple import inspect from pathlib import PosixPath @@ -16,7 +16,12 @@ from torch.func import grad, vmap from torch.utils.data import DataLoader -from dattri.func.utils import _unflatten_params, flatten_func, flatten_params, partial_param +from dattri.func.utils import ( + _unflatten_params, + flatten_func, + flatten_params, + partial_param, +) def _default_checkpoint_load_func( @@ -53,7 +58,7 @@ def __init__( ], target_func: Optional[Callable] = None, checkpoints_load_func: Optional[Callable] = None, - group_target_func: Optional[Callable] = None, + group_target_func: bool = False, ) -> None: """Initialize the AttributionTask. @@ -90,6 +95,11 @@ def f(params, data): in terms of what is calculated, but it should take the parameters and the data as input. Other than that, the forwarding of model should be in `torch.func` style. + When group_target_func=True, target_func is also used for group + attribution: it should take (params_dict, batches) where batches + is a list of batches from the group DataLoader, and return a scalar + (e.g. sum of per-batch losses). The gradient of this scalar w.r.t. + params is the test-side gradient for the group. A typical example is as follows: ```python def f(params, data): @@ -110,10 +120,10 @@ def checkpoints_load_func(model, checkpoint): model.eval() return model ```. - group_target_func (Callable): Optional. When attributing to a group (e.g. a - DataLoader passed via DataloaderGroup), this scalar function is used - instead of the per-sample target. Signature (params_dict, loader) -> scalar. - The gradient of this w.r.t. params is the test-side gradient for the group. + group_target_func (bool): If True, enable group attribution: when a + DataLoader is passed (e.g. via DataloaderGroup), target_func is + called with (params_dict, list_of_batches) and should return a + scalar. Default is False. """ self.model = model self.group_target_func = group_target_func @@ -265,25 +275,25 @@ def get_grad_target_func( self.grad_target_func_kwargs = grad_target_func_kwargs base_grad_target = self.grad_target_func - group_target_func = self.group_target_func + + if not self.group_target_func: + return self.grad_target_func + model_ref = self.model - def wrapped(parameters, data): + def wrapped( + parameters: torch.Tensor, + data: Union[DataLoader, object], + ) -> torch.Tensor: if isinstance(data, DataLoader): - if group_target_func is None: - raise ValueError( - "A DataLoader was passed (e.g. via DataloaderGroup) but this " - "AttributionTask was not given group_target_func. For group " - "attribution, pass group_target_func=(params_dict, loader) -> scalar." - ) # Pre-fetch batches outside grad to avoid tracing DataLoader/dataset # access (e.g. .numpy() in dataset __getitem__) which can raise # "Tensor that doesn't have storage" when run inside autograd. batches = list(data) - def flat_group_target(flat_params): + def flat_group_target(flat_params: torch.Tensor) -> torch.Tensor: params_dict = _unflatten_params(flat_params, model_ref) - return group_target_func(params_dict, batches) + return self.original_target_func(params_dict, batches) g = grad(flat_group_target)(parameters) return g.unsqueeze(0) From 73be7fc1367b0c3a828dc4a39204c095cb85dfe2 Mon Sep 17 00:00:00 2001 From: jj39 Date: Fri, 6 Mar 2026 21:50:33 -0600 Subject: [PATCH 13/16] Add optional argument group_target_func --- dattri/task.py | 49 ++++++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/dattri/task.py b/dattri/task.py index c4ae88d1d..e7d29b49f 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -276,31 +276,30 @@ def get_grad_target_func( base_grad_target = self.grad_target_func - if not self.group_target_func: - return self.grad_target_func - - model_ref = self.model - - def wrapped( - parameters: torch.Tensor, - data: Union[DataLoader, object], - ) -> torch.Tensor: - if isinstance(data, DataLoader): - # Pre-fetch batches outside grad to avoid tracing DataLoader/dataset - # access (e.g. .numpy() in dataset __getitem__) which can raise - # "Tensor that doesn't have storage" when run inside autograd. - batches = list(data) - - def flat_group_target(flat_params: torch.Tensor) -> torch.Tensor: - params_dict = _unflatten_params(flat_params, model_ref) - return self.original_target_func(params_dict, batches) - - g = grad(flat_group_target)(parameters) - return g.unsqueeze(0) - - return base_grad_target(parameters, data) - - return wrapped + if self.group_target_func: + model_ref = self.model + + def wrapped( + parameters: torch.Tensor, + data: Union[DataLoader, object], + ) -> torch.Tensor: + if isinstance(data, DataLoader): + # Pre-fetch batches outside grad to avoid tracing DataLoader/dataset + # access (e.g. .numpy() in dataset __getitem__) which can raise + # "Tensor that doesn't have storage" when run inside autograd. + batches = list(data) + + def flat_group_target(flat_params: torch.Tensor) -> torch.Tensor: + params_dict = _unflatten_params(flat_params, model_ref) + return self.original_target_func(params_dict, batches) + + g = grad(flat_group_target)(parameters) + return g.unsqueeze(0) + return base_grad_target(parameters, data) + + return wrapped + + return self.grad_target_func def get_target_func( self, From 3f46a137a817d2fc661048f581c3fc08924210cb Mon Sep 17 00:00:00 2001 From: jj39 Date: Fri, 6 Mar 2026 21:55:03 -0600 Subject: [PATCH 14/16] adapt task.py change --- .../data_cleaning/tracin_dataloader_group.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/examples/data_cleaning/tracin_dataloader_group.py b/examples/data_cleaning/tracin_dataloader_group.py index 5cd179c17..b77a4e4e8 100644 --- a/examples/data_cleaning/tracin_dataloader_group.py +++ b/examples/data_cleaning/tracin_dataloader_group.py @@ -1,5 +1,5 @@ -"""This example shows how to use TracInAttributor with DataloaderGroup and a -user-defined group target via AttributionTask (group_target_func). +"""This example shows how to use TracInAttributor with DataloaderGroup and +group_target_func=True so target_func is used for group attribution. Uses MNIST + MLP. """ @@ -87,31 +87,33 @@ def __len__(self) -> int: model.to(args.device) model.eval() - # loss and target in AttributionTask style; match IF example signature + # loss and target in AttributionTask style; match IF example signature. + # When group_target_func=True, target_func is also called with (params_dict, list_of_batches). def f(params, data_target_pair): image, label = data_target_pair label = label.view(-1).long() yhat = torch.func.functional_call(model, params, (image,)) return nn.CrossEntropyLoss()(yhat, label) - # group target: gradient = sum of per-sample gradients (use loss * batch_size) - def group_target_func(params, loader): - device = next(iter(params.values())).device - total = None - for batch in loader: - image, label = batch - image, label = image.to(device), label.to(device) - loss = f(params, (image, label)) - n = image.shape[0] - total = loss * n if total is None else total + loss * n - return total + def target_func(params, data): + if isinstance(data, list): + # group mode: data is list of (image, label) batches + device = next(iter(params.values())).device + total = None + for image, label in data: + image, label = image.to(device), label.to(device) + loss = f(params, (image, label)) + n = image.shape[0] + total = loss * n if total is None else total + loss * n + return total + return f(params, data) task = AttributionTask( loss_func=f, model=model, checkpoints=model.state_dict(), - target_func=f, - group_target_func=group_target_func, + target_func=target_func, + group_target_func=True, ) attributor = TracInAttributor( @@ -120,14 +122,13 @@ def group_target_func(params, loader): normalized_grad=False, device=args.device, ) - attributor.projector_kwargs = None test_group = DataloaderGroup(test_loader) with torch.no_grad(): scores = attributor.attribute(train_loader, test_group) scores_temp = attributor.attribute(train_loader, test_loader) - print("Test Dataloader Group (AttributionTask + group_target_func) — MNIST + MLP.") + print("Test Dataloader Group (AttributionTask + group_target_func=True) — MNIST + MLP.") print(f"Score Shape: {scores.shape}") print(f"Calculated Scores (first 10):\n{scores.flatten()[:10]}") print(f"Calculated Scores Temp sum over test (first 10):\n{scores_temp.sum(dim=1)[:10]}") From 5ced13346bdb6f76e4840c01007b3ed92a1847be Mon Sep 17 00:00:00 2001 From: jj39 Date: Fri, 6 Mar 2026 21:55:30 -0600 Subject: [PATCH 15/16] adapt task.py change --- test/dattri/algorithm/test_tracin.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index 4307cc7eb..e6206268b 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -591,20 +591,23 @@ def f(params, image_label_pair): yhat = torch.func.functional_call(model, params, image_t) return loss(yhat, label_t) - def group_target_func(params, loader): - loss_fn = nn.CrossEntropyLoss(reduction="sum") - total = None - for image, label in loader: - yhat = torch.func.functional_call(model, params, (image,)) - loss = loss_fn(yhat, label.long()) - total = loss if total is None else total + loss - return total + def target_func(params, data): + if isinstance(data, list): + loss_fn = nn.CrossEntropyLoss(reduction="sum") + total = None + for image, label in data: + yhat = torch.func.functional_call(model, params, (image,)) + loss = loss_fn(yhat, label.long()) + total = loss if total is None else total + loss + return total + return f(params, data) task = AttributionTask( loss_func=f, model=model, checkpoints=checkpoint_list, - group_target_func=group_target_func, + target_func=target_func, + group_target_func=True, ) attributor = TracInAttributor( From 2701ac932c40d538e065ddc41a486b514388af5d Mon Sep 17 00:00:00 2001 From: jj39 Date: Mon, 9 Mar 2026 21:00:32 -0500 Subject: [PATCH 16/16] adapt updated projector --- dattri/algorithm/tracin.py | 4 +++- test/dattri/algorithm/test_tracin.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 9279d6329..6fd816ced 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -178,10 +178,12 @@ def attribute( # noqa: PLR0912 grad_t = self.grad_target_func(parameters, test_batch_data) if self.proj_params.proj_dim is not None: # define the projector for this batch of data + # use grad_t.shape[0] (not test_batch_data[0]) to support + # DataloaderGroup where test_batch_data is a DataLoader self.test_random_project = random_project( grad_t, proj_params=RandomProjectionParams( - feature_batch_size=test_batch_data[0].shape[0], + feature_batch_size=grad_t.shape[0], device=self.device, **self.proj_params.model_dump(), ), diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index c569c212b..d572fa8f7 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -676,7 +676,6 @@ def target_func(params, data): task=task, weight_list=torch.ones(len(checkpoint_list)), normalized_grad=False, - projector_kwargs=None, device="cpu", )