Skip to content

Cache function implementation for TracInAttributor #255

@brightXian

Description

@brightXian

Implement .cache() for TracIn family attributors. Currently, .cache() is an empty method and training gradients are recomputed on every .attribute() call.

Proposed changes:

  1. .cache():
def cache(
    self,
    full_train_dataloader: torch.utils.data.DataLoader,
) -> None:

Pre-compute and store training gradients for all checkpoints.

  1. Modify .attribute() signature to support cached mode:
def attribute(
    self,
    test_dataloader: torch.utils.data.DataLoader,
    train_dataloader: Optional[torch.utils.data.DataLoader] = None,
) -> Tensor:

train_dataloader becomes optional. When .cache() has been called, .attribute() can be called with only test_dataloader.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions