diff --git a/benchmarks/fp8/ms_amp/Dockerfile b/benchmarks/fp8/ms_amp/Dockerfile new file mode 100644 index 00000000000..dd4d7c25297 --- /dev/null +++ b/benchmarks/fp8/ms_amp/Dockerfile @@ -0,0 +1,13 @@ +FROM ghcr.io/azure/msamp + +RUN pip install transformers evaluate datasets +RUN git clone https://github.com/huggingface/accelerate + +RUN cd accelerate && \ + pip install -e . && \ + cd benchmarks/fp8 + +CMD ["bash"] + + + diff --git a/benchmarks/fp8/ms_amp/ddp.py b/benchmarks/fp8/ms_amp/ddp.py new file mode 100644 index 00000000000..6e7530df9fc --- /dev/null +++ b/benchmarks/fp8/ms_amp/ddp.py @@ -0,0 +1,122 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. + +This particular script verifies this for DDP training. +""" +import evaluate +import msamp +import torch +from fp8_utils import evaluate_model, get_training_utilities +from torch.nn.parallel import DistributedDataParallel as DDP + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(opt_level="O2"): + set_seed(42) + scaler = torch.cuda.amp.GradScaler() + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + + model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level) + + model.to(device) + + # Convert the model to DDP + device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index + model = DDP(model, device_ids=device_ids, output_device=output_device) + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + scaler.scale(loss).backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(opt_level="O2"): + kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] + AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + for opt_level in ["O1", "O2"]: + baseline_not_trained, baseline_trained = train_baseline(opt_level) + accelerator_not_trained, accelerator_trained = train_integration(opt_level) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/benchmarks/fp8/ms_amp/distrib_deepspeed.py b/benchmarks/fp8/ms_amp/distrib_deepspeed.py new file mode 100644 index 00000000000..57a2569729f --- /dev/null +++ b/benchmarks/fp8/ms_amp/distrib_deepspeed.py @@ -0,0 +1,161 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. + +This particular script verifies this for DeepSpeed training. + +NOTE: MS-AMP does *not* support ZeRO-3. +""" + +# import msamp.deepspeed as msamp_deepspeed +import evaluate +import torch +from fp8_utils import evaluate_model, get_training_utilities +from msamp import deepspeed as msamp_deepspeed + +from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.state import AcceleratorState +from accelerate.utils import set_seed + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(zero_stage: int = 1, opt_level: str = "O1"): + set_seed(42) + accelerator = Accelerator() + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + import numpy as np + + config = { + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": zero_stage, + "offload_optimizer": {"device": "none", "nvme_path": None}, + "offload_param": {"device": "none", "nvme_path": None}, + }, + "gradient_clipping": 1.0, + "steps_per_print": np.inf, + "bf16": {"enabled": True}, + "fp16": {"enabled": False}, + "zero_allow_untested_optimizer": True, + "msamp": { + "enabled": True, + "opt_level": opt_level, + }, + } + ( + model, + optimizer, + _, + _, + ) = msamp_deepspeed.initialize( + model=model, + optimizer=optimizer, + config_params=config, + ) + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + model.backward(loss) + model.step() + for _ in range(accelerator.num_processes): + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + torch.cuda.empty_cache() + AcceleratorState()._reset_state(True) + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(zero_stage: int = 1, opt_level: str = "O1"): + set_seed(42) + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=zero_stage, + enable_msamp=True, + msamp_opt_level=opt_level, + ) + accelerator = Accelerator(mixed_precision="fp8", deepspeed_plugin=deepspeed_plugin) + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 + + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + torch.cuda.empty_cache() + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + AcceleratorState()._reset_state(True) + return base_model_results, trained_model_results + + +if __name__ == "__main__": + for zero_stage in [1, 2]: + for opt_level in ["O1", "O2", "O3"]: + baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level) + accelerator_not_trained, accelerator_trained = train_integration(zero_stage, opt_level) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/ms_amp/fp8_utils.py b/benchmarks/fp8/ms_amp/fp8_utils.py new file mode 100644 index 00000000000..602ce07fdc6 --- /dev/null +++ b/benchmarks/fp8/ms_amp/fp8_utils.py @@ -0,0 +1,118 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def get_dataloaders(model_name: str, batch_size: int = 16): + from datasets import load_dataset + from torch.utils.data import DataLoader + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + return tokenizer.pad( + examples, + padding="longest", + pad_to_multiple_of=16, # Specific for FP8 + return_tensors="pt", + ) + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], + shuffle=False, + collate_fn=collate_fn, + batch_size=16, + drop_last=True, + ) + + return train_dataloader, eval_dataloader + + +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + from torch.optim import AdamW + from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup + + from accelerate import Accelerator + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def get_named_parameters(model): + """ + Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted + from parallel) + """ + from accelerate.utils import extract_model_from_parallel + + model = extract_model_from_parallel(model) + return {n: p for n, p in model.named_parameters()} + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + # W/ MS-AMP, we need to cast while evaluating + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() diff --git a/benchmarks/fp8/ms_amp/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py new file mode 100644 index 00000000000..396cb6b2981 --- /dev/null +++ b/benchmarks/fp8/ms_amp/fsdp.py @@ -0,0 +1,149 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. + +This particular script verifies this for FSDP training. +""" +from functools import partial + +import evaluate +import msamp +import torch +from fp8_utils import evaluate_model, get_training_utilities +from msamp.common.dtype import Dtypes +from msamp.fsdp import FP8FullyShardedDataParallel +from msamp.optim import FSDPAdamW +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers.models.bert import BertLayer + +from accelerate import Accelerator +from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin +from accelerate.utils import FP8RecipeKwargs, set_seed + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") +FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) + + +def train_baseline(opt_level="O2"): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + model, optimizer = msamp.initialize( + model, optimizer, opt_level=opt_level, weight_qtype=Dtypes.kfloat8_e4m3, use_fsdp=True + ) + + model = FP8FullyShardedDataParallel( + model, + use_orig_params=True, + auto_wrap_policy=FSDP_WRAP_POLICY, + cpu_offload=False, + sync_module_states=False, + backward_prefetch=None, + forward_prefetch=False, + limit_all_gathers=True, + device_id=device, + ) + optimizer = FSDPAdamW(optimizer) + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.free_memory( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Baseline: Opt level {opt_level}: Accuracy should be higher for the trained model: {base_model_results["accuracy"]} < {trained_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'Baseline: Opt level {opt_level}: F1 score should be higher for the trained model: {base_model_results["f1"]} < {trained_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(opt_level="O2"): + kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] + # AcceleratorState()._reset_state(True) + fsdp_plugin = FSDPPlugin( + auto_wrap_policy=FSDP_WRAP_POLICY, + use_orig_params=True, + ) + accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.free_memory( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Integration: Opt level {opt_level}: Accuracy should be higher for the trained model: {base_model_results["accuracy"]} < {trained_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'Integration: Opt level {opt_level}: F1 score should be higher for the trained model: {base_model_results["f1"]} < {trained_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + # baseline_not_trained, baseline_trained = train_baseline("O1") + # accelerator_not_trained, accelerator_trained = train_integration("O1") + # print(baseline_trained) + for opt_level in ["O1", "O2"]: + baseline_not_trained, baseline_trained = train_baseline(opt_level) + accelerator_not_trained, accelerator_trained = train_integration(opt_level) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/benchmarks/fp8/ms_amp/non_distributed.py b/benchmarks/fp8/ms_amp/non_distributed.py new file mode 100644 index 00000000000..791cea108ed --- /dev/null +++ b/benchmarks/fp8/ms_amp/non_distributed.py @@ -0,0 +1,117 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. + +This particular script verifies this for single GPU training. +""" +import evaluate +import msamp +import torch +from fp8_utils import evaluate_model, get_training_utilities + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(opt_level="O2"): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + + model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level) + model.to("cuda") + + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + scaler = torch.cuda.amp.GradScaler() + + for batch in train_dataloader: + batch = batch.to("cuda") + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + loss = scaler.scale(loss) + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(opt_level="O2"): + kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] + AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + for opt_level in ["O1", "O2"]: + baseline_not_trained, baseline_trained = train_baseline(opt_level) + accelerator_not_trained, accelerator_trained = train_integration(opt_level) + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/benchmarks/fp8/Dockerfile b/benchmarks/fp8/transformer_engine/Dockerfile similarity index 100% rename from benchmarks/fp8/Dockerfile rename to benchmarks/fp8/transformer_engine/Dockerfile diff --git a/benchmarks/fp8/ddp.py b/benchmarks/fp8/transformer_engine/ddp.py similarity index 100% rename from benchmarks/fp8/ddp.py rename to benchmarks/fp8/transformer_engine/ddp.py diff --git a/benchmarks/fp8/distrib_deepspeed.py b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py similarity index 100% rename from benchmarks/fp8/distrib_deepspeed.py rename to benchmarks/fp8/transformer_engine/distrib_deepspeed.py diff --git a/benchmarks/fp8/fp8_utils.py b/benchmarks/fp8/transformer_engine/fp8_utils.py similarity index 100% rename from benchmarks/fp8/fp8_utils.py rename to benchmarks/fp8/transformer_engine/fp8_utils.py diff --git a/benchmarks/fp8/fsdp.py b/benchmarks/fp8/transformer_engine/fsdp.py similarity index 100% rename from benchmarks/fp8/fsdp.py rename to benchmarks/fp8/transformer_engine/fsdp.py diff --git a/benchmarks/fp8/non_distributed.py b/benchmarks/fp8/transformer_engine/non_distributed.py similarity index 100% rename from benchmarks/fp8/non_distributed.py rename to benchmarks/fp8/transformer_engine/non_distributed.py diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 91006258efa..b8d9f1f1b07 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -104,7 +104,7 @@ save_fsdp_optimizer, wait_for_everyone, ) -from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME +from .utils.constants import PROFILE_PATTERN_NAME from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module @@ -307,11 +307,11 @@ def __init__( deepspeed_plugin.set_mixed_precision(mixed_precision) deepspeed_plugin.set_deepspeed_weakref() - if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance( - fsdp_plugin, FullyShardedDataParallelPlugin - ): - if is_torch_version("<", FSDP_PYTORCH_VERSION): - raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") + # if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance( + # fsdp_plugin, FullyShardedDataParallelPlugin + # ): + # if is_torch_version("<", FSDP_PYTORCH_VERSION): + # raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") if fsdp_plugin is None: # init from env variables fsdp_plugin = ( @@ -402,7 +402,7 @@ def __init__( self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) ): raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.") - self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in ( + self.delayed_fp8_autocast = self.fp8_backend == "TE" and self.distributed_type in ( DistributedType.MULTI_GPU, DistributedType.FSDP, ) @@ -507,6 +507,12 @@ def __init__( elif self.state.mixed_precision == "fp8": # We always enable `native_amp` for FP8 self.native_amp = True + # MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP: + if self.fp8_backend == "MSAMP" and self.distributed_type not in ( + DistributedType.FSDP, + DistributedType.DEEPSPEED, + ): + self.scaler = torch.cuda.amp.GradScaler() # Start of internal step tracking self.step = 0 @@ -1193,8 +1199,7 @@ def _prepare_one(self, obj, first_pass=False, device_placement=None): elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) elif isinstance(obj, torch.optim.Optimizer): - optimizer = self.prepare_optimizer(obj, device_placement=device_placement) - return optimizer + return self.prepare_optimizer(obj, device_placement=device_placement) # Second pass of preparation: LR scheduler (which need the full list of optimizers) elif isinstance(obj, LRScheduler): scheduler = self.prepare_scheduler(obj) @@ -1306,21 +1311,20 @@ def prepare(self, *args, device_placement=None): args = self._prepare_ipex_or_xpu(*args) elif self.device.type == "xpu" and is_xpu_available(): args = self._prepare_ipex_or_xpu(*args) - if self.fp8_recipe_handler is not None and self.fp8_recipe_handler.backend == "TE": + if self.fp8_backend == "TE": args = self._prepare_te(*args) if self.distributed_type == DistributedType.DEEPSPEED: result = self._prepare_deepspeed(*args) elif self.distributed_type == DistributedType.MEGATRON_LM: result = self._prepare_megatron_lm(*args) else: - if self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP": - args = self._prepare_msamp(*args) - # MS-AMP will handle the device placement - device_placement = [False for _ in args] + if self.fp8_backend == "MSAMP": + args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) result = tuple( self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) ) result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) + if tpu_should_fix_optimizer: # 2. grabbing new model parameters new_named_params = self._get_named_parameters(*result) @@ -1331,6 +1335,16 @@ def prepare(self, *args, device_placement=None): if isinstance(obj, torch.optim.Optimizer): obj._switch_parameters(mapping) + if self.distributed_type == DistributedType.FSDP and self.fp8_backend == "MSAMP": + # We need to convert the underlying optimizer to FSDPAdamW *after* FSDP wrapping + result = list(result) + from msamp.optim import FSDPAdamW + + for i, obj in enumerate(result): + if isinstance(obj, AcceleratedOptimizer): + result[i].optimizer = FSDPAdamW(optimizer=obj.optimizer) + result = tuple(result) + for item in result: if any( item in container @@ -1379,19 +1393,24 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." ) - if self.native_amp: model._original_forward = model.forward - model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + # NOTE: MS-AMP is special, and adds a `__func__` already to `model.forward` + # When enabled, strictly use `model.forward` autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) - new_forward = autocast_context(model_forward_func) - if hasattr(model.forward, "__func__"): - model.forward = MethodType(new_forward, model) - model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + if self.fp8_backend == "MSAMP": + model_forward_func = model.forward + model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) else: - model.forward = convert_outputs_to_fp32(new_forward) + model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + new_forward = autocast_context(model_forward_func) + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + else: + model.forward = convert_outputs_to_fp32(new_forward) - # We prepare fp8 after, allowing for bf16 autocast to happen first + # We prepare TE fp8 after, allowing for bf16 autocast to happen first if getattr(self.fp8_recipe_handler, "backend", None) == "TE" and not self.delayed_fp8_autocast: model = apply_fp8_autowrap(model, self.fp8_recipe_handler) @@ -1440,7 +1459,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e device_ids, output_device = [self.local_process_index], self.local_process_index else: device_ids, output_device = None, None - model = torch.nn.parallel.DistributedDataParallel( model, device_ids=device_ids, output_device=output_device, **kwargs ) @@ -1448,7 +1466,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e self.ddp_handler.register_comm_hook(model) elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + if self.mixed_precision == "fp8" and self.fp8_backend == "MSAMP": + # MS-AMP uses a patched version of FSDP + from msamp.fsdp import FP8FullyShardedDataParallel as FSDP + else: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again @@ -1618,6 +1640,13 @@ def _prepare_te(self, *args): def _prepare_deepspeed(self, *args): import deepspeed + ds_initialize = deepspeed.initialize + if self.fp8_backend == "MSAMP": + # MS-AMP requires DeepSpeed patches + from msamp import deepspeed as msamp_deepspeed + + ds_initialize = msamp_deepspeed.initialize + deepspeed_plugin = self.state.deepspeed_plugin is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) @@ -1806,7 +1835,7 @@ def _prepare_deepspeed(self, *args): if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES: kwargs["lr_scheduler"] = scheduler - engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) + engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs) if optimizer is not None: optimizer = DeepSpeedOptimizerWrapper(optimizer) if scheduler is not None: @@ -1983,7 +2012,7 @@ def _prepare_ipex_or_xpu(self, *args): result[i] = optimizer return tuple(result) - def _prepare_msamp(self, *args): + def _prepare_msamp(self, *args, device_placement): if not is_msamp_available(): raise ImportError( "MS-AMP was not found on your system. Please ensure that MS-AMP is available " @@ -1991,18 +2020,20 @@ def _prepare_msamp(self, *args): ) else: import msamp - model, optimizer = None, None num_models, num_optimizers = 0, 0 result = [obj for obj in args] - for obj in result: + for i, obj in enumerate(result): if isinstance(obj, torch.nn.Module): model = obj num_models += 1 elif isinstance(obj, (torch.optim.Optimizer)): optimizer = obj + optimizer_index = i num_optimizers += 1 - if optimizer is None or model is None: + if optimizer is None and model is None: + return result, device_placement + elif optimizer is None or model is None: raise ValueError( "You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP." ) @@ -2011,13 +2042,28 @@ def _prepare_msamp(self, *args): f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP." ) else: - model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level) + if self.distributed_type == DistributedType.FSDP: + # We need to set the auto_wrap policy before initializing the model + self.state.fsdp_plugin.set_auto_wrap_policy(model) + # NOTE: MS-AMP fsdp relies on it's own MP policy, we must drop the users + self.state.fsdp_plugin.mixed_precision_policy = None + from msamp.common.dtype import Dtypes + + msamp_init_args = dict(model=model, optimizer=optimizer, opt_level=self.fp8_recipe_handler.opt_level) + if self.distributed_type == DistributedType.FSDP: + msamp_init_args["use_fsdp"] = True + msamp_init_args["weight_qtype"] = Dtypes.kfloat8_e4m3 + model, optimizer = msamp.initialize(**msamp_init_args) + for i in range(len(result)): if isinstance(result[i], torch.nn.Module): result[i] = model elif isinstance(result[i], (torch.optim.Optimizer)): result[i] = optimizer - return tuple(result) + if optimizer_index is not None: + # NOTE: MS-AMP moves the optimizer, but not the model + device_placement[optimizer_index] = False + return tuple(result), device_placement def prepare_data_loader( self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None @@ -2109,7 +2155,9 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N return optimizer if device_placement is None: device_placement = self.device_placement - optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler) + # NOTE: Special case: with MS-AMP we do *not* pass in the scaler, optimizer handles it for us + scaler = None if self.fp8_backend == "MSAMP" else self.scaler + optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler) self._optimizers.append(optimizer) return optimizer @@ -3555,3 +3603,12 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None: raise ValueError( "Backward pass not properly called on LOMO optimizers. Are you sure you passed a LOMO optimizer in accelerator.prepare()?" ) + + @property + def fp8_backend(self): + "Returns the configured backend for training in FP8" + if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None: + return self.fp8_recipe_handler.backend + elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: + return "MSAMP" + return None diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 919b7fadc2b..53d55ab4e70 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -972,6 +972,16 @@ class DeepSpeedPlugin: " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..." }, ) + enable_msamp: bool = field( + default=None, + metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."}, + ) + msamp_opt_level: str = field( + default=None, + metadata={ + "help": "Optimization level for MS-AMP. Only applicable if `enable_msamp` is True. Should be one of ['O1', 'O2', 'O3']." + }, + ) def __post_init__(self): from .deepspeed import HfDeepSpeedConfig @@ -1006,6 +1016,12 @@ def __post_init__(self): os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true" ) + if self.enable_msamp is None: + self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP" + + if self.msamp_opt_level is None: + self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1") + if self.hf_ds_config is None: self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") if ( @@ -1075,6 +1091,14 @@ def __post_init__(self): if self.zero3_init_flag and not self.hf_ds_config.is_zero3(): warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.") self.zero3_init_flag = False + if self.enable_msamp: + if self.zero_stage == 3: + raise NotImplementedError( + "MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead." + ) + if self.msamp_opt_level not in ["O1", "O2", "O3"]: + raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1', 'O2', 'O3'].") + self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level} def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs): mismatches = [] if mismatches is None else mismatches @@ -1144,6 +1168,10 @@ def set_mixed_precision(self, mixed_precision): if "bf16" not in ds_config: ds_config["bf16"] = {"enabled": True} + if mixed_precision == "fp8" and self.enable_msamp: + if "msamp" not in ds_config: + ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level} + if mixed_precision != "no": diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true":