From 9d0e69a75b54b56eec6ba688bf5672fd577c3a32 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Thu, 5 Feb 2026 23:48:42 -0500 Subject: [PATCH 1/9] Compute IHVP on full dataset instead of summing batch inverses --- dattri/algorithm/influence_function.py | 100 ++++++++++++++----------- 1 file changed, 57 insertions(+), 43 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index d9ff15a63..99d37e2a1 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -96,22 +96,26 @@ def transform_test_rep( """ from dattri.func.hessian import ihvp_explicit - vector_product = 0 + batches = [] + for batch in self.full_train_dataloader: + batches.append(tuple(x.to(self.device) for x in batch)) + full_data = [] + for i in range(len(batches[0])): + full_data.append(torch.cat([b[i] for b in batches], dim=0)) + full_data = tuple(full_data) + model_params, _ = self.task.get_param(ckpt_idx, layer_name=self.layer_name) - for full_data_ in self.full_train_dataloader: - # move to device - full_data = tuple(data.to(self.device) for data in full_data_) - self.ihvp_func = ihvp_explicit( - partial( - self.task.get_loss_func( - layer_name=self.layer_name, - ckpt_idx=ckpt_idx, - ), - **{self.task.loss_func_data_key: full_data}, + self.ihvp_func = ihvp_explicit( + partial( + self.task.get_loss_func( + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, ), - **self.transformation_kwargs, - ) - vector_product += self.ihvp_func((model_params,), test_rep).detach() + **{self.task.loss_func_data_key: full_data}, + ), + **self.transformation_kwargs, + ) + vector_product = self.ihvp_func((model_params,), test_rep).detach() return vector_product def _compute_denom( @@ -222,22 +226,26 @@ def transform_test_rep( """ from dattri.func.hessian import ihvp_cg - vector_product = 0 + batches = [] + for batch in self.full_train_dataloader: + batches.append(tuple(x.to(self.device) for x in batch)) + full_data = [] + for i in range(len(batches[0])): + full_data.append(torch.cat([b[i] for b in batches], dim=0)) + full_data = tuple(full_data) + model_params, _ = self.task.get_param(ckpt_idx, layer_name=self.layer_name) - for full_data_ in self.full_train_dataloader: - # move to device - full_data = tuple(data.to(self.device) for data in full_data_) - self.ihvp_func = ihvp_cg( - partial( - self.task.get_loss_func( - layer_name=self.layer_name, - ckpt_idx=ckpt_idx, - ), - **{self.task.loss_func_data_key: full_data}, + self.ihvp_func = ihvp_cg( + partial( + self.task.get_loss_func( + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, ), - **self.transformation_kwargs, - ) - vector_product += self.ihvp_func((model_params,), test_rep).detach() + **{self.task.loss_func_data_key: full_data}, + ), + **self.transformation_kwargs, + ) + vector_product = self.ihvp_func((model_params,), test_rep).detach() return vector_product def _compute_denom( @@ -520,21 +528,27 @@ def transform_test_rep( """ from dattri.func.hessian import ihvp_lissa - vector_product = 0 + batches = [] + for batch in self.full_train_dataloader: + batches.append(tuple(x.to(self.device) for x in batch)) + full_data = [] + for i in range(len(batches[0])): + full_data.append(torch.cat([b[i] for b in batches], dim=0)) + full_data = tuple(full_data) + model_params, _ = self.task.get_param(ckpt_idx, layer_name=self.layer_name) - for full_data_ in self.full_train_dataloader: - # move to device - full_data = tuple(data.to(self.device) for data in full_data_) - self.ihvp_func = ihvp_lissa( - self.task.get_loss_func(layer_name=self.layer_name, ckpt_idx=ckpt_idx), - collate_fn=IFAttributorLiSSA.lissa_collate_fn, - **self.transformation_kwargs, - ) - vector_product += self.ihvp_func( - (model_params, *full_data), - test_rep, - in_dims=(None,) + (0,) * len(full_data), - ).detach() + self.ihvp_func = ihvp_lissa( + self.task.get_loss_func(layer_name=self.layer_name, ckpt_idx=ckpt_idx), + collate_fn=IFAttributorLiSSA.lissa_collate_fn, + **self.transformation_kwargs, + ) + + vector_product = self.ihvp_func( + (model_params, *full_data), + test_rep, + in_dims=(None,) + (0,) * len(full_data), + ).detach() + return vector_product @staticmethod @@ -1023,4 +1037,4 @@ def transform_test_rep( # Flatten the parameters again transformed_test_rep_layers = [ifvp[name] for name in self.module_name] - return torch.cat(transformed_test_rep_layers, dim=1) + return torch.cat(transformed_test_rep_layers, dim=1) \ No newline at end of file From 736d36e756d5e0045caf24274452da0f16533459 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Wed, 25 Feb 2026 22:35:55 -0500 Subject: [PATCH 2/9] multiply it by the step size epsilon = 1/N to get the actual parameter change in explicit and cg --- dattri/algorithm/influence_function.py | 104 +++++++++++-------------- 1 file changed, 46 insertions(+), 58 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 99d37e2a1..6566ad88e 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -96,27 +96,24 @@ def transform_test_rep( """ from dattri.func.hessian import ihvp_explicit - batches = [] - for batch in self.full_train_dataloader: - batches.append(tuple(x.to(self.device) for x in batch)) - full_data = [] - for i in range(len(batches[0])): - full_data.append(torch.cat([b[i] for b in batches], dim=0)) - full_data = tuple(full_data) - + vector_product = 0 model_params, _ = self.task.get_param(ckpt_idx, layer_name=self.layer_name) - self.ihvp_func = ihvp_explicit( - partial( - self.task.get_loss_func( - layer_name=self.layer_name, - ckpt_idx=ckpt_idx, + for full_data_ in self.full_train_dataloader: + # move to device + full_data = tuple(data.to(self.device) for data in full_data_) + self.ihvp_func = ihvp_explicit( + partial( + self.task.get_loss_func( + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, + ), + **{self.task.loss_func_data_key: full_data}, ), - **{self.task.loss_func_data_key: full_data}, - ), - **self.transformation_kwargs, - ) - vector_product = self.ihvp_func((model_params,), test_rep).detach() - return vector_product + **self.transformation_kwargs, + ) + vector_product += self.ihvp_func((model_params,), test_rep).detach() + N = full_data[0].shape[0] + return vector_product / N def _compute_denom( self, @@ -226,27 +223,24 @@ def transform_test_rep( """ from dattri.func.hessian import ihvp_cg - batches = [] - for batch in self.full_train_dataloader: - batches.append(tuple(x.to(self.device) for x in batch)) - full_data = [] - for i in range(len(batches[0])): - full_data.append(torch.cat([b[i] for b in batches], dim=0)) - full_data = tuple(full_data) - + vector_product = 0 model_params, _ = self.task.get_param(ckpt_idx, layer_name=self.layer_name) - self.ihvp_func = ihvp_cg( - partial( - self.task.get_loss_func( - layer_name=self.layer_name, - ckpt_idx=ckpt_idx, + for full_data_ in self.full_train_dataloader: + # move to device + full_data = tuple(data.to(self.device) for data in full_data_) + self.ihvp_func = ihvp_cg( + partial( + self.task.get_loss_func( + layer_name=self.layer_name, + ckpt_idx=ckpt_idx, + ), + **{self.task.loss_func_data_key: full_data}, ), - **{self.task.loss_func_data_key: full_data}, - ), - **self.transformation_kwargs, - ) - vector_product = self.ihvp_func((model_params,), test_rep).detach() - return vector_product + **self.transformation_kwargs, + ) + vector_product += self.ihvp_func((model_params,), test_rep).detach() + N = full_data[0].shape[0] + return vector_product / N def _compute_denom( self, @@ -528,27 +522,21 @@ def transform_test_rep( """ from dattri.func.hessian import ihvp_lissa - batches = [] - for batch in self.full_train_dataloader: - batches.append(tuple(x.to(self.device) for x in batch)) - full_data = [] - for i in range(len(batches[0])): - full_data.append(torch.cat([b[i] for b in batches], dim=0)) - full_data = tuple(full_data) - + vector_product = 0 model_params, _ = self.task.get_param(ckpt_idx, layer_name=self.layer_name) - self.ihvp_func = ihvp_lissa( - self.task.get_loss_func(layer_name=self.layer_name, ckpt_idx=ckpt_idx), - collate_fn=IFAttributorLiSSA.lissa_collate_fn, - **self.transformation_kwargs, - ) - - vector_product = self.ihvp_func( - (model_params, *full_data), - test_rep, - in_dims=(None,) + (0,) * len(full_data), - ).detach() - + for full_data_ in self.full_train_dataloader: + # move to device + full_data = tuple(data.to(self.device) for data in full_data_) + self.ihvp_func = ihvp_lissa( + self.task.get_loss_func(layer_name=self.layer_name, ckpt_idx=ckpt_idx), + collate_fn=IFAttributorLiSSA.lissa_collate_fn, + **self.transformation_kwargs, + ) + vector_product += self.ihvp_func( + (model_params, *full_data), + test_rep, + in_dims=(None,) + (0,) * len(full_data), + ).detach() return vector_product @staticmethod From ccb695745de0578eaac9ed3cbb4651a0d9ba13fe Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Fri, 6 Mar 2026 12:32:23 -0500 Subject: [PATCH 3/9] LiSSA --- dattri/algorithm/influence_function.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 6566ad88e..775c071d1 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -112,8 +112,8 @@ def transform_test_rep( **self.transformation_kwargs, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() - N = full_data[0].shape[0] - return vector_product / N + n = full_data[0].shape[0] + return vector_product / n def _compute_denom( self, @@ -239,8 +239,8 @@ def transform_test_rep( **self.transformation_kwargs, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() - N = full_data[0].shape[0] - return vector_product / N + n = full_data[0].shape[0] + return vector_product / n def _compute_denom( self, @@ -537,7 +537,8 @@ def transform_test_rep( test_rep, in_dims=(None,) + (0,) * len(full_data), ).detach() - return vector_product + n = full_data[0].shape[0] + return vector_product / n @staticmethod def lissa_collate_fn( @@ -607,7 +608,7 @@ def __init__( task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, device: Optional[str] = "cpu", - regularization: float = 0.0, + regularization: float = 3e-4, fim_estimate_data_ratio: float = 1.0, ) -> None: """Initialize the DataInf inverse Hessian attributor. From 2a5485309ffbd03e6c4fdf48ba33243c67a1d179 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Fri, 6 Mar 2026 12:43:46 -0500 Subject: [PATCH 4/9] lissa --- dattri/algorithm/influence_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 775c071d1..777023d32 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -608,7 +608,7 @@ def __init__( task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, device: Optional[str] = "cpu", - regularization: float = 3e-4, + regularization: float = 0.0, fim_estimate_data_ratio: float = 1.0, ) -> None: """Initialize the DataInf inverse Hessian attributor. From 89644119b2975ccee843df25588ae3cb807212f2 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Thu, 12 Mar 2026 00:15:33 -0400 Subject: [PATCH 5/9] correct batch_size, update Arnoldi defaults, and add loss reduction note --- dattri/algorithm/influence_function.py | 10 +++++----- dattri/task.py | 6 ++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 777023d32..d695e35e1 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -112,7 +112,7 @@ def transform_test_rep( **self.transformation_kwargs, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() - n = full_data[0].shape[0] + n = self.full_train_dataloader.batch_size return vector_product / n def _compute_denom( @@ -239,7 +239,7 @@ def transform_test_rep( **self.transformation_kwargs, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() - n = full_data[0].shape[0] + n = self.full_train_dataloader.batchsize return vector_product / n def _compute_denom( @@ -294,8 +294,8 @@ def __init__( layer_name: Optional[Union[str, List[str]]] = None, device: Optional[str] = "cpu", precompute_data_ratio: float = 1.0, - proj_dim: int = 100, - max_iter: int = 100, + proj_dim: int = 500, + max_iter: int = 1000, norm_constant: float = 1.0, tol: float = 1e-7, regularization: float = 0.0, @@ -537,7 +537,7 @@ def transform_test_rep( test_rep, in_dims=(None,) + (0,) * len(full_data), ).detach() - n = full_data[0].shape[0] + n = self.full_train_dataloader.batchsize return vector_product / n @staticmethod diff --git a/dattri/task.py b/dattri/task.py index fc57f3cba..ecd880f56 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -304,6 +304,12 @@ def get_grad_loss_func( ) -> Callable: """Return a function that computes the gradient of the loss function. + Note: + The reduction type of the loss function determines the scaling of the + computed gradients. Certain attributors, such as "IFAttributorArnoldi", + mathematically require "reduction='sum'" to compute correctly scaled + influence values. + Args: in_dims (Tuple[Union[None, int], ...]): The input dimensions of the loss function. This should be a tuple of integers and None. The length of the From 8e57beabf18f6a4fad58ee076c84dbf522c6ef45 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Thu, 12 Mar 2026 09:22:43 -0400 Subject: [PATCH 6/9] dataloader --- dattri/algorithm/influence_function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index d695e35e1..e91c14fba 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -239,7 +239,7 @@ def transform_test_rep( **self.transformation_kwargs, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() - n = self.full_train_dataloader.batchsize + n = self.full_train_dataloader.batch_size return vector_product / n def _compute_denom( @@ -537,7 +537,7 @@ def transform_test_rep( test_rep, in_dims=(None,) + (0,) * len(full_data), ).detach() - n = self.full_train_dataloader.batchsize + n = self.full_train_dataloader.batch_size return vector_product / n @staticmethod From 46dc16c3ecef7ca20ffab3895e343f237c38f239 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Thu, 12 Mar 2026 09:44:13 -0400 Subject: [PATCH 7/9] fixed --- dattri/algorithm/influence_function.py | 3 ++- dattri/task.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index e91c14fba..61fa1deef 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -1026,4 +1026,5 @@ def transform_test_rep( # Flatten the parameters again transformed_test_rep_layers = [ifvp[name] for name in self.module_name] - return torch.cat(transformed_test_rep_layers, dim=1) \ No newline at end of file + return torch.cat(transformed_test_rep_layers, dim=1) + \ No newline at end of file diff --git a/dattri/task.py b/dattri/task.py index ecd880f56..5ccf28598 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -305,9 +305,9 @@ def get_grad_loss_func( """Return a function that computes the gradient of the loss function. Note: - The reduction type of the loss function determines the scaling of the - computed gradients. Certain attributors, such as "IFAttributorArnoldi", - mathematically require "reduction='sum'" to compute correctly scaled + The reduction type of the loss function determines the scaling of the + computed gradients. Certain attributors, such as "IFAttributorArnoldi", + mathematically require "reduction='sum'" to compute correctly scaled influence values. Args: From fefed6c16191b68126effeae0cf1d41e3765adb7 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Fri, 13 Mar 2026 14:03:41 -0400 Subject: [PATCH 8/9] whitespace --- dattri/algorithm/influence_function.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 61fa1deef..063020a26 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -1027,4 +1027,3 @@ def transform_test_rep( transformed_test_rep_layers = [ifvp[name] for name in self.module_name] return torch.cat(transformed_test_rep_layers, dim=1) - \ No newline at end of file From 92b7f5b7e1ee0573413ca0b510da049787372ad0 Mon Sep 17 00:00:00 2001 From: Huining Wang Date: Fri, 13 Mar 2026 14:52:50 -0400 Subject: [PATCH 9/9] defult setting --- dattri/algorithm/influence_function.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index 063020a26..3ae3dc9a3 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -462,11 +462,11 @@ def __init__( task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, device: Optional[str] = "cpu", - batch_size: int = 1, + batch_size: int = 100, num_repeat: int = 1, - recursion_depth: int = 5000, - damping: float = 0.0, - scaling: float = 50.0, + recursion_depth: int = 100, + damping: float = 5e-4, + scaling: float = 5.0, mode: str = "rev-rev", ) -> None: """Initialize the LiSSA inverse Hessian attributor.