diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 00337451..ddc337d6 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import List, Optional, Tuple, Union + from typing import Dict, List, Optional, Tuple, Union from torch import Tensor from torch.utils.data import DataLoader @@ -18,6 +18,8 @@ import torch from dattri.func.hessian import ihvp_arnoldi, ihvp_cg, ihvp_explicit, ihvp_lissa +from dattri.func.projection import random_project +from dattri.params.projection import IFProjectionParams, RandomProjectionParams from .base import BaseInnerProductAttributor @@ -54,6 +56,7 @@ def __init__( self, task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, + proj_params: Optional[IFProjectionParams] = None, device: Optional[str] = "cpu", regularization: float = 0.0, ) -> None: @@ -67,6 +70,8 @@ def __init__( parameters are used. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. + proj_params (Optional[IFProjectionParams]): Projection parameters for + random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term added to Hessian matrix. Useful for singular or ill-conditioned Hessian matrices. @@ -74,9 +79,36 @@ def __init__( Default is 0.0. """ super().__init__(task, layer_name, device) - self.transformation_kwargs = { - "regularization": regularization, - } + + self.proj_params = proj_params + self.regularization = regularization + + def transform_train_rep( + self, + ckpt_idx: int, # noqa:ARG002 + train_rep: torch.Tensor, + ) -> torch.Tensor: + """Transform the train representations via random projection. + + Args: + ckpt_idx (int): Index of the model checkpoints. + train_rep (torch.Tensor): Train representations to be transformed. + + Returns: + torch.Tensor: Transformed train representations with projected dimension. + """ + if self.proj_params is not None: + sample_features = torch.zeros(1, train_rep.shape[1]) + projector = random_project( + sample_features, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), + ) + return projector(train_rep, ensemble_id=0) + return train_rep def transform_test_rep( self, @@ -118,6 +150,15 @@ def transform_test_rep( ) raise TypeError(msg) + # Create proj_params_obj only if projection is needed + rand_proj_params = None + if self.proj_params is not None: + rand_proj_params = RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ) + self.ihvp_func = ihvp_explicit( partial( self.task.get_loss_func( @@ -126,7 +167,8 @@ def transform_test_rep( ), **{self.task.loss_func_data_key: full_data}, ), - **self.transformation_kwargs, + regularization=self.regularization, + proj_params=rand_proj_params, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() return vector_product @@ -164,6 +206,21 @@ def _compute_denom( if relatif_method == "l": # g_i^T (H^-1 g_i) + if self.proj_params is not None: + sample_features = torch.zeros(1, train_batch_rep.shape[1]) + + rand_proj_params = None + if self.proj_params is not None: + rand_proj_params = RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ) + projector = random_project( + sample_features, + proj_params=rand_proj_params, + ) + test_batch_rep = projector(test_batch_rep, ensemble_id=0) val = (test_batch_rep * transformed).sum(dim=1).clamp_min(1e-12).sqrt() elif relatif_method == "theta": # ||H^-1 g_i|| @@ -655,6 +712,7 @@ def __init__( self, task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, + proj_params: Optional[IFProjectionParams] = None, device: Optional[str] = "cpu", regularization: float = 0.0, fim_estimate_data_ratio: float = 1.0, @@ -669,6 +727,8 @@ def __init__( parameters are used. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. + proj_params (Optional[IFProjectionParams]): Projection parameters for + random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term for Hessian vector product. Adding `regularization * I` to the Hessian matrix, where `I` is the @@ -681,6 +741,10 @@ def __init__( self.regularization = regularization self.fim_estimate_data_ratio = fim_estimate_data_ratio + # per-layer projections + self.proj_params = proj_params + self.layer_projectors = {} + def cache( self, full_train_dataloader: DataLoader, @@ -696,6 +760,30 @@ def cache( self.full_train_dataloader = full_train_dataloader self._cached_train_reps = {} + # Create per-layer projectors based on first checkpoint + model_params, param_layer_map = self.task.get_param( + 0, + layer_name=self.layer_name, + layer_split=True, + ) + + if self.proj_params is not None: + layer_sizes = [0] * (param_layer_map[-1] + 1) + for idx, layer_index in enumerate(param_layer_map): + layer_sizes[layer_index] += model_params[idx].numel() + + # create projectors for all the layers + for layer_idx in range(param_layer_map[-1] + 1): + sample_features = torch.zeros(1, layer_sizes[layer_idx]) + self.layer_projectors[layer_idx] = random_project( + sample_features, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), + ) + for checkpoint_idx in range(len(self.task.get_checkpoints())): _cached_train_reps_list = [] iter_number = math.ceil( @@ -728,13 +816,58 @@ def cache( ckpt_idx=checkpoint_idx, data=sampled_data, ) - _cached_train_reps_list.append(sampled_data_rep) + + if self.proj_params is not None: + sampled_data_rep_layers = self._get_layer_wise_reps( + checkpoint_idx, + sampled_data_rep, + ) + + # project every layer + projected_layers = [] + for layer_idx, layer_rep in enumerate(sampled_data_rep_layers): + projected = self.layer_projectors[layer_idx]( + layer_rep, + ensemble_id=0, + ) + projected_layers.append(projected) + + # concat to one tensor + sampled_data_rep = torch.cat(projected_layers, dim=1) + _cached_train_reps_list.append(sampled_data_rep.detach()) + self._cached_train_reps[checkpoint_idx] = torch.cat( _cached_train_reps_list, dim=0, ) - def transform_test_rep( + def transform_train_rep( + self, + ckpt_idx: int, + train_rep: torch.Tensor, + ) -> torch.Tensor: + """Generate train representation with per-layer projection. + + Args: + ckpt_idx (int): Index of model checkpoints for ensembling. + train_rep (torch.Tensor): Training representations to be transformed. + + Returns: + torch.Tensor: Projected training representation. + """ + # split and project + if self.proj_params is not None: + train_rep_layers = self._get_layer_wise_reps(ckpt_idx, train_rep) + projected_layers = [] + + for layer_idx, layer_rep in enumerate(train_rep_layers): + projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) + projected_layers.append(projected) + + return torch.cat(projected_layers, dim=1) + return train_rep + + def transform_test_rep( # noqa: PLR0914 self, ckpt_idx: int, test_rep: torch.Tensor, @@ -784,16 +917,31 @@ def _transform_single_test_rep( return (v - coef @ grad.T) / regularization regularization = self.regularization - # Split layer-wise train and test representations - test_rep_layers = self._get_layer_wise_reps(ckpt_idx, test_rep) - cached_train_rep_layers = self._get_layer_wise_reps( - ckpt_idx, - self._cached_train_reps[ckpt_idx], - ) + if self.proj_params is not None: + # project test + test_rep_layers_raw = self._get_layer_wise_reps(ckpt_idx, test_rep) + test_rep_layers = [] + for layer_idx, layer_rep in enumerate(test_rep_layers_raw): + projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) + test_rep_layers.append(projected) + + # project train + cached_train_rep_layers = self._get_layer_wise_reps( + ckpt_idx, + self._cached_train_reps[ckpt_idx], + projected=True, + ) + else: + test_rep_layers = self._get_layer_wise_reps(ckpt_idx, test_rep) + cached_train_rep_layers = self._get_layer_wise_reps( + ckpt_idx, + self._cached_train_reps[ckpt_idx], + projected=False, + ) layer_cnt = len(cached_train_rep_layers) + transformed_test_rep_layers = [] # Use test batch size as intermediate batch size - # Peak memory usage: max(train_batch_size,test_batch_size)*test_batch_size*p batch_size = test_rep.shape[0] for layer in range(layer_cnt): grad_layer = cached_train_rep_layers[layer] @@ -808,10 +956,12 @@ def _transform_single_test_rep( for batch in grad_batches: reg = 0.1 if regularization is None else regularization contribution = torch.func.vmap( - lambda grad, layer=layer, reg=reg: _transform_single_test_rep( - test_rep_layers[layer], - grad, - reg, + lambda grad, layer=layer, reg=reg: ( + _transform_single_test_rep( + test_rep_layers[layer], + grad, + reg, + ) ), )(batch) # Accumulate the batches and average at the end @@ -824,6 +974,7 @@ def _get_layer_wise_reps( self, ckpt_idx: int, query: torch.Tensor, + projected: bool = False, ) -> Tuple[torch.Tensor, ...]: """Split a representation into layer-wise representations. @@ -832,11 +983,24 @@ def _get_layer_wise_reps( is used for ensembling of different trained model. query (torch.Tensor): Input representations to split, could be train/test representations, of shape (batch_size,parameter) + projected (bool): If True, assumes query is already projected and splits + evenly by proj_dim. Default: False. Returns: Tuple[torch.Tensor, ...]: The layer-wise splitted tensor, a tuple of shape (batch_size,layer0_size), (batch_size,layer1_size)... """ + if projected: + # Split evenly by proj_dim + proj_dim = self.proj_params.proj_dim + num_layers = query.shape[1] // proj_dim + query_layers = [] + for i in range(num_layers): + start_idx = i * proj_dim + end_idx = start_idx + proj_dim + query_layers.append(query[:, start_idx:end_idx]) + return query_layers + # Original splitting logic model_params, param_layer_map = self.task.get_param( ckpt_idx, layer_split=True, @@ -859,6 +1023,7 @@ def __init__( self, task: AttributionTask, module_name: Optional[Union[str, List[str]]] = None, + proj_params: Optional[IFProjectionParams] = None, device: Optional[str] = "cpu", damping: float = 0.0, ) -> None: @@ -882,6 +1047,8 @@ def __init__( modules are used. This should be a string or a list of strings if multiple modules are needed. The name of module should follow the key of model.named_modules(). Default: None. + proj_params (Optional[IFProjectionParams]): Projection parameters for + random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". damping (float): Damping factor used for non-convexity in EK-FAC IFVP calculation. Default is 0.0. @@ -908,6 +1075,7 @@ def __init__( module_name = [module_name] self.module_name = module_name + self.proj_params = proj_params self.damping = damping self.name_to_module = { @@ -916,6 +1084,8 @@ def __init__( self.module_to_name = {v: k for k, v in self.name_to_module.items()} self.layer_cache = {} # cache for each layer + self.input_projectors = {} + self.output_projectors = {} # Update layer_name corresponding to selected modules self.layer_name = [] @@ -952,6 +1122,35 @@ def cache( if max_iter is None: max_iter = len(full_train_dataloader) + if self.proj_params is not None: + for name in self.module_name: + mod = self.name_to_module[name] + input_dim = mod.in_features + if mod.bias is not None: + input_dim += 1 + output_dim = mod.out_features + + if name not in self.input_projectors: + blksz_in = torch.zeros(1, input_dim) + self.input_projectors[name] = random_project( + blksz_in, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), + ) + + if name not in self.output_projectors: + blksz_out = torch.zeros(1, output_dim) + self.output_projectors[name] = random_project( + blksz_out, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), + ) def _ekfac_hook( module: torch.nn.Module, @@ -1001,6 +1200,8 @@ def _ekfac_hook( self.layer_cache, max_iter, device=self.device, + input_projectors=self.input_projectors, + output_projectors=self.output_projectors, ) # 2. Calculate the eigenvalue decomposition of S and A @@ -1014,73 +1215,141 @@ def _ekfac_hook( self.layer_cache, max_iter, device=self.device, + input_projectors=self.input_projectors, + output_projectors=self.output_projectors, ) # Remove hooks after preprocessing the FIM for handle in handles: handle.remove() - def transform_test_rep( + def _project_rep( self, - ckpt_idx: int, - test_rep: torch.Tensor, - ) -> torch.Tensor: - """Calculate the transformation on the test representations. + rep: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Unflatten representation and project each layer. Args: - ckpt_idx (int): Index of the model checkpoints. Used for ensembling - different trained model checkpoints. - test_rep (torch.Tensor): Test representations to be transformed. - Typically a 2-d tensor with shape (batch_size, num_parameters). + rep (torch.Tensor): The representation to project, typically gradients. + A 2-d tensor with shape (batch_size, num_parameters). Returns: - torch.Tensor: Transformed test representations. Typically a 2-d - tensor with shape (batch_size, transformed_dimension). - - Raises: - ValueError: If specifies a non-zero `ckpt_idx`. + Dict[str, torch.Tensor]: Dictionary mapping module name to projected + layer representation with shape (batch, proj_dim_out, proj_dim_in). """ - if ckpt_idx != 0: - error_msg = ( - "EK-FAC only supports single model checkpoint, " - "but receives non-zero `ckpt_idx`." - ) - raise ValueError(error_msg) - - # Unflatten the test_rep + # Unflatten the rep full_model_params = { k: p for k, p in self.task.model.named_parameters() if p.requires_grad } partial_model_params = { name: full_model_params[name] for name in self.layer_name } - layer_test_rep = {} + layer_rep = {} current_index = 0 for name, params in partial_model_params.items(): size = math.prod(params.shape) - layer_test_rep[name] = test_rep[ + layer_rep[name] = rep[ :, current_index : current_index + size, ].reshape(-1, *params.shape) current_index += size - ifvp = {} + layer_rep_proj = {} for name in self.module_name: if self.name_to_module[name].bias is not None: - dim_out = layer_test_rep[name + ".weight"].shape[1] - dim_in = layer_test_rep[name + ".weight"].shape[2] + 1 + dim_out = layer_rep[name + ".weight"].shape[1] + dim_in = layer_rep[name + ".weight"].shape[2] + 1 _v = torch.cat( [ - layer_test_rep[name + ".weight"].flatten(start_dim=1), - layer_test_rep[name + ".bias"].flatten(start_dim=1), + layer_rep[name + ".weight"].flatten(start_dim=1), + layer_rep[name + ".bias"].flatten(start_dim=1), ], dim=-1, ) _v = _v.reshape(-1, dim_out, dim_in) else: - _v = layer_test_rep[name + ".weight"] + _v = layer_rep[name + ".weight"] + if self.proj_params and name in self.input_projectors: + original_shape = _v.shape + _v_reshaped = _v.reshape(-1, original_shape[-1]) + _v_projected = self.input_projectors[name](_v_reshaped, ensemble_id=0) + _v = _v_projected.reshape(original_shape[0], original_shape[1], -1) + + if self.proj_params and name in self.output_projectors: + original_shape = _v.shape + _v_transposed = _v.transpose(1, 2) + _v_reshaped = _v_transposed.reshape(-1, original_shape[1]) + _v_projected = self.output_projectors[name]( + _v_reshaped, ensemble_id=0, + ) + _v = _v_projected.reshape( + original_shape[0], -1, _v_projected.shape[-1], + ) + _v = _v.transpose(1, 2) + + layer_rep_proj[name] = _v + + return layer_rep_proj + + def transform_train_rep( # noqa: PLR6301 + self, + ckpt_idx: int, # noqa:ARG002 + train_rep: torch.Tensor, + ) -> torch.Tensor: + """Transform train representation with projection. + + Args: + ckpt_idx (int): Index of model checkpoints for ensembling. + train_rep (torch.Tensor): Training representations. + + Returns: + torch.Tensor: Projected training representation. + """ + # project + if self.proj_params: + layer_rep_proj = self._project_rep(train_rep) + # flatten + projected_layers = [ + layer_rep_proj[name].flatten(start_dim=1) + for name in self.module_name + ] + train_rep = torch.cat(projected_layers, dim=1) + + return train_rep + + def transform_test_rep( + self, + ckpt_idx: int, + test_rep: torch.Tensor, + ) -> torch.Tensor: + """Calculate the transformation on the test representations. + + Args: + ckpt_idx (int): Index of the model checkpoints. Used for ensembling + different trained model checkpoints. + test_rep (torch.Tensor): Test representations to be transformed. + Typically a 2-d tensor with shape (batch_size, num_parameters). + Returns: + torch.Tensor: Transformed test representations. Typically a 2-d + tensor with shape (batch_size, transformed_dimension). + + Raises: + ValueError: If specifies a non-zero `ckpt_idx`. + """ + if ckpt_idx != 0: + error_msg = ( + "EK-FAC only supports single model checkpoint, " + "but receives non-zero `ckpt_idx`." + ) + raise ValueError(error_msg) + + # project test rep + test_rep_proj = self._project_rep(test_rep) + ifvp = {} + for name in self.module_name: + _v = test_rep_proj[name] _lambda = self.cached_lambdas[name] q_a, q_s = self.cached_q[name] diff --git a/dattri/func/fisher.py b/dattri/func/fisher.py index fe879ddb..4f84ccda 100644 --- a/dattri/func/fisher.py +++ b/dattri/func/fisher.py @@ -241,6 +241,8 @@ def _update_covariance( layer_cache: Dict[str, Tuple[torch.tensor]], total_samples: int, mask: torch.Tensor, + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, Tuple[torch.tensor]]: """Update the running estimation of the covariance matrices S and A in EK-FAC IFVP. @@ -254,6 +256,10 @@ def _update_covariance( mask (torch.Tensor): A tensor of shape (batch_size, t), where 1's indicate that the IFVP will be estimated on these input positions and 0's indicate that these positions are irrelevant (e.g. padding tokens). + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, Tuple[torch.tensor]]: A dict of tuples of tensors, storing the @@ -264,6 +270,7 @@ def _update_covariance( # Uniformly reshape the tensors into (batch_size, t, ...) # The t here is the sequence length or time steps for sequential input # t = 1 if the given input is not sequential + a_prev = a_prev_raw if a_prev_raw.ndim == 2: # noqa: PLR2004 a_prev = a_prev_raw.unsqueeze(1) @@ -271,6 +278,13 @@ def _update_covariance( # Calculate batch covariance matrix for A a_prev_reshaped = a_prev_masked.view(-1, a_prev.size(-1)) + + # project input activations + if input_projectors is not None and layer_name in input_projectors: + a_prev_reshaped = input_projectors[layer_name]( + a_prev_reshaped, ensemble_id=0, + ) + batch_cov_a = a_prev_reshaped.transpose(0, 1) @ a_prev_reshaped batch_cov_a /= batch_samples @@ -278,6 +292,13 @@ def _update_covariance( ds_curr = s_curr_raw.grad ds_curr_reshaped = ds_curr.view(-1, s_curr_raw.size(-1)) + + # project output gradients + if output_projectors is not None and layer_name in output_projectors: + ds_curr_reshaped = output_projectors[layer_name]( + ds_curr_reshaped, ensemble_id=0, + ) + batch_cov_s = ds_curr_reshaped.transpose(0, 1) @ ds_curr_reshaped batch_cov_s /= batch_samples @@ -297,13 +318,15 @@ def _update_covariance( return curr_estimate -def _update_lambda( +def _update_lambda( # noqa: PLR0914 curr_estimate: Dict[str, torch.tensor], layer_cache: Dict[str, Tuple[torch.tensor]], cached_q: Dict[str, Tuple[torch.tensor]], total_samples: int, mask: torch.Tensor, max_steps_for_vec: int = 10, + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, torch.tensor]: """Update the running estimation of the corrected eigenvalues in EK-FAC IFVP. @@ -326,6 +349,10 @@ def _update_lambda( max_steps_for_vec (int): An integer default to 10. Controls the maximum number of input steps that is allowed for vectorized calculation of `dtheta`. + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, torch.tensor]: A dict of tensors, storing the updated running lambdas. @@ -335,10 +362,28 @@ def _update_lambda( # The t here is the sequence length or time steps for sequential input # t = 1 if the given input is not sequential ds_curr = s_curr_raw.grad + a_prev = a_prev_raw if a_prev_raw.ndim == 2: # noqa: PLR2004 a_prev = a_prev_raw.unsqueeze(1) ds_curr = ds_curr.unsqueeze(1) + # project + if input_projectors is not None and layer_name in input_projectors: + original_shape = a_prev.shape + a_prev_flat = a_prev.view(-1, a_prev.size(-1)) + a_prev_projected = input_projectors[layer_name](a_prev_flat, ensemble_id=0) + a_prev = a_prev_projected.view(original_shape[0], original_shape[1], -1) + + if output_projectors is not None and layer_name in output_projectors: + original_shape = ds_curr.shape + ds_curr_flat = ds_curr.view(-1, ds_curr.size(-1)) + ds_curr_projected = output_projectors[layer_name]( + ds_curr_flat, ensemble_id=0, + ) + ds_curr = ds_curr_projected.view( + original_shape[0], original_shape[1], -1, + ) + a_prev_masked = a_prev * mask[..., None].to(a_prev.device) batch_samples = a_prev_masked.shape[0] @@ -387,6 +432,8 @@ def estimate_covariance( layer_cache: Dict[str, Tuple[torch.tensor]], max_iter: Optional[int] = None, device: Optional[str] = "cpu", + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, Tuple[torch.tensor]]: """Estimate the 'covariance' matrices S and A in EK-FAC IFVP. @@ -409,6 +456,10 @@ def estimate_covariance( max_iter (Optional[int]): An integer indicating the maximum number of batches that will be used for estimating the covariance matrices. device (Optional[str]): Device to run the attributor on. Default is "cpu". + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, Tuple[torch.tensor]]: A dict that contains a pair of @@ -442,6 +493,8 @@ def estimate_covariance( layer_cache, total_samples, mask, + input_projectors, + output_projectors, ) total_samples += int(mask.sum()) @@ -480,6 +533,8 @@ def estimate_lambda( layer_cache: Dict[str, Tuple[torch.tensor]], max_iter: Optional[int] = None, device: Optional[str] = "cpu", + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, torch.tensor]: """Estimate the corrected eigenvalues in EK-FAC IFVP. @@ -504,6 +559,10 @@ def estimate_lambda( max_iter (Optional[int]): An integer indicating the maximum number of batches that will be used for estimating the lambdas. device (Optional[str]): Device to run the attributor on. Default is "cpu". + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, torch.tensor]: A dict that contains the estimated lambda @@ -538,6 +597,8 @@ def estimate_lambda( eigenvectors, total_samples, mask, + input_projectors=input_projectors, + output_projectors=output_projectors, ) total_samples += batch_size if i == max_iter - 1: diff --git a/dattri/func/hessian.py b/dattri/func/hessian.py index f2fc4445..93aa1ba2 100644 --- a/dattri/func/hessian.py +++ b/dattri/func/hessian.py @@ -23,6 +23,8 @@ from collections.abc import Callable from typing import Optional, Tuple, Union + from dattri.params.projection import RandomProjectionParams + import torch from torch import Tensor from torch.func import grad, hessian, jvp, vjp, vmap @@ -225,6 +227,7 @@ def ihvp_explicit( func: Callable, argnums: int = 0, regularization: float = 0.0, + proj_params: Optional[RandomProjectionParams] = None, ) -> Callable: """IHVP via explicit Hessian calculation. @@ -244,11 +247,15 @@ def ihvp_explicit( matrix is singular or ill-conditioned. The regularization term is `regularization * I`, where `I` is the identity matrix directly added to the Hessian matrix. + proj_params (Optional[Dict[str, Any]]): Keyword arguments for + random projection. Default: None. Returns: A function that takes a tuple of Tensor `x` and a vector `v` and returns the IHVP of the Hessian of `func` and `v`. """ + from dattri.func.projection import random_project + hessian_func = hessian(func, argnums=argnums) def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: @@ -263,12 +270,26 @@ def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: The IHVP value. """ hessian_tensor = hessian_func(*x) + if proj_params is not None: + sample_features = torch.zeros(1, hessian_tensor.shape[0]) + projector = random_project( + sample_features, + proj_params, + ) + # project H + proj_h_pt_t = projector(hessian_tensor, ensemble_id=0) + proj_p_h_pt = projector(proj_h_pt_t.T, ensemble_id=0).T + proj_v = projector(v, ensemble_id=0) + return torch.linalg.solve( + proj_p_h_pt + + torch.eye(proj_p_h_pt.shape[0]).to(proj_v.device) * regularization, + proj_v.T, + ).T return torch.linalg.solve( hessian_tensor + torch.eye(hessian_tensor.shape[0]).to(v.device) * regularization, v.T, ).T - return _ihvp_explicit_func diff --git a/dattri/params/projection.py b/dattri/params/projection.py index ea5c9d93..87f7b64c 100644 --- a/dattri/params/projection.py +++ b/dattri/params/projection.py @@ -35,7 +35,7 @@ class BaseProjectionParams(BaseModel): class GeneralProjectionParams(BaseProjectionParams): - """General projection params used by TracIn, TRAK, RandomProjectionParams. + """General projection params used by IF, TracIn, TRAK, RandomProjectionParams. Args: proj_dim (int): Dimension of the projected feature. @@ -50,6 +50,23 @@ class GeneralProjectionParams(BaseProjectionParams): proj_dim: int +class IFProjectionParams(GeneralProjectionParams): + """Projection params for IF-based attributors. + + Args: + proj_dim (int): Dimension of the projected feature. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ + + proj_dim: int = 512 + proj_max_batch_size: int = 32 + + class LoGraProjectionParams(BaseProjectionParams): """Projection params for LoGra attributor. diff --git a/test/dattri/algorithm/test_influence_function.py b/test/dattri/algorithm/test_influence_function.py index aa2678b2..29d9945f 100644 --- a/test/dattri/algorithm/test_influence_function.py +++ b/test/dattri/algorithm/test_influence_function.py @@ -14,6 +14,7 @@ IFAttributorLiSSA, ) from dattri.benchmark.datasets.mnist import train_mnist_lr +from dattri.params.projection import IFProjectionParams from dattri.task import AttributionTask @@ -464,3 +465,64 @@ def f(params, data_target_pair): threshold = 0.98 corr = average_pairwise_correlation(gt_test_rep, transformed_test_rep) assert corr > threshold + + def test_influence_functions_with_random_projection(self): + """Test for random projection in Explicit, DataInf and EK-FAC attributors.""" + train_dataset = TensorDataset( + torch.randn(20, 1, 28, 28), + torch.randint(0, 10, (20,)), + ) + train_loader = DataLoader(train_dataset, batch_size=4) + + proj_params = { + "proj_dim": 512, + "proj_max_batch_size": 32, + "proj_seed": 0, + "device": "cpu", + } + + proj_params = IFProjectionParams() + + model = train_mnist_lr(train_loader) + + def f(params, data_target_pair): + image, label = data_target_pair + loss = nn.CrossEntropyLoss() + yhat = torch.func.functional_call(model, params, image) + return loss(yhat, label.long()) + + task = AttributionTask( + loss_func=f, + model=model, + checkpoints=model.state_dict(), + ) + + # Explicit with random projection + attributor_exp = IFAttributorExplicit( + task=task, + device=torch.device("cpu"), + regularization=1e-3, + proj_params=proj_params, + ) + attributor_exp.cache(train_loader) + attributor_exp.attribute(train_loader, train_loader) + + # DataInf with random projection + attributor_datainf = IFAttributorDataInf( + task=task, + device=torch.device("cpu"), + regularization=1e-3, + proj_params=proj_params, + ) + attributor_datainf.cache(train_loader) + attributor_datainf.attribute(train_loader, train_loader) + + # EK-FAC with random projection + attributor_ekfac = IFAttributorEKFAC( + task=task, + device=torch.device("cpu"), + damping=0.1, + proj_params=proj_params, + ) + attributor_ekfac.cache(train_loader) + attributor_ekfac.attribute(train_loader, train_loader)