diff --git a/experiments/gpt2_wikitext/readme.md b/experiments/gpt2_wikitext/readme.md index 4911bcec3..fcd1f8b0e 100644 --- a/experiments/gpt2_wikitext/readme.md +++ b/experiments/gpt2_wikitext/readme.md @@ -20,10 +20,7 @@ This experiment could only be run on cuda device. pip install -r requirements.txt ``` -### Troubleshooting - -#### `vmap` over calling `.item()` Error in Transformers - + + + +if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + low_cpu_mem_usage=args.low_cpu_mem_usage, + trust_remote_code=args.trust_remote_code, + attn_implementation="eager", # Use eager attention for better performance + ) + model = model.cuda() #### NumPy Version Compatibility Issue diff --git a/experiments/gpt2_wikitext/score_TRAK.py b/experiments/gpt2_wikitext/score_TRAK.py index 8f0c1cbbe..c69dd6aa7 100644 --- a/experiments/gpt2_wikitext/score_TRAK.py +++ b/experiments/gpt2_wikitext/score_TRAK.py @@ -54,14 +54,9 @@ default_data_collator, get_scheduler, ) -from transformers.utils import check_min_version +from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version -# send_example_telemetry was removed in newer versions of transformers -try: - from transformers.utils import send_example_telemetry -except ImportError: - send_example_telemetry = None from dattri.benchmark.utils import SubsetSampler from dattri.func.utils import flatten_func, flatten_params @@ -222,6 +217,27 @@ def parse_args(): " account special tokens)." ), ) + + # add arguments for random projection and fix memory issues + parser.add_argument( + "--proj_dim", + type=int, + default=512, + help="Output dimension for random projection used by TRAK / TracIn.", + ) + parser.add_argument( + "--proj_max_batch_size", + type=int, + default=16, + help="Maximum batch size to process per projection block (controls memory usage).", + ) + parser.add_argument( + "--proj_type", + type=str, + default="random_mask", + choices=["normal", "rademacher", "random_mask", "sjlt", "grass"], + help="Random projection type used for TRAK/TracIn (default: random_mask).", + ) parser.add_argument( "--preprocessing_num_workers", type=int, @@ -342,11 +358,7 @@ def parse_args(): def main(): args = parse_args() - - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - if send_example_telemetry is not None: - send_example_telemetry("run_clm_no_trainer", args) + send_example_telemetry("run_clm_no_trainer", args) # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -505,6 +517,7 @@ def main(): config=config, low_cpu_mem_usage=args.low_cpu_mem_usage, trust_remote_code=args.trust_remote_code, + attn_implementation="eager", # Use eager attention for better performance ) model = model.cuda() else: @@ -631,14 +644,15 @@ def f(params, batch): """ input_ids, attention_mask, labels = batch - input_ids = input_ids.cuda() - attention_mask = attention_mask.cuda() - labels = labels.cuda() + # Re-add batch dimension removed by vmap + input_ids = input_ids.unsqueeze(0).cuda() + attention_mask = attention_mask.unsqueeze(0).cuda() + labels = labels.unsqueeze(0).cuda() outputs = torch.func.functional_call( model, params, - input_ids, + (input_ids,), # Pass as tuple to avoid dimension issues kwargs={"attention_mask": attention_mask, "labels": labels}, ) logp = -outputs.loss @@ -650,14 +664,15 @@ def m(params, batch): """ input_ids, attention_mask, labels = batch - input_ids = input_ids.cuda() - attention_mask = attention_mask.cuda() - labels = labels.cuda() + # Re-add batch dimension removed by vmap + input_ids = input_ids.unsqueeze(0).cuda() + attention_mask = attention_mask.unsqueeze(0).cuda() + labels = labels.unsqueeze(0).cuda() outputs = torch.func.functional_call( model, params, - input_ids, + (input_ids,), # Pass as tuple to avoid dimension issues kwargs={"attention_mask": attention_mask, "labels": labels}, ) p = torch.exp(-outputs.loss) @@ -669,35 +684,70 @@ def loss_tracin(params, batch): (TracIn sums over checkpoint updates of gradient dot-products). """ input_ids, attention_mask, labels = batch - input_ids = input_ids.cuda() - attention_mask = attention_mask.cuda() - labels = labels.cuda() + + # Re-add batch dimension removed by vmap + input_ids = input_ids.unsqueeze(0).cuda() + attention_mask = attention_mask.unsqueeze(0).cuda() + labels = labels.unsqueeze(0).cuda() + outputs = torch.func.functional_call( model, params, - input_ids, + (input_ids,), # Pass as tuple to avoid dimension issues kwargs={"attention_mask": attention_mask, "labels": labels}, ) return outputs.loss method = args.method + + #this gets the existing checkpoints in the output directory + checkpoint_root_abs = Path(args.output_dir).resolve() + existing_ckpt_dirs = [p for p in checkpoint_root_abs.iterdir() if p.is_dir()] + existing_names = {p.name for p in existing_ckpt_dirs} + has_minus1 = "-1" in existing_names + numeric_sorted = sorted([int(n) for n in existing_names if n.isdigit()]) + numeric_count = len(numeric_sorted) + if method.startswith("TRAK-"): parts = method.split("-") if len(parts) == 2 and parts[1].isdigit(): num_checkpoints = int(parts[1]) + # requested_checkpoints = int(parts[1]) else: raise ValueError( "Invalid method name for TRAK, must be like 'TRAK-5' or 'TRAK-10'." ) - checkpoints = [f"{args.output_dir}/{i}" for i in range(num_checkpoints)] + #prevent checkpoint id error when only -1 is present + if has_minus1 and numeric_count == 0: + selected_indices = [-1] + else: + if numeric_count == 0 and not has_minus1: + raise FileNotFoundError( f"No numeric checkpoint directories found in {checkpoint_root_abs}." ) + if numeric_count > 0 and has_minus1: + selected_indices = list(range(-1, min(num_checkpoints, numeric_count))) + if numeric_count > 0 and not has_minus1: + selected_indices = list(range(min(num_checkpoints, numeric_count))) + checkpoints = [str(checkpoint_root_abs / str(i)) for i in selected_indices] + elif method in ["TracIn", "Grad-Dot", "Grad-Cos"]: num_checkpoints = 5 - checkpoints = [f"{args.output_dir}/{i}" for i in range(num_checkpoints)] + if has_minus1 and numeric_count == 0: + selected_indices = [-1] + else: #prevent checkpoint id error when only -1 is present + if numeric_count == 0 and not has_minus1: + raise FileNotFoundError( f"No numeric checkpoint directories found in {checkpoint_root_abs}.") + if numeric_count > 0 and has_minus1: + selected_indices = list(range(-1, min(num_checkpoints, numeric_count))) + if numeric_count > 0 and not has_minus1: + selected_indices = list(range(min(num_checkpoints, numeric_count))) + checkpoints = [str(checkpoint_root_abs / str(i)) for i in selected_indices] + else: raise ValueError( f"Unknown --method {method}. Try 'TRAK-5', 'TracIn', 'Grad-Dot', or 'Grad-Cos'." ) - + + #modified for huggingface hub validation error def checkpoints_load_func(model, checkpoint_path): new_model = AutoModelForCausalLM.from_pretrained(checkpoint_path).cuda() new_model.eval() @@ -721,9 +771,12 @@ def checkpoints_load_func(model, checkpoint_path): ) if method.startswith("TRAK"): + # fix memory issues projector_kwargs = { "device": "cuda", - "proj_dim": 2048, + "proj_dim": args.proj_dim, + "proj_max_batch_size": args.proj_max_batch_size, + "proj_type": args.proj_type, } attributor = TRAKAttributor( task=task, @@ -737,11 +790,16 @@ def checkpoints_load_func(model, checkpoint_path): if method == "Grad-Cos": normalized_grad = True + #get the number of checkpoints + num_checkpoints = len(checkpoints) weight_list = torch.ones(num_checkpoints) * 1e-3 + # fix memory issues projector_kwargs = { "device": "cuda", - "proj_dim": 2048, + "proj_dim": args.proj_dim, + "proj_max_batch_size": args.proj_max_batch_size, + "proj_type": args.proj_type, } attributor = TracInAttributor( diff --git a/experiments/gpt2_wikitext/score_logra.py b/experiments/gpt2_wikitext/score_logra.py index 34db0be68..6e4104a60 100644 --- a/experiments/gpt2_wikitext/score_logra.py +++ b/experiments/gpt2_wikitext/score_logra.py @@ -54,15 +54,8 @@ default_data_collator, get_scheduler, ) -from transformers.utils import check_min_version +from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version - -# send_example_telemetry was removed in newer versions of transformers -try: - from transformers.utils import send_example_telemetry -except ImportError: - send_example_telemetry = None - from dattri.benchmark.utils import SubsetSampler @@ -328,10 +321,7 @@ def parse_args(): def main(): args = parse_args() - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The - # information sent is the one passed as arguments along with your Python/PyTorch versions. - if send_example_telemetry is not None: - send_example_telemetry("run_clm_no_trainer", args) + send_example_telemetry("run_clm_no_trainer", args) # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -490,6 +480,7 @@ def main(): config=config, low_cpu_mem_usage=args.low_cpu_mem_usage, trust_remote_code=args.trust_remote_code, + attn_implementation="eager", # Use eager attention for better performance ) else: logger.info("Training new model from scratch") @@ -597,7 +588,7 @@ def group_texts(examples): from transformers.pytorch_utils import Conv1D from dattri.task import AttributionTask - model_id = 0 + model_id = -1 # Use checkpoint 0 (final checkpoint) checkpoint = f"{args.output_dir}/{model_id}" def checkpoints_load_func(model, checkpoint):