Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions dattri/algorithm/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def transform_test_rep(
**self.transformation_kwargs,
)
vector_product += self.ihvp_func((model_params,), test_rep).detach()
return vector_product
n = self.full_train_dataloader.batch_size
return vector_product / n

def _compute_denom(
self,
Expand Down Expand Up @@ -238,7 +239,8 @@ def transform_test_rep(
**self.transformation_kwargs,
)
vector_product += self.ihvp_func((model_params,), test_rep).detach()
return vector_product
n = self.full_train_dataloader.batch_size
return vector_product / n

def _compute_denom(
self,
Expand Down Expand Up @@ -292,8 +294,8 @@ def __init__(
layer_name: Optional[Union[str, List[str]]] = None,
device: Optional[str] = "cpu",
precompute_data_ratio: float = 1.0,
proj_dim: int = 100,
max_iter: int = 100,
proj_dim: int = 500,
max_iter: int = 1000,
norm_constant: float = 1.0,
tol: float = 1e-7,
regularization: float = 0.0,
Expand Down Expand Up @@ -460,11 +462,11 @@ def __init__(
task: AttributionTask,
layer_name: Optional[Union[str, List[str]]] = None,
device: Optional[str] = "cpu",
batch_size: int = 1,
batch_size: int = 100,
num_repeat: int = 1,
recursion_depth: int = 5000,
damping: float = 0.0,
scaling: float = 50.0,
recursion_depth: int = 100,
damping: float = 5e-4,
scaling: float = 5.0,
mode: str = "rev-rev",
) -> None:
"""Initialize the LiSSA inverse Hessian attributor.
Expand Down Expand Up @@ -535,7 +537,8 @@ def transform_test_rep(
test_rep,
in_dims=(None,) + (0,) * len(full_data),
).detach()
return vector_product
n = self.full_train_dataloader.batch_size
return vector_product / n

@staticmethod
def lissa_collate_fn(
Expand Down
6 changes: 6 additions & 0 deletions dattri/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ def get_grad_loss_func(
) -> Callable:
"""Return a function that computes the gradient of the loss function.

Note:
The reduction type of the loss function determines the scaling of the
computed gradients. Certain attributors, such as "IFAttributorArnoldi",
mathematically require "reduction='sum'" to compute correctly scaled
influence values.

Args:
in_dims (Tuple[Union[None, int], ...]): The input dimensions of the loss
function. This should be a tuple of integers and None. The length of the
Expand Down
Loading