diff --git a/dattri/algorithm/tracin.py b/dattri/algorithm/tracin.py index 425849782..c6f492bf0 100644 --- a/dattri/algorithm/tracin.py +++ b/dattri/algorithm/tracin.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict if TYPE_CHECKING: - from typing import List, Optional, Union + from typing import List, Literal, Optional, Union from dattri.task import AttributionTask @@ -14,6 +14,7 @@ from torch.nn.functional import normalize from tqdm import tqdm +from dattri.algorithm.block_projected_if.offload import create_offload_manager from dattri.func.projection import random_project from .base import BaseAttributor @@ -38,6 +39,9 @@ def __init__( projector_kwargs: Optional[Dict[str, Any]] = None, layer_name: Optional[Union[str, List[str]]] = None, device: str = "cpu", + offload: Literal["none", "cpu", "disk"] = "none", + cache_dir: Optional[str] = None, + chunk_size: int = 16, ) -> None: """Initialize the TracIn attributor. @@ -56,6 +60,12 @@ def __init__( if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. device (str): The device to run the attributor. Default is cpu. + offload: Memory management strategy ("none", "cpu", "disk"), stating + the place to offload the gradients. + "cpu": stores gradients on CPU and moves to device when needed. + "disk": stores gradients on disk and moves to device when needed. + cache_dir: Directory for caching (required when offload="disk"). + chunk_size: Chunk size for processing in disk offload. """ self.task = task self.weight_list = weight_list @@ -66,57 +76,50 @@ def __init__( self.normalized_grad = normalized_grad self.layer_name = layer_name self.device = device + self.offload = offload + self.cache_dir = cache_dir + self.chunk_size = chunk_size self.full_train_dataloader = None + self._offload_managers = [] # to get per-sample gradients for a mini-batch of train/test samples self.grad_target_func = self.task.get_grad_target_func(in_dims=(None, 0)) self.grad_loss_func = self.task.get_grad_loss_func(in_dims=(None, 0)) - def cache(self) -> None: - """Precompute and cache some values for efficiency.""" - - def attribute( # noqa: PLR0912 + def cache( self, - train_dataloader: torch.utils.data.DataLoader, - test_dataloader: torch.utils.data.DataLoader, - ) -> Tensor: - """Calculate the influence of the training set on the test set. + full_train_dataloader: torch.utils.data.DataLoader, + ) -> None: + """Cache the dataset for gradient calculation. Args: - train_dataloader (torch.utils.data.DataLoader): The dataloader for - training samples to calculate the influence. It can be a subset - of the full training set if `cache` is called before. A subset - means that only a part of the training set's influence is calculated. - The dataloader should not be shuffled. - test_dataloader (torch.utils.data.DataLoader): The dataloader for - test samples to calculate the influence. The dataloader should not - be shuffled. + full_train_dataloader (torch.utils.data.DataLoader): The dataloader + with full training samples for gradient calculation. Raises: - ValueError: The length of params_list and weight_list don't match. - - Returns: - Tensor: The influence of the training set on the test set, with - the shape of (num_train_samples, num_test_samples). + ValueError: If the length of checkpoints and weight list don't match. """ - _check_shuffle(test_dataloader) - _check_shuffle(train_dataloader) + _check_shuffle(full_train_dataloader) + self.full_train_dataloader = full_train_dataloader # check the length match between checkpoint list and weight list if len(self.task.get_checkpoints()) != len(self.weight_list): msg = "the length of checkpoints and weights lists don't match." raise ValueError(msg) - # placeholder for the TDA result - # should work for torch dataset without sampler - tda_output = torch.zeros( - size=(len(train_dataloader.sampler), len(test_dataloader.sampler)), - ) + # Initialize offload managers (one per checkpoint) + self._offload_managers = [] + layer_names = ["grad"] # Dummy layer name for API compatibility + for _ in range(len(self.task.get_checkpoints())): + offloader = create_offload_manager( + offload_type=self.offload, + device=self.device, + layer_names=layer_names, + cache_dir=self.cache_dir, + chunk_size=self.chunk_size, + ) + self._offload_managers.append(offloader) - # iterate over each checkpoint (each ensemble) - for ckpt_idx, ckpt_weight in zip( - range(len(self.task.get_checkpoints())), - self.weight_list, - ): + for ckpt_idx in range(len(self.task.get_checkpoints())): parameters, _ = self.task.get_param( ckpt_idx=ckpt_idx, layer_name=self.layer_name, @@ -134,9 +137,9 @@ def attribute( # noqa: PLR0912 ckpt_idx=ckpt_idx, ) - for train_batch_idx, train_batch_data_ in enumerate( + for batch_idx, train_batch_data_ in enumerate( tqdm( - train_dataloader, + full_train_dataloader, desc="calculating gradient of training set...", leave=False, ), @@ -144,17 +147,15 @@ def attribute( # noqa: PLR0912 # move to device if isinstance(train_batch_data_, (tuple, list)): train_batch_data = tuple( - x.to(self.device) for x in train_batch_data_ + data.to(self.device) for data in train_batch_data_ ) else: train_batch_data = train_batch_data_ - # get gradient of train grad_t = self.grad_loss_func(parameters, train_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data self.train_random_project = random_project( grad_t, - # get the batch size, prevent edge case train_batch_data[0].shape[0], **self.projector_kwargs, ) @@ -165,6 +166,194 @@ def attribute( # noqa: PLR0912 ) else: train_batch_grad = torch.nan_to_num(grad_t) + # Store using offload manager (wrap as list for API compatibility) + self._offload_managers[ckpt_idx].store_gradients( + batch_idx, + [train_batch_grad.clone().detach()], + is_test=False, + ) + + def attribute( # noqa: PLR0912, PLR0915 + self, + test_dataloader: torch.utils.data.DataLoader, + train_dataloader: Optional[torch.utils.data.DataLoader] = None, + ) -> Tensor: + """Calculate the influence of the training set on the test set. + + Args: + train_dataloader (torch.utils.data.DataLoader): The dataloader for + training samples to calculate the influence. It can be a subset + of the full training set if `cache` is called before. A subset + means that only a part of the training set's influence is calculated. + The dataloader should not be shuffled. + test_dataloader (torch.utils.data.DataLoader): The dataloader for + test samples to calculate the influence. The dataloader should not + be shuffled. + + Returns: + Tensor: The influence of the training set on the test set, with + the shape of (num_train_samples, num_test_samples). + + Raises: + ValueError: The length of params_list and weight_list don't match. + ValueError: If the train_dataloader is not None and the full training + dataloader is cached or no train_loader is provided in both cases. + """ + _check_shuffle(test_dataloader) + if train_dataloader is not None: + _check_shuffle(train_dataloader) + + if train_dataloader is not None and self.full_train_dataloader is not None: + msg = "You have cached a training loader by .cache()\ + and you are trying to attribute a different training loader.\ + If this new training loader is a subset of the cached training\ + loader, please don't input the training dataloader in\ + .attribute() and directly use index to select the corresponding\ + scores." + raise ValueError(msg) + if train_dataloader is None and self.full_train_dataloader is None: + msg = "You did not state a training loader in .attribute() and you\ + did not cache a training loader by .cache(). Please provide a\ + training loader or cache a training loader." + raise ValueError(msg) + # check the length match between checkpoint list and weight list + if len(self.task.get_checkpoints()) != len(self.weight_list): + msg = "the length of checkpoints and weights lists don't match." + raise ValueError(msg) + + # placeholder for the TDA result + # should work for torch dataset without sampler + tda_output = torch.zeros( + size=( + len((train_dataloader or self.full_train_dataloader).sampler), + len(test_dataloader.sampler), + ), + ) + # use normalize or identity depending on config, + norm = normalize if self.normalized_grad else lambda x: x + + # iterate over each checkpoint (each ensemble) + for ckpt_idx, ckpt_weight in zip( + range(len(self.task.get_checkpoints())), + self.weight_list, + ): + parameters, _ = self.task.get_param( + ckpt_idx=ckpt_idx, + layer_name=self.layer_name, + ) + + if self.layer_name is not None: + self.grad_target_func = self.task.get_grad_target_func( + in_dims=(None, 0), + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, + ) + self.grad_loss_func = self.task.get_grad_loss_func( + in_dims=(None, 0), + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, + ) + + if train_dataloader is not None: + for train_batch_idx, train_batch_data_ in enumerate( + tqdm( + train_dataloader, + desc="calculating gradient of training set...", + leave=False, + ), + ): + # move to device + if isinstance(train_batch_data_, (tuple, list)): + train_batch_data = tuple( + x.to(self.device) for x in train_batch_data_ + ) + else: + train_batch_data = train_batch_data_ + grad_t = self.grad_loss_func(parameters, train_batch_data) + if self.projector_kwargs is not None: + # define the projector for this batch of data + self.train_random_project = random_project( + grad_t, + # get the batch size, prevent edge case + train_batch_data[0].shape[0], + **self.projector_kwargs, + ) + # param index as ensemble id + train_batch_grad = self.train_random_project( + torch.nan_to_num(grad_t), + ensemble_id=ckpt_idx, + ) + else: + train_batch_grad = torch.nan_to_num(grad_t) + + for test_batch_idx, test_batch_data_ in enumerate( + tqdm( + test_dataloader, + desc="calculating gradient of test set...", + leave=False, + ), + ): + # move to device + if isinstance(test_batch_data_, (tuple, list)): + test_batch_data = tuple( + x.to(self.device) for x in test_batch_data_ + ) + else: + test_batch_data = test_batch_data_ + grad_t = self.grad_target_func(parameters, test_batch_data) + if self.projector_kwargs is not None: + # define the projector for this batch of data + self.test_random_project = random_project( + grad_t, + test_batch_data[0].shape[0], + **self.projector_kwargs, + ) + + test_batch_grad = self.test_random_project( + torch.nan_to_num(grad_t), + ensemble_id=ckpt_idx, + ) + else: + test_batch_grad = torch.nan_to_num(grad_t) + + # results position based on batch info + row_st = train_batch_idx * train_dataloader.batch_size + row_ed = min( + (train_batch_idx + 1) * train_dataloader.batch_size, + len(train_dataloader.sampler), + ) + + col_st = test_batch_idx * test_dataloader.batch_size + col_ed = min( + (test_batch_idx + 1) * test_dataloader.batch_size, + len(test_dataloader.sampler), + ) + # accumulate the TDA score in corresponding positions (blocks) + tda_output[row_st:row_ed, col_st:col_ed] += ( + ( + norm(train_batch_grad) + @ norm(test_batch_grad).T + * ckpt_weight + ) + .detach() + .cpu() + ) + + else: + # For "none" mode: concat all cached grads into one tensor + # for a single efficient matmul per test batch + if self.offload == "none": + all_train_grads = torch.cat( + [ + self._offload_managers[ckpt_idx].retrieve_gradients( + i, + is_test=False, + )[0] + for i in range(len(self.full_train_dataloader)) + ], + dim=0, + ) + all_train_grads = norm(all_train_grads) for test_batch_idx, test_batch_data_ in enumerate( tqdm( @@ -180,7 +369,6 @@ def attribute( # noqa: PLR0912 ) else: test_batch_data = test_batch_data_ - # get gradient of test grad_t = self.grad_target_func(parameters, test_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data @@ -189,7 +377,6 @@ def attribute( # noqa: PLR0912 test_batch_data[0].shape[0], **self.projector_kwargs, ) - test_batch_grad = self.test_random_project( torch.nan_to_num(grad_t), ensemble_id=ckpt_idx, @@ -197,35 +384,48 @@ def attribute( # noqa: PLR0912 else: test_batch_grad = torch.nan_to_num(grad_t) - # results position based on batch info - row_st = train_batch_idx * train_dataloader.batch_size - row_ed = min( - (train_batch_idx + 1) * train_dataloader.batch_size, - len(train_dataloader.sampler), - ) - col_st = test_batch_idx * test_dataloader.batch_size col_ed = min( (test_batch_idx + 1) * test_dataloader.batch_size, len(test_dataloader.sampler), ) - # accumulate the TDA score in corresponding positions (blocks) - if self.normalized_grad: - tda_output[row_st:row_ed, col_st:col_ed] += ( - ( - normalize(train_batch_grad) - @ normalize(test_batch_grad).T - * ckpt_weight - ) + + if self.offload == "none": + # all train grads are already in memory and pre-normalized above + tda_output[:, col_st:col_ed] += ( + (all_train_grads @ norm(test_batch_grad).T * ckpt_weight) .detach() .cpu() ) else: - tda_output[row_st:row_ed, col_st:col_ed] += ( - (train_batch_grad @ test_batch_grad.T * ckpt_weight) - .detach() - .cpu() - ) + # For cpu/disk offload: retrieve train grads batch by batch + # to keep memory footprint low + for train_batch_idx in range(len(self.full_train_dataloader)): + train_batch_grad = self._offload_managers[ + ckpt_idx + ].retrieve_gradients( + train_batch_idx, + is_test=False, + )[0] + + row_st = ( + train_batch_idx * self.full_train_dataloader.batch_size + ) + row_ed = min( + (train_batch_idx + 1) + * self.full_train_dataloader.batch_size, + len(self.full_train_dataloader.sampler), + ) + + tda_output[row_st:row_ed, col_st:col_ed] += ( + ( + norm(train_batch_grad) + @ norm(test_batch_grad).T + * ckpt_weight + ) + .detach() + .cpu() + ) return tda_output @@ -242,12 +442,12 @@ def self_attribute( means that only a part of the training set's influence is calculated. The dataloader should not be shuffled. - Raises: - ValueError: The length of params_list and weight_list don't match. - Returns: Tensor: The influence of the training set on itself, with the shape of (num_train_samples,). + + Raises: + ValueError: The length of params_list and weight_list don't match. """ test_dataloader = train_dataloader _check_shuffle(test_dataloader) @@ -296,7 +496,6 @@ def self_attribute( train_batch_data = tuple( data.to(self.device) for data in train_batch_data_ ) - # get gradient of train grad_t = self.grad_loss_func(parameters, train_batch_data) if self.projector_kwargs is not None: # define the projector for this batch of data diff --git a/examples/noisy_label_detection/tracin_noisy_label.py b/examples/noisy_label_detection/tracin_noisy_label.py index a6c5d3219..948c3d360 100644 --- a/examples/noisy_label_detection/tracin_noisy_label.py +++ b/examples/noisy_label_detection/tracin_noisy_label.py @@ -64,7 +64,7 @@ def f(params, image_label_pair): ) with torch.no_grad(): - score = attributor.attribute(train_loader, test_loader).diag() + score = attributor.attribute(test_loader, train_loader).diag() _, indices = torch.sort(-score) cr = 0 diff --git a/experiments/benchmark_result.py b/experiments/benchmark_result.py index a84724c7d..d30dcf920 100644 --- a/experiments/benchmark_result.py +++ b/experiments/benchmark_result.py @@ -280,7 +280,7 @@ def loss_rps(pre_activation_list, label_list): device=args.device, ) with torch.no_grad(): - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) # compute metrics metrics_score = METRICS_DICT[args.metric](score, groundtruth)[0] diff --git a/experiments/benchmark_result_mt.py b/experiments/benchmark_result_mt.py index 000e23a57..75bea4d59 100644 --- a/experiments/benchmark_result_mt.py +++ b/experiments/benchmark_result_mt.py @@ -150,7 +150,7 @@ def loss_tracin(params, data_target_pair): device=args.device, ) with torch.no_grad(): - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) best_result = 0 diff --git a/experiments/benchmark_result_nanogpt.py b/experiments/benchmark_result_nanogpt.py index 4abcde8e6..9b4c93f8d 100644 --- a/experiments/benchmark_result_nanogpt.py +++ b/experiments/benchmark_result_nanogpt.py @@ -154,7 +154,7 @@ def loss_tracin(params, data_target_pair): device=args.device, ) with torch.no_grad(): - score = attributor.attribute(train_loader, val_loader) + score = attributor.attribute(val_loader, train_loader) best_result = 0 best_config = None diff --git a/experiments/gpt2_wikitext/score_TRAK.py b/experiments/gpt2_wikitext/score_TRAK.py index 8f0c1cbbe..42e5d8517 100644 --- a/experiments/gpt2_wikitext/score_TRAK.py +++ b/experiments/gpt2_wikitext/score_TRAK.py @@ -757,7 +757,7 @@ def checkpoints_load_func(model, checkpoint_path): attributor.cache(train_dataloader) score = attributor.attribute(eval_dataloader) else: - score = attributor.attribute(train_dataloader, eval_dataloader) + score = attributor.attribute(eval_dataloader, train_dataloader) torch.save(score, "score_TRAK.pt") logger.info("Attribution scores saved to score_TRAK.pt") diff --git a/test/dattri/algorithm/test_tracin.py b/test/dattri/algorithm/test_tracin.py index e6dada3cb..7ce8a679f 100644 --- a/test/dattri/algorithm/test_tracin.py +++ b/test/dattri/algorithm/test_tracin.py @@ -19,7 +19,7 @@ class TestTracInAttributor: """Test for TracIn.""" - def test_tracin_proj(self): + def test_tracin_proj(self): # noqa: PLR0914 """Test for TracIn with projectors.""" train_dataset = TensorDataset( torch.randn(20, 1, 28, 28), @@ -73,7 +73,7 @@ def f(params, image_label_pair): "device": pytest_device, } - # test with projector list + # test with projector list, without cache attributor = TracInAttributor( task=task, weight_list=torch.ones(len(checkpoint_list)), @@ -83,15 +83,58 @@ def f(params, image_label_pair): ) # Original test - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) assert score.shape == (len(train_loader.dataset), len(test_loader.dataset)) assert torch.count_nonzero(score) == len(train_loader.dataset) * len( test_loader.dataset, ) + # test with projector list, with cache + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + projector_kwargs=projector_kwargs, + device=torch.device(pytest_device), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + + # test with projector list, with offload(cpu) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + projector_kwargs=projector_kwargs, + device=torch.device(pytest_device), + offload="cpu", + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + + # test with projector, with offload(disk) + cache_path = Path("./cache") + if not cache_path.exists(): + cache_path.mkdir(parents=True) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + projector_kwargs=projector_kwargs, + device=torch.device(pytest_device), + offload="disk", + cache_dir=str(cache_path), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + shutil.rmtree(path) + shutil.rmtree(cache_path) - def test_tracin(self): + def test_tracin(self): # noqa: PLR0914 """Test for TracIn without projectors.""" train_dataset = TensorDataset( torch.randn(20, 1, 28, 28), @@ -135,20 +178,66 @@ def f(params, image_label_pair): ) pytest_device = "cpu" - # test with no projector list + # test with no projector list, without cache attributor = TracInAttributor( task=task, weight_list=torch.ones(len(checkpoint_list)), normalized_grad=True, device=torch.device(pytest_device), ) - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) assert score.shape == (len(train_loader.dataset), len(test_loader.dataset)) assert torch.count_nonzero(score) == len(train_loader.dataset) * len( test_loader.dataset, ) + # test with no projector, with cache + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + device=torch.device(pytest_device), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + score3 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + assert torch.allclose(score2, score3) + + # test with no projector, with offload(cpu) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + device=torch.device(pytest_device), + offload="cpu", + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + score3 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + assert torch.allclose(score2, score3) + + # test with no projector, with offload(disk) + cache_path = Path("./cache") + if not cache_path.exists(): + cache_path.mkdir(parents=True) + attributor = TracInAttributor( + task=task, + weight_list=torch.ones(len(checkpoint_list)), + normalized_grad=True, + device=torch.device(pytest_device), + offload="disk", + cache_dir=str(cache_path), + ) + attributor.cache(train_loader) + score2 = attributor.attribute(test_loader) + score3 = attributor.attribute(test_loader) + assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05) + assert torch.allclose(score2, score3) + shutil.rmtree(path) + shutil.rmtree(cache_path) def test_tracin_self_attribute(self): """Test for self_attribute in TracIn without projectors.""" @@ -349,7 +438,7 @@ def f(params, dict_batch): device=torch.device(pytest_device), ) - score = attributor.attribute(train_loader, test_loader) + score = attributor.attribute(test_loader, train_loader) assert score.shape == (len(train_loader.dataset), len(test_loader.dataset)) assert torch.count_nonzero(score) == len(train_loader.dataset) * len(