Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 264 additions & 65 deletions dattri/algorithm/tracin.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/noisy_label_detection/tracin_noisy_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def f(params, image_label_pair):
)

with torch.no_grad():
score = attributor.attribute(train_loader, test_loader).diag()
score = attributor.attribute(test_loader, train_loader).diag()

_, indices = torch.sort(-score)
cr = 0
Expand Down
2 changes: 1 addition & 1 deletion experiments/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def loss_rps(pre_activation_list, label_list):
device=args.device,
)
with torch.no_grad():
score = attributor.attribute(train_loader, test_loader)
score = attributor.attribute(test_loader, train_loader)

# compute metrics
metrics_score = METRICS_DICT[args.metric](score, groundtruth)[0]
Expand Down
2 changes: 1 addition & 1 deletion experiments/benchmark_result_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def loss_tracin(params, data_target_pair):
device=args.device,
)
with torch.no_grad():
score = attributor.attribute(train_loader, test_loader)
score = attributor.attribute(test_loader, train_loader)


best_result = 0
Expand Down
2 changes: 1 addition & 1 deletion experiments/benchmark_result_nanogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def loss_tracin(params, data_target_pair):
device=args.device,
)
with torch.no_grad():
score = attributor.attribute(train_loader, val_loader)
score = attributor.attribute(val_loader, train_loader)

best_result = 0
best_config = None
Expand Down
2 changes: 1 addition & 1 deletion experiments/gpt2_wikitext/score_TRAK.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def checkpoints_load_func(model, checkpoint_path):
attributor.cache(train_dataloader)
score = attributor.attribute(eval_dataloader)
else:
score = attributor.attribute(train_dataloader, eval_dataloader)
score = attributor.attribute(eval_dataloader, train_dataloader)

torch.save(score, "score_TRAK.pt")
logger.info("Attribution scores saved to score_TRAK.pt")
Expand Down
103 changes: 96 additions & 7 deletions test/dattri/algorithm/test_tracin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class TestTracInAttributor:
"""Test for TracIn."""

def test_tracin_proj(self):
def test_tracin_proj(self): # noqa: PLR0914
"""Test for TracIn with projectors."""
train_dataset = TensorDataset(
torch.randn(20, 1, 28, 28),
Expand Down Expand Up @@ -73,7 +73,7 @@ def f(params, image_label_pair):
"device": pytest_device,
}

# test with projector list
# test with projector list, without cache
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
Expand All @@ -83,15 +83,58 @@ def f(params, image_label_pair):
)

# Original test
score = attributor.attribute(train_loader, test_loader)
score = attributor.attribute(test_loader, train_loader)
assert score.shape == (len(train_loader.dataset), len(test_loader.dataset))
assert torch.count_nonzero(score) == len(train_loader.dataset) * len(
test_loader.dataset,
)

# test with projector list, with cache
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=True,
projector_kwargs=projector_kwargs,
device=torch.device(pytest_device),
)
attributor.cache(train_loader)
score2 = attributor.attribute(test_loader)
assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05)

# test with projector list, with offload(cpu)
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=True,
projector_kwargs=projector_kwargs,
device=torch.device(pytest_device),
offload="cpu",
)
attributor.cache(train_loader)
score2 = attributor.attribute(test_loader)
assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05)

# test with projector, with offload(disk)
cache_path = Path("./cache")
if not cache_path.exists():
cache_path.mkdir(parents=True)
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=True,
projector_kwargs=projector_kwargs,
device=torch.device(pytest_device),
offload="disk",
cache_dir=str(cache_path),
)
attributor.cache(train_loader)
score2 = attributor.attribute(test_loader)
assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05)

shutil.rmtree(path)
shutil.rmtree(cache_path)

def test_tracin(self):
def test_tracin(self): # noqa: PLR0914
"""Test for TracIn without projectors."""
train_dataset = TensorDataset(
torch.randn(20, 1, 28, 28),
Expand Down Expand Up @@ -135,20 +178,66 @@ def f(params, image_label_pair):
)

pytest_device = "cpu"
# test with no projector list
# test with no projector list, without cache
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=True,
device=torch.device(pytest_device),
)
score = attributor.attribute(train_loader, test_loader)
score = attributor.attribute(test_loader, train_loader)
assert score.shape == (len(train_loader.dataset), len(test_loader.dataset))
assert torch.count_nonzero(score) == len(train_loader.dataset) * len(
test_loader.dataset,
)

# test with no projector, with cache
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=True,
device=torch.device(pytest_device),
)
attributor.cache(train_loader)
score2 = attributor.attribute(test_loader)
score3 = attributor.attribute(test_loader)
assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05)
assert torch.allclose(score2, score3)

# test with no projector, with offload(cpu)
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=True,
device=torch.device(pytest_device),
offload="cpu",
)
attributor.cache(train_loader)
score2 = attributor.attribute(test_loader)
score3 = attributor.attribute(test_loader)
assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05)
assert torch.allclose(score2, score3)

# test with no projector, with offload(disk)
cache_path = Path("./cache")
if not cache_path.exists():
cache_path.mkdir(parents=True)
attributor = TracInAttributor(
task=task,
weight_list=torch.ones(len(checkpoint_list)),
normalized_grad=True,
device=torch.device(pytest_device),
offload="disk",
cache_dir=str(cache_path),
)
attributor.cache(train_loader)
score2 = attributor.attribute(test_loader)
score3 = attributor.attribute(test_loader)
assert torch.allclose(score, score2, rtol=1e-03, atol=1e-05)
assert torch.allclose(score2, score3)

shutil.rmtree(path)
shutil.rmtree(cache_path)

def test_tracin_self_attribute(self):
"""Test for self_attribute in TracIn without projectors."""
Expand Down Expand Up @@ -349,7 +438,7 @@ def f(params, dict_batch):
device=torch.device(pytest_device),
)

score = attributor.attribute(train_loader, test_loader)
score = attributor.attribute(test_loader, train_loader)

assert score.shape == (len(train_loader.dataset), len(test_loader.dataset))
assert torch.count_nonzero(score) == len(train_loader.dataset) * len(
Expand Down
Loading