diff --git a/dattri/algorithm/influence_function.py b/dattri/algorithm/influence_function.py index d9ff15a63..3ae3dc9a3 100644 --- a/dattri/algorithm/influence_function.py +++ b/dattri/algorithm/influence_function.py @@ -112,7 +112,8 @@ def transform_test_rep( **self.transformation_kwargs, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() - return vector_product + n = self.full_train_dataloader.batch_size + return vector_product / n def _compute_denom( self, @@ -238,7 +239,8 @@ def transform_test_rep( **self.transformation_kwargs, ) vector_product += self.ihvp_func((model_params,), test_rep).detach() - return vector_product + n = self.full_train_dataloader.batch_size + return vector_product / n def _compute_denom( self, @@ -292,8 +294,8 @@ def __init__( layer_name: Optional[Union[str, List[str]]] = None, device: Optional[str] = "cpu", precompute_data_ratio: float = 1.0, - proj_dim: int = 100, - max_iter: int = 100, + proj_dim: int = 500, + max_iter: int = 1000, norm_constant: float = 1.0, tol: float = 1e-7, regularization: float = 0.0, @@ -460,11 +462,11 @@ def __init__( task: AttributionTask, layer_name: Optional[Union[str, List[str]]] = None, device: Optional[str] = "cpu", - batch_size: int = 1, + batch_size: int = 100, num_repeat: int = 1, - recursion_depth: int = 5000, - damping: float = 0.0, - scaling: float = 50.0, + recursion_depth: int = 100, + damping: float = 5e-4, + scaling: float = 5.0, mode: str = "rev-rev", ) -> None: """Initialize the LiSSA inverse Hessian attributor. @@ -535,7 +537,8 @@ def transform_test_rep( test_rep, in_dims=(None,) + (0,) * len(full_data), ).detach() - return vector_product + n = self.full_train_dataloader.batch_size + return vector_product / n @staticmethod def lissa_collate_fn( diff --git a/dattri/task.py b/dattri/task.py index fc57f3cba..5ccf28598 100644 --- a/dattri/task.py +++ b/dattri/task.py @@ -304,6 +304,12 @@ def get_grad_loss_func( ) -> Callable: """Return a function that computes the gradient of the loss function. + Note: + The reduction type of the loss function determines the scaling of the + computed gradients. Certain attributors, such as "IFAttributorArnoldi", + mathematically require "reduction='sum'" to compute correctly scaled + influence values. + Args: in_dims (Tuple[Union[None, int], ...]): The input dimensions of the loss function. This should be a tuple of integers and None. The length of the