When using DataInf, the resulting correlation remains consistently low (r < 0.3).
Parameter adjustments do not lead to noticeable improvement. The correlation stays within a similar range after modification.
The slope also deviates from 1.
import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
from dattri.algorithm.influence_function import IFAttributorDataInf
from dattri.benchmark.load import load_benchmark
from dattri.metric import loo_corr
from dattri.task import AttributionTask
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cpu", type=str)
args = parser.parse_args()
# download the pre-trained benchmark
# includes some trained model and ground truth
model_details, groundtruth = load_benchmark(
model="lr", dataset="mnist", metric="loo"
)
def f(params, data_target_pair):
image, label = data_target_pair
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model_details["model"], params, image)
return loss(yhat, label.long())
task = AttributionTask(
model=model_details["model"].to(args.device),
loss_func=f,
checkpoints=model_details["models_full"][0], # here we use one full model
checkpoints_load_func=lambda model,
ckpt: (model.load_state_dict(torch.load(ckpt, map_location=args.device)), model)[1],
)
attributor = IFAttributorDataInf(
task=task, device=args.device,
)
attributor.cache(
DataLoader(
model_details["train_dataset"],
batch_size=64,
sampler=model_details["train_sampler"],
)
)
score = attributor.attribute(
DataLoader(
model_details["train_dataset"],
batch_size=1000,
sampler=model_details["train_sampler"],
),
DataLoader(
model_details["test_dataset"],
batch_size=5000,
sampler=model_details["test_sampler"],
),
)
loo_score = loo_corr(score.detach().cpu(), groundtruth)[0]
print("loo:", torch.mean(torch.as_tensor(loo_score)).item())
import matplotlib.pyplot as plt
import numpy as np
x = groundtruth[0].detach().cpu().numpy()[:, 0]
y = score.detach().cpu().numpy()[:, 0]
a, b = np.polyfit(x, y, 1)
r = np.corrcoef(x, y)[0, 1]
plt.scatter(x, y, alpha=0.3)
plt.plot(x, a * x + b)
plt.xlabel("ground truth")
plt.ylabel("score")
plt.title(f"slope={a}, r={r}")
plt.grid(True)
plt.show()

When using DataInf, the resulting correlation remains consistently low (r < 0.3).
Parameter adjustments do not lead to noticeable improvement. The correlation stays within a similar range after modification.
The slope also deviates from 1.