From 7021bc51f6aa2054a2199406092b3d6c74b1f569 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Tue, 30 Dec 2025 19:33:50 -0600 Subject: [PATCH 01/18] random projection in IF-EKFAC --- dattri/algorithm/influence_function.py | 171 ++++++++++++++++++++----- dattri/func/fisher.py | 53 +++++++- 2 files changed, 191 insertions(+), 33 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index d9ff15a63..4556ba295 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -19,7 +19,8 @@ from dattri.func.hessian import ihvp_arnoldi, ihvp_cg, ihvp_explicit, ihvp_lissa -from .base import BaseInnerProductAttributor +from .base import BaseAttributor, BaseInnerProductAttributor +from dattri.func.projection import random_project def _lissa_collate_fn( @@ -46,6 +47,13 @@ def _lissa_collate_fn( "lissa": partial(ihvp_lissa, collate_fn=_lissa_collate_fn), } +DEFAULT_PROJECTOR_KWARGS = { + "proj_dim": 512, + "proj_max_batch_size": 32, + "proj_seed": 0, + "device": "cpu", +} + class IFAttributorExplicit(BaseInnerProductAttributor): """The inner product attributor with explicit inverse hessian transformation.""" @@ -792,6 +800,7 @@ def __init__( self, task: AttributionTask, module_name: Optional[Union[str, List[str]]] = None, + projector_kwargs: Optional[Dict[str, Any]] = None, device: Optional[str] = "cpu", damping: float = 0.0, ) -> None: @@ -842,6 +851,10 @@ def __init__( self.module_name = module_name + self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS + if projector_kwargs is not None: + self.projector_kwargs.update(projector_kwargs) + self.damping = damping self.name_to_module = { name: self.task.model.get_submodule(name) for name in module_name @@ -849,6 +862,8 @@ def __init__( self.module_to_name = {v: k for k, v in self.name_to_module.items()} self.layer_cache = {} # cache for each layer + self.input_projectors = {} + self.output_projectors = {} # Update layer_name corresponding to selected modules self.layer_name = [] @@ -886,6 +901,31 @@ def cache( if max_iter is None: max_iter = len(full_train_dataloader) + # init projectors for layers + if self.projector_kwargs: + for name in self.module_name: + mod = self.name_to_module[name] + input_dim = mod.in_features + if mod.bias is not None: + input_dim += 1 + output_dim = mod.out_features + + if name not in self.input_projectors: + blksz_in = torch.zeros(1, input_dim) + self.input_projectors[name] = random_project( + blksz_in, + feature_batch_size=1, + **self.projector_kwargs, + ) + + if name not in self.output_projectors: + blksz_out = torch.zeros(1, output_dim) + self.output_projectors[name] = random_project( + blksz_out, + feature_batch_size=1, + **self.projector_kwargs, + ) + def _ekfac_hook( module: torch.nn.Module, inputs: Union[Tensor, Tuple[Tensor]], @@ -934,6 +974,8 @@ def _ekfac_hook( self.layer_cache, max_iter, device=self.device, + input_projectors=self.input_projectors, + output_projectors=self.output_projectors, ) # 2. Calculate the eigenvalue decomposition of S and A @@ -947,73 +989,138 @@ def _ekfac_hook( self.layer_cache, max_iter, device=self.device, + input_projectors=self.input_projectors, + output_projectors=self.output_projectors, ) # Remove hooks after preprocessing the FIM for handle in handles: handle.remove() - def transform_test_rep( + def _project_rep( self, - ckpt_idx: int, - test_rep: torch.Tensor, - ) -> torch.Tensor: - """Calculate the transformation on the test representations. + rep: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Unflatten representation and project each layer. Args: - ckpt_idx (int): Index of the model checkpoints. Used for ensembling - different trained model checkpoints. - test_rep (torch.Tensor): Test representations to be transformed. - Typically a 2-d tensor with shape (batch_size, num_parameters). + rep (torch.Tensor): The representation to project, typically gradients. + A 2-d tensor with shape (batch_size, num_parameters). Returns: - torch.Tensor: Transformed test representations. Typically a 2-d - tensor with shape (batch_size, transformed_dimension). - - Raises: - ValueError: If specifies a non-zero `ckpt_idx`. + Dict[str, torch.Tensor]: Dictionary mapping module name to projected + layer representation with shape (batch, proj_dim_out, proj_dim_in). """ - if ckpt_idx != 0: - error_msg = ( - "EK-FAC only supports single model checkpoint, " - "but receives non-zero `ckpt_idx`." - ) - raise ValueError(error_msg) - - # Unflatten the test_rep + # Unflatten the rep full_model_params = { k: p for k, p in self.task.model.named_parameters() if p.requires_grad } partial_model_params = { name: full_model_params[name] for name in self.layer_name } - layer_test_rep = {} + layer_rep = {} current_index = 0 for name, params in partial_model_params.items(): size = math.prod(params.shape) - layer_test_rep[name] = test_rep[ + layer_rep[name] = rep[ :, current_index : current_index + size, ].reshape(-1, *params.shape) current_index += size - ifvp = {} + layer_rep_proj = {} for name in self.module_name: if self.name_to_module[name].bias is not None: - dim_out = layer_test_rep[name + ".weight"].shape[1] - dim_in = layer_test_rep[name + ".weight"].shape[2] + 1 + dim_out = layer_rep[name + ".weight"].shape[1] + dim_in = layer_rep[name + ".weight"].shape[2] + 1 _v = torch.cat( [ - layer_test_rep[name + ".weight"].flatten(start_dim=1), - layer_test_rep[name + ".bias"].flatten(start_dim=1), + layer_rep[name + ".weight"].flatten(start_dim=1), + layer_rep[name + ".bias"].flatten(start_dim=1), ], dim=-1, ) _v = _v.reshape(-1, dim_out, dim_in) else: - _v = layer_test_rep[name + ".weight"] + _v = layer_rep[name + ".weight"] + + # project + if self.projector_kwargs and name in self.input_projectors: + original_shape = _v.shape + _v_reshaped = _v.reshape(-1, original_shape[-1]) + _v_projected = self.input_projectors[name](_v_reshaped, ensemble_id=0) + _v = _v_projected.reshape(original_shape[0], original_shape[1], -1) + + if self.projector_kwargs and name in self.output_projectors: + original_shape = _v.shape + _v_transposed = _v.transpose(1, 2) + _v_reshaped = _v_transposed.reshape(-1, original_shape[1]) + _v_projected = self.output_projectors[name](_v_reshaped, ensemble_id=0) + _v = _v_projected.reshape(original_shape[0], -1, _v_projected.shape[-1]).transpose(1, 2) + + layer_rep_proj[name] = _v + + return layer_rep_proj + + def generate_train_rep( + self, + ckpt_idx: int, + data: Tuple[torch.Tensor, ...], + ) -> torch.Tensor: + """Generate train representation with projection. + + Args: + ckpt_idx (int): Index of model checkpoints for ensembling. + data (Tuple[torch.Tensor, ...]): Training data batch. + Returns: + torch.Tensor: Projected training representation. + """ + # default to base projector + train_rep = super().generate_train_rep(ckpt_idx, data) + + # project + if self.projector_kwargs and self.input_projectors: + layer_rep_proj = self._project_rep(train_rep) + # flatten + projected_layers = [layer_rep_proj[name].flatten(start_dim=1) for name in self.module_name] + train_rep = torch.cat(projected_layers, dim=1) + + return train_rep + + def transform_test_rep( + self, + ckpt_idx: int, + test_rep: torch.Tensor, + ) -> torch.Tensor: + """Calculate the transformation on the test representations. + + Args: + ckpt_idx (int): Index of the model checkpoints. Used for ensembling + different trained model checkpoints. + test_rep (torch.Tensor): Test representations to be transformed. + Typically a 2-d tensor with shape (batch_size, num_parameters). + + Returns: + torch.Tensor: Transformed test representations. Typically a 2-d + tensor with shape (batch_size, transformed_dimension). + + Raises: + ValueError: If specifies a non-zero `ckpt_idx`. + """ + if ckpt_idx != 0: + error_msg = ( + "EK-FAC only supports single model checkpoint, " + "but receives non-zero `ckpt_idx`." + ) + raise ValueError(error_msg) + + # project test rep + test_rep_proj = self._project_rep(test_rep) + ifvp = {} + for name in self.module_name: + _v = test_rep_proj[name] _lambda = self.cached_lambdas[name] q_a, q_s = self.cached_q[name] @@ -1023,4 +1130,4 @@ def transform_test_rep( # Flatten the parameters again transformed_test_rep_layers = [ifvp[name] for name in self.module_name] - return torch.cat(transformed_test_rep_layers, dim=1) + return torch.cat(transformed_test_rep_layers, dim=1) \ No newline at end of file diff --git a/dattri/func/fisher.py b/dattri/func/fisher.py index fe879ddb3..a0faf2ff9 100644 --- a/dattri/func/fisher.py +++ b/dattri/func/fisher.py @@ -241,6 +241,8 @@ def _update_covariance( layer_cache: Dict[str, Tuple[torch.tensor]], total_samples: int, mask: torch.Tensor, + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, Tuple[torch.tensor]]: """Update the running estimation of the covariance matrices S and A in EK-FAC IFVP. @@ -254,6 +256,10 @@ def _update_covariance( mask (torch.Tensor): A tensor of shape (batch_size, t), where 1's indicate that the IFVP will be estimated on these input positions and 0's indicate that these positions are irrelevant (e.g. padding tokens). + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, Tuple[torch.tensor]]: A dict of tuples of tensors, storing the @@ -271,6 +277,11 @@ def _update_covariance( # Calculate batch covariance matrix for A a_prev_reshaped = a_prev_masked.view(-1, a_prev.size(-1)) + + # project input activations + if input_projectors is not None and layer_name in input_projectors: + a_prev_reshaped = input_projectors[layer_name](a_prev_reshaped, ensemble_id=0) + batch_cov_a = a_prev_reshaped.transpose(0, 1) @ a_prev_reshaped batch_cov_a /= batch_samples @@ -278,8 +289,13 @@ def _update_covariance( ds_curr = s_curr_raw.grad ds_curr_reshaped = ds_curr.view(-1, s_curr_raw.size(-1)) + + # project output gradients + if output_projectors is not None and layer_name in output_projectors: + ds_curr_reshaped = output_projectors[layer_name](ds_curr_reshaped, ensemble_id=0) + batch_cov_s = ds_curr_reshaped.transpose(0, 1) @ ds_curr_reshaped - batch_cov_s /= batch_samples + batch_cov_s /= batch_samples # Update the running covariance matrices for A and S if layer_name in curr_estimate: @@ -304,6 +320,8 @@ def _update_lambda( total_samples: int, mask: torch.Tensor, max_steps_for_vec: int = 10, + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, torch.tensor]: """Update the running estimation of the corrected eigenvalues in EK-FAC IFVP. @@ -326,6 +344,10 @@ def _update_lambda( max_steps_for_vec (int): An integer default to 10. Controls the maximum number of input steps that is allowed for vectorized calculation of `dtheta`. + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, torch.tensor]: A dict of tensors, storing the updated running lambdas. @@ -339,6 +361,19 @@ def _update_lambda( a_prev = a_prev_raw.unsqueeze(1) ds_curr = ds_curr.unsqueeze(1) + # project + if input_projectors is not None and layer_name in input_projectors: + original_shape = a_prev.shape + a_prev_flat = a_prev.view(-1, a_prev.size(-1)) + a_prev_projected = input_projectors[layer_name](a_prev_flat, ensemble_id=0) + a_prev = a_prev_projected.view(original_shape[0], original_shape[1], -1) + + if output_projectors is not None and layer_name in output_projectors: + original_shape = ds_curr.shape + ds_curr_flat = ds_curr.view(-1, ds_curr.size(-1)) + ds_curr_projected = output_projectors[layer_name](ds_curr_flat, ensemble_id=0) + ds_curr = ds_curr_projected.view(original_shape[0], original_shape[1], -1) + a_prev_masked = a_prev * mask[..., None].to(a_prev.device) batch_samples = a_prev_masked.shape[0] @@ -387,6 +422,8 @@ def estimate_covariance( layer_cache: Dict[str, Tuple[torch.tensor]], max_iter: Optional[int] = None, device: Optional[str] = "cpu", + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, Tuple[torch.tensor]]: """Estimate the 'covariance' matrices S and A in EK-FAC IFVP. @@ -409,6 +446,10 @@ def estimate_covariance( max_iter (Optional[int]): An integer indicating the maximum number of batches that will be used for estimating the covariance matrices. device (Optional[str]): Device to run the attributor on. Default is "cpu". + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, Tuple[torch.tensor]]: A dict that contains a pair of @@ -442,6 +483,8 @@ def estimate_covariance( layer_cache, total_samples, mask, + input_projectors, + output_projectors, ) total_samples += int(mask.sum()) @@ -480,6 +523,8 @@ def estimate_lambda( layer_cache: Dict[str, Tuple[torch.tensor]], max_iter: Optional[int] = None, device: Optional[str] = "cpu", + input_projectors: Optional[Dict[str, Callable]] = None, + output_projectors: Optional[Dict[str, Callable]] = None, ) -> Dict[str, torch.tensor]: """Estimate the corrected eigenvalues in EK-FAC IFVP. @@ -504,6 +549,10 @@ def estimate_lambda( max_iter (Optional[int]): An integer indicating the maximum number of batches that will be used for estimating the lambdas. device (Optional[str]): Device to run the attributor on. Default is "cpu". + input_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting input activations. Keys are layer names. + output_projectors (Optional[Dict[str, Callable]]): A dict of projector functions + for projecting output gradients. Keys are layer names. Returns: Dict[str, torch.tensor]: A dict that contains the estimated lambda @@ -538,6 +587,8 @@ def estimate_lambda( eigenvectors, total_samples, mask, + input_projectors=input_projectors, + output_projectors=output_projectors, ) total_samples += batch_size if i == max_iter - 1: From 153901bea0a506526cb8ceda22e168cf111b7785 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Thu, 1 Jan 2026 02:19:52 -0600 Subject: [PATCH 02/18] add projector for DataInf --- dattri/algorithm/influence_function.py | 119 +++++++++++++++++++------ 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 4556ba295..7dd96107c 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -612,12 +612,12 @@ def __init__( self, task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, + projector_kwargs: Optional[Dict[str, Any]] = None, device: Optional[str] = "cpu", regularization: float = 0.0, fim_estimate_data_ratio: float = 1.0, ) -> None: """Initialize the DataInf inverse Hessian attributor. - Args: task (AttributionTask): The task to be attributed. Must be an instance of `AttributionTask`. @@ -626,6 +626,7 @@ def __init__( parameters are used. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. + projector_kwargs: Kwargs for random projection (e.g., proj_dim=512). device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term for Hessian vector product. Adding `regularization * I` to the Hessian matrix, where `I` is the @@ -637,6 +638,12 @@ def __init__( super().__init__(task, layer_name, device) self.regularization = regularization self.fim_estimate_data_ratio = fim_estimate_data_ratio + self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS + + # per-layer projections + if projector_kwargs is not None: + self.projector_kwargs.update(projector_kwargs) + self.layer_projectors = {} def cache( self, @@ -650,6 +657,27 @@ def cache( self.full_train_dataloader = full_train_dataloader self._cached_train_reps = {} + # Create per-layer projectors based on first checkpoint + model_params, param_layer_map = self.task.get_param( + 0, + layer_name=self.layer_name, + layer_split=True + ) + + layer_sizes = [0] * (param_layer_map[-1] + 1) + for idx, layer_index in enumerate(param_layer_map): + layer_sizes[layer_index] += model_params[idx].numel() + + # create projectors for all the layers + for layer_idx in range(param_layer_map[-1] + 1): + sample_features = torch.zeros(1, layer_sizes[layer_idx]) + self.layer_projectors[layer_idx] = random_project( + sample_features, + feature_batch_size=1, + **self.projector_kwargs, + ) + + for checkpoint_idx in range(len(self.task.get_checkpoints())): _cached_train_reps_list = [] iter_number = math.ceil( @@ -669,12 +697,44 @@ def cache( ckpt_idx=checkpoint_idx, data=sampled_data, ) - _cached_train_reps_list.append(sampled_data_rep) + + sampled_data_rep_layers = self._get_layer_wise_reps( + checkpoint_idx, + sampled_data_rep + ) + + # project every layer + projected_layers = [] + for layer_idx, layer_rep in enumerate(sampled_data_rep_layers): + projected = self.layer_projectors[layer_idx]( + layer_rep, + ensemble_id=0 + ) + projected_layers.append(projected) + + # concat to one tensor + sampled_data_rep_projected = torch.cat(projected_layers, dim=1) + _cached_train_reps_list.append(sampled_data_rep_projected.detach()) + self._cached_train_reps[checkpoint_idx] = torch.cat( _cached_train_reps_list, dim=0, ) + def transform_train_rep(self, ckpt_idx, train_rep): + """Generate train representation with per-layer projection.""" + # split and project + train_rep_layers = self._get_layer_wise_reps(ckpt_idx, train_rep) + projected_layers = [] + + for layer_idx, layer_rep in enumerate(train_rep_layers): + projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) + projected_layers.append(projected) + + output = torch.cat(projected_layers, dim=1) + + return output + def transform_test_rep( self, ckpt_idx: int, @@ -725,16 +785,29 @@ def _transform_single_test_rep( return (v - coef @ grad.T) / regularization regularization = self.regularization - # Split layer-wise train and test representations - test_rep_layers = self._get_layer_wise_reps(ckpt_idx, test_rep) - cached_train_rep_layers = self._get_layer_wise_reps( - ckpt_idx, - self._cached_train_reps[ckpt_idx], - ) - layer_cnt = len(cached_train_rep_layers) + # project test + test_rep_layers_raw = self._get_layer_wise_reps(ckpt_idx, test_rep) + test_rep_layers_proj = [] + for layer_idx, layer_rep in enumerate(test_rep_layers_raw): + projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) + test_rep_layers_proj.append(projected) + + cached_train_rep = self._cached_train_reps[ckpt_idx] + proj_dim = self.projector_kwargs.get('proj_dim') + + # now layer dim is proj_dim after projection + layer_cnt = len(test_rep_layers_raw) + cached_train_rep_layers = [] + start_idx = 0 + for layer_idx in range(layer_cnt): + end_idx = start_idx + proj_dim + cached_train_rep_layers.append( + cached_train_rep[:, start_idx:end_idx] + ) + start_idx = end_idx + transformed_test_rep_layers = [] # Use test batch size as intermediate batch size - # Peak memory usage: max(train_batch_size,test_batch_size)*test_batch_size*p batch_size = test_rep.shape[0] for layer in range(layer_cnt): grad_layer = cached_train_rep_layers[layer] @@ -743,14 +816,14 @@ def _transform_single_test_rep( # Split gradients into smaller batches grad_batches = grad_layer.split(batch_size, dim=0) running_transformation = torch.zeros( - test_rep_layers[layer].shape, - device=test_rep_layers[layer].device, + test_rep_layers_proj[layer].shape, + device=test_rep_layers_proj[layer].device, ) for batch in grad_batches: reg = 0.1 if regularization is None else regularization contribution = torch.func.vmap( - lambda grad, layer=layer, reg=reg: _transform_single_test_rep( - test_rep_layers[layer], + lambda grad: _transform_single_test_rep( + test_rep_layers_proj[layer], grad, reg, ), @@ -1063,23 +1136,11 @@ def _project_rep( return layer_rep_proj - def generate_train_rep( + def transform_train_rep( # noqa: PLR6301 self, - ckpt_idx: int, - data: Tuple[torch.Tensor, ...], + ckpt_idx: int, # noqa:ARG002 + train_rep: torch.Tensor, ) -> torch.Tensor: - """Generate train representation with projection. - - Args: - ckpt_idx (int): Index of model checkpoints for ensembling. - data (Tuple[torch.Tensor, ...]): Training data batch. - - Returns: - torch.Tensor: Projected training representation. - """ - # default to base projector - train_rep = super().generate_train_rep(ckpt_idx, data) - # project if self.projector_kwargs and self.input_projectors: layer_rep_proj = self._project_rep(train_rep) From dc7b2bb667edf051b1fa3439c62f33b65e85a664 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Thu, 1 Jan 2026 02:52:25 -0600 Subject: [PATCH 03/18] function argument --- dattri/algorithm/influence_function.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 7dd96107c..f3de8e411 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -1141,6 +1141,13 @@ def transform_train_rep( # noqa: PLR6301 ckpt_idx: int, # noqa:ARG002 train_rep: torch.Tensor, ) -> torch.Tensor: + """Transform train representation with projection. + Args: + ckpt_idx (int): Index of model checkpoints for ensembling. + data (Tuple[torch.Tensor, ...]): Training data batch. + Returns: + torch.Tensor: Projected training representation. + """ # project if self.projector_kwargs and self.input_projectors: layer_rep_proj = self._project_rep(train_rep) From 91bb2f55abd3745db532567e16329e67d085711b Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 7 Jan 2026 02:51:47 -0600 Subject: [PATCH 04/18] add random projection for IF_Explicit --- dattri/algorithm/influence_function.py | 37 ++++++++++++++++++++++++++ dattri/func/hessian.py | 22 ++++++++++++--- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index f3de8e411..131869c52 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -62,6 +62,7 @@ def __init__( self, task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, + projector_kwargs: Optional[Dict[str, Any]] = None, device: Optional[str] = "cpu", regularization: float = 0.0, ) -> None: @@ -82,10 +83,39 @@ def __init__( Default is 0.0. """ super().__init__(task, layer_name, device) + + self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS + if projector_kwargs is not None: + self.projector_kwargs.update(projector_kwargs) self.transformation_kwargs = { "regularization": regularization, + "projector_kwargs": self.projector_kwargs, } + def transform_train_rep( + self, + ckpt_idx: int, + train_rep: torch.Tensor, + ) -> torch.Tensor: + """Transform the train representations via random projection. + + Args: + ckpt_idx (int): Index of the model checkpoints. + train_rep (torch.Tensor): Train representations to be transformed. + + Returns: + torch.Tensor: Transformed train representations with projected dimension. + """ + from dattri.func.projection import random_project + + sample_features = torch.zeros(1, train_rep.shape[1]) + projector = random_project( + sample_features, + 1, + **self.projector_kwargs + ) + return projector(train_rep, ensemble_id=0) + def transform_test_rep( self, ckpt_idx: int, @@ -155,6 +185,13 @@ def _compute_denom( if relatif_method == "l": # g_i^T (H^-1 g_i) + sample_features = torch.zeros(1, train_batch_rep.shape[1]) + projector = random_project( + sample_features, + 1, + **self.projector_kwargs + ) + test_batch_rep_proj = projector(test_batch_rep, ensemble_id=0) val = (test_batch_rep * transformed).sum(dim=1).clamp_min(1e-12).sqrt() elif relatif_method == "theta": # ||H^-1 g_i|| diff --git a/dattri/func/hessian.py b/dattri/func/hessian.py index f2fc4445e..4ba86255c 100644 --- a/dattri/func/hessian.py +++ b/dattri/func/hessian.py @@ -225,6 +225,7 @@ def ihvp_explicit( func: Callable, argnums: int = 0, regularization: float = 0.0, + projector_kwargs: Optional[Dict[str, Any]] = None, ) -> Callable: """IHVP via explicit Hessian calculation. @@ -249,7 +250,10 @@ def ihvp_explicit( A function that takes a tuple of Tensor `x` and a vector `v` and returns the IHVP of the Hessian of `func` and `v`. """ - hessian_func = hessian(func, argnums=argnums) + + from dattri.func.projection import random_project + + hessian_func = hessian(func, argnums=argnums) def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: """The IHVP function using explicit hessian. @@ -263,10 +267,20 @@ def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: The IHVP value. """ hessian_tensor = hessian_func(*x) + sample_features = torch.zeros(1, hessian_tensor.shape[0]) + projector = random_project( + sample_features, + 1, + **projector_kwargs + ) + # project H + proj_H_PT_T = projector(hessian_tensor, ensemble_id=0) + proj_P_H_PT = projector(proj_H_PT_T.T, ensemble_id=0).T + proj_v = projector(v, ensemble_id = 0) return torch.linalg.solve( - hessian_tensor - + torch.eye(hessian_tensor.shape[0]).to(v.device) * regularization, - v.T, + proj_P_H_PT + + torch.eye(proj_P_H_PT.shape[0]).to(proj_v.device) * regularization, + proj_v.T, ).T return _ihvp_explicit_func From 7a34b79b02e0fbba4ecc005c143dc37c74dc4c09 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 7 Jan 2026 03:28:22 -0600 Subject: [PATCH 05/18] ruff & darglint --- dattri/algorithm/influence_function.py | 159 +++++++++++++++---------- dattri/func/fisher.py | 38 +++--- dattri/func/hessian.py | 27 +++-- 3 files changed, 130 insertions(+), 94 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 131869c52..f2860cb6f 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import List, Optional, Tuple, Union + from typing import Any, Dict, List, Optional, Tuple, Union from torch import Tensor from torch.utils.data import DataLoader @@ -18,10 +18,10 @@ import torch from dattri.func.hessian import ihvp_arnoldi, ihvp_cg, ihvp_explicit, ihvp_lissa - -from .base import BaseAttributor, BaseInnerProductAttributor from dattri.func.projection import random_project +from .base import BaseInnerProductAttributor + def _lissa_collate_fn( sampled_input: List[Tensor], @@ -62,7 +62,7 @@ def __init__( self, task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, - projector_kwargs: Optional[Dict[str, Any]] = None, + projector_kwargs: Optional[Dict[str, Any]] = None, device: Optional[str] = "cpu", regularization: float = 0.0, ) -> None: @@ -76,6 +76,8 @@ def __init__( parameters are used. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. + projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for + random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term added to Hessian matrix. Useful for singular or ill-conditioned Hessian matrices. @@ -89,7 +91,7 @@ def __init__( self.projector_kwargs.update(projector_kwargs) self.transformation_kwargs = { "regularization": regularization, - "projector_kwargs": self.projector_kwargs, + "projector_kwargs": self.projector_kwargs, } def transform_train_rep( @@ -107,12 +109,12 @@ def transform_train_rep( torch.Tensor: Transformed train representations with projected dimension. """ from dattri.func.projection import random_project - + sample_features = torch.zeros(1, train_rep.shape[1]) projector = random_project( - sample_features, - 1, - **self.projector_kwargs + sample_features, + 1, + **self.projector_kwargs, ) return projector(train_rep, ensemble_id=0) @@ -187,12 +189,12 @@ def _compute_denom( # g_i^T (H^-1 g_i) sample_features = torch.zeros(1, train_batch_rep.shape[1]) projector = random_project( - sample_features, - 1, - **self.projector_kwargs + sample_features, + 1, + **self.projector_kwargs, ) test_batch_rep_proj = projector(test_batch_rep, ensemble_id=0) - val = (test_batch_rep * transformed).sum(dim=1).clamp_min(1e-12).sqrt() + val = (test_batch_rep_proj * transformed).sum(dim=1).clamp_min(1e-12).sqrt() elif relatif_method == "theta": # ||H^-1 g_i|| val = transformed.norm(dim=1).clamp_min(1e-12) @@ -655,6 +657,7 @@ def __init__( fim_estimate_data_ratio: float = 1.0, ) -> None: """Initialize the DataInf inverse Hessian attributor. + Args: task (AttributionTask): The task to be attributed. Must be an instance of `AttributionTask`. @@ -663,7 +666,8 @@ def __init__( parameters are used. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. - projector_kwargs: Kwargs for random projection (e.g., proj_dim=512). + projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for + random projection (e.g., proj_dim=512). Default: None. device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term for Hessian vector product. Adding `regularization * I` to the Hessian matrix, where `I` is the @@ -677,7 +681,7 @@ def __init__( self.fim_estimate_data_ratio = fim_estimate_data_ratio self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS - # per-layer projections + # per-layer projections if projector_kwargs is not None: self.projector_kwargs.update(projector_kwargs) self.layer_projectors = {} @@ -696,16 +700,16 @@ def cache( # Create per-layer projectors based on first checkpoint model_params, param_layer_map = self.task.get_param( - 0, + 0, layer_name=self.layer_name, - layer_split=True + layer_split=True, ) - + layer_sizes = [0] * (param_layer_map[-1] + 1) for idx, layer_index in enumerate(param_layer_map): layer_sizes[layer_index] += model_params[idx].numel() - - # create projectors for all the layers + + # create projectors for all the layers for layer_idx in range(param_layer_map[-1] + 1): sample_features = torch.zeros(1, layer_sizes[layer_idx]) self.layer_projectors[layer_idx] = random_project( @@ -714,7 +718,6 @@ def cache( **self.projector_kwargs, ) - for checkpoint_idx in range(len(self.task.get_checkpoints())): _cached_train_reps_list = [] iter_number = math.ceil( @@ -734,32 +737,44 @@ def cache( ckpt_idx=checkpoint_idx, data=sampled_data, ) - + sampled_data_rep_layers = self._get_layer_wise_reps( - checkpoint_idx, - sampled_data_rep + checkpoint_idx, + sampled_data_rep, ) - # project every layer + # project every layer projected_layers = [] for layer_idx, layer_rep in enumerate(sampled_data_rep_layers): projected = self.layer_projectors[layer_idx]( - layer_rep, - ensemble_id=0 + layer_rep, + ensemble_id=0, ) projected_layers.append(projected) - - # concat to one tensor + + # concat to one tensor sampled_data_rep_projected = torch.cat(projected_layers, dim=1) _cached_train_reps_list.append(sampled_data_rep_projected.detach()) - + self._cached_train_reps[checkpoint_idx] = torch.cat( _cached_train_reps_list, dim=0, ) - def transform_train_rep(self, ckpt_idx, train_rep): - """Generate train representation with per-layer projection.""" + def transform_train_rep( + self, + ckpt_idx: int, + train_rep: torch.Tensor, + ) -> torch.Tensor: + """Generate train representation with per-layer projection. + + Args: + ckpt_idx (int): Index of model checkpoints for ensembling. + train_rep (torch.Tensor): Training representations to be transformed. + + Returns: + torch.Tensor: Projected training representation. + """ # split and project train_rep_layers = self._get_layer_wise_reps(ckpt_idx, train_rep) projected_layers = [] @@ -768,9 +783,7 @@ def transform_train_rep(self, ckpt_idx, train_rep): projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) projected_layers.append(projected) - output = torch.cat(projected_layers, dim=1) - - return output + return torch.cat(projected_layers, dim=1) def transform_test_rep( self, @@ -822,7 +835,7 @@ def _transform_single_test_rep( return (v - coef @ grad.T) / regularization regularization = self.regularization - # project test + # project test test_rep_layers_raw = self._get_layer_wise_reps(ckpt_idx, test_rep) test_rep_layers_proj = [] for layer_idx, layer_rep in enumerate(test_rep_layers_raw): @@ -830,19 +843,19 @@ def _transform_single_test_rep( test_rep_layers_proj.append(projected) cached_train_rep = self._cached_train_reps[ckpt_idx] - proj_dim = self.projector_kwargs.get('proj_dim') + proj_dim = self.projector_kwargs.get("proj_dim") - # now layer dim is proj_dim after projection + # now layer dim is proj_dim after projection layer_cnt = len(test_rep_layers_raw) cached_train_rep_layers = [] start_idx = 0 - for layer_idx in range(layer_cnt): + for _layer_idx in range(layer_cnt): end_idx = start_idx + proj_dim cached_train_rep_layers.append( - cached_train_rep[:, start_idx:end_idx] + cached_train_rep[:, start_idx:end_idx], ) start_idx = end_idx - + transformed_test_rep_layers = [] # Use test batch size as intermediate batch size batch_size = test_rep.shape[0] @@ -859,10 +872,12 @@ def _transform_single_test_rep( for batch in grad_batches: reg = 0.1 if regularization is None else regularization contribution = torch.func.vmap( - lambda grad: _transform_single_test_rep( - test_rep_layers_proj[layer], - grad, - reg, + lambda grad, layer=layer, reg=reg: ( + _transform_single_test_rep( + test_rep_layers_proj[layer], + grad, + reg, + ) ), )(batch) # Accumulate the batches and average at the end @@ -934,6 +949,8 @@ def __init__( modules are used. This should be a string or a list of strings if multiple modules are needed. The name of module should follow the key of model.named_modules(). Default: None. + projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for + random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". damping (float): Damping factor used for non-convexity in EK-FAC IFVP calculation. Default is 0.0. @@ -972,8 +989,8 @@ def __init__( self.module_to_name = {v: k for k, v in self.name_to_module.items()} self.layer_cache = {} # cache for each layer - self.input_projectors = {} - self.output_projectors = {} + self.input_projectors = {} + self.output_projectors = {} # Update layer_name corresponding to selected modules self.layer_name = [] @@ -1011,7 +1028,7 @@ def cache( if max_iter is None: max_iter = len(full_train_dataloader) - # init projectors for layers + # init projectors for layers if self.projector_kwargs: for name in self.module_name: mod = self.name_to_module[name] @@ -1019,7 +1036,7 @@ def cache( if mod.bias is not None: input_dim += 1 output_dim = mod.out_features - + if name not in self.input_projectors: blksz_in = torch.zeros(1, input_dim) self.input_projectors[name] = random_project( @@ -1027,7 +1044,7 @@ def cache( feature_batch_size=1, **self.projector_kwargs, ) - + if name not in self.output_projectors: blksz_out = torch.zeros(1, output_dim) self.output_projectors[name] = random_project( @@ -1110,7 +1127,7 @@ def _ekfac_hook( def _project_rep( self, rep: torch.Tensor, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: """Unflatten representation and project each layer. Args: @@ -1155,19 +1172,24 @@ def _project_rep( else: _v = layer_rep[name + ".weight"] - # project + # project if self.projector_kwargs and name in self.input_projectors: - original_shape = _v.shape - _v_reshaped = _v.reshape(-1, original_shape[-1]) + original_shape = _v.shape + _v_reshaped = _v.reshape(-1, original_shape[-1]) _v_projected = self.input_projectors[name](_v_reshaped, ensemble_id=0) - _v = _v_projected.reshape(original_shape[0], original_shape[1], -1) + _v = _v_projected.reshape(original_shape[0], original_shape[1], -1) if self.projector_kwargs and name in self.output_projectors: - original_shape = _v.shape - _v_transposed = _v.transpose(1, 2) - _v_reshaped = _v_transposed.reshape(-1, original_shape[1]) - _v_projected = self.output_projectors[name](_v_reshaped, ensemble_id=0) - _v = _v_projected.reshape(original_shape[0], -1, _v_projected.shape[-1]).transpose(1, 2) + original_shape = _v.shape + _v_transposed = _v.transpose(1, 2) + _v_reshaped = _v_transposed.reshape(-1, original_shape[1]) + _v_projected = self.output_projectors[name]( + _v_reshaped, ensemble_id=0, + ) + _v = _v_projected.reshape( + original_shape[0], -1, _v_projected.shape[-1], + ) + _v = _v.transpose(1, 2) layer_rep_proj[name] = _v @@ -1179,17 +1201,22 @@ def transform_train_rep( # noqa: PLR6301 train_rep: torch.Tensor, ) -> torch.Tensor: """Transform train representation with projection. + Args: ckpt_idx (int): Index of model checkpoints for ensembling. - data (Tuple[torch.Tensor, ...]): Training data batch. + train_rep (torch.Tensor): Training representations. + Returns: torch.Tensor: Projected training representation. """ - # project + # project if self.projector_kwargs and self.input_projectors: layer_rep_proj = self._project_rep(train_rep) - # flatten - projected_layers = [layer_rep_proj[name].flatten(start_dim=1) for name in self.module_name] + # flatten + projected_layers = [ + layer_rep_proj[name].flatten(start_dim=1) + for name in self.module_name + ] train_rep = torch.cat(projected_layers, dim=1) return train_rep @@ -1221,7 +1248,7 @@ def transform_test_rep( ) raise ValueError(error_msg) - # project test rep + # project test rep test_rep_proj = self._project_rep(test_rep) ifvp = {} for name in self.module_name: @@ -1235,4 +1262,4 @@ def transform_test_rep( # Flatten the parameters again transformed_test_rep_layers = [ifvp[name] for name in self.module_name] - return torch.cat(transformed_test_rep_layers, dim=1) \ No newline at end of file + return torch.cat(transformed_test_rep_layers, dim=1) diff --git a/dattri/func/fisher.py b/dattri/func/fisher.py index a0faf2ff9..8ac054d7d 100644 --- a/dattri/func/fisher.py +++ b/dattri/func/fisher.py @@ -277,11 +277,13 @@ def _update_covariance( # Calculate batch covariance matrix for A a_prev_reshaped = a_prev_masked.view(-1, a_prev.size(-1)) - - # project input activations + + # project input activations if input_projectors is not None and layer_name in input_projectors: - a_prev_reshaped = input_projectors[layer_name](a_prev_reshaped, ensemble_id=0) - + a_prev_reshaped = input_projectors[layer_name]( + a_prev_reshaped, ensemble_id=0, + ) + batch_cov_a = a_prev_reshaped.transpose(0, 1) @ a_prev_reshaped batch_cov_a /= batch_samples @@ -289,13 +291,15 @@ def _update_covariance( ds_curr = s_curr_raw.grad ds_curr_reshaped = ds_curr.view(-1, s_curr_raw.size(-1)) - - # project output gradients + + # project output gradients if output_projectors is not None and layer_name in output_projectors: - ds_curr_reshaped = output_projectors[layer_name](ds_curr_reshaped, ensemble_id=0) - + ds_curr_reshaped = output_projectors[layer_name]( + ds_curr_reshaped, ensemble_id=0, + ) + batch_cov_s = ds_curr_reshaped.transpose(0, 1) @ ds_curr_reshaped - batch_cov_s /= batch_samples + batch_cov_s /= batch_samples # Update the running covariance matrices for A and S if layer_name in curr_estimate: @@ -361,18 +365,22 @@ def _update_lambda( a_prev = a_prev_raw.unsqueeze(1) ds_curr = ds_curr.unsqueeze(1) - # project + # project if input_projectors is not None and layer_name in input_projectors: - original_shape = a_prev.shape + original_shape = a_prev.shape a_prev_flat = a_prev.view(-1, a_prev.size(-1)) a_prev_projected = input_projectors[layer_name](a_prev_flat, ensemble_id=0) a_prev = a_prev_projected.view(original_shape[0], original_shape[1], -1) - + if output_projectors is not None and layer_name in output_projectors: - original_shape = ds_curr.shape + original_shape = ds_curr.shape ds_curr_flat = ds_curr.view(-1, ds_curr.size(-1)) - ds_curr_projected = output_projectors[layer_name](ds_curr_flat, ensemble_id=0) - ds_curr = ds_curr_projected.view(original_shape[0], original_shape[1], -1) + ds_curr_projected = output_projectors[layer_name]( + ds_curr_flat, ensemble_id=0, + ) + ds_curr = ds_curr_projected.view( + original_shape[0], original_shape[1], -1, + ) a_prev_masked = a_prev * mask[..., None].to(a_prev.device) diff --git a/dattri/func/hessian.py b/dattri/func/hessian.py index 4ba86255c..20114e90f 100644 --- a/dattri/func/hessian.py +++ b/dattri/func/hessian.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing import Optional, Tuple, Union + from typing import Any, Optional, Tuple, Union import torch from torch import Tensor @@ -225,7 +225,7 @@ def ihvp_explicit( func: Callable, argnums: int = 0, regularization: float = 0.0, - projector_kwargs: Optional[Dict[str, Any]] = None, + projector_kwargs: Optional[dict[str, Any]] = None, ) -> Callable: """IHVP via explicit Hessian calculation. @@ -245,15 +245,16 @@ def ihvp_explicit( matrix is singular or ill-conditioned. The regularization term is `regularization * I`, where `I` is the identity matrix directly added to the Hessian matrix. + projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for + random projection. Default: None. Returns: A function that takes a tuple of Tensor `x` and a vector `v` and returns the IHVP of the Hessian of `func` and `v`. """ - from dattri.func.projection import random_project - hessian_func = hessian(func, argnums=argnums) + hessian_func = hessian(func, argnums=argnums) def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: """The IHVP function using explicit hessian. @@ -269,17 +270,17 @@ def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: hessian_tensor = hessian_func(*x) sample_features = torch.zeros(1, hessian_tensor.shape[0]) projector = random_project( - sample_features, - 1, - **projector_kwargs + sample_features, + 1, + **projector_kwargs, ) - # project H - proj_H_PT_T = projector(hessian_tensor, ensemble_id=0) - proj_P_H_PT = projector(proj_H_PT_T.T, ensemble_id=0).T - proj_v = projector(v, ensemble_id = 0) + # project H + proj_h_pt_t = projector(hessian_tensor, ensemble_id=0) + proj_p_h_pt = projector(proj_h_pt_t.T, ensemble_id=0).T + proj_v = projector(v, ensemble_id=0) return torch.linalg.solve( - proj_P_H_PT - + torch.eye(proj_P_H_PT.shape[0]).to(proj_v.device) * regularization, + proj_p_h_pt + + torch.eye(proj_p_h_pt.shape[0]).to(proj_v.device) * regularization, proj_v.T, ).T From 5e0bf2310b7f360690cd976f7c85140192930325 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 7 Jan 2026 11:24:09 -0600 Subject: [PATCH 06/18] ruff --- dattri/algorithm/influence_function.py | 8 ++++---- dattri/func/fisher.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index f2860cb6f..964120c0a 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -96,7 +96,7 @@ def __init__( def transform_train_rep( self, - ckpt_idx: int, + ckpt_idx: int, # noqa:ARG002 train_rep: torch.Tensor, ) -> torch.Tensor: """Transform the train representations via random projection. @@ -667,7 +667,7 @@ def __init__( if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for - random projection (e.g., proj_dim=512). Default: None. + random projection (e.g., proj_dim=512). device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term for Hessian vector product. Adding `regularization * I` to the Hessian matrix, where `I` is the @@ -785,7 +785,7 @@ def transform_train_rep( return torch.cat(projected_layers, dim=1) - def transform_test_rep( + def transform_test_rep( # noqa: PLR0914 self, ckpt_idx: int, test_rep: torch.Tensor, @@ -950,7 +950,7 @@ def __init__( multiple modules are needed. The name of module should follow the key of model.named_modules(). Default: None. projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for - random projection. Default: None. + random projection. device (str): Device to run the attributor on. Default is "cpu". damping (float): Damping factor used for non-convexity in EK-FAC IFVP calculation. Default is 0.0. diff --git a/dattri/func/fisher.py b/dattri/func/fisher.py index 8ac054d7d..f122e5737 100644 --- a/dattri/func/fisher.py +++ b/dattri/func/fisher.py @@ -317,7 +317,7 @@ def _update_covariance( return curr_estimate -def _update_lambda( +def _update_lambda( # noqa: PLR0914 curr_estimate: Dict[str, torch.tensor], layer_cache: Dict[str, Tuple[torch.tensor]], cached_q: Dict[str, Tuple[torch.tensor]], From d3eb026a57d1c16fa88aa4148ecedb2981fafbb6 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 7 Jan 2026 17:57:19 -0600 Subject: [PATCH 07/18] minor fix --- dattri/algorithm/influence_function.py | 79 ++++++++++++++------------ 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 964120c0a..ec96ca6e4 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -842,19 +842,12 @@ def _transform_single_test_rep( projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) test_rep_layers_proj.append(projected) - cached_train_rep = self._cached_train_reps[ckpt_idx] - proj_dim = self.projector_kwargs.get("proj_dim") - - # now layer dim is proj_dim after projection - layer_cnt = len(test_rep_layers_raw) - cached_train_rep_layers = [] - start_idx = 0 - for _layer_idx in range(layer_cnt): - end_idx = start_idx + proj_dim - cached_train_rep_layers.append( - cached_train_rep[:, start_idx:end_idx], - ) - start_idx = end_idx + cached_train_rep_layers = self._get_layer_wise_reps( + ckpt_idx, + self._cached_train_reps[ckpt_idx], + projected=True, + ) + layer_cnt = len(cached_train_rep_layers) transformed_test_rep_layers = [] # Use test batch size as intermediate batch size @@ -890,6 +883,7 @@ def _get_layer_wise_reps( self, ckpt_idx: int, query: torch.Tensor, + projected: bool = False, ) -> Tuple[torch.Tensor, ...]: """Split a representation into layer-wise representations. @@ -898,11 +892,24 @@ def _get_layer_wise_reps( is used for ensembling of different trained model. query (torch.Tensor): Input representations to split, could be train/test representations, of shape (batch_size,parameter) + projected (bool): If True, assumes query is already projected and splits + evenly by proj_dim. Default: False. Returns: Tuple[torch.Tensor, ...]: The layer-wise splitted tensor, a tuple of shape (batch_size,layer0_size), (batch_size,layer1_size)... """ + if projected: + # Split evenly by proj_dim + proj_dim = self.projector_kwargs.get("proj_dim") + num_layers = query.shape[1] // proj_dim + query_layers = [] + for i in range(num_layers): + start_idx = i * proj_dim + end_idx = start_idx + proj_dim + query_layers.append(query[:, start_idx:end_idx]) + return query_layers + # Original splitting logic model_params, param_layer_map = self.task.get_param( ckpt_idx, layer_split=True, @@ -1028,30 +1035,28 @@ def cache( if max_iter is None: max_iter = len(full_train_dataloader) - # init projectors for layers - if self.projector_kwargs: - for name in self.module_name: - mod = self.name_to_module[name] - input_dim = mod.in_features - if mod.bias is not None: - input_dim += 1 - output_dim = mod.out_features - - if name not in self.input_projectors: - blksz_in = torch.zeros(1, input_dim) - self.input_projectors[name] = random_project( - blksz_in, - feature_batch_size=1, - **self.projector_kwargs, - ) + for name in self.module_name: + mod = self.name_to_module[name] + input_dim = mod.in_features + if mod.bias is not None: + input_dim += 1 + output_dim = mod.out_features + + if name not in self.input_projectors: + blksz_in = torch.zeros(1, input_dim) + self.input_projectors[name] = random_project( + blksz_in, + feature_batch_size=1, + **self.projector_kwargs, + ) - if name not in self.output_projectors: - blksz_out = torch.zeros(1, output_dim) - self.output_projectors[name] = random_project( - blksz_out, - feature_batch_size=1, - **self.projector_kwargs, - ) + if name not in self.output_projectors: + blksz_out = torch.zeros(1, output_dim) + self.output_projectors[name] = random_project( + blksz_out, + feature_batch_size=1, + **self.projector_kwargs, + ) def _ekfac_hook( module: torch.nn.Module, @@ -1127,7 +1132,7 @@ def _ekfac_hook( def _project_rep( self, rep: torch.Tensor, - ) -> dict[str, torch.Tensor]: + ) -> Dict[str, torch.Tensor]: """Unflatten representation and project each layer. Args: From c992481a643738c1e14f1b477ba690605173fd34 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 7 Jan 2026 20:54:34 -0600 Subject: [PATCH 08/18] make projection optional --- dattri/algorithm/influence_function.py | 208 ++++++++++++------------- dattri/func/hessian.py | 33 ++-- 2 files changed, 123 insertions(+), 118 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index ec96ca6e4..318ef8fd2 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -47,13 +47,6 @@ def _lissa_collate_fn( "lissa": partial(ihvp_lissa, collate_fn=_lissa_collate_fn), } -DEFAULT_PROJECTOR_KWARGS = { - "proj_dim": 512, - "proj_max_batch_size": 32, - "proj_seed": 0, - "device": "cpu", -} - class IFAttributorExplicit(BaseInnerProductAttributor): """The inner product attributor with explicit inverse hessian transformation.""" @@ -86,9 +79,7 @@ def __init__( """ super().__init__(task, layer_name, device) - self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS - if projector_kwargs is not None: - self.projector_kwargs.update(projector_kwargs) + self.projector_kwargs = projector_kwargs self.transformation_kwargs = { "regularization": regularization, "projector_kwargs": self.projector_kwargs, @@ -110,13 +101,15 @@ def transform_train_rep( """ from dattri.func.projection import random_project - sample_features = torch.zeros(1, train_rep.shape[1]) - projector = random_project( - sample_features, - 1, - **self.projector_kwargs, - ) - return projector(train_rep, ensemble_id=0) + if self.projector_kwargs is not None: + sample_features = torch.zeros(1, train_rep.shape[1]) + projector = random_project( + sample_features, + 1, + **self.projector_kwargs, + ) + return projector(train_rep, ensemble_id=0) + return train_rep def transform_test_rep( self, @@ -187,14 +180,15 @@ def _compute_denom( if relatif_method == "l": # g_i^T (H^-1 g_i) - sample_features = torch.zeros(1, train_batch_rep.shape[1]) - projector = random_project( - sample_features, - 1, - **self.projector_kwargs, - ) - test_batch_rep_proj = projector(test_batch_rep, ensemble_id=0) - val = (test_batch_rep_proj * transformed).sum(dim=1).clamp_min(1e-12).sqrt() + if self.projector_kwargs is not None: + sample_features = torch.zeros(1, train_batch_rep.shape[1]) + projector = random_project( + sample_features, + 1, + **self.projector_kwargs, + ) + test_batch_rep = projector(test_batch_rep, ensemble_id=0) + val = (test_batch_rep * transformed).sum(dim=1).clamp_min(1e-12).sqrt() elif relatif_method == "theta": # ||H^-1 g_i|| val = transformed.norm(dim=1).clamp_min(1e-12) @@ -679,11 +673,9 @@ def __init__( super().__init__(task, layer_name, device) self.regularization = regularization self.fim_estimate_data_ratio = fim_estimate_data_ratio - self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS # per-layer projections - if projector_kwargs is not None: - self.projector_kwargs.update(projector_kwargs) + self.projector_kwargs = projector_kwargs self.layer_projectors = {} def cache( @@ -705,18 +697,19 @@ def cache( layer_split=True, ) - layer_sizes = [0] * (param_layer_map[-1] + 1) - for idx, layer_index in enumerate(param_layer_map): - layer_sizes[layer_index] += model_params[idx].numel() + if self.projector_kwargs is not None: + layer_sizes = [0] * (param_layer_map[-1] + 1) + for idx, layer_index in enumerate(param_layer_map): + layer_sizes[layer_index] += model_params[idx].numel() - # create projectors for all the layers - for layer_idx in range(param_layer_map[-1] + 1): - sample_features = torch.zeros(1, layer_sizes[layer_idx]) - self.layer_projectors[layer_idx] = random_project( - sample_features, - feature_batch_size=1, - **self.projector_kwargs, - ) + # create projectors for all the layers + for layer_idx in range(param_layer_map[-1] + 1): + sample_features = torch.zeros(1, layer_sizes[layer_idx]) + self.layer_projectors[layer_idx] = random_project( + sample_features, + feature_batch_size=1, + **self.projector_kwargs, + ) for checkpoint_idx in range(len(self.task.get_checkpoints())): _cached_train_reps_list = [] @@ -738,23 +731,24 @@ def cache( data=sampled_data, ) - sampled_data_rep_layers = self._get_layer_wise_reps( - checkpoint_idx, - sampled_data_rep, - ) - - # project every layer - projected_layers = [] - for layer_idx, layer_rep in enumerate(sampled_data_rep_layers): - projected = self.layer_projectors[layer_idx]( - layer_rep, - ensemble_id=0, + if self.projector_kwargs is not None: + sampled_data_rep_layers = self._get_layer_wise_reps( + checkpoint_idx, + sampled_data_rep, ) - projected_layers.append(projected) - # concat to one tensor - sampled_data_rep_projected = torch.cat(projected_layers, dim=1) - _cached_train_reps_list.append(sampled_data_rep_projected.detach()) + # project every layer + projected_layers = [] + for layer_idx, layer_rep in enumerate(sampled_data_rep_layers): + projected = self.layer_projectors[layer_idx]( + layer_rep, + ensemble_id=0, + ) + projected_layers.append(projected) + + # concat to one tensor + sampled_data_rep = torch.cat(projected_layers, dim=1) + _cached_train_reps_list.append(sampled_data_rep.detach()) self._cached_train_reps[checkpoint_idx] = torch.cat( _cached_train_reps_list, @@ -776,14 +770,16 @@ def transform_train_rep( torch.Tensor: Projected training representation. """ # split and project - train_rep_layers = self._get_layer_wise_reps(ckpt_idx, train_rep) - projected_layers = [] + if self.projector_kwargs is not None: + train_rep_layers = self._get_layer_wise_reps(ckpt_idx, train_rep) + projected_layers = [] - for layer_idx, layer_rep in enumerate(train_rep_layers): - projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) - projected_layers.append(projected) + for layer_idx, layer_rep in enumerate(train_rep_layers): + projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) + projected_layers.append(projected) - return torch.cat(projected_layers, dim=1) + return torch.cat(projected_layers, dim=1) + return train_rep def transform_test_rep( # noqa: PLR0914 self, @@ -835,18 +831,27 @@ def _transform_single_test_rep( return (v - coef @ grad.T) / regularization regularization = self.regularization - # project test - test_rep_layers_raw = self._get_layer_wise_reps(ckpt_idx, test_rep) - test_rep_layers_proj = [] - for layer_idx, layer_rep in enumerate(test_rep_layers_raw): - projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) - test_rep_layers_proj.append(projected) - - cached_train_rep_layers = self._get_layer_wise_reps( - ckpt_idx, - self._cached_train_reps[ckpt_idx], - projected=True, - ) + if self.projector_kwargs is not None: + # project test + test_rep_layers_raw = self._get_layer_wise_reps(ckpt_idx, test_rep) + test_rep_layers = [] + for layer_idx, layer_rep in enumerate(test_rep_layers_raw): + projected = self.layer_projectors[layer_idx](layer_rep, ensemble_id=0) + test_rep_layers.append(projected) + + # project train + cached_train_rep_layers = self._get_layer_wise_reps( + ckpt_idx, + self._cached_train_reps[ckpt_idx], + projected=True, + ) + else: + test_rep_layers = self._get_layer_wise_reps(ckpt_idx, test_rep) + cached_train_rep_layers = self._get_layer_wise_reps( + ckpt_idx, + self._cached_train_reps[ckpt_idx], + projected=False, + ) layer_cnt = len(cached_train_rep_layers) transformed_test_rep_layers = [] @@ -859,15 +864,15 @@ def _transform_single_test_rep( # Split gradients into smaller batches grad_batches = grad_layer.split(batch_size, dim=0) running_transformation = torch.zeros( - test_rep_layers_proj[layer].shape, - device=test_rep_layers_proj[layer].device, + test_rep_layers[layer].shape, + device=test_rep_layers[layer].device, ) for batch in grad_batches: reg = 0.1 if regularization is None else regularization contribution = torch.func.vmap( lambda grad, layer=layer, reg=reg: ( _transform_single_test_rep( - test_rep_layers_proj[layer], + test_rep_layers[layer], grad, reg, ) @@ -984,10 +989,7 @@ def __init__( module_name = [module_name] self.module_name = module_name - - self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS - if projector_kwargs is not None: - self.projector_kwargs.update(projector_kwargs) + self.projector_kwargs = projector_kwargs self.damping = damping self.name_to_module = { @@ -1034,29 +1036,29 @@ def cache( if max_iter is None: max_iter = len(full_train_dataloader) + if self.projector_kwargs is not None: + for name in self.module_name: + mod = self.name_to_module[name] + input_dim = mod.in_features + if mod.bias is not None: + input_dim += 1 + output_dim = mod.out_features + + if name not in self.input_projectors: + blksz_in = torch.zeros(1, input_dim) + self.input_projectors[name] = random_project( + blksz_in, + feature_batch_size=1, + **self.projector_kwargs, + ) - for name in self.module_name: - mod = self.name_to_module[name] - input_dim = mod.in_features - if mod.bias is not None: - input_dim += 1 - output_dim = mod.out_features - - if name not in self.input_projectors: - blksz_in = torch.zeros(1, input_dim) - self.input_projectors[name] = random_project( - blksz_in, - feature_batch_size=1, - **self.projector_kwargs, - ) - - if name not in self.output_projectors: - blksz_out = torch.zeros(1, output_dim) - self.output_projectors[name] = random_project( - blksz_out, - feature_batch_size=1, - **self.projector_kwargs, - ) + if name not in self.output_projectors: + blksz_out = torch.zeros(1, output_dim) + self.output_projectors[name] = random_project( + blksz_out, + feature_batch_size=1, + **self.projector_kwargs, + ) def _ekfac_hook( module: torch.nn.Module, @@ -1176,8 +1178,6 @@ def _project_rep( _v = _v.reshape(-1, dim_out, dim_in) else: _v = layer_rep[name + ".weight"] - - # project if self.projector_kwargs and name in self.input_projectors: original_shape = _v.shape _v_reshaped = _v.reshape(-1, original_shape[-1]) @@ -1215,7 +1215,7 @@ def transform_train_rep( # noqa: PLR6301 torch.Tensor: Projected training representation. """ # project - if self.projector_kwargs and self.input_projectors: + if self.projector_kwargs: layer_rep_proj = self._project_rep(train_rep) # flatten projected_layers = [ diff --git a/dattri/func/hessian.py b/dattri/func/hessian.py index 20114e90f..c286a67d0 100644 --- a/dattri/func/hessian.py +++ b/dattri/func/hessian.py @@ -268,22 +268,27 @@ def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: The IHVP value. """ hessian_tensor = hessian_func(*x) - sample_features = torch.zeros(1, hessian_tensor.shape[0]) - projector = random_project( - sample_features, - 1, - **projector_kwargs, - ) - # project H - proj_h_pt_t = projector(hessian_tensor, ensemble_id=0) - proj_p_h_pt = projector(proj_h_pt_t.T, ensemble_id=0).T - proj_v = projector(v, ensemble_id=0) + if projector_kwargs is not None: + sample_features = torch.zeros(1, hessian_tensor.shape[0]) + projector = random_project( + sample_features, + 1, + **projector_kwargs, + ) + # project H + proj_h_pt_t = projector(hessian_tensor, ensemble_id=0) + proj_p_h_pt = projector(proj_h_pt_t.T, ensemble_id=0).T + proj_v = projector(v, ensemble_id=0) + return torch.linalg.solve( + proj_p_h_pt + + torch.eye(proj_p_h_pt.shape[0]).to(proj_v.device) * regularization, + proj_v.T, + ).T return torch.linalg.solve( - proj_p_h_pt - + torch.eye(proj_p_h_pt.shape[0]).to(proj_v.device) * regularization, - proj_v.T, + hessian_tensor + + torch.eye(hessian_tensor.shape[0]).to(v.device) * regularization, + v.T, ).T - return _ihvp_explicit_func From e2ee9ec8ae1fe3789223c396b3fde07c77c32b4b Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Fri, 9 Jan 2026 11:49:33 -0600 Subject: [PATCH 09/18] minor fix --- dattri/algorithm/influence_function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 318ef8fd2..c1a33f38e 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -99,7 +99,6 @@ def transform_train_rep( Returns: torch.Tensor: Transformed train representations with projected dimension. """ - from dattri.func.projection import random_project if self.projector_kwargs is not None: sample_features = torch.zeros(1, train_rep.shape[1]) From 0aae59f0a3beae8f2271f4e5a971859945f17079 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Fri, 9 Jan 2026 12:04:10 -0600 Subject: [PATCH 10/18] ruff --- dattri/algorithm/influence_function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index c1a33f38e..cfcabeeaa 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -99,7 +99,6 @@ def transform_train_rep( Returns: torch.Tensor: Transformed train representations with projected dimension. """ - if self.projector_kwargs is not None: sample_features = torch.zeros(1, train_rep.shape[1]) projector = random_project( From 44488e5a283389e1a6c9a4a8006187c0063f2724 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 4 Feb 2026 13:57:01 -0600 Subject: [PATCH 11/18] add simple unit test --- .../algorithm/test_influence_function.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/dattri/algorithm/test_influence_function.py b/test/dattri/algorithm/test_influence_function.py index aa2678b29..3553f3e4f 100644 --- a/test/dattri/algorithm/test_influence_function.py +++ b/test/dattri/algorithm/test_influence_function.py @@ -464,3 +464,63 @@ def f(params, data_target_pair): threshold = 0.98 corr = average_pairwise_correlation(gt_test_rep, transformed_test_rep) assert corr > threshold + + def test_influence_functions_with_random_projection(self): + """Test for random projection in Explicit, DataInf and EK-FAC attributors.""" + + train_dataset = TensorDataset( + torch.randn(20, 1, 28, 28), + torch.randint(0, 10, (20,)), + ) + train_loader = DataLoader(train_dataset, batch_size=4) + + projector_kwargs = { + "proj_dim": 512, + "proj_max_batch_size": 32, + "proj_seed": 0, + "device": "cpu", + } + model = train_mnist_lr(train_loader) + + def f(params, data_target_pair): + image, label = data_target_pair + loss = nn.CrossEntropyLoss() + yhat = torch.func.functional_call(model, params, image) + return loss(yhat, label.long()) + + task = AttributionTask( + loss_func=f, + model=model, + checkpoints=model.state_dict(), + ) + + # Explicit with random projection + attributor_exp = IFAttributorExplicit( + task=task, + device=torch.device("cpu"), + regularization=1e-3, + projector_kwargs=projector_kwargs, + ) + attributor_exp.cache(train_loader) + attributor_exp.attribute(train_loader, train_loader) + + # DataInf with random projection + attributor_datainf = IFAttributorDataInf( + task=task, + device=torch.device("cpu"), + regularization=1e-3, + projector_kwargs=projector_kwargs, + ) + attributor_datainf.cache(train_loader) + attributor_datainf.attribute(train_loader, train_loader) + + # EK-FAC with random projection + attributor_ekfac = IFAttributorEKFAC( + task=task, + device=torch.device("cpu"), + damping=0.1, + projector_kwargs=projector_kwargs, + ) + attributor_ekfac.cache(train_loader) + attributor_ekfac.attribute(train_loader, train_loader) + From 7799b81471ff023d45d7e84f62fab9a87717f729 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Thu, 5 Feb 2026 00:34:49 -0600 Subject: [PATCH 12/18] ruff --- test/dattri/algorithm/test_influence_function.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/dattri/algorithm/test_influence_function.py b/test/dattri/algorithm/test_influence_function.py index 3553f3e4f..34b2e534d 100644 --- a/test/dattri/algorithm/test_influence_function.py +++ b/test/dattri/algorithm/test_influence_function.py @@ -467,7 +467,6 @@ def f(params, data_target_pair): def test_influence_functions_with_random_projection(self): """Test for random projection in Explicit, DataInf and EK-FAC attributors.""" - train_dataset = TensorDataset( torch.randn(20, 1, 28, 28), torch.randint(0, 10, (20,)), @@ -522,5 +521,4 @@ def f(params, data_target_pair): projector_kwargs=projector_kwargs, ) attributor_ekfac.cache(train_loader) - attributor_ekfac.attribute(train_loader, train_loader) - + attributor_ekfac.attribute(train_loader, train_loader) From 47ba2dedcac48b805507a16297eb0fa8263e9fee Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 11 Mar 2026 21:15:56 -0500 Subject: [PATCH 13/18] adapt latest projection API --- dattri/algorithm/influence_function.py | 102 +++++++++++------- dattri/func/hessian.py | 13 +-- dattri/params/projection.py | 9 +- .../algorithm/test_influence_function.py | 12 ++- 4 files changed, 87 insertions(+), 49 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index cfcabeeaa..002ac7949 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Dict, List, Optional, Tuple, Union + from typing import Dict, List, Optional, Tuple, Union from torch import Tensor from torch.utils.data import DataLoader @@ -19,6 +19,7 @@ from dattri.func.hessian import ihvp_arnoldi, ihvp_cg, ihvp_explicit, ihvp_lissa from dattri.func.projection import random_project +from dattri.params.projection import IFProjectionParams, RandomProjectionParams from .base import BaseInnerProductAttributor @@ -55,7 +56,7 @@ def __init__( self, task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, - projector_kwargs: Optional[Dict[str, Any]] = None, + proj_params: Optional[IFProjectionParams] = None, device: Optional[str] = "cpu", regularization: float = 0.0, ) -> None: @@ -69,7 +70,7 @@ def __init__( parameters are used. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. - projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for + proj_params (Optional[IFProjectionParams]): Projection parameters for random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term added to Hessian matrix. @@ -79,11 +80,8 @@ def __init__( """ super().__init__(task, layer_name, device) - self.projector_kwargs = projector_kwargs - self.transformation_kwargs = { - "regularization": regularization, - "projector_kwargs": self.projector_kwargs, - } + self.proj_params = proj_params + self.regularization = regularization def transform_train_rep( self, @@ -99,12 +97,15 @@ def transform_train_rep( Returns: torch.Tensor: Transformed train representations with projected dimension. """ - if self.projector_kwargs is not None: + if self.proj_params is not None: sample_features = torch.zeros(1, train_rep.shape[1]) projector = random_project( sample_features, - 1, - **self.projector_kwargs, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), ) return projector(train_rep, ensemble_id=0) return train_rep @@ -132,6 +133,14 @@ def transform_test_rep( for full_data_ in self.full_train_dataloader: # move to device full_data = tuple(data.to(self.device) for data in full_data_) + # Create proj_params_obj only if projection is needed + rand_proj_params = None + if self.proj_params is not None: + rand_proj_params = RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ) self.ihvp_func = ihvp_explicit( partial( self.task.get_loss_func( @@ -140,7 +149,8 @@ def transform_test_rep( ), **{self.task.loss_func_data_key: full_data}, ), - **self.transformation_kwargs, + regularization=self.regularization, + proj_params=rand_proj_params, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() return vector_product @@ -178,12 +188,19 @@ def _compute_denom( if relatif_method == "l": # g_i^T (H^-1 g_i) - if self.projector_kwargs is not None: + if self.proj_params is not None: sample_features = torch.zeros(1, train_batch_rep.shape[1]) + + rand_proj_params = None + if self.proj_params is not None: + rand_proj_params = RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ) projector = random_project( sample_features, - 1, - **self.projector_kwargs, + proj_params=rand_proj_params, ) test_batch_rep = projector(test_batch_rep, ensemble_id=0) val = (test_batch_rep * transformed).sum(dim=1).clamp_min(1e-12).sqrt() @@ -643,7 +660,7 @@ def __init__( self, task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, - projector_kwargs: Optional[Dict[str, Any]] = None, + proj_params: Optional[IFProjectionParams] = None, device: Optional[str] = "cpu", regularization: float = 0.0, fim_estimate_data_ratio: float = 1.0, @@ -658,8 +675,8 @@ def __init__( parameters are used. This should be a string or a list of strings if multiple layers are needed. The name of layer should follow the key of model.named_parameters(). Default: None. - projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for - random projection (e.g., proj_dim=512). + proj_params (Optional[IFProjectionParams]): Projection parameters for + random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". regularization (float): Regularization term for Hessian vector product. Adding `regularization * I` to the Hessian matrix, where `I` is the @@ -673,7 +690,7 @@ def __init__( self.fim_estimate_data_ratio = fim_estimate_data_ratio # per-layer projections - self.projector_kwargs = projector_kwargs + self.proj_params = proj_params self.layer_projectors = {} def cache( @@ -695,7 +712,7 @@ def cache( layer_split=True, ) - if self.projector_kwargs is not None: + if self.proj_params is not None: layer_sizes = [0] * (param_layer_map[-1] + 1) for idx, layer_index in enumerate(param_layer_map): layer_sizes[layer_index] += model_params[idx].numel() @@ -705,8 +722,11 @@ def cache( sample_features = torch.zeros(1, layer_sizes[layer_idx]) self.layer_projectors[layer_idx] = random_project( sample_features, - feature_batch_size=1, - **self.projector_kwargs, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), ) for checkpoint_idx in range(len(self.task.get_checkpoints())): @@ -729,7 +749,7 @@ def cache( data=sampled_data, ) - if self.projector_kwargs is not None: + if self.proj_params is not None: sampled_data_rep_layers = self._get_layer_wise_reps( checkpoint_idx, sampled_data_rep, @@ -768,7 +788,7 @@ def transform_train_rep( torch.Tensor: Projected training representation. """ # split and project - if self.projector_kwargs is not None: + if self.proj_params is not None: train_rep_layers = self._get_layer_wise_reps(ckpt_idx, train_rep) projected_layers = [] @@ -829,7 +849,7 @@ def _transform_single_test_rep( return (v - coef @ grad.T) / regularization regularization = self.regularization - if self.projector_kwargs is not None: + if self.proj_params is not None: # project test test_rep_layers_raw = self._get_layer_wise_reps(ckpt_idx, test_rep) test_rep_layers = [] @@ -904,7 +924,7 @@ def _get_layer_wise_reps( """ if projected: # Split evenly by proj_dim - proj_dim = self.projector_kwargs.get("proj_dim") + proj_dim = self.proj_params.get("proj_dim") num_layers = query.shape[1] // proj_dim query_layers = [] for i in range(num_layers): @@ -935,7 +955,7 @@ def __init__( self, task: AttributionTask, module_name: Optional[Union[str, List[str]]] = None, - projector_kwargs: Optional[Dict[str, Any]] = None, + proj_params: Optional[IFProjectionParams] = None, device: Optional[str] = "cpu", damping: float = 0.0, ) -> None: @@ -959,8 +979,8 @@ def __init__( modules are used. This should be a string or a list of strings if multiple modules are needed. The name of module should follow the key of model.named_modules(). Default: None. - projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for - random projection. + proj_params (Optional[IFProjectionParams]): Projection parameters for + random projection. Default: None. device (str): Device to run the attributor on. Default is "cpu". damping (float): Damping factor used for non-convexity in EK-FAC IFVP calculation. Default is 0.0. @@ -987,7 +1007,7 @@ def __init__( module_name = [module_name] self.module_name = module_name - self.projector_kwargs = projector_kwargs + self.proj_params = proj_params self.damping = damping self.name_to_module = { @@ -1034,7 +1054,7 @@ def cache( if max_iter is None: max_iter = len(full_train_dataloader) - if self.projector_kwargs is not None: + if self.proj_params is not None: for name in self.module_name: mod = self.name_to_module[name] input_dim = mod.in_features @@ -1046,16 +1066,22 @@ def cache( blksz_in = torch.zeros(1, input_dim) self.input_projectors[name] = random_project( blksz_in, - feature_batch_size=1, - **self.projector_kwargs, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), ) if name not in self.output_projectors: blksz_out = torch.zeros(1, output_dim) self.output_projectors[name] = random_project( blksz_out, - feature_batch_size=1, - **self.projector_kwargs, + proj_params=RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ), ) def _ekfac_hook( @@ -1176,13 +1202,13 @@ def _project_rep( _v = _v.reshape(-1, dim_out, dim_in) else: _v = layer_rep[name + ".weight"] - if self.projector_kwargs and name in self.input_projectors: + if self.proj_params and name in self.input_projectors: original_shape = _v.shape _v_reshaped = _v.reshape(-1, original_shape[-1]) _v_projected = self.input_projectors[name](_v_reshaped, ensemble_id=0) _v = _v_projected.reshape(original_shape[0], original_shape[1], -1) - if self.projector_kwargs and name in self.output_projectors: + if self.proj_params and name in self.output_projectors: original_shape = _v.shape _v_transposed = _v.transpose(1, 2) _v_reshaped = _v_transposed.reshape(-1, original_shape[1]) @@ -1213,7 +1239,7 @@ def transform_train_rep( # noqa: PLR6301 torch.Tensor: Projected training representation. """ # project - if self.projector_kwargs: + if self.proj_params: layer_rep_proj = self._project_rep(train_rep) # flatten projected_layers = [ diff --git a/dattri/func/hessian.py b/dattri/func/hessian.py index c286a67d0..93aa1ba2b 100644 --- a/dattri/func/hessian.py +++ b/dattri/func/hessian.py @@ -21,7 +21,9 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing import Any, Optional, Tuple, Union + from typing import Optional, Tuple, Union + + from dattri.params.projection import RandomProjectionParams import torch from torch import Tensor @@ -225,7 +227,7 @@ def ihvp_explicit( func: Callable, argnums: int = 0, regularization: float = 0.0, - projector_kwargs: Optional[dict[str, Any]] = None, + proj_params: Optional[RandomProjectionParams] = None, ) -> Callable: """IHVP via explicit Hessian calculation. @@ -245,7 +247,7 @@ def ihvp_explicit( matrix is singular or ill-conditioned. The regularization term is `regularization * I`, where `I` is the identity matrix directly added to the Hessian matrix. - projector_kwargs (Optional[Dict[str, Any]]): Keyword arguments for + proj_params (Optional[Dict[str, Any]]): Keyword arguments for random projection. Default: None. Returns: @@ -268,12 +270,11 @@ def _ihvp_explicit_func(x: Tuple[torch.Tensor, ...], v: Tensor) -> Tensor: The IHVP value. """ hessian_tensor = hessian_func(*x) - if projector_kwargs is not None: + if proj_params is not None: sample_features = torch.zeros(1, hessian_tensor.shape[0]) projector = random_project( sample_features, - 1, - **projector_kwargs, + proj_params, ) # project H proj_h_pt_t = projector(hessian_tensor, ensemble_id=0) diff --git a/dattri/params/projection.py b/dattri/params/projection.py index 42f7b76d5..0725d4089 100644 --- a/dattri/params/projection.py +++ b/dattri/params/projection.py @@ -26,11 +26,18 @@ class BaseProjectionParams(BaseModel): class GeneralProjectionParams(BaseProjectionParams): - """General projection params used by TracIn, TRAK, RandomProjectionParams.""" + """General projection params used by IF, TracIn, TRAK, RandomProjectionParams.""" proj_dim: int +class IFProjectionParams(GeneralProjectionParams): + """Projection params for IF-based attributors.""" + + proj_dim: int = 512 + proj_max_batch_size: int = 32 + + class LoGraProjectionParams(BaseProjectionParams): """Projection params for LoGra attributor.""" diff --git a/test/dattri/algorithm/test_influence_function.py b/test/dattri/algorithm/test_influence_function.py index 34b2e534d..29d9945f7 100644 --- a/test/dattri/algorithm/test_influence_function.py +++ b/test/dattri/algorithm/test_influence_function.py @@ -14,6 +14,7 @@ IFAttributorLiSSA, ) from dattri.benchmark.datasets.mnist import train_mnist_lr +from dattri.params.projection import IFProjectionParams from dattri.task import AttributionTask @@ -473,12 +474,15 @@ def test_influence_functions_with_random_projection(self): ) train_loader = DataLoader(train_dataset, batch_size=4) - projector_kwargs = { + proj_params = { "proj_dim": 512, "proj_max_batch_size": 32, "proj_seed": 0, "device": "cpu", } + + proj_params = IFProjectionParams() + model = train_mnist_lr(train_loader) def f(params, data_target_pair): @@ -498,7 +502,7 @@ def f(params, data_target_pair): task=task, device=torch.device("cpu"), regularization=1e-3, - projector_kwargs=projector_kwargs, + proj_params=proj_params, ) attributor_exp.cache(train_loader) attributor_exp.attribute(train_loader, train_loader) @@ -508,7 +512,7 @@ def f(params, data_target_pair): task=task, device=torch.device("cpu"), regularization=1e-3, - projector_kwargs=projector_kwargs, + proj_params=proj_params, ) attributor_datainf.cache(train_loader) attributor_datainf.attribute(train_loader, train_loader) @@ -518,7 +522,7 @@ def f(params, data_target_pair): task=task, device=torch.device("cpu"), damping=0.1, - projector_kwargs=projector_kwargs, + proj_params=proj_params, ) attributor_ekfac.cache(train_loader) attributor_ekfac.attribute(train_loader, train_loader) From 0537ebca59617e8caec076f1202156f3c18b43b9 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Wed, 11 Mar 2026 21:20:09 -0500 Subject: [PATCH 14/18] minor fix --- dattri/algorithm/influence_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 002ac7949..6ca2a8629 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -924,7 +924,7 @@ def _get_layer_wise_reps( """ if projected: # Split evenly by proj_dim - proj_dim = self.proj_params.get("proj_dim") + proj_dim = self.proj_params.proj_dim num_layers = query.shape[1] // proj_dim query_layers = [] for i in range(num_layers): From 35655e1fbec41fdd051031799a283b29feba20ef Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Sat, 14 Mar 2026 15:38:46 -0500 Subject: [PATCH 15/18] resolve conflict, add arguments to IF params --- dattri/params/projection.py | 112 +++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 9 deletions(-) diff --git a/dattri/params/projection.py b/dattri/params/projection.py index 0725d4089..87f7b64ca 100644 --- a/dattri/params/projection.py +++ b/dattri/params/projection.py @@ -9,7 +9,16 @@ class BaseProjectionParams(BaseModel): - """Base projection params (no proj_dim).""" + """Base projection params (no proj_dim). + + Args: + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -26,46 +35,116 @@ class BaseProjectionParams(BaseModel): class GeneralProjectionParams(BaseProjectionParams): - """General projection params used by IF, TracIn, TRAK, RandomProjectionParams.""" + """General projection params used by IF, TracIn, TRAK, RandomProjectionParams. + + Args: + proj_dim (int): Dimension of the projected feature. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ proj_dim: int class IFProjectionParams(GeneralProjectionParams): - """Projection params for IF-based attributors.""" + """Projection params for IF-based attributors. + + Args: + proj_dim (int): Dimension of the projected feature. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ proj_dim: int = 512 proj_max_batch_size: int = 32 class LoGraProjectionParams(BaseProjectionParams): - """Projection params for LoGra attributor.""" + """Projection params for LoGra attributor. + + Args: + proj_dim_per_layer (int): Dimension of the projected feature per layer. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ proj_dim_per_layer: int = 4096 class FactGrassProjectionParams(BaseProjectionParams): - """Projection params for FactGraSS attributor.""" + """Projection params for FactGraSS attributor. + + Args: + proj_dim_per_layer (int): Dimension of the projected feature per layer. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ proj_dim_per_layer: int = 4096 class TracInProjectionParams(BaseProjectionParams): - """Projection params for TracIn attributor.""" + """Projection params for TracIn attributor. + + Args: + proj_dim (int): Dimension of the projected feature. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ proj_dim: int = 512 proj_max_batch_size: int = 32 class TRAKProjectionParams(GeneralProjectionParams): - """Projection params for TRAK attributor.""" + """Projection params for TRAK attributor. + + Args: + proj_dim (int): Dimension of the projected feature. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ proj_dim: int = 512 proj_max_batch_size: int = 32 class DVEmbProjectionParams(BaseProjectionParams): - """Projection params for DVEmb; proj_dim_per_layer can be None for no projection.""" + """Projection params for DVEmb; proj_dim_per_layer can be None for no projection. + + Args: + proj_dim_per_layer (int): Dimension of the projected feature per layer. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + """ proj_dim_per_layer: Optional[int] = None proj_type: Literal[ @@ -79,7 +158,22 @@ class DVEmbProjectionParams(BaseProjectionParams): class RandomProjectionParams(GeneralProjectionParams): - """Params for random_project(); adds feature_batch_size and device.""" + """Params for random_project(). + + Args: + proj_dim (int): Dimension of the projected feature. + proj_max_batch_size (int): The maximum batch size if the CudaProjector is + used. Must be a multiple of 8. The maximum batch size is 32 for A100 + GPUs, 16 for V100 GPUs, 40 for H100 GPUs. + proj_seed (int): Random seed used by the projector. Defaults to 0. + feature_batch_size (int): The batch size of each tensor in the feature + about to be projected. The typical type of feature are gradients of + torch.nn.Module model but can be restricted to this. + device (torch.device): Device to use. Defaults to cpu. + proj_type (Literal["identity", "normal", "rademacher", "sjlt", + "random_mask", "grass"]): The random projection type used for the projection. + device (Union[str, torch.device]): "cuda" or "cpu". Defaults to "cpu". + """ proj_dim: int proj_max_batch_size: int From ed4e25160d2f810e097c094b4ae8d6cbc0a6ef94 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Thu, 19 Mar 2026 15:01:43 -0500 Subject: [PATCH 16/18] minor fix --- dattri/func/fisher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dattri/func/fisher.py b/dattri/func/fisher.py index f122e5737..a86788d61 100644 --- a/dattri/func/fisher.py +++ b/dattri/func/fisher.py @@ -270,6 +270,7 @@ def _update_covariance( # Uniformly reshape the tensors into (batch_size, t, ...) # The t here is the sequence length or time steps for sequential input # t = 1 if the given input is not sequential + a_prev = a_prev_raw if a_prev_raw.ndim == 2: # noqa: PLR2004 a_prev = a_prev_raw.unsqueeze(1) From eac0077e770008943085ede04900082c04a4f156 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Thu, 19 Mar 2026 15:12:48 -0500 Subject: [PATCH 17/18] minor fix --- dattri/func/fisher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dattri/func/fisher.py b/dattri/func/fisher.py index a86788d61..4f84ccda7 100644 --- a/dattri/func/fisher.py +++ b/dattri/func/fisher.py @@ -362,6 +362,7 @@ def _update_lambda( # noqa: PLR0914 # The t here is the sequence length or time steps for sequential input # t = 1 if the given input is not sequential ds_curr = s_curr_raw.grad + a_prev = a_prev_raw if a_prev_raw.ndim == 2: # noqa: PLR2004 a_prev = a_prev_raw.unsqueeze(1) ds_curr = ds_curr.unsqueeze(1) From 5d982bd17d2a069b6a1aa0cda50bf6d05ab27151 Mon Sep 17 00:00:00 2001 From: Haochen Ding Date: Tue, 24 Mar 2026 19:57:49 -0500 Subject: [PATCH 18/18] reapply projection --- dattri/algorithm/influence_function.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 291a9e8f1..ddc337d66 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -150,6 +150,15 @@ def transform_test_rep( ) raise TypeError(msg) + # Create proj_params_obj only if projection is needed + rand_proj_params = None + if self.proj_params is not None: + rand_proj_params = RandomProjectionParams( + feature_batch_size=1, + device=self.device, + **self.proj_params.model_dump(), + ) + self.ihvp_func = ihvp_explicit( partial( self.task.get_loss_func(