Implement .cache() for TracIn family attributors. Currently, .cache() is an empty method and training gradients are recomputed on every .attribute() call.
Proposed changes:
.cache():
def cache(
self,
full_train_dataloader: torch.utils.data.DataLoader,
) -> None:
Pre-compute and store training gradients for all checkpoints.
- Modify
.attribute() signature to support cached mode:
def attribute(
self,
test_dataloader: torch.utils.data.DataLoader,
train_dataloader: Optional[torch.utils.data.DataLoader] = None,
) -> Tensor:
train_dataloader becomes optional. When .cache() has been called, .attribute() can be called with only test_dataloader.
Implement
.cache()for TracIn family attributors. Currently,.cache()is an empty method and training gradients are recomputed on every.attribute()call.Proposed changes:
.cache():Pre-compute and store training gradients for all checkpoints.
.attribute()signature to support cached mode:train_dataloaderbecomes optional. When.cache()has been called,.attribute()can be called with onlytest_dataloader.