From 2c8a3eb569e133d67263068cab2b5e76eccd3450 Mon Sep 17 00:00:00 2001 From: Mingtao Xian Date: Sun, 1 Mar 2026 15:58:52 +0800 Subject: [PATCH 1/3] add cache support for TracInAttributor --- dattri/algorithm/tracin.py | 250 ++++++++++++++---- .../tracin_noisy_label.py | 2 +- experiments/benchmark_result.py | 2 +- experiments/benchmark_result_mt.py | 2 +- experiments/benchmark_result_nanogpt.py | 2 +- experiments/gpt2_wikitext/score_TRAK.py | 2 +- test/dattri/algorithm/test_tracin.py | 35 ++- 7 files changed, 236 insertions(+), 59 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 425849782..a3e9f8574 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -67,17 +67,88 @@ def __init__( self.layer_name = layer_name self.device = device self.full_train_dataloader = None + self._cached_train_grads = [] # to get per-sample gradients for a mini-batch of train/test samples self.grad_target_func = self.task.get_grad_target_func(in_dims=(None, 0)) self.grad_loss_func = self.task.get_grad_loss_func(in_dims=(None, 0)) - def cache(self) -> None: - """Precompute and cache some values for efficiency.""" + def cache( + self, + full_train_dataloader: torch.utils.data.DataLoader, + ) -> None: + """Cache the dataset for gradient calculation. + + Args: + full_train_dataloader (torch.utils.data.DataLoader): The dataloader + with full training samples for gradient calculation. + + Raises: + ValueError: If the length of checkpoints and weight list don't match. + """ + _check_shuffle(full_train_dataloader) + self.full_train_dataloader = full_train_dataloader + self._cached_train_grads = [] + + # check the length match between checkpoint list and weight list + if len(self.task.get_checkpoints()) != len(self.weight_list): + msg = "the length of checkpoints and weights lists don't match." + raise ValueError(msg) + + for ckpt_idx in range(len(self.task.get_checkpoints())): + parameters, _ = self.task.get_param( + ckpt_idx=ckpt_idx, + layer_name=self.layer_name, + ) - def attribute( # noqa: PLR0912 + if self.layer_name is not None: + self.grad_target_func = self.task.get_grad_target_func( + in_dims=(None, 0), + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, + ) + self.grad_loss_func = self.task.get_grad_loss_func( + in_dims=(None, 0), + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, + ) + + full_train_grad_list = [] + for train_batch_data_ in tqdm( + full_train_dataloader, + desc="calculating gradient of training set...", + leave=False, + ): + # move to device + if isinstance(train_batch_data_, (tuple, list)): + train_batch_data = tuple( + data.to(self.device) for data in train_batch_data_ + ) + else: + train_batch_data = train_batch_data_ + # get gradient of train + grad_t = self.grad_loss_func(parameters, train_batch_data) + if self.projector_kwargs is not None: + # define the projector for this batch of data + self.train_random_project = random_project( + grad_t, + train_batch_data[0].shape[0], + **self.projector_kwargs, + ) + # param index as ensemble id + train_batch_grad = self.train_random_project( + torch.nan_to_num(grad_t), + ensemble_id=ckpt_idx, + ) + else: + train_batch_grad = torch.nan_to_num(grad_t) + full_train_grad_list.append(train_batch_grad.clone().detach()) + # Concatenate all batches + self._cached_train_grads.append(torch.cat(full_train_grad_list, dim=0)) + + def attribute( # noqa: PLR0912, PLR0915 self, - train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, + train_dataloader: Optional[torch.utils.data.DataLoader] = None, ) -> Tensor: """Calculate the influence of the training set on the test set. @@ -93,14 +164,30 @@ def attribute( # noqa: PLR0912 Raises: ValueError: The length of params_list and weight_list don't match. + ValueError: If the train_dataloader is not None and the full training + dataloader is cached or no train_loader is provided in both cases. Returns: Tensor: The influence of the training set on the test set, with the shape of (num_train_samples, num_test_samples). """ _check_shuffle(test_dataloader) - _check_shuffle(train_dataloader) - + if train_dataloader is not None: + _check_shuffle(train_dataloader) + + if train_dataloader is not None and self.full_train_dataloader is not None: + message = "You have cached a training loader by .cache()\ + and you are trying to attribute a different training loader.\ + If this new training loader is a subset of the cached training\ + loader, please don't input the training dataloader in\ + .attribute() and directly use index to select the corresponding\ + scores." + raise ValueError(message) + if train_dataloader is None and self.full_train_dataloader is None: + message = "You did not state a training loader in .attribute() and you\ + did not cache a training loader by .cache(). Please provide a\ + training loader or cache a training loader." + raise ValueError(message) # check the length match between checkpoint list and weight list if len(self.task.get_checkpoints()) != len(self.weight_list): msg = "the length of checkpoints and weights lists don't match." @@ -109,7 +196,10 @@ def attribute( # noqa: PLR0912 # placeholder for the TDA result # should work for torch dataset without sampler tda_output = torch.zeros( - size=(len(train_dataloader.sampler), len(test_dataloader.sampler)), + size=( + len((train_dataloader or self.full_train_dataloader).sampler), + len(test_dataloader.sampler), + ), ) # iterate over each checkpoint (each ensemble) @@ -134,38 +224,102 @@ def attribute( # noqa: PLR0912 ckpt_idx=ckpt_idx, ) - for train_batch_idx, train_batch_data_ in enumerate( - tqdm( - train_dataloader, - desc="calculating gradient of training set...", - leave=False, - ), - ): - # move to device - if isinstance(train_batch_data_, (tuple, list)): - train_batch_data = tuple( - x.to(self.device) for x in train_batch_data_ - ) - else: - train_batch_data = train_batch_data_ - # get gradient of train - grad_t = self.grad_loss_func(parameters, train_batch_data) - if self.projector_kwargs is not None: - # define the projector for this batch of data - self.train_random_project = random_project( - grad_t, - # get the batch size, prevent edge case - train_batch_data[0].shape[0], - **self.projector_kwargs, - ) - # param index as ensemble id - train_batch_grad = self.train_random_project( - torch.nan_to_num(grad_t), - ensemble_id=ckpt_idx, - ) - else: - train_batch_grad = torch.nan_to_num(grad_t) + if train_dataloader is not None: + for train_batch_idx, train_batch_data_ in enumerate( + tqdm( + train_dataloader, + desc="calculating gradient of training set...", + leave=False, + ), + ): + # move to device + if isinstance(train_batch_data_, (tuple, list)): + train_batch_data = tuple( + x.to(self.device) for x in train_batch_data_ + ) + else: + train_batch_data = train_batch_data_ + # get gradient of train + grad_t = self.grad_loss_func(parameters, train_batch_data) + if self.projector_kwargs is not None: + # define the projector for this batch of data + self.train_random_project = random_project( + grad_t, + # get the batch size, prevent edge case + train_batch_data[0].shape[0], + **self.projector_kwargs, + ) + # param index as ensemble id + train_batch_grad = self.train_random_project( + torch.nan_to_num(grad_t), + ensemble_id=ckpt_idx, + ) + else: + train_batch_grad = torch.nan_to_num(grad_t) + + for test_batch_idx, test_batch_data_ in enumerate( + tqdm( + test_dataloader, + desc="calculating gradient of test set...", + leave=False, + ), + ): + # move to device + if isinstance(test_batch_data_, (tuple, list)): + test_batch_data = tuple( + x.to(self.device) for x in test_batch_data_ + ) + else: + test_batch_data = test_batch_data_ + # get gradient of test + grad_t = self.grad_target_func(parameters, test_batch_data) + if self.projector_kwargs is not None: + # define the projector for this batch of data + self.test_random_project = random_project( + grad_t, + test_batch_data[0].shape[0], + **self.projector_kwargs, + ) + test_batch_grad = self.test_random_project( + torch.nan_to_num(grad_t), + ensemble_id=ckpt_idx, + ) + else: + test_batch_grad = torch.nan_to_num(grad_t) + + # results position based on batch info + row_st = train_batch_idx * train_dataloader.batch_size + row_ed = min( + (train_batch_idx + 1) * train_dataloader.batch_size, + len(train_dataloader.sampler), + ) + + col_st = test_batch_idx * test_dataloader.batch_size + col_ed = min( + (test_batch_idx + 1) * test_dataloader.batch_size, + len(test_dataloader.sampler), + ) + # accumulate the TDA score in corresponding positions (blocks) + if self.normalized_grad: + tda_output[row_st:row_ed, col_st:col_ed] += ( + ( + normalize(train_batch_grad) + @ normalize(test_batch_grad).T + * ckpt_weight + ) + .detach() + .cpu() + ) + else: + tda_output[row_st:row_ed, col_st:col_ed] += ( + (train_batch_grad @ test_batch_grad.T * ckpt_weight) + .detach() + .cpu() + ) + + else: + # use the cached training gradients for test_batch_idx, test_batch_data_ in enumerate( tqdm( test_dataloader, @@ -189,7 +343,6 @@ def attribute( # noqa: PLR0912 test_batch_data[0].shape[0], **self.projector_kwargs, ) - test_batch_grad = self.test_random_project( torch.nan_to_num(grad_t), ensemble_id=ckpt_idx, @@ -198,22 +351,17 @@ def attribute( # noqa: PLR0912 test_batch_grad = torch.nan_to_num(grad_t) # results position based on batch info - row_st = train_batch_idx * train_dataloader.batch_size - row_ed = min( - (train_batch_idx + 1) * train_dataloader.batch_size, - len(train_dataloader.sampler), - ) - col_st = test_batch_idx * test_dataloader.batch_size col_ed = min( (test_batch_idx + 1) * test_dataloader.batch_size, len(test_dataloader.sampler), ) + # accumulate the TDA score in corresponding positions (blocks) if self.normalized_grad: - tda_output[row_st:row_ed, col_st:col_ed] += ( + tda_output[:, col_st:col_ed] += ( ( - normalize(train_batch_grad) + normalize(self._cached_train_grads[ckpt_idx]) @ normalize(test_batch_grad).T * ckpt_weight ) @@ -221,8 +369,12 @@ def attribute( # noqa: PLR0912 .cpu() ) else: - tda_output[row_st:row_ed, col_st:col_ed] += ( - (train_batch_grad @ test_batch_grad.T * ckpt_weight) + tda_output[:, col_st:col_ed] += ( + ( + self._cached_train_grads[ckpt_idx] + @ test_batch_grad.T + * ckpt_weight + ) .detach() .cpu() ) diff --git a/examples/noisy_label_detection/tracin_noisy_label.py b/examples/noisy_label_detection/tracin_noisy_label.py index a6c5d3219..948c3d360 100644 --- a/examples/noisy_label_detection/tracin_noisy_label.py +++ b/examples/noisy_label_detection/tracin_noisy_label.py @@ -64,7 +64,7 @@ def f(params, image_label_pair): ) with torch.no_grad(): - score = attributor.attribute(train_loader, test_loader).diag() + score = attributor.attribute(test_loader, train_loader).diag() _, indices = torch.sort(-score) cr = 0 diff --git a/experiments/benchmark_result.py b/experiments/benchmark_result.py index a84724c7d..d30dcf920 100644 --- a/experiments/benchmark_result.py +++ b/experiments/benchmark_result.py @@ -280,7 +280,7 @@ def loss_rps(pre_activation_list, label_list): device=args.device, ) with torch.no_grad(): - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) # compute metrics metrics_score = METRICS_DICT[args.metric](score, groundtruth)[0] diff --git a/experiments/benchmark_result_mt.py b/experiments/benchmark_result_mt.py index 000e23a57..75bea4d59 100644 --- a/experiments/benchmark_result_mt.py +++ b/experiments/benchmark_result_mt.py @@ -150,7 +150,7 @@ def loss_tracin(params, data_target_pair): device=args.device, ) with torch.no_grad(): - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) best_result = 0 diff --git a/experiments/benchmark_result_nanogpt.py b/experiments/benchmark_result_nanogpt.py index 4abcde8e6..9b4c93f8d 100644 --- a/experiments/benchmark_result_nanogpt.py +++ b/experiments/benchmark_result_nanogpt.py @@ -154,7 +154,7 @@ def loss_tracin(params, data_target_pair): device=args.device, ) with torch.no_grad(): - score = attributor.attribute(train_loader, val_loader) + score = attributor.attribute(val_loader, train_loader) best_result = 0 best_config = None diff --git a/experiments/gpt2_wikitext/score_TRAK.py b/experiments/gpt2_wikitext/score_TRAK.py index 8f0c1cbbe..42e5d8517 100644 --- a/experiments/gpt2_wikitext/score_TRAK.py +++ b/experiments/gpt2_wikitext/score_TRAK.py @@ -757,7 +757,7 @@ def checkpoints_load_func(model, checkpoint_path): attributor.cache(train_dataloader) score = attributor.attribute(eval_dataloader) else: - score = attributor.attribute(train_dataloader, eval_dataloader) + score = attributor.attribute(eval_dataloader, train_dataloader) torch.save(score, "score_TRAK.pt") logger.info("Attribution scores saved to score_TRAK.pt") diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index e6dada3cb..974a2a824 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -73,7 +73,7 @@ def f(params, image_label_pair): "device": pytest_device, } - # test with projector list + # test with projector list, without cache attributor = TracInAttributor( task=task, weight_list=torch.ones(len(checkpoint_list)), @@ -83,12 +83,24 @@ def f(params, image_label_pair): ) # Original test - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) assert score.shape == (len(train_loader.dataset), len(test_loader.dataset)) assert torch.count_nonzero(score) == len(train_loader.dataset) * len( test_loader.dataset, ) + # test with projector list, with cache + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + projector_kwargs=projector_kwargs, + device=torch.device(pytest_device), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + assert torch.allclose(score, score2) + shutil.rmtree(path) def test_tracin(self): @@ -135,19 +147,32 @@ def f(params, image_label_pair): ) pytest_device = "cpu" - # test with no projector list + # test with no projector list, without cache attributor = TracInAttributor( task=task, weight_list=torch.ones(len(checkpoint_list)), normalized_grad=True, device=torch.device(pytest_device), ) - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) assert score.shape == (len(train_loader.dataset), len(test_loader.dataset)) assert torch.count_nonzero(score) == len(train_loader.dataset) * len( test_loader.dataset, ) + # test with no projector, with cache + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + device=torch.device(pytest_device), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + score3 = attributor.attribute(test_loader) + assert torch.allclose(score, score2) + assert torch.allclose(score2, score3) + shutil.rmtree(path) def test_tracin_self_attribute(self): @@ -349,7 +374,7 @@ def f(params, dict_batch): device=torch.device(pytest_device), ) - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) assert score.shape == (len(train_loader.dataset), len(test_loader.dataset)) assert torch.count_nonzero(score) == len(train_loader.dataset) * len( From 6df20710fa9f84e017862cedf728e574d15ea6a0 Mon Sep 17 00:00:00 2001 From: Mingtao Xian Date: Mon, 2 Mar 2026 00:05:17 +0800 Subject: [PATCH 2/3] add offload support for TracInAttributor cache --- dattri/algorithm/tracin.py | 151 ++++++++++++++++++--------- test/dattri/algorithm/test_tracin.py | 72 ++++++++++++- 2 files changed, 167 insertions(+), 56 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index a3e9f8574..204b8f8af 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict if TYPE_CHECKING: - from typing import List, Optional, Union + from typing import List, Literal, Optional, Union from dattri.task import AttributionTask @@ -14,6 +14,7 @@ from torch.nn.functional import normalize from tqdm import tqdm +from dattri.algorithm.block_projected_if.offload import create_offload_manager from dattri.func.projection import random_project from .base import BaseAttributor @@ -38,6 +39,9 @@ def __init__( projector_kwargs: Optional[Dict[str, Any]] = None, layer_name: Optional[Union[str, List[str]]] = None, device: str = "cpu", + offload: Literal["none", "cpu", "disk"] = "none", + cache_dir: Optional[str] = None, + chunk_size: int = 16, ) -> None: """Initialize the TracIn attributor. @@ -56,6 +60,12 @@ def __init__( if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. device (str): The device to run the attributor. Default is cpu. + offload: Memory management strategy ("none", "cpu", "disk"), stating + the place to offload the gradients. + "cpu": stores gradients on CPU and moves to device when needed. + "disk": stores gradients on disk and moves to device when needed. + cache_dir: Directory for caching (required when offload="disk"). + chunk_size: Chunk size for processing in disk offload. """ self.task = task self.weight_list = weight_list @@ -66,8 +76,11 @@ def __init__( self.normalized_grad = normalized_grad self.layer_name = layer_name self.device = device + self.offload = offload + self.cache_dir = cache_dir + self.chunk_size = chunk_size self.full_train_dataloader = None - self._cached_train_grads = [] + self._offload_managers = [] # to get per-sample gradients for a mini-batch of train/test samples self.grad_target_func = self.task.get_grad_target_func(in_dims=(None, 0)) self.grad_loss_func = self.task.get_grad_loss_func(in_dims=(None, 0)) @@ -87,13 +100,25 @@ def cache( """ _check_shuffle(full_train_dataloader) self.full_train_dataloader = full_train_dataloader - self._cached_train_grads = [] # check the length match between checkpoint list and weight list if len(self.task.get_checkpoints()) != len(self.weight_list): msg = "the length of checkpoints and weights lists don't match." raise ValueError(msg) + # Initialize offload managers (one per checkpoint) + self._offload_managers = [] + layer_names = ["grad"] # Dummy layer name for API compatibility + for _ in range(len(self.task.get_checkpoints())): + offloader = create_offload_manager( + offload_type=self.offload, + device=self.device, + layer_names=layer_names, + cache_dir=self.cache_dir, + chunk_size=self.chunk_size, + ) + self._offload_managers.append(offloader) + for ckpt_idx in range(len(self.task.get_checkpoints())): parameters, _ = self.task.get_param( ckpt_idx=ckpt_idx, @@ -112,11 +137,12 @@ def cache( ckpt_idx=ckpt_idx, ) - full_train_grad_list = [] - for train_batch_data_ in tqdm( - full_train_dataloader, - desc="calculating gradient of training set...", - leave=False, + for batch_idx, train_batch_data_ in enumerate( + tqdm( + full_train_dataloader, + desc="calculating gradient of training set...", + leave=False, + ), ): # move to device if isinstance(train_batch_data_, (tuple, list)): @@ -125,7 +151,6 @@ def cache( ) else: train_batch_data = train_batch_data_ - # get gradient of train grad_t = self.grad_loss_func(parameters, train_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data @@ -141,9 +166,12 @@ def cache( ) else: train_batch_grad = torch.nan_to_num(grad_t) - full_train_grad_list.append(train_batch_grad.clone().detach()) - # Concatenate all batches - self._cached_train_grads.append(torch.cat(full_train_grad_list, dim=0)) + # Store using offload manager (wrap as list for API compatibility) + self._offload_managers[ckpt_idx].store_gradients( + batch_idx, + [train_batch_grad.clone().detach()], + is_test=False, + ) def attribute( # noqa: PLR0912, PLR0915 self, @@ -176,18 +204,18 @@ def attribute( # noqa: PLR0912, PLR0915 _check_shuffle(train_dataloader) if train_dataloader is not None and self.full_train_dataloader is not None: - message = "You have cached a training loader by .cache()\ + msg = "You have cached a training loader by .cache()\ and you are trying to attribute a different training loader.\ If this new training loader is a subset of the cached training\ loader, please don't input the training dataloader in\ .attribute() and directly use index to select the corresponding\ scores." - raise ValueError(message) + raise ValueError(msg) if train_dataloader is None and self.full_train_dataloader is None: - message = "You did not state a training loader in .attribute() and you\ + msg = "You did not state a training loader in .attribute() and you\ did not cache a training loader by .cache(). Please provide a\ training loader or cache a training loader." - raise ValueError(message) + raise ValueError(msg) # check the length match between checkpoint list and weight list if len(self.task.get_checkpoints()) != len(self.weight_list): msg = "the length of checkpoints and weights lists don't match." @@ -201,6 +229,8 @@ def attribute( # noqa: PLR0912, PLR0915 len(test_dataloader.sampler), ), ) + # use normalize or identity depending on config, + norm = normalize if self.normalized_grad else lambda x: x # iterate over each checkpoint (each ensemble) for ckpt_idx, ckpt_weight in zip( @@ -239,7 +269,6 @@ def attribute( # noqa: PLR0912, PLR0915 ) else: train_batch_data = train_batch_data_ - # get gradient of train grad_t = self.grad_loss_func(parameters, train_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data @@ -271,7 +300,6 @@ def attribute( # noqa: PLR0912, PLR0915 ) else: test_batch_data = test_batch_data_ - # get gradient of test grad_t = self.grad_target_func(parameters, test_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data @@ -301,25 +329,32 @@ def attribute( # noqa: PLR0912, PLR0915 len(test_dataloader.sampler), ) # accumulate the TDA score in corresponding positions (blocks) - if self.normalized_grad: - tda_output[row_st:row_ed, col_st:col_ed] += ( - ( - normalize(train_batch_grad) - @ normalize(test_batch_grad).T - * ckpt_weight - ) - .detach() - .cpu() - ) - else: - tda_output[row_st:row_ed, col_st:col_ed] += ( - (train_batch_grad @ test_batch_grad.T * ckpt_weight) - .detach() - .cpu() + tda_output[row_st:row_ed, col_st:col_ed] += ( + ( + norm(train_batch_grad) + @ norm(test_batch_grad).T + * ckpt_weight ) + .detach() + .cpu() + ) else: - # use the cached training gradients + # For "none" mode: concat all cached grads into one tensor + # for a single efficient matmul per test batch + if self.offload == "none": + all_train_grads = torch.cat( + [ + self._offload_managers[ckpt_idx].retrieve_gradients( + i, + is_test=False, + )[0] + for i in range(len(self.full_train_dataloader)) + ], + dim=0, + ) + all_train_grads = norm(all_train_grads) + for test_batch_idx, test_batch_data_ in enumerate( tqdm( test_dataloader, @@ -334,7 +369,6 @@ def attribute( # noqa: PLR0912, PLR0915 ) else: test_batch_data = test_batch_data_ - # get gradient of test grad_t = self.grad_target_func(parameters, test_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data @@ -350,34 +384,48 @@ def attribute( # noqa: PLR0912, PLR0915 else: test_batch_grad = torch.nan_to_num(grad_t) - # results position based on batch info col_st = test_batch_idx * test_dataloader.batch_size col_ed = min( (test_batch_idx + 1) * test_dataloader.batch_size, len(test_dataloader.sampler), ) - # accumulate the TDA score in corresponding positions (blocks) - if self.normalized_grad: + if self.offload == "none": + # all train grads are already in memory and pre-normalized above tda_output[:, col_st:col_ed] += ( - ( - normalize(self._cached_train_grads[ckpt_idx]) - @ normalize(test_batch_grad).T - * ckpt_weight - ) + (all_train_grads @ norm(test_batch_grad).T * ckpt_weight) .detach() .cpu() ) else: - tda_output[:, col_st:col_ed] += ( - ( - self._cached_train_grads[ckpt_idx] - @ test_batch_grad.T - * ckpt_weight + # For cpu/disk offload: retrieve train grads batch by batch + # to keep memory footprint low + for train_batch_idx in range(len(self.full_train_dataloader)): + train_batch_grad = self._offload_managers[ + ckpt_idx + ].retrieve_gradients( + train_batch_idx, + is_test=False, + )[0] + + row_st = ( + train_batch_idx * self.full_train_dataloader.batch_size + ) + row_ed = min( + (train_batch_idx + 1) + * self.full_train_dataloader.batch_size, + len(self.full_train_dataloader.sampler), + ) + + tda_output[row_st:row_ed, col_st:col_ed] += ( + ( + norm(train_batch_grad) + @ norm(test_batch_grad).T + * ckpt_weight + ) + .detach() + .cpu() ) - .detach() - .cpu() - ) return tda_output @@ -448,7 +496,6 @@ def self_attribute( train_batch_data = tuple( data.to(self.device) for data in train_batch_data_ ) - # get gradient of train grad_t = self.grad_loss_func(parameters, train_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index 974a2a824..7ce8a679f 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -19,7 +19,7 @@ class TestTracInAttributor: """Test for TracIn.""" - def test_tracin_proj(self): + def test_tracin_proj(self): # noqa: PLR0914 """Test for TracIn with projectors.""" train_dataset = TensorDataset( torch.randn(20, 1, 28, 28), @@ -99,11 +99,42 @@ def f(params, image_label_pair): ) attributor.cache(train_loader) score2 = attributor.attribute(test_loader) - assert torch.allclose(score, score2) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + + # test with projector list, with offload(cpu) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + projector_kwargs=projector_kwargs, + device=torch.device(pytest_device), + offload="cpu", + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + + # test with projector, with offload(disk) + cache_path = Path("./cache") + if not cache_path.exists(): + cache_path.mkdir(parents=True) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + projector_kwargs=projector_kwargs, + device=torch.device(pytest_device), + offload="disk", + cache_dir=str(cache_path), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) shutil.rmtree(path) + shutil.rmtree(cache_path) - def test_tracin(self): + def test_tracin(self): # noqa: PLR0914 """Test for TracIn without projectors.""" train_dataset = TensorDataset( torch.randn(20, 1, 28, 28), @@ -170,10 +201,43 @@ def f(params, image_label_pair): attributor.cache(train_loader) score2 = attributor.attribute(test_loader) score3 = attributor.attribute(test_loader) - assert torch.allclose(score, score2) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + assert torch.allclose(score2, score3) + + # test with no projector, with offload(cpu) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + device=torch.device(pytest_device), + offload="cpu", + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + score3 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + assert torch.allclose(score2, score3) + + # test with no projector, with offload(disk) + cache_path = Path("./cache") + if not cache_path.exists(): + cache_path.mkdir(parents=True) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + device=torch.device(pytest_device), + offload="disk", + cache_dir=str(cache_path), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + score3 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) assert torch.allclose(score2, score3) shutil.rmtree(path) + shutil.rmtree(cache_path) def test_tracin_self_attribute(self): """Test for self_attribute in TracIn without projectors.""" From a4e04f0e0e3af05249110773679210cbb59a8ecf Mon Sep 17 00:00:00 2001 From: Mingtao Xian Date: Mon, 2 Mar 2026 18:50:05 +0800 Subject: [PATCH 3/3] fix lint --- dattri/algorithm/tracin.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 204b8f8af..c6f492bf0 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -190,14 +190,14 @@ def attribute( # noqa: PLR0912, PLR0915 test samples to calculate the influence. The dataloader should not be shuffled. + Returns: + Tensor: The influence of the training set on the test set, with + the shape of (num_train_samples, num_test_samples). + Raises: ValueError: The length of params_list and weight_list don't match. ValueError: If the train_dataloader is not None and the full training dataloader is cached or no train_loader is provided in both cases. - - Returns: - Tensor: The influence of the training set on the test set, with - the shape of (num_train_samples, num_test_samples). """ _check_shuffle(test_dataloader) if train_dataloader is not None: @@ -442,12 +442,12 @@ def self_attribute( means that only a part of the training set's influence is calculated. The dataloader should not be shuffled. - Raises: - ValueError: The length of params_list and weight_list don't match. - Returns: Tensor: The influence of the training set on itself, with the shape of (num_train_samples,). + + Raises: + ValueError: The length of params_list and weight_list don't match. """ test_dataloader = train_dataloader _check_shuffle(test_dataloader)