Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9a562c2
new slurm for score_TRAK and score_logra
DanielNi868 Nov 6, 2025
bb26f89
fix importError of transformer
DanielNi868 Nov 6, 2025
5845c68
fix huggingface path error
DanielNi868 Nov 6, 2025
a566d7c
modify the checkpoints_load_func function in two score files
DanielNi868 Nov 6, 2025
07dbb38
modify the checkpoints_load_func function in two score files
DanielNi868 Nov 6, 2025
7cfca83
fix index error for score_TRAK
DanielNi868 Nov 7, 2025
27fe7cd
fix the train batch size
DanielNi868 Nov 7, 2025
ce2dce9
add argument for random projection in TRAK and fix memory issue
DanielNi868 Nov 9, 2025
a8a47a6
fix argument error
DanielNi868 Nov 9, 2025
01e09b0
update proj_dim and batch size to solve OOM error
DanielNi868 Nov 9, 2025
65581fc
add project_type kwarg
DanielNi868 Nov 9, 2025
807c42b
fix checkpoint loading error
DanielNi868 Nov 9, 2025
9b20c18
update readme in gpt2_wikitext
DanielNi868 Nov 9, 2025
1d5fecd
error message in import
DanielNi868 Nov 29, 2025
d42f599
added arguments
DanielNi868 Nov 29, 2025
7e9a05e
unsqueeze
DanielNi868 Nov 29, 2025
3d0a859
checkpoint
DanielNi868 Nov 29, 2025
57ac89a
huggingface
DanielNi868 Nov 29, 2025
c995b3b
model_id
DanielNi868 Nov 29, 2025
27d48a8
huggingface
DanielNi868 Nov 29, 2025
a5503cf
fix model id
DanielNi868 Dec 19, 2025
4159392
fix transformer's version
DanielNi868 Dec 20, 2025
3c9dd6b
fix transformer's huggingface_id
DanielNi868 Dec 20, 2025
ad29c76
fix transformer's huggingface_id
DanielNi868 Dec 20, 2025
28ee7a5
fix checkpoint id error
DanielNi868 Dec 20, 2025
864be74
fix checkpoint id error
DanielNi868 Dec 20, 2025
c3409dc
change to GPU 2
DanielNi868 Dec 20, 2025
0f06e7e
score_logra fix
DanielNi868 Dec 21, 2025
aa51bd0
fix: telemetry call guard; sync with upstream; remove slurm scripts a…
DanielNi868 Dec 25, 2025
dcad19a
update score_logra
DanielNi868 Dec 26, 2025
153fe68
Merge branch 'main' into main
DanielNi868 Dec 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions experiments/gpt2_wikitext/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ This experiment could only be run on cuda device.
pip install -r requirements.txt
```

### Troubleshooting

#### `vmap` over calling `.item()` Error in Transformers

<!-- ### Troubleshooting: vmap over calling .item() Error in Transformers
After installing transformers, you might encounter the following error:
```bash
We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations.
Expand Down Expand Up @@ -69,7 +66,19 @@ Comment out these lines:
Then, add the following line:
```bash
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
```
``` -->


if args.model_name_or_path:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
trust_remote_code=args.trust_remote_code,
attn_implementation="eager", # Use eager attention for better performance
)
model = model.cuda()

#### NumPy Version Compatibility Issue

Expand Down
114 changes: 86 additions & 28 deletions experiments/gpt2_wikitext/score_TRAK.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,9 @@
default_data_collator,
get_scheduler,
)
from transformers.utils import check_min_version
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

# send_example_telemetry was removed in newer versions of transformers
try:
from transformers.utils import send_example_telemetry
except ImportError:
send_example_telemetry = None

from dattri.benchmark.utils import SubsetSampler
from dattri.func.utils import flatten_func, flatten_params
Expand Down Expand Up @@ -222,6 +217,27 @@ def parse_args():
" account special tokens)."
),
)

# add arguments for random projection and fix memory issues
parser.add_argument(
"--proj_dim",
type=int,
default=512,
help="Output dimension for random projection used by TRAK / TracIn.",
)
parser.add_argument(
"--proj_max_batch_size",
type=int,
default=16,
help="Maximum batch size to process per projection block (controls memory usage).",
)
parser.add_argument(
"--proj_type",
type=str,
default="random_mask",
choices=["normal", "rademacher", "random_mask", "sjlt", "grass"],
help="Random projection type used for TRAK/TracIn (default: random_mask).",
)

@DanielNi868 DanielNi868 Nov 29, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have torch.OutOfMemoryError: CUDA out of memory. When I did not have these 3 parameters

parser.add_argument(
"--preprocessing_num_workers",
type=int,
Expand Down Expand Up @@ -342,11 +358,7 @@ def parse_args():

def main():
args = parse_args()

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
if send_example_telemetry is not None:
send_example_telemetry("run_clm_no_trainer", args)
send_example_telemetry("run_clm_no_trainer", args)

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
Expand Down Expand Up @@ -505,6 +517,7 @@ def main():
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
trust_remote_code=args.trust_remote_code,
attn_implementation="eager", # Use eager attention for better performance
)
model = model.cuda()
else:
Expand Down Expand Up @@ -631,14 +644,15 @@ def f(params, batch):
"""
input_ids, attention_mask, labels = batch

input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
labels = labels.cuda()
# Re-add batch dimension removed by vmap
input_ids = input_ids.unsqueeze(0).cuda()
attention_mask = attention_mask.unsqueeze(0).cuda()
labels = labels.unsqueeze(0).cuda()

@DanielNi868 DanielNi868 Nov 29, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have IndexError: too many indices for tensor of dimension 2

outputs = torch.func.functional_call(
model,
params,
input_ids,
(input_ids,), # Pass as tuple to avoid dimension issues
kwargs={"attention_mask": attention_mask, "labels": labels},
)
logp = -outputs.loss
Expand All @@ -650,14 +664,15 @@ def m(params, batch):
"""
input_ids, attention_mask, labels = batch

input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
labels = labels.cuda()
# Re-add batch dimension removed by vmap
input_ids = input_ids.unsqueeze(0).cuda()
attention_mask = attention_mask.unsqueeze(0).cuda()
labels = labels.unsqueeze(0).cuda()

@DanielNi868 DanielNi868 Nov 29, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have IndexError: too many indices for tensor of dimension 2


outputs = torch.func.functional_call(
model,
params,
input_ids,
(input_ids,), # Pass as tuple to avoid dimension issues
kwargs={"attention_mask": attention_mask, "labels": labels},
)
p = torch.exp(-outputs.loss)
Expand All @@ -669,35 +684,70 @@ def loss_tracin(params, batch):
(TracIn sums over checkpoint updates of gradient dot-products).
"""
input_ids, attention_mask, labels = batch
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
labels = labels.cuda()

# Re-add batch dimension removed by vmap
input_ids = input_ids.unsqueeze(0).cuda()
attention_mask = attention_mask.unsqueeze(0).cuda()
labels = labels.unsqueeze(0).cuda()

@DanielNi868 DanielNi868 Nov 29, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have IndexError: too many indices for tensor of dimension 2

outputs = torch.func.functional_call(
model,
params,
input_ids,
(input_ids,), # Pass as tuple to avoid dimension issues
kwargs={"attention_mask": attention_mask, "labels": labels},
)
return outputs.loss

method = args.method

#this gets the existing checkpoints in the output directory
checkpoint_root_abs = Path(args.output_dir).resolve()
existing_ckpt_dirs = [p for p in checkpoint_root_abs.iterdir() if p.is_dir()]
existing_names = {p.name for p in existing_ckpt_dirs}
has_minus1 = "-1" in existing_names
numeric_sorted = sorted([int(n) for n in existing_names if n.isdigit()])
numeric_count = len(numeric_sorted)

if method.startswith("TRAK-"):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run this again and it can run this time, I think this modification can be deleted

parts = method.split("-")
if len(parts) == 2 and parts[1].isdigit():
num_checkpoints = int(parts[1])
# requested_checkpoints = int(parts[1])
else:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run this again and it can run this time, I think this modification can be deleted

raise ValueError(
"Invalid method name for TRAK, must be like 'TRAK-5' or 'TRAK-10'."
)
checkpoints = [f"{args.output_dir}/{i}" for i in range(num_checkpoints)]
#prevent checkpoint id error when only -1 is present
if has_minus1 and numeric_count == 0:
selected_indices = [-1]
else:
if numeric_count == 0 and not has_minus1:
raise FileNotFoundError( f"No numeric checkpoint directories found in {checkpoint_root_abs}." )
if numeric_count > 0 and has_minus1:
selected_indices = list(range(-1, min(num_checkpoints, numeric_count)))
if numeric_count > 0 and not has_minus1:
selected_indices = list(range(min(num_checkpoints, numeric_count)))
checkpoints = [str(checkpoint_root_abs / str(i)) for i in selected_indices]

elif method in ["TracIn", "Grad-Dot", "Grad-Cos"]:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run this again and it can run this time, I think this modification can be deleted

num_checkpoints = 5
checkpoints = [f"{args.output_dir}/{i}" for i in range(num_checkpoints)]
if has_minus1 and numeric_count == 0:
selected_indices = [-1]
else: #prevent checkpoint id error when only -1 is present
if numeric_count == 0 and not has_minus1:
raise FileNotFoundError( f"No numeric checkpoint directories found in {checkpoint_root_abs}.")
if numeric_count > 0 and has_minus1:
selected_indices = list(range(-1, min(num_checkpoints, numeric_count)))
if numeric_count > 0 and not has_minus1:
selected_indices = list(range(min(num_checkpoints, numeric_count)))
checkpoints = [str(checkpoint_root_abs / str(i)) for i in selected_indices]

else:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I run this again and it can run this time, I think this modification can be deleted

raise ValueError(
f"Unknown --method {method}. Try 'TRAK-5', 'TracIn', 'Grad-Dot', or 'Grad-Cos'."
)


#modified for huggingface hub validation error
def checkpoints_load_func(model, checkpoint_path):
new_model = AutoModelForCausalLM.from_pretrained(checkpoint_path).cuda()
new_model.eval()

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this modification can be deleted

Expand All @@ -721,9 +771,12 @@ def checkpoints_load_func(model, checkpoint_path):
)

if method.startswith("TRAK"):
# fix memory issues
projector_kwargs = {
"device": "cuda",
"proj_dim": 2048,
"proj_dim": args.proj_dim,
"proj_max_batch_size": args.proj_max_batch_size,
"proj_type": args.proj_type,

@DanielNi868 DanielNi868 Nov 29, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB. GPU 0 has a total capacity of 44.35 GiB of which 41.56 GiB is free. Including non-PyTorch memory, this process has 2.79 GiB memory in use. Of the allocated memory 2.36 GiB is allocated by PyTorch, and 114.69 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting

}
attributor = TRAKAttributor(
task=task,
Expand All @@ -737,11 +790,16 @@ def checkpoints_load_func(model, checkpoint_path):
if method == "Grad-Cos":
normalized_grad = True

#get the number of checkpoints
num_checkpoints = len(checkpoints)
weight_list = torch.ones(num_checkpoints) * 1e-3

# fix memory issues
projector_kwargs = {
"device": "cuda",
"proj_dim": 2048,
"proj_dim": args.proj_dim,
"proj_max_batch_size": args.proj_max_batch_size,
"proj_type": args.proj_type,

@DanielNi868 DanielNi868 Nov 29, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 GiB. GPU 0 has a total capacity of 44.35 GiB of which 41.56 GiB is free. Including non-PyTorch memory, this process has 2.79 GiB memory in use. Of the allocated memory 2.36 GiB is allocated by PyTorch, and 114.69 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting

}

attributor = TracInAttributor(
Expand Down
17 changes: 4 additions & 13 deletions experiments/gpt2_wikitext/score_logra.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,8 @@
default_data_collator,
get_scheduler,
)
from transformers.utils import check_min_version
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

# send_example_telemetry was removed in newer versions of transformers
try:
from transformers.utils import send_example_telemetry
except ImportError:
send_example_telemetry = None

from dattri.benchmark.utils import SubsetSampler


Expand Down Expand Up @@ -328,10 +321,7 @@ def parse_args():
def main():
args = parse_args()

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
if send_example_telemetry is not None:
send_example_telemetry("run_clm_no_trainer", args)
send_example_telemetry("run_clm_no_trainer", args)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same import error as in TRAK

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
Expand Down Expand Up @@ -490,6 +480,7 @@ def main():
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
trust_remote_code=args.trust_remote_code,
attn_implementation="eager", # Use eager attention for better performance

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing in this PR is to add this line in score_TRAK and score_logra.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are something else you need to change in score_logra and score_TRAK, please comment why they are needed in order to fix the transformer error regarding the vmap. Otherwise, we may keep them unchanged.

)
else:
logger.info("Training new model from scratch")
Expand Down Expand Up @@ -597,7 +588,7 @@ def group_texts(examples):
from transformers.pytorch_utils import Conv1D
from dattri.task import AttributionTask

model_id = 0
model_id = -1 # Use checkpoint 0 (final checkpoint)
checkpoint = f"{args.output_dir}/{model_id}"

@DanielNi868 DanielNi868 Nov 29, 2025

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FileNotFoundError: Checkpoint directory not found: /dattri/experiments/gpt2_wikitext/checkpoints/-1. Please ensure the checkpoint exists at this path.


def checkpoints_load_func(model, checkpoint):
Expand Down
Loading