From 569c23dad84b22adc9ca8f7e5cfe05c5b728ce0a Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 14 Aug 2024 21:18:24 -0400 Subject: [PATCH 01/14] Working MS-AMP implementation + some fixes for accelerate --- benchmarks/fp8/ms_amp/Dockerfile | 13 ++ benchmarks/fp8/{ => ms_amp}/ddp.py | 0 .../fp8/{ => ms_amp}/distrib_deepspeed.py | 0 benchmarks/fp8/ms_amp/fp8_utils.py | 116 +++++++++++ benchmarks/fp8/{ => ms_amp}/fsdp.py | 0 benchmarks/fp8/ms_amp/non_distributed.py | 117 +++++++++++ .../fp8/{ => transformer_engine}/Dockerfile | 0 benchmarks/fp8/transformer_engine/ddp.py | 143 +++++++++++++ .../transformer_engine/distrib_deepspeed.py | 189 ++++++++++++++++++ .../fp8/{ => transformer_engine}/fp8_utils.py | 0 benchmarks/fp8/transformer_engine/fsdp.py | 160 +++++++++++++++ .../non_distributed.py | 0 src/accelerate/accelerator.py | 38 ++-- 13 files changed, 762 insertions(+), 14 deletions(-) create mode 100644 benchmarks/fp8/ms_amp/Dockerfile rename benchmarks/fp8/{ => ms_amp}/ddp.py (100%) rename benchmarks/fp8/{ => ms_amp}/distrib_deepspeed.py (100%) create mode 100644 benchmarks/fp8/ms_amp/fp8_utils.py rename benchmarks/fp8/{ => ms_amp}/fsdp.py (100%) create mode 100644 benchmarks/fp8/ms_amp/non_distributed.py rename benchmarks/fp8/{ => transformer_engine}/Dockerfile (100%) create mode 100644 benchmarks/fp8/transformer_engine/ddp.py create mode 100644 benchmarks/fp8/transformer_engine/distrib_deepspeed.py rename benchmarks/fp8/{ => transformer_engine}/fp8_utils.py (100%) create mode 100644 benchmarks/fp8/transformer_engine/fsdp.py rename benchmarks/fp8/{ => transformer_engine}/non_distributed.py (100%) diff --git a/benchmarks/fp8/ms_amp/Dockerfile b/benchmarks/fp8/ms_amp/Dockerfile new file mode 100644 index 00000000000..d2d1c130e12 --- /dev/null +++ b/benchmarks/fp8/ms_amp/Dockerfile @@ -0,0 +1,13 @@ +FROM ghcr.io/azure/msamp + +RUN pip install transformers evaluate datasets +# RUN git clone https://github.com/huggingface/accelerate + +# RUN cd accelerate && \ +# pip install -e . && \ +# cd benchmarks/fp8 + +CMD ["bash"] + + + diff --git a/benchmarks/fp8/ddp.py b/benchmarks/fp8/ms_amp/ddp.py similarity index 100% rename from benchmarks/fp8/ddp.py rename to benchmarks/fp8/ms_amp/ddp.py diff --git a/benchmarks/fp8/distrib_deepspeed.py b/benchmarks/fp8/ms_amp/distrib_deepspeed.py similarity index 100% rename from benchmarks/fp8/distrib_deepspeed.py rename to benchmarks/fp8/ms_amp/distrib_deepspeed.py diff --git a/benchmarks/fp8/ms_amp/fp8_utils.py b/benchmarks/fp8/ms_amp/fp8_utils.py new file mode 100644 index 00000000000..d28702e05ff --- /dev/null +++ b/benchmarks/fp8/ms_amp/fp8_utils.py @@ -0,0 +1,116 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def get_dataloaders(model_name: str, batch_size: int = 16): + from datasets import load_dataset + from torch.utils.data import DataLoader + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + return tokenizer.pad( + examples, + padding="longest", + pad_to_multiple_of=16, # Specific for FP8 + return_tensors="pt", + ) + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], + shuffle=False, + collate_fn=collate_fn, + batch_size=16, + drop_last=True, + ) + + return train_dataloader, eval_dataloader + + +def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None): + """ + Returns a tuple of: + - Model + - Optimizer + - Train dataloader (prepared) + - Eval dataloader (prepared) + - LR Scheduler + Suitable for training on the MRPC dataset + """ + from torch.optim import AdamW + from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup + + from accelerate import Accelerator + + if accelerator is None: + accelerator = Accelerator() + model = AutoModelForSequenceClassification.from_pretrained(model_name) + train_dataloader, eval_dataloader = get_dataloaders(model_name, batch_size) + optimizer = AdamW(model.parameters(), lr=0.0001) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + return model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + + +def get_named_parameters(model): + """ + Same thing as `Accelerator.get_named_parameters` Returns a list of the named parameters of the model (extracted + from parallel) + """ + from accelerate.utils import extract_model_from_parallel + + model = extract_model_from_parallel(model) + return {n: p for n, p in model.named_parameters()} + + +def evaluate_model(model, dataloader, metric, accelerator=None): + "Turns model to .eval(), runs dataloader, calculates metric, then turns eval back on" + model.eval() + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] + if accelerator is not None and accelerator.num_processes > 1: + predictions, references = accelerator.gather_for_metrics((predictions, references)) + metric.add_batch(predictions=predictions, references=references) + return metric.compute() diff --git a/benchmarks/fp8/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py similarity index 100% rename from benchmarks/fp8/fsdp.py rename to benchmarks/fp8/ms_amp/fsdp.py diff --git a/benchmarks/fp8/ms_amp/non_distributed.py b/benchmarks/fp8/ms_amp/non_distributed.py new file mode 100644 index 00000000000..23383a2c245 --- /dev/null +++ b/benchmarks/fp8/ms_amp/non_distributed.py @@ -0,0 +1,117 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. + +This particular script verifies this for single GPU training. +""" +import evaluate +import torch +import msamp +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from transformer_engine.common.recipe import DelayedScaling + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(opt_level="O2"): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + + model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level) + model.to("cuda") + + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + scaler = torch.cuda.amp.GradScaler() + + for batch in train_dataloader: + batch = batch.to("cuda") + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) + loss = outputs.loss + loss = scaler.scale(loss) + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(opt_level="O2"): + kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] + # AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC) + model.train() + + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + # for opt_level in ["O1", "O2"]: + baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/benchmarks/fp8/Dockerfile b/benchmarks/fp8/transformer_engine/Dockerfile similarity index 100% rename from benchmarks/fp8/Dockerfile rename to benchmarks/fp8/transformer_engine/Dockerfile diff --git a/benchmarks/fp8/transformer_engine/ddp.py b/benchmarks/fp8/transformer_engine/ddp.py new file mode 100644 index 00000000000..d14e086ce71 --- /dev/null +++ b/benchmarks/fp8/transformer_engine/ddp.py @@ -0,0 +1,143 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for DDP training. +""" +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from torch.nn.parallel import DistributedDataParallel as DDP +from transformer_engine.common.recipe import DelayedScaling + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + model.to(device) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + new_named_params = get_named_parameters(model) + + # Convert the model to DDP + device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index + model = DDP(model, device_ids=device_ids, output_device=output_device) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/transformer_engine/distrib_deepspeed.py b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py new file mode 100644 index 00000000000..291d09ec103 --- /dev/null +++ b/benchmarks/fp8/transformer_engine/distrib_deepspeed.py @@ -0,0 +1,189 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for DDP training. +""" +from unittest.mock import patch + +import deepspeed +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from transformer_engine.common.recipe import DelayedScaling + +from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + + +def train_baseline(zero_stage: int = 1): + # This forces transformers to think Zero-3 Init should be used + with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: + mock.return_value = zero_stage == 3 + set_seed(42) + + accelerator = Accelerator() + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + new_named_params = get_named_parameters(model) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + import numpy as np + + config = { + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, + "gradient_accumulation_steps": 1, + "zero_optimization": { + "stage": zero_stage, + "offload_optimizer": {"device": "none", "nvme_path": None}, + "offload_param": {"device": "none", "nvme_path": None}, + "stage3_gather_16bit_weights_on_model_save": False, + }, + "gradient_clipping": 1.0, + "steps_per_print": np.inf, + "bf16": {"enabled": True}, + "fp16": {"enabled": False}, + "zero_allow_untested_optimizer": True, + } + + ( + model, + optimizer, + _, + _, + ) = deepspeed.initialize( + model=model, + optimizer=optimizer, + config_params=config, + ) + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + model_outputs = [] + data = [] + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + model.backward(loss) + model.step() + for _ in range(accelerator.num_processes): + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results, model_outputs, data + + +def train_integration(zero_stage: int = 1): + set_seed(42) + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=zero_stage, + zero3_init_flag=zero_stage == 3, + ) + accelerator = Accelerator( + mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin + ) + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 + + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + model_outputs = [] + data = [] + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + data.append(batch.to("cpu")) + model_outputs.append(outputs.logits.to("cpu")) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.destroy() + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results, model_outputs, data + + +if __name__ == "__main__": + # for zero_stage in [1, 2, 3]: + zero_stage = 1 + baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) + accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/fp8_utils.py b/benchmarks/fp8/transformer_engine/fp8_utils.py similarity index 100% rename from benchmarks/fp8/fp8_utils.py rename to benchmarks/fp8/transformer_engine/fp8_utils.py diff --git a/benchmarks/fp8/transformer_engine/fsdp.py b/benchmarks/fp8/transformer_engine/fsdp.py new file mode 100644 index 00000000000..42d35e0dd5e --- /dev/null +++ b/benchmarks/fp8/transformer_engine/fsdp.py @@ -0,0 +1,160 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. + +This particular script verifies this for FSDP training. +""" +from functools import partial + +import evaluate +import torch +import transformer_engine.common.recipe as te_recipe +import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformer_engine.common.recipe import DelayedScaling +from transformers.models.bert import BertLayer + +from accelerate import Accelerator +from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin +from accelerate.state import AcceleratorState +from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils.transformer_engine import convert_model + + +MODEL_NAME = "bert-base-cased" +METRIC = evaluate.load("glue", "mrpc") + +FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) + + +def train_baseline(): + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) + accelerator = Accelerator() + device = accelerator.device + model.to(device) + + # Convert the model to TE + old_named_params = get_named_parameters(model) + + with torch.no_grad(): + convert_model(model) + + FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} + fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) + + new_named_params = get_named_parameters(model) + + # Convert the model to FSDP + model = FSDP( + model, + use_orig_params=True, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + auto_wrap_policy=FSDP_WRAP_POLICY, + ) + + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + for param_group in optimizer.param_groups: + param_group["params"] = [mapping[p] for p in param_group["params"]] + + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + batch = batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +def train_integration(): + FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} + kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] + AcceleratorState()._reset_state(True) + fsdp_plugin = FSDPPlugin( + auto_wrap_policy=FSDP_WRAP_POLICY, + use_orig_params=True, + mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), + ) + accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) + set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( + MODEL_NAME, accelerator=accelerator + ) + + model, optimizer = accelerator.prepare(model, optimizer) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model.train() + + for _ in range(2): + for batch in train_dataloader: + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + + return base_model_results, trained_model_results + + +if __name__ == "__main__": + baseline_not_trained, baseline_trained = train_baseline() + accelerator_not_trained, accelerator_trained = train_integration() + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + + torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/non_distributed.py b/benchmarks/fp8/transformer_engine/non_distributed.py similarity index 100% rename from benchmarks/fp8/non_distributed.py rename to benchmarks/fp8/transformer_engine/non_distributed.py diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 91006258efa..a7a61bded46 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -507,6 +507,9 @@ def __init__( elif self.state.mixed_precision == "fp8": # We always enable `native_amp` for FP8 self.native_amp = True + # MS-AMP requires grad scaler however + if self.fp8_recipe_handler.backend == "MSAMP": + self.scaler = torch.cuda.amp.GradScaler() # Start of internal step tracking self.step = 0 @@ -1193,8 +1196,7 @@ def _prepare_one(self, obj, first_pass=False, device_placement=None): elif isinstance(obj, torch.nn.Module): return self.prepare_model(obj, device_placement=device_placement) elif isinstance(obj, torch.optim.Optimizer): - optimizer = self.prepare_optimizer(obj, device_placement=device_placement) - return optimizer + return self.prepare_optimizer(obj, device_placement=device_placement) # Second pass of preparation: LR scheduler (which need the full list of optimizers) elif isinstance(obj, LRScheduler): scheduler = self.prepare_scheduler(obj) @@ -1306,17 +1308,16 @@ def prepare(self, *args, device_placement=None): args = self._prepare_ipex_or_xpu(*args) elif self.device.type == "xpu" and is_xpu_available(): args = self._prepare_ipex_or_xpu(*args) - if self.fp8_recipe_handler is not None and self.fp8_recipe_handler.backend == "TE": - args = self._prepare_te(*args) + if self.fp8_recipe_handler is not None: + if self.fp8_recipe_handler.backend == "TE": + args = self._prepare_te(*args) + elif self.fp8_recipe_handler.backend == "MSAMP": + args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) if self.distributed_type == DistributedType.DEEPSPEED: result = self._prepare_deepspeed(*args) elif self.distributed_type == DistributedType.MEGATRON_LM: result = self._prepare_megatron_lm(*args) else: - if self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP": - args = self._prepare_msamp(*args) - # MS-AMP will handle the device placement - device_placement = [False for _ in args] result = tuple( self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) ) @@ -1391,7 +1392,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e else: model.forward = convert_outputs_to_fp32(new_forward) - # We prepare fp8 after, allowing for bf16 autocast to happen first + # We prepare TE fp8 after, allowing for bf16 autocast to happen first if getattr(self.fp8_recipe_handler, "backend", None) == "TE" and not self.delayed_fp8_autocast: model = apply_fp8_autowrap(model, self.fp8_recipe_handler) @@ -1983,7 +1984,7 @@ def _prepare_ipex_or_xpu(self, *args): result[i] = optimizer return tuple(result) - def _prepare_msamp(self, *args): + def _prepare_msamp(self, *args, device_placement): if not is_msamp_available(): raise ImportError( "MS-AMP was not found on your system. Please ensure that MS-AMP is available " @@ -1995,14 +1996,17 @@ def _prepare_msamp(self, *args): model, optimizer = None, None num_models, num_optimizers = 0, 0 result = [obj for obj in args] - for obj in result: + for i, obj in enumerate(result): if isinstance(obj, torch.nn.Module): model = obj num_models += 1 elif isinstance(obj, (torch.optim.Optimizer)): optimizer = obj + optimizer_index = i num_optimizers += 1 - if optimizer is None or model is None: + if optimizer is None and model is None: + return result, device_placement + elif optimizer is None or model is None: raise ValueError( "You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP." ) @@ -2012,12 +2016,16 @@ def _prepare_msamp(self, *args): ) else: model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level) + for i in range(len(result)): if isinstance(result[i], torch.nn.Module): result[i] = model elif isinstance(result[i], (torch.optim.Optimizer)): result[i] = optimizer - return tuple(result) + if optimizer_index is not None: + # NOTE: MS-AMP moves the optimizer, but not the model + device_placement[optimizer_index] = False + return tuple(result), device_placement def prepare_data_loader( self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None @@ -2109,7 +2117,9 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N return optimizer if device_placement is None: device_placement = self.device_placement - optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler) + # NOTE: Special case: with MS-AMP we do *not* pass in the scaler, optimizer handles it for us + scaler = None if (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP") else self.scaler + optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler) self._optimizers.append(optimizer) return optimizer From 8ed5816d814cd209d5afc1418a069bbdb2df0dc4 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 14 Aug 2024 21:20:54 -0400 Subject: [PATCH 02/14] Off, but we're training now at least --- benchmarks/fp8/ms_amp/non_distributed.py | 10 +++++----- src/accelerate/accelerator.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/benchmarks/fp8/ms_amp/non_distributed.py b/benchmarks/fp8/ms_amp/non_distributed.py index 23383a2c245..95902167ff3 100644 --- a/benchmarks/fp8/ms_amp/non_distributed.py +++ b/benchmarks/fp8/ms_amp/non_distributed.py @@ -18,15 +18,15 @@ This particular script verifies this for single GPU training. """ import evaluate -import torch import msamp -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities -from transformer_engine.common.recipe import DelayedScaling +import torch +from fp8_utils import evaluate_model, get_training_utilities from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils import FP8RecipeKwargs, set_seed + MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") @@ -67,7 +67,7 @@ def train_baseline(opt_level="O2"): def train_integration(opt_level="O2"): kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] - # AcceleratorState()._reset_state(True) + AcceleratorState()._reset_state(True) accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( @@ -113,5 +113,5 @@ def train_integration(opt_level="O2"): baseline_trained["accuracy"] == accelerator_trained["accuracy"] ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' assert ( - baseline_trained["f1"] == accelerator_trained["f1"] + baseline_trained["f1"] == accelerator_trained["f1"] ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a7a61bded46..57b92acffa7 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1380,7 +1380,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." ) - if self.native_amp: model._original_forward = model.forward model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward @@ -2118,7 +2117,9 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N if device_placement is None: device_placement = self.device_placement # NOTE: Special case: with MS-AMP we do *not* pass in the scaler, optimizer handles it for us - scaler = None if (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP") else self.scaler + scaler = ( + None if (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP") else self.scaler + ) optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler) self._optimizers.append(optimizer) return optimizer From 64552d49ad945e7ed9996ada476181eadd8f7112 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Wed, 14 Aug 2024 21:59:03 -0400 Subject: [PATCH 03/14] Fixed MS-AMP implementation --- benchmarks/fp8/ms_amp/fp8_utils.py | 4 ++- src/accelerate/accelerator.py | 43 +++++++++++++++++++----------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/benchmarks/fp8/ms_amp/fp8_utils.py b/benchmarks/fp8/ms_amp/fp8_utils.py index d28702e05ff..602ce07fdc6 100644 --- a/benchmarks/fp8/ms_amp/fp8_utils.py +++ b/benchmarks/fp8/ms_amp/fp8_utils.py @@ -107,7 +107,9 @@ def evaluate_model(model, dataloader, metric, accelerator=None): model.eval() for step, batch in enumerate(dataloader): with torch.no_grad(): - outputs = model(**batch) + # W/ MS-AMP, we need to cast while evaluating + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) predictions = outputs.logits.argmax(dim=-1) references = batch["labels"] if accelerator is not None and accelerator.num_processes > 1: diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 57b92acffa7..8e375cacfb8 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -402,7 +402,7 @@ def __init__( self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) ): raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.") - self.delayed_fp8_autocast = self.fp8_recipe_handler.backend == "TE" and self.distributed_type in ( + self.delayed_fp8_autocast = self.fp8_backend == "TE" and self.distributed_type in ( DistributedType.MULTI_GPU, DistributedType.FSDP, ) @@ -508,7 +508,7 @@ def __init__( # We always enable `native_amp` for FP8 self.native_amp = True # MS-AMP requires grad scaler however - if self.fp8_recipe_handler.backend == "MSAMP": + if self.fp8_backend == "MSAMP": self.scaler = torch.cuda.amp.GradScaler() # Start of internal step tracking @@ -1308,20 +1308,20 @@ def prepare(self, *args, device_placement=None): args = self._prepare_ipex_or_xpu(*args) elif self.device.type == "xpu" and is_xpu_available(): args = self._prepare_ipex_or_xpu(*args) - if self.fp8_recipe_handler is not None: - if self.fp8_recipe_handler.backend == "TE": - args = self._prepare_te(*args) - elif self.fp8_recipe_handler.backend == "MSAMP": - args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) + if self.fp8_backend == "TE": + args = self._prepare_te(*args) if self.distributed_type == DistributedType.DEEPSPEED: result = self._prepare_deepspeed(*args) elif self.distributed_type == DistributedType.MEGATRON_LM: result = self._prepare_megatron_lm(*args) else: + if self.fp8_backend == "MSAMP" and self.distributed_type != DistributedType.FSDP: + args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) result = tuple( self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) ) result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) + if tpu_should_fix_optimizer: # 2. grabbing new model parameters new_named_params = self._get_named_parameters(*result) @@ -1382,14 +1382,20 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ) if self.native_amp: model._original_forward = model.forward - model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + # NOTE: MS-AMP is special, and adds a `__func__` already to `model.forward` + # When enabled, strictly use `model.forward` autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) - new_forward = autocast_context(model_forward_func) - if hasattr(model.forward, "__func__"): - model.forward = MethodType(new_forward, model) - model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + if self.fp8_backend == "MSAMP": + model_forward_func = model.forward + model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) else: - model.forward = convert_outputs_to_fp32(new_forward) + model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + new_forward = autocast_context(model_forward_func) + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + else: + model.forward = convert_outputs_to_fp32(new_forward) # We prepare TE fp8 after, allowing for bf16 autocast to happen first if getattr(self.fp8_recipe_handler, "backend", None) == "TE" and not self.delayed_fp8_autocast: @@ -2117,9 +2123,7 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N if device_placement is None: device_placement = self.device_placement # NOTE: Special case: with MS-AMP we do *not* pass in the scaler, optimizer handles it for us - scaler = ( - None if (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP") else self.scaler - ) + scaler = None if self.fp8_backend == "MSAMP" else self.scaler optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler) self._optimizers.append(optimizer) return optimizer @@ -3566,3 +3570,10 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None: raise ValueError( "Backward pass not properly called on LOMO optimizers. Are you sure you passed a LOMO optimizer in accelerator.prepare()?" ) + + @property + def fp8_backend(self): + "Returns the configured backend for training in FP8" + if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None: + return self.fp8_recipe_handler.backend + return None From bb4061c7e14204c7ac5e384459d2c52c8f434b6c Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 15 Aug 2024 08:31:41 -0400 Subject: [PATCH 04/14] Working MS-AMP with non-distributed and DDP --- benchmarks/fp8/ms_amp/ddp.py | 95 +++++++++--------------- benchmarks/fp8/ms_amp/non_distributed.py | 32 ++++---- src/accelerate/accelerator.py | 35 ++++----- 3 files changed, 69 insertions(+), 93 deletions(-) diff --git a/benchmarks/fp8/ms_amp/ddp.py b/benchmarks/fp8/ms_amp/ddp.py index d14e086ce71..6e7530df9fc 100644 --- a/benchmarks/fp8/ms_amp/ddp.py +++ b/benchmarks/fp8/ms_amp/ddp.py @@ -13,68 +13,51 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. This particular script verifies this for DDP training. """ import evaluate +import msamp import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities +from fp8_utils import evaluate_model, get_training_utilities from torch.nn.parallel import DistributedDataParallel as DDP -from transformer_engine.common.recipe import DelayedScaling from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") -def train_baseline(): +def train_baseline(opt_level="O2"): set_seed(42) + scaler = torch.cuda.amp.GradScaler() model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator() device = accelerator.device - model.to(device) - - # Convert the model to TE - old_named_params = get_named_parameters(model) - with torch.no_grad(): - convert_model(model) + model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level) - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) - - new_named_params = get_named_parameters(model) + model.to(device) # Convert the model to DDP device_ids, output_device = [accelerator.local_process_index], accelerator.local_process_index model = DDP(model, device_ids=device_ids, output_device=output_device) - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] - base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch = batch.to(device) - outputs = model(**batch) + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + scaler.scale(loss).backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) @@ -88,9 +71,8 @@ def train_baseline(): return base_model_results, trained_model_results -def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] +def train_integration(opt_level="O2"): + kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] AcceleratorState()._reset_state(True) accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) set_seed(42) @@ -101,15 +83,14 @@ def train_integration(): model, optimizer = accelerator.prepare(model, optimizer) base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - - for _ in range(2): - for batch in train_dataloader: + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): outputs = model(**batch) loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) @@ -124,20 +105,18 @@ def train_integration(): if __name__ == "__main__": - baseline_not_trained, baseline_trained = train_baseline() - accelerator_not_trained, accelerator_trained = train_integration() - - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' - - torch.distributed.destroy_process_group() + for opt_level in ["O1", "O2"]: + baseline_not_trained, baseline_trained = train_baseline(opt_level) + accelerator_not_trained, accelerator_trained = train_integration(opt_level) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/benchmarks/fp8/ms_amp/non_distributed.py b/benchmarks/fp8/ms_amp/non_distributed.py index 95902167ff3..791cea108ed 100644 --- a/benchmarks/fp8/ms_amp/non_distributed.py +++ b/benchmarks/fp8/ms_amp/non_distributed.py @@ -99,19 +99,19 @@ def train_integration(opt_level="O2"): if __name__ == "__main__": - # for opt_level in ["O1", "O2"]: - baseline_not_trained, baseline_trained = train_baseline() - accelerator_not_trained, accelerator_trained = train_integration() - - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + for opt_level in ["O1", "O2"]: + baseline_not_trained, baseline_trained = train_baseline(opt_level) + accelerator_not_trained, accelerator_trained = train_integration(opt_level) + + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 8e375cacfb8..a88c6a80a3c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -26,7 +26,6 @@ from collections import OrderedDict from contextlib import contextmanager from functools import partial -from types import MethodType from typing import Any, Callable, Union import torch @@ -73,7 +72,6 @@ clean_state_dict_for_safetensors, compare_versions, convert_model, - convert_outputs_to_fp32, extract_model_from_parallel, gather, gather_object, @@ -1380,22 +1378,22 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." ) - if self.native_amp: - model._original_forward = model.forward - # NOTE: MS-AMP is special, and adds a `__func__` already to `model.forward` - # When enabled, strictly use `model.forward` - autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) - if self.fp8_backend == "MSAMP": - model_forward_func = model.forward - model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) - else: - model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward - new_forward = autocast_context(model_forward_func) - if hasattr(model.forward, "__func__"): - model.forward = MethodType(new_forward, model) - model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) - else: - model.forward = convert_outputs_to_fp32(new_forward) + # if self.native_amp: + # model._original_forward = model.forward + # # NOTE: MS-AMP is special, and adds a `__func__` already to `model.forward` + # # When enabled, strictly use `model.forward` + # autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) + # if self.fp8_backend == "MSAMP": + # model_forward_func = model.forward + # model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) + # else: + # model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + # new_forward = autocast_context(model_forward_func) + # if hasattr(model.forward, "__func__"): + # model.forward = MethodType(new_forward, model) + # model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + # else: + # model.forward = convert_outputs_to_fp32(new_forward) # We prepare TE fp8 after, allowing for bf16 autocast to happen first if getattr(self.fp8_recipe_handler, "backend", None) == "TE" and not self.delayed_fp8_autocast: @@ -1446,7 +1444,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e device_ids, output_device = [self.local_process_index], self.local_process_index else: device_ids, output_device = None, None - model = torch.nn.parallel.DistributedDataParallel( model, device_ids=device_ids, output_device=output_device, **kwargs ) From 8da59ac99f11d6e4dfd28dfd9870b4c5f4f5b7fa Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 15 Aug 2024 09:18:05 -0400 Subject: [PATCH 05/14] Working base fsdp version with only MSAMP --- benchmarks/fp8/ms_amp/fsdp.py | 144 ++++++++++++++++------------------ src/accelerate/accelerator.py | 34 ++++---- 2 files changed, 85 insertions(+), 93 deletions(-) diff --git a/benchmarks/fp8/ms_amp/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py index 42d35e0dd5e..1dd8f699325 100644 --- a/benchmarks/fp8/ms_amp/fsdp.py +++ b/benchmarks/fp8/ms_amp/fsdp.py @@ -13,83 +13,82 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. This particular script verifies this for FSDP training. """ -from functools import partial - import evaluate +import msamp import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformer_engine.common.recipe import DelayedScaling -from transformers.models.bert import BertLayer +from msamp.fsdp import FsdpReplacer, FP8FullyShardedDataParallel +from msamp.optim import FSDPAdamW +from fp8_utils import evaluate_model, get_training_utilities, get_named_parameters, get_dataloaders from accelerate import Accelerator -from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin from accelerate.state import AcceleratorState from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp import MixedPrecision +from transformers.models.bert import BertLayer +from functools import partial MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") - FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) -def train_baseline(): +from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup + +def train_baseline(opt_level="O2"): set_seed(42) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator() device = accelerator.device + model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) + train_dataloader, eval_dataloader = get_dataloaders(MODEL_NAME) + train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) + + # old_named_params = get_named_parameters(model) + # This single call: + # 1. Replaces all linear layers with MS-AMP's `LinearReplacer` + # 2. Replaces the weights with `ScalingParameters` model.to(device) + model = FsdpReplacer.replace(model) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) - - new_named_params = get_named_parameters(model) - - # Convert the model to FSDP - model = FSDP( + # Same as FullyShardedDataParallel, but overrides `FlatParamHandle`, `post_backward_hook`, and adds comm hook + model = FP8FullyShardedDataParallel( model, use_orig_params=True, - mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), auto_wrap_policy=FSDP_WRAP_POLICY, ) - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] + # TODO: Make this happen using existing AdamW + optimizer = FSDPAdamW( + model.parameters(), + lr=0.0001, + ) + + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=len(train_dataloader) * 2, + ) base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - for _ in range(2): - for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - batch = batch.to(device) - outputs = model(**batch) + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**batch) loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + print(f'Process {accelerator.process_index}:\nBase model results: {base_model_results}\nTrained model results: {trained_model_results}') assert ( trained_model_results["accuracy"] > base_model_results["accuracy"] ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' @@ -97,19 +96,13 @@ def train_baseline(): trained_model_results["f1"] > base_model_results["f1"] ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' - return base_model_results, trained_model_results + # return base_model_results, trained_model_results -def train_integration(): - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] +def train_integration(opt_level="O2"): + kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] AcceleratorState()._reset_state(True) - fsdp_plugin = FSDPPlugin( - auto_wrap_policy=FSDP_WRAP_POLICY, - use_orig_params=True, - mixed_precision_policy=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), - ) - accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator @@ -118,15 +111,14 @@ def train_integration(): model, optimizer = accelerator.prepare(model, optimizer) base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - - for _ in range(2): - for batch in train_dataloader: + for i, batch in enumerate(train_dataloader): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): outputs = model(**batch) loss = outputs.loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) @@ -141,20 +133,18 @@ def train_integration(): if __name__ == "__main__": - baseline_not_trained, baseline_trained = train_baseline() - accelerator_not_trained, accelerator_trained = train_integration() - - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' - - torch.distributed.destroy_process_group() + # for opt_level in ["O1", "O2"]: + train_baseline() + # accelerator_not_trained, accelerator_trained = train_integration(opt_level) + # assert ( + # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + # ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + # assert ( + # baseline_not_trained["f1"] == accelerator_not_trained["f1"] + # ), f'F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + # assert ( + # baseline_trained["accuracy"] == accelerator_trained["accuracy"] + # ), f'Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + # assert ( + # baseline_trained["f1"] == accelerator_trained["f1"] + # ), f'F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index a88c6a80a3c..ab46b6c5e13 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -26,6 +26,7 @@ from collections import OrderedDict from contextlib import contextmanager from functools import partial +from types import MethodType from typing import Any, Callable, Union import torch @@ -72,6 +73,7 @@ clean_state_dict_for_safetensors, compare_versions, convert_model, + convert_outputs_to_fp32, extract_model_from_parallel, gather, gather_object, @@ -1378,22 +1380,22 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." ) - # if self.native_amp: - # model._original_forward = model.forward - # # NOTE: MS-AMP is special, and adds a `__func__` already to `model.forward` - # # When enabled, strictly use `model.forward` - # autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) - # if self.fp8_backend == "MSAMP": - # model_forward_func = model.forward - # model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) - # else: - # model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward - # new_forward = autocast_context(model_forward_func) - # if hasattr(model.forward, "__func__"): - # model.forward = MethodType(new_forward, model) - # model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) - # else: - # model.forward = convert_outputs_to_fp32(new_forward) + if self.native_amp: + model._original_forward = model.forward + # NOTE: MS-AMP is special, and adds a `__func__` already to `model.forward` + # When enabled, strictly use `model.forward` + autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) + if self.fp8_backend == "MSAMP": + model_forward_func = model.forward + model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) + else: + model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + new_forward = autocast_context(model_forward_func) + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + else: + model.forward = convert_outputs_to_fp32(new_forward) # We prepare TE fp8 after, allowing for bf16 autocast to happen first if getattr(self.fp8_recipe_handler, "backend", None) == "TE" and not self.delayed_fp8_autocast: From f3fe0d07234992468481c8ab3e40e7595b5a2d59 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 15 Aug 2024 10:38:46 -0400 Subject: [PATCH 06/14] Freeze: working version --- benchmarks/fp8/ms_amp/fsdp.py | 193 ++++++++++++++++++++++++++++++---- 1 file changed, 171 insertions(+), 22 deletions(-) diff --git a/benchmarks/fp8/ms_amp/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py index 1dd8f699325..4400939fda1 100644 --- a/benchmarks/fp8/ms_amp/fsdp.py +++ b/benchmarks/fp8/ms_amp/fsdp.py @@ -19,9 +19,10 @@ """ import evaluate import msamp +import inspect import torch from msamp.fsdp import FsdpReplacer, FP8FullyShardedDataParallel -from msamp.optim import FSDPAdamW +from msamp.optim import FSDPAdamW, LBAdamW from fp8_utils import evaluate_model, get_training_utilities, get_named_parameters, get_dataloaders from accelerate import Accelerator @@ -31,6 +32,8 @@ from torch.distributed.fsdp import MixedPrecision from transformers.models.bert import BertLayer from functools import partial +import torch.distributed as dist +from msamp.common.tensor import ScalingMeta, ScalingTensor MODEL_NAME = "bert-base-cased" @@ -39,6 +42,127 @@ from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup +from msamp.common.dtype import Dtypes +from msamp.common.tensor import ScalingTensor + + +class MSAMPOptimWrapper(torch.optim.Optimizer): + """ + Wrapper around an optimizer to make it compatible for FSDP. + """ + def __init__(self, optimizer): + self.optimizer = optimizer + self.adjust_param_groups() + + @property + def state(self): + return self.optimizer.state + + @state.setter + def state(self, state): + self.optimizer.state = state + + @property + def param_groups(self): + return self.optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self.optimizer.param_groups = param_groups + + @property + def defaults(self): + return self.optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self.optimizer.defaults = defaults + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + + def load_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict) + + def state_dict(self): + return self.optimizer.state_dict() + + def zero_grad(self, set_to_none=None): + for param in self.original_params: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + def step(self): + for i, param in enumerate(self.original_params): + if self.master_weights[i] is not None: + grad_meta = param._grad_meta + dtype = Dtypes.qtype_to_dtype[grad_meta.qtype] + self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) + param.grad = None + self.optimizer.step() + + # Copy master weight to weight + for i, param in enumerate(self.original_params): + if hasattr(param, '_meta') and param._meta is not None: + hp_data = None + if param.numel() == 0: + param._meta.amax[0].zero_() + else: + hp_data = self.master_weights[i].float() + param._meta.amax[0] = hp_data.abs().max() + + dist.all_reduce(param._meta.amax[0], op=dist.ReduceOp.MAX) + param._meta.reset_scaling_factor() + if param.numel() > 0: + with ScalingMeta.in_time_scaling_context(False): + data = hp_data.cast(param._meta.qtype, param._meta, False) \ + .value.view(torch.float32) + param.data.copy_(data) + else: + param._meta.scale_inv.data.copy_(torch.reciprocal(param._meta.scale)) + + def train(self): + """ + Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free` + """ + return self.optimizer.train() + + def eval(self): + """ + Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free` + """ + return self.optimizer.eval() + + def adjust_param_groups(self): + self.original_params, self.master_weights = [], [] + for group in self.param_groups: + params = [] + for param in group['params']: + if param is None: + continue + + self.original_params.append(param) + if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: + dtype = Dtypes.qtype_to_dtype[param._meta.qtype] + param = ScalingTensor(param.view(dtype), param._meta) + master_weight = param.cast(Dtypes.kfloat16) + master_weight.requires_grad = True + self.master_weights.append(master_weight) + params.append(master_weight) + else: + self.master_weights.append(None) + params.append(param) + + group['params'] = params + + def train_baseline(opt_level="O2"): set_seed(42) @@ -48,25 +172,50 @@ def train_baseline(opt_level="O2"): train_dataloader, eval_dataloader = get_dataloaders(MODEL_NAME) train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) - # old_named_params = get_named_parameters(model) - # This single call: - # 1. Replaces all linear layers with MS-AMP's `LinearReplacer` - # 2. Replaces the weights with `ScalingParameters` - model.to(device) - model = FsdpReplacer.replace(model) + from msamp.nn import LinearReplacer + model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) - # Same as FullyShardedDataParallel, but overrides `FlatParamHandle`, `post_backward_hook`, and adds comm hook + for _, submodule in model.named_modules(): + params_to_process = list(submodule.named_parameters(recurse=False)) + for param_name, param in params_to_process: + if not isinstance(param, torch.Tensor): + data = param.value.view(-1) + padded = 0 + if data.numel() % 4 != 0: + padded = 4 - data.numel() % 4 + data = torch.nn.functional.pad(data, (0, padded)) + + data = data.view(dtype=torch.float32) + new_param = torch.nn.Parameter(data) + new_param._original_shape = param.shape + new_param._padded = padded + new_param._meta = param.meta + new_param._scaling_metas = param._scaling_metas + + setattr(submodule, param_name, new_param) + + model.to(device) model = FP8FullyShardedDataParallel( model, use_orig_params=True, auto_wrap_policy=FSDP_WRAP_POLICY, ) - # TODO: Make this happen using existing AdamW - optimizer = FSDPAdamW( - model.parameters(), - lr=0.0001, - ) + # optimizer = FSDPAdamW(model.parameters(), lr=0.0001) + optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) + default_args = optimizer.defaults + + default_args['exp_avg_dtype'] = torch.uint8 + default_args['exp_avg_sq_dtype'] = torch.float16 + + # Currently, we don't support foreach, capturable, differentiable, and fused. + for k in ['foreach', 'capturable', 'differentiable', 'fused']: + default_args.pop(k, None) + + optimizer = LBAdamW(optimizer.param_groups, **default_args) + + optimizer = MSAMPOptimWrapper(optimizer) + # Same as FullyShardedDataParallel, but overrides `FlatParamHandle`, `post_backward_hook`, and adds comm hook lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, @@ -74,7 +223,7 @@ def train_baseline(opt_level="O2"): num_training_steps=len(train_dataloader) * 2, ) - base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + # base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() for i, batch in enumerate(train_dataloader): @@ -86,15 +235,15 @@ def train_baseline(opt_level="O2"): optimizer.zero_grad() lr_scheduler.step() - trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + # trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) - print(f'Process {accelerator.process_index}:\nBase model results: {base_model_results}\nTrained model results: {trained_model_results}') - assert ( - trained_model_results["accuracy"] > base_model_results["accuracy"] - ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' - assert ( - trained_model_results["f1"] > base_model_results["f1"] - ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + # print(f'Process {accelerator.process_index}:\nBase model results: {base_model_results}\nTrained model results: {trained_model_results}') + # assert ( + # trained_model_results["accuracy"] > base_model_results["accuracy"] + # ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + # assert ( + # trained_model_results["f1"] > base_model_results["f1"] + # ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' # return base_model_results, trained_model_results From 41ce97ca73b516930ac78369940489d3f6784cb4 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 15 Aug 2024 13:38:58 -0400 Subject: [PATCH 07/14] Snapshot: Works --- benchmarks/fp8/ms_amp/fsdp.py | 234 ++++++---------------------------- src/accelerate/accelerator.py | 69 ++++++++-- 2 files changed, 92 insertions(+), 211 deletions(-) diff --git a/benchmarks/fp8/ms_amp/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py index 4400939fda1..bf94068a75c 100644 --- a/benchmarks/fp8/ms_amp/fsdp.py +++ b/benchmarks/fp8/ms_amp/fsdp.py @@ -19,21 +19,19 @@ """ import evaluate import msamp -import inspect import torch -from msamp.fsdp import FsdpReplacer, FP8FullyShardedDataParallel -from msamp.optim import FSDPAdamW, LBAdamW -from fp8_utils import evaluate_model, get_training_utilities, get_named_parameters, get_dataloaders +from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin +from msamp.fsdp import FP8FullyShardedDataParallel +from msamp.optim import FSDPAdamW +from msamp.common.dtype import Dtypes +from fp8_utils import evaluate_model, get_training_utilities from accelerate import Accelerator from accelerate.state import AcceleratorState from accelerate.utils import FP8RecipeKwargs, set_seed from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from torch.distributed.fsdp import MixedPrecision from transformers.models.bert import BertLayer from functools import partial -import torch.distributed as dist -from msamp.common.tensor import ScalingMeta, ScalingTensor MODEL_NAME = "bert-base-cased" @@ -41,224 +39,66 @@ FSDP_WRAP_POLICY = partial(transformer_auto_wrap_policy, transformer_layer_cls={BertLayer}) -from transformers import AutoModelForSequenceClassification, get_linear_schedule_with_warmup -from msamp.common.dtype import Dtypes -from msamp.common.tensor import ScalingTensor - - -class MSAMPOptimWrapper(torch.optim.Optimizer): - """ - Wrapper around an optimizer to make it compatible for FSDP. - """ - def __init__(self, optimizer): - self.optimizer = optimizer - self.adjust_param_groups() - - @property - def state(self): - return self.optimizer.state - - @state.setter - def state(self, state): - self.optimizer.state = state - - @property - def param_groups(self): - return self.optimizer.param_groups - - @param_groups.setter - def param_groups(self, param_groups): - self.optimizer.param_groups = param_groups - - @property - def defaults(self): - return self.optimizer.defaults - - @defaults.setter - def defaults(self, defaults): - self.optimizer.defaults = defaults - - def add_param_group(self, param_group): - self.optimizer.add_param_group(param_group) - - def load_state_dict(self, state_dict): - self.optimizer.load_state_dict(state_dict) - - def state_dict(self): - return self.optimizer.state_dict() - - def zero_grad(self, set_to_none=None): - for param in self.original_params: - if set_to_none: - param.grad = None - else: - if param.grad is not None: - if param.grad.grad_fn is not None: - param.grad.detach_() - else: - param.grad.requires_grad_(False) - param.grad.zero_() - - def step(self): - for i, param in enumerate(self.original_params): - if self.master_weights[i] is not None: - grad_meta = param._grad_meta - dtype = Dtypes.qtype_to_dtype[grad_meta.qtype] - self.master_weights[i].grad = ScalingTensor(param.grad.view(dtype), grad_meta) - param.grad = None - self.optimizer.step() - - # Copy master weight to weight - for i, param in enumerate(self.original_params): - if hasattr(param, '_meta') and param._meta is not None: - hp_data = None - if param.numel() == 0: - param._meta.amax[0].zero_() - else: - hp_data = self.master_weights[i].float() - param._meta.amax[0] = hp_data.abs().max() - - dist.all_reduce(param._meta.amax[0], op=dist.ReduceOp.MAX) - param._meta.reset_scaling_factor() - if param.numel() > 0: - with ScalingMeta.in_time_scaling_context(False): - data = hp_data.cast(param._meta.qtype, param._meta, False) \ - .value.view(torch.float32) - param.data.copy_(data) - else: - param._meta.scale_inv.data.copy_(torch.reciprocal(param._meta.scale)) - - def train(self): - """ - Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free` - """ - return self.optimizer.train() - - def eval(self): - """ - Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free` - """ - return self.optimizer.eval() - - def adjust_param_groups(self): - self.original_params, self.master_weights = [], [] - for group in self.param_groups: - params = [] - for param in group['params']: - if param is None: - continue - - self.original_params.append(param) - if hasattr(param, '_meta') and param._meta is not None and param.numel() > 0: - dtype = Dtypes.qtype_to_dtype[param._meta.qtype] - param = ScalingTensor(param.view(dtype), param._meta) - master_weight = param.cast(Dtypes.kfloat16) - master_weight.requires_grad = True - self.master_weights.append(master_weight) - params.append(master_weight) - else: - self.master_weights.append(None) - params.append(param) - - group['params'] = params - - - def train_baseline(opt_level="O2"): set_seed(42) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME) accelerator = Accelerator() device = accelerator.device - model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) - train_dataloader, eval_dataloader = get_dataloaders(MODEL_NAME) - train_dataloader, eval_dataloader = accelerator.prepare(train_dataloader, eval_dataloader) - - from msamp.nn import LinearReplacer - model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) - - for _, submodule in model.named_modules(): - params_to_process = list(submodule.named_parameters(recurse=False)) - for param_name, param in params_to_process: - if not isinstance(param, torch.Tensor): - data = param.value.view(-1) - padded = 0 - if data.numel() % 4 != 0: - padded = 4 - data.numel() % 4 - data = torch.nn.functional.pad(data, (0, padded)) - - data = data.view(dtype=torch.float32) - new_param = torch.nn.Parameter(data) - new_param._original_shape = param.shape - new_param._padded = padded - new_param._meta = param.meta - new_param._scaling_metas = param._scaling_metas - - setattr(submodule, param_name, new_param) - + model, optimizer = msamp.initialize( + model, optimizer, + opt_level=opt_level, + weight_qtype=Dtypes.kfloat8_e4m3, + use_fsdp=True + ) + model.to(device) + model = FP8FullyShardedDataParallel( model, use_orig_params=True, auto_wrap_policy=FSDP_WRAP_POLICY, ) + optimizer = FSDPAdamW(optimizer) - # optimizer = FSDPAdamW(model.parameters(), lr=0.0001) - optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) - default_args = optimizer.defaults - - default_args['exp_avg_dtype'] = torch.uint8 - default_args['exp_avg_sq_dtype'] = torch.float16 - - # Currently, we don't support foreach, capturable, differentiable, and fused. - for k in ['foreach', 'capturable', 'differentiable', 'fused']: - default_args.pop(k, None) - - optimizer = LBAdamW(optimizer.param_groups, **default_args) - - optimizer = MSAMPOptimWrapper(optimizer) - # Same as FullyShardedDataParallel, but overrides `FlatParamHandle`, `post_backward_hook`, and adds comm hook - - lr_scheduler = get_linear_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=100, - num_training_steps=len(train_dataloader) * 2, - ) - - # base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() for i, batch in enumerate(train_dataloader): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): outputs = model(**batch) loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() - # trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) - # print(f'Process {accelerator.process_index}:\nBase model results: {base_model_results}\nTrained model results: {trained_model_results}') - # assert ( - # trained_model_results["accuracy"] > base_model_results["accuracy"] - # ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' - # assert ( - # trained_model_results["f1"] > base_model_results["f1"] - # ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + assert ( + trained_model_results["accuracy"] > base_model_results["accuracy"] + ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + assert ( + trained_model_results["f1"] > base_model_results["f1"] + ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' - # return base_model_results, trained_model_results + return base_model_results, trained_model_results def train_integration(opt_level="O2"): kwargs_handlers = [FP8RecipeKwargs(backend="msamp", opt_level=opt_level)] - AcceleratorState()._reset_state(True) - accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=kwargs_handlers) + # AcceleratorState()._reset_state(True) + fsdp_plugin = FSDPPlugin( + auto_wrap_policy=FSDP_WRAP_POLICY, + use_orig_params=True, + ) + accelerator = Accelerator(mixed_precision="fp8", fsdp_plugin=fsdp_plugin, kwargs_handlers=kwargs_handlers) set_seed(42) model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator ) model, optimizer = accelerator.prepare(model, optimizer) - base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + # base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() for i, batch in enumerate(train_dataloader): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): @@ -283,8 +123,8 @@ def train_integration(opt_level="O2"): if __name__ == "__main__": # for opt_level in ["O1", "O2"]: - train_baseline() - # accelerator_not_trained, accelerator_trained = train_integration(opt_level) + # baseline_not_trained, baseline_trained = train_baseline(opt_level) + accelerator_not_trained, accelerator_trained = train_integration("O2") # assert ( # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] # ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index ab46b6c5e13..0aa69ee7857 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -307,11 +307,11 @@ def __init__( deepspeed_plugin.set_mixed_precision(mixed_precision) deepspeed_plugin.set_deepspeed_weakref() - if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance( - fsdp_plugin, FullyShardedDataParallelPlugin - ): - if is_torch_version("<", FSDP_PYTORCH_VERSION): - raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") + # if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance( + # fsdp_plugin, FullyShardedDataParallelPlugin + # ): + # if is_torch_version("<", FSDP_PYTORCH_VERSION): + # raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") if fsdp_plugin is None: # init from env variables fsdp_plugin = ( @@ -1315,12 +1315,13 @@ def prepare(self, *args, device_placement=None): elif self.distributed_type == DistributedType.MEGATRON_LM: result = self._prepare_megatron_lm(*args) else: - if self.fp8_backend == "MSAMP" and self.distributed_type != DistributedType.FSDP: + if self.fp8_backend == "MSAMP": args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) - result = tuple( - self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) - ) - result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) + result = args + # result = tuple( + # self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) + # ) + # result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) if tpu_should_fix_optimizer: # 2. grabbing new model parameters @@ -1332,6 +1333,16 @@ def prepare(self, *args, device_placement=None): if isinstance(obj, torch.optim.Optimizer): obj._switch_parameters(mapping) + # if self.distributed_type == DistributedType.FSDP and self.fp8_backend == "MSAMP": + # # We need to convert the underlying optimizer to FSDPAdamW *after* FSDP wrapping + # result = list(result) + # from msamp.optim import FSDPAdamW + # for i, obj in enumerate(result): + # if isinstance(obj, AcceleratedOptimizer): + # result[i].optimizer = FSDPAdamW(obj.optimizer) + # print(f'Wrapping optimizer in FSDP: {type(obj.optimizer)}') + # result = tuple(result) + for item in result: if any( item in container @@ -1368,6 +1379,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e """ if device_placement is None: device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP + print(f'Model device placement: {device_placement}') self._models.append(model) # TODO: Look at enabling native TP training directly with a proper config @@ -1430,6 +1442,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "You can't train a model that has been loaded in 8-bit precision with CPU or disk offload." ) elif device_placement and not self.verify_device_map(model): + print(f"Moving model to device: {self.device}") model = model.to(self.device) if not evaluation_mode: if self.distributed_type in ( @@ -1453,7 +1466,12 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e self.ddp_handler.register_comm_hook(model) elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + if self.mixed_precision == "fp8" and self.fp8_backend == "MSAMP": + print('Importing `FP8FullyShardedDataParallel` from `msamp.fsdp`') + # MS-AMP uses a patched version of FSDP + from msamp.fsdp import FP8FullyShardedDataParallel as FSDP + else: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again @@ -1462,6 +1480,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e is_type_fsdp = isinstance(model, FSDP) or ( is_compiled_module(model) and isinstance(model._orig_mod, FSDP) ) + print(f'Is type FSDP: {is_type_fsdp}') if not is_type_fsdp: self.state.fsdp_plugin.set_auto_wrap_policy(model) @@ -1481,6 +1500,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "device_id": self.device, } model = FSDP(model, **kwargs) + print(f'Wrapped model in FSDP: {type(model)}') if fsdp_plugin.activation_checkpointing: from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, @@ -1507,7 +1527,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype # * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None, # we sill want to upcast the flat_param. - if self.mixed_precision != "no": # if mixed precision is set + if self.mixed_precision != "no" and self.fp8_backend != "MSAMP": # if mixed precision is set upcasted_log = [] for module in FSDP.fsdp_modules(model): # Referencing DeepSpeed Zero3 @@ -1996,7 +2016,6 @@ def _prepare_msamp(self, *args, device_placement): ) else: import msamp - model, optimizer = None, None num_models, num_optimizers = 0, 0 result = [obj for obj in args] @@ -2019,7 +2038,29 @@ def _prepare_msamp(self, *args, device_placement): f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP." ) else: - model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level) + if self.distributed_type == DistributedType.FSDP: + # We need to set the auto_wrap policy before initializing the model + self.state.fsdp_plugin.set_auto_wrap_policy(model) + from msamp.common.dtype import Dtypes + model, optimizer = msamp.initialize( + model, optimizer, + opt_level=self.fp8_recipe_handler.opt_level, + use_fsdp=self.distributed_type == DistributedType.FSDP, + weight_qtype=Dtypes.kfloat8_e4m3, + ) + # now we can prepare the model? + # model = self.prepare_model(model, device_placement=True) + from msamp.fsdp import FP8FullyShardedDataParallel as FSDP + + model = FSDP( + model, + use_orig_params=True, + auto_wrap_policy=self.state.fsdp_plugin.auto_wrap_policy, + device_id=self.device, + ) + print(f'Prepared: {type(model)}') + from msamp.optim import FSDPAdamW + optimizer = FSDPAdamW(optimizer) for i in range(len(result)): if isinstance(result[i], torch.nn.Module): From 7db3368b98ae2dd1db9af44f25ef6e718ab70664 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 15 Aug 2024 14:43:18 -0400 Subject: [PATCH 08/14] Working! --- benchmarks/fp8/ms_amp/fsdp.py | 47 +++++++++++++++++++------------ src/accelerate/accelerator.py | 52 ++++++++++++----------------------- 2 files changed, 47 insertions(+), 52 deletions(-) diff --git a/benchmarks/fp8/ms_amp/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py index bf94068a75c..843ce30968d 100644 --- a/benchmarks/fp8/ms_amp/fsdp.py +++ b/benchmarks/fp8/ms_amp/fsdp.py @@ -45,18 +45,22 @@ def train_baseline(opt_level="O2"): accelerator = Accelerator() device = accelerator.device model, optimizer = msamp.initialize( - model, optimizer, - opt_level=opt_level, - weight_qtype=Dtypes.kfloat8_e4m3, + model, optimizer, + opt_level=opt_level, + weight_qtype=Dtypes.kfloat8_e4m3, use_fsdp=True ) - - model.to(device) model = FP8FullyShardedDataParallel( model, use_orig_params=True, auto_wrap_policy=FSDP_WRAP_POLICY, + cpu_offload=False, + sync_module_states=False, + backward_prefetch=None, + forward_prefetch=False, + limit_all_gathers=True, + device_id=device ) optimizer = FSDPAdamW(optimizer) @@ -67,19 +71,22 @@ def train_baseline(opt_level="O2"): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): outputs = model(**batch) loss = outputs.loss - loss.backward() - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + loss.backward() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = ( + accelerator.free_memory(model, optimizer, train_dataloader, eval_dataloader, lr_scheduler) + ) assert ( trained_model_results["accuracy"] > base_model_results["accuracy"] - ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + ), f'Baseline: Opt level {opt_level}: Accuracy should be higher for the trained model: {base_model_results["accuracy"]} < {trained_model_results["accuracy"]}' assert ( trained_model_results["f1"] > base_model_results["f1"] - ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + ), f'Baseline: Opt level {opt_level}: F1 score should be higher for the trained model: {base_model_results["f1"]} < {trained_model_results["f1"]}' return base_model_results, trained_model_results @@ -98,7 +105,7 @@ def train_integration(opt_level="O2"): ) model, optimizer = accelerator.prepare(model, optimizer) - # base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() for i, batch in enumerate(train_dataloader): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): @@ -111,20 +118,26 @@ def train_integration(opt_level="O2"): trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = ( + accelerator.free_memory(model, optimizer, train_dataloader, eval_dataloader, lr_scheduler) + ) assert ( trained_model_results["accuracy"] > base_model_results["accuracy"] - ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' + ), f'Integration: Opt level {opt_level}: Accuracy should be higher for the trained model: {base_model_results["accuracy"]} < {trained_model_results["accuracy"]}' assert ( trained_model_results["f1"] > base_model_results["f1"] - ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + ), f'Integration: Opt level {opt_level}: F1 score should be higher for the trained model: {base_model_results["f1"]} < {trained_model_results["f1"]}' return base_model_results, trained_model_results if __name__ == "__main__": - # for opt_level in ["O1", "O2"]: - # baseline_not_trained, baseline_trained = train_baseline(opt_level) - accelerator_not_trained, accelerator_trained = train_integration("O2") + # baseline_not_trained, baseline_trained = train_baseline("O1") + # accelerator_not_trained, accelerator_trained = train_integration("O1") + # print(baseline_trained) + for opt_level in ["O1", "O2"]: + # baseline_not_trained, baseline_trained = train_baseline(opt_level) + accelerator_not_trained, accelerator_trained = train_integration(opt_level) # assert ( # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] # ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0aa69ee7857..fc14ee792b1 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -508,7 +508,7 @@ def __init__( # We always enable `native_amp` for FP8 self.native_amp = True # MS-AMP requires grad scaler however - if self.fp8_backend == "MSAMP": + if self.fp8_backend == "MSAMP" and self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED): self.scaler = torch.cuda.amp.GradScaler() # Start of internal step tracking @@ -1317,11 +1317,10 @@ def prepare(self, *args, device_placement=None): else: if self.fp8_backend == "MSAMP": args, device_placement = self._prepare_msamp(*args, device_placement=device_placement) - result = args - # result = tuple( - # self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) - # ) - # result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) + result = tuple( + self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) + ) + result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) if tpu_should_fix_optimizer: # 2. grabbing new model parameters @@ -1333,15 +1332,14 @@ def prepare(self, *args, device_placement=None): if isinstance(obj, torch.optim.Optimizer): obj._switch_parameters(mapping) - # if self.distributed_type == DistributedType.FSDP and self.fp8_backend == "MSAMP": - # # We need to convert the underlying optimizer to FSDPAdamW *after* FSDP wrapping - # result = list(result) - # from msamp.optim import FSDPAdamW - # for i, obj in enumerate(result): - # if isinstance(obj, AcceleratedOptimizer): - # result[i].optimizer = FSDPAdamW(obj.optimizer) - # print(f'Wrapping optimizer in FSDP: {type(obj.optimizer)}') - # result = tuple(result) + if self.distributed_type == DistributedType.FSDP and self.fp8_backend == "MSAMP": + # We need to convert the underlying optimizer to FSDPAdamW *after* FSDP wrapping + result = list(result) + from msamp.optim import FSDPAdamW + for i, obj in enumerate(result): + if isinstance(obj, AcceleratedOptimizer): + result[i].optimizer = FSDPAdamW(optimizer=obj.optimizer) + result = tuple(result) for item in result: if any( @@ -1379,7 +1377,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e """ if device_placement is None: device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP - print(f'Model device placement: {device_placement}') self._models.append(model) # TODO: Look at enabling native TP training directly with a proper config @@ -1442,7 +1439,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "You can't train a model that has been loaded in 8-bit precision with CPU or disk offload." ) elif device_placement and not self.verify_device_map(model): - print(f"Moving model to device: {self.device}") model = model.to(self.device) if not evaluation_mode: if self.distributed_type in ( @@ -1467,7 +1463,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model if self.mixed_precision == "fp8" and self.fp8_backend == "MSAMP": - print('Importing `FP8FullyShardedDataParallel` from `msamp.fsdp`') # MS-AMP uses a patched version of FSDP from msamp.fsdp import FP8FullyShardedDataParallel as FSDP else: @@ -1480,7 +1475,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e is_type_fsdp = isinstance(model, FSDP) or ( is_compiled_module(model) and isinstance(model._orig_mod, FSDP) ) - print(f'Is type FSDP: {is_type_fsdp}') if not is_type_fsdp: self.state.fsdp_plugin.set_auto_wrap_policy(model) @@ -1500,7 +1494,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "device_id": self.device, } model = FSDP(model, **kwargs) - print(f'Wrapped model in FSDP: {type(model)}') if fsdp_plugin.activation_checkpointing: from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, @@ -1527,7 +1520,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype # * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None, # we sill want to upcast the flat_param. - if self.mixed_precision != "no" and self.fp8_backend != "MSAMP": # if mixed precision is set + if self.mixed_precision != "no": # if mixed precision is set upcasted_log = [] for module in FSDP.fsdp_modules(model): # Referencing DeepSpeed Zero3 @@ -2041,26 +2034,15 @@ def _prepare_msamp(self, *args, device_placement): if self.distributed_type == DistributedType.FSDP: # We need to set the auto_wrap policy before initializing the model self.state.fsdp_plugin.set_auto_wrap_policy(model) + # NOTE: MS-AMP fsdp relies on it's own MP policy, we must drop the users + self.state.fsdp_plugin.mixed_precision_policy = None from msamp.common.dtype import Dtypes model, optimizer = msamp.initialize( - model, optimizer, + model, optimizer, opt_level=self.fp8_recipe_handler.opt_level, use_fsdp=self.distributed_type == DistributedType.FSDP, weight_qtype=Dtypes.kfloat8_e4m3, ) - # now we can prepare the model? - # model = self.prepare_model(model, device_placement=True) - from msamp.fsdp import FP8FullyShardedDataParallel as FSDP - - model = FSDP( - model, - use_orig_params=True, - auto_wrap_policy=self.state.fsdp_plugin.auto_wrap_policy, - device_id=self.device, - ) - print(f'Prepared: {type(model)}') - from msamp.optim import FSDPAdamW - optimizer = FSDPAdamW(optimizer) for i in range(len(result)): if isinstance(result[i], torch.nn.Module): From 019d14da1268f33280c176df7bdc746f4f0c99da Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 15 Aug 2024 16:31:33 -0400 Subject: [PATCH 09/14] Bookmark --- benchmarks/fp8/ms_amp/distrib_deepspeed.py | 88 ++++++++++------------ benchmarks/fp8/ms_amp/fsdp.py | 26 +++---- 2 files changed, 53 insertions(+), 61 deletions(-) diff --git a/benchmarks/fp8/ms_amp/distrib_deepspeed.py b/benchmarks/fp8/ms_amp/distrib_deepspeed.py index 291d09ec103..ad534f80ac5 100644 --- a/benchmarks/fp8/ms_amp/distrib_deepspeed.py +++ b/benchmarks/fp8/ms_amp/distrib_deepspeed.py @@ -13,31 +13,30 @@ # limitations under the License. """ -This script tests to ensure that `accelerate` performs at the same level as raw `TransformersEngine`. +This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. -This particular script verifies this for DDP training. +This particular script verifies this for DeepSpeed training. """ from unittest.mock import patch -import deepspeed +from msamp import deepspeed import evaluate import torch -import transformer_engine.common.recipe as te_recipe -import transformer_engine.pytorch as te -from fp8_utils import evaluate_model, get_named_parameters, get_training_utilities -from transformer_engine.common.recipe import DelayedScaling +# import transformer_engine.common.recipe as te_recipe +# import transformer_engine.pytorch as te +from fp8_utils import evaluate_model, get_training_utilities +# from transformer_engine.common.recipe import DelayedScaling from accelerate import Accelerator, DeepSpeedPlugin from accelerate.state import AcceleratorState from accelerate.utils import FP8RecipeKwargs, set_seed -from accelerate.utils.transformer_engine import convert_model MODEL_NAME = "bert-base-cased" METRIC = evaluate.load("glue", "mrpc") -def train_baseline(zero_stage: int = 1): +def train_baseline(zero_stage: int = 1, opt_level: str = "O1"): # This forces transformers to think Zero-3 Init should be used with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: mock.return_value = zero_stage == 3 @@ -48,20 +47,6 @@ def train_baseline(zero_stage: int = 1): MODEL_NAME, accelerator=accelerator ) - # Convert the model to TE - old_named_params = get_named_parameters(model) - - with torch.no_grad(): - convert_model(model) - new_named_params = get_named_parameters(model) - - mapping = {p: new_named_params[n] for n, p in old_named_params.items()} - for param_group in optimizer.param_groups: - param_group["params"] = [mapping[p] for p in param_group["params"]] - - FP8_RECIPE_KWARGS = {"fp8_format": te_recipe.Format.HYBRID, "amax_history_len": 32, "amax_compute_algo": "max"} - fp8_recipe = DelayedScaling(**FP8_RECIPE_KWARGS) - import numpy as np config = { @@ -79,6 +64,10 @@ def train_baseline(zero_stage: int = 1): "bf16": {"enabled": True}, "fp16": {"enabled": False}, "zero_allow_untested_optimizer": True, + "msamp": { + "enabled": True, + "opt_level": opt_level, + } } ( @@ -95,15 +84,9 @@ def train_baseline(zero_stage: int = 1): base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - model_outputs = [] - data = [] - for _ in range(2): for batch in train_dataloader: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - outputs = model(**batch) - data.append(batch.to("cpu")) - model_outputs.append(outputs.logits.to("cpu")) + outputs = model(**batch) loss = outputs.loss model.backward(loss) model.step() @@ -112,6 +95,8 @@ def train_baseline(zero_stage: int = 1): trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.destroy() + torch.cuda.empty_cache() + AcceleratorState()._reset_state(True) assert ( trained_model_results["accuracy"] > base_model_results["accuracy"] ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' @@ -119,7 +104,7 @@ def train_baseline(zero_stage: int = 1): trained_model_results["f1"] > base_model_results["f1"] ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' - return base_model_results, trained_model_results, model_outputs, data + return base_model_results, trained_model_results def train_integration(zero_stage: int = 1): @@ -158,6 +143,7 @@ def train_integration(zero_stage: int = 1): trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.destroy() + torch.cuda.empty_cache() assert ( trained_model_results["accuracy"] > base_model_results["accuracy"] ), f'Accuracy should be higher for the trained model: {trained_model_results["accuracy"]} > {base_model_results["accuracy"]}' @@ -165,25 +151,31 @@ def train_integration(zero_stage: int = 1): trained_model_results["f1"] > base_model_results["f1"] ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' - return base_model_results, trained_model_results, model_outputs, data + return base_model_results, trained_model_results if __name__ == "__main__": - # for zero_stage in [1, 2, 3]: - zero_stage = 1 - baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage) - accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - assert ( - baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - assert ( - baseline_not_trained["f1"] == accelerator_not_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - assert ( - baseline_trained["accuracy"] == accelerator_trained["accuracy"] - ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - assert ( - baseline_trained["f1"] == accelerator_trained["f1"] - ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + results = {"1": [], "2": [], "3": []} + for zero_stage in [1, 2, 3]: + for opt_level in ["O1", "O2", "O3"]: + baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level) + results[zero_stage].append({"opt_level": opt_level, "not_trained": baseline_not_trained, "trained": baseline_trained}) + for stage, stage_results in results.items(): + print(f'zero_stage={stage}:\n') + for result in stage_results: + print(f'opt_level={result["opt_level"]}:\nBaseline not trained: {result["not_trained"]}\nBaseline trained: {result["trained"]}\n') + # accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) + # assert ( + # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + # assert ( + # baseline_not_trained["f1"] == accelerator_not_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + # assert ( + # baseline_trained["accuracy"] == accelerator_trained["accuracy"] + # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + # assert ( + # baseline_trained["f1"] == accelerator_trained["f1"] + # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/ms_amp/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py index 843ce30968d..5d8a088bfa2 100644 --- a/benchmarks/fp8/ms_amp/fsdp.py +++ b/benchmarks/fp8/ms_amp/fsdp.py @@ -136,17 +136,17 @@ def train_integration(opt_level="O2"): # accelerator_not_trained, accelerator_trained = train_integration("O1") # print(baseline_trained) for opt_level in ["O1", "O2"]: - # baseline_not_trained, baseline_trained = train_baseline(opt_level) + baseline_not_trained, baseline_trained = train_baseline(opt_level) accelerator_not_trained, accelerator_trained = train_integration(opt_level) - # assert ( - # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - # ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - # assert ( - # baseline_not_trained["f1"] == accelerator_not_trained["f1"] - # ), f'F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - # assert ( - # baseline_trained["accuracy"] == accelerator_trained["accuracy"] - # ), f'Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - # assert ( - # baseline_trained["f1"] == accelerator_trained["f1"] - # ), f'F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'Accuracy not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'F1 not the same for untrained baseline and accelerator using opt_level={opt_level}: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'Accuracy not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'F1 not the same for trained baseline and accelerator using opt_level={opt_level}: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' From e7979c30c3bb65e928da0199b97b8bc200b527b8 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 15 Aug 2024 16:35:07 -0400 Subject: [PATCH 10/14] Continue --- benchmarks/fp8/ms_amp/distrib_deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/fp8/ms_amp/distrib_deepspeed.py b/benchmarks/fp8/ms_amp/distrib_deepspeed.py index ad534f80ac5..2a743900842 100644 --- a/benchmarks/fp8/ms_amp/distrib_deepspeed.py +++ b/benchmarks/fp8/ms_amp/distrib_deepspeed.py @@ -159,7 +159,7 @@ def train_integration(zero_stage: int = 1): for zero_stage in [1, 2, 3]: for opt_level in ["O1", "O2", "O3"]: baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level) - results[zero_stage].append({"opt_level": opt_level, "not_trained": baseline_not_trained, "trained": baseline_trained}) + results[str(zero_stage)].append({"opt_level": opt_level, "not_trained": baseline_not_trained, "trained": baseline_trained}) for stage, stage_results in results.items(): print(f'zero_stage={stage}:\n') for result in stage_results: From 215693525b18f05a1600dd72c7f6eb9b5ead3d3c Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 16 Aug 2024 12:03:43 -0400 Subject: [PATCH 11/14] checkpoint --- benchmarks/fp8/ms_amp/distrib_deepspeed.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/benchmarks/fp8/ms_amp/distrib_deepspeed.py b/benchmarks/fp8/ms_amp/distrib_deepspeed.py index 2a743900842..bdaeaa9367e 100644 --- a/benchmarks/fp8/ms_amp/distrib_deepspeed.py +++ b/benchmarks/fp8/ms_amp/distrib_deepspeed.py @@ -155,15 +155,16 @@ def train_integration(zero_stage: int = 1): if __name__ == "__main__": - results = {"1": [], "2": [], "3": []} - for zero_stage in [1, 2, 3]: - for opt_level in ["O1", "O2", "O3"]: - baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level) - results[str(zero_stage)].append({"opt_level": opt_level, "not_trained": baseline_not_trained, "trained": baseline_trained}) - for stage, stage_results in results.items(): - print(f'zero_stage={stage}:\n') - for result in stage_results: - print(f'opt_level={result["opt_level"]}:\nBaseline not trained: {result["not_trained"]}\nBaseline trained: {result["trained"]}\n') + # results = {"1": [], "2": [], "3": []} + # for zero_stage in [1, 2, 3]: + # for opt_level in ["O1", "O2", "O3"]: + baseline_not_trained, baseline_trained = train_baseline(3, "O3") + print(baseline_not_trained, baseline_trained) + # results[str(zero_stage)].append({"opt_level": opt_level, "not_trained": baseline_not_trained, "trained": baseline_trained}) + # for stage, stage_results in results.items(): + # print(f'zero_stage={stage}:\n') + # for result in stage_results: + # print(f'opt_level={result["opt_level"]}:\nBaseline not trained: {result["not_trained"]}\nBaseline trained: {result["trained"]}\n') # accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) # assert ( # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] From 4a37e5f05e90fe28f53c0ae03fc4cd03ad4a8d0e Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 16 Aug 2024 14:58:13 -0400 Subject: [PATCH 12/14] Clean and fin --- benchmarks/fp8/ms_amp/Dockerfile | 8 +-- benchmarks/fp8/ms_amp/distrib_deepspeed.py | 77 ++++++++-------------- benchmarks/fp8/ms_amp/fsdp.py | 29 ++++---- src/accelerate/accelerator.py | 29 ++++++-- src/accelerate/utils/dataclasses.py | 28 ++++++++ 5 files changed, 95 insertions(+), 76 deletions(-) diff --git a/benchmarks/fp8/ms_amp/Dockerfile b/benchmarks/fp8/ms_amp/Dockerfile index d2d1c130e12..dd4d7c25297 100644 --- a/benchmarks/fp8/ms_amp/Dockerfile +++ b/benchmarks/fp8/ms_amp/Dockerfile @@ -1,11 +1,11 @@ FROM ghcr.io/azure/msamp RUN pip install transformers evaluate datasets -# RUN git clone https://github.com/huggingface/accelerate +RUN git clone https://github.com/huggingface/accelerate -# RUN cd accelerate && \ -# pip install -e . && \ -# cd benchmarks/fp8 +RUN cd accelerate && \ + pip install -e . && \ + cd benchmarks/fp8 CMD ["bash"] diff --git a/benchmarks/fp8/ms_amp/distrib_deepspeed.py b/benchmarks/fp8/ms_amp/distrib_deepspeed.py index bdaeaa9367e..57a2569729f 100644 --- a/benchmarks/fp8/ms_amp/distrib_deepspeed.py +++ b/benchmarks/fp8/ms_amp/distrib_deepspeed.py @@ -16,20 +16,19 @@ This script tests to ensure that `accelerate` performs at the same level as raw `MS-AMP`. This particular script verifies this for DeepSpeed training. + +NOTE: MS-AMP does *not* support ZeRO-3. """ -from unittest.mock import patch -from msamp import deepspeed +# import msamp.deepspeed as msamp_deepspeed import evaluate import torch -# import transformer_engine.common.recipe as te_recipe -# import transformer_engine.pytorch as te from fp8_utils import evaluate_model, get_training_utilities -# from transformer_engine.common.recipe import DelayedScaling +from msamp import deepspeed as msamp_deepspeed from accelerate import Accelerator, DeepSpeedPlugin from accelerate.state import AcceleratorState -from accelerate.utils import FP8RecipeKwargs, set_seed +from accelerate.utils import set_seed MODEL_NAME = "bert-base-cased" @@ -37,11 +36,7 @@ def train_baseline(zero_stage: int = 1, opt_level: str = "O1"): - # This forces transformers to think Zero-3 Init should be used - with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock: - mock.return_value = zero_stage == 3 set_seed(42) - accelerator = Accelerator() model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( MODEL_NAME, accelerator=accelerator @@ -57,7 +52,6 @@ def train_baseline(zero_stage: int = 1, opt_level: str = "O1"): "stage": zero_stage, "offload_optimizer": {"device": "none", "nvme_path": None}, "offload_param": {"device": "none", "nvme_path": None}, - "stage3_gather_16bit_weights_on_model_save": False, }, "gradient_clipping": 1.0, "steps_per_print": np.inf, @@ -67,15 +61,14 @@ def train_baseline(zero_stage: int = 1, opt_level: str = "O1"): "msamp": { "enabled": True, "opt_level": opt_level, - } + }, } - ( model, optimizer, _, _, - ) = deepspeed.initialize( + ) = msamp_deepspeed.initialize( model=model, optimizer=optimizer, config_params=config, @@ -107,18 +100,14 @@ def train_baseline(zero_stage: int = 1, opt_level: str = "O1"): return base_model_results, trained_model_results -def train_integration(zero_stage: int = 1): +def train_integration(zero_stage: int = 1, opt_level: str = "O1"): set_seed(42) - FP8_RECIPE_KWARGS = {"fp8_format": "HYBRID", "amax_history_len": 32, "amax_compute_algo": "max"} - kwargs_handlers = [FP8RecipeKwargs(backend="TE", **FP8_RECIPE_KWARGS)] - AcceleratorState()._reset_state(True) deepspeed_plugin = DeepSpeedPlugin( zero_stage=zero_stage, - zero3_init_flag=zero_stage == 3, - ) - accelerator = Accelerator( - mixed_precision="fp8", kwargs_handlers=kwargs_handlers, deepspeed_plugin=deepspeed_plugin + enable_msamp=True, + msamp_opt_level=opt_level, ) + accelerator = Accelerator(mixed_precision="fp8", deepspeed_plugin=deepspeed_plugin) accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16 model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities( @@ -128,13 +117,9 @@ def train_integration(zero_stage: int = 1): model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) model.train() - model_outputs = [] - data = [] for _ in range(2): for batch in train_dataloader: outputs = model(**batch) - data.append(batch.to("cpu")) - model_outputs.append(outputs.logits.to("cpu")) loss = outputs.loss accelerator.backward(loss) optimizer.step() @@ -151,32 +136,26 @@ def train_integration(zero_stage: int = 1): trained_model_results["f1"] > base_model_results["f1"] ), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}' + AcceleratorState()._reset_state(True) return base_model_results, trained_model_results if __name__ == "__main__": - # results = {"1": [], "2": [], "3": []} - # for zero_stage in [1, 2, 3]: - # for opt_level in ["O1", "O2", "O3"]: - baseline_not_trained, baseline_trained = train_baseline(3, "O3") - print(baseline_not_trained, baseline_trained) - # results[str(zero_stage)].append({"opt_level": opt_level, "not_trained": baseline_not_trained, "trained": baseline_trained}) - # for stage, stage_results in results.items(): - # print(f'zero_stage={stage}:\n') - # for result in stage_results: - # print(f'opt_level={result["opt_level"]}:\nBaseline not trained: {result["not_trained"]}\nBaseline trained: {result["trained"]}\n') - # accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage) - # assert ( - # baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] - # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' - # assert ( - # baseline_not_trained["f1"] == accelerator_not_trained["f1"] - # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' - # assert ( - # baseline_trained["accuracy"] == accelerator_trained["accuracy"] - # ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' - # assert ( - # baseline_trained["f1"] == accelerator_trained["f1"] - # ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' + for zero_stage in [1, 2]: + for opt_level in ["O1", "O2", "O3"]: + baseline_not_trained, baseline_trained = train_baseline(zero_stage, opt_level) + accelerator_not_trained, accelerator_trained = train_integration(zero_stage, opt_level) + assert ( + baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}' + assert ( + baseline_not_trained["f1"] == accelerator_not_trained["f1"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}' + assert ( + baseline_trained["accuracy"] == accelerator_trained["accuracy"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nAccuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}' + assert ( + baseline_trained["f1"] == accelerator_trained["f1"] + ), f'ZERO stage {zero_stage}, opt_level={opt_level}:\nF1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}' torch.distributed.destroy_process_group() diff --git a/benchmarks/fp8/ms_amp/fsdp.py b/benchmarks/fp8/ms_amp/fsdp.py index 5d8a088bfa2..396cb6b2981 100644 --- a/benchmarks/fp8/ms_amp/fsdp.py +++ b/benchmarks/fp8/ms_amp/fsdp.py @@ -17,21 +17,21 @@ This particular script verifies this for FSDP training. """ +from functools import partial + import evaluate import msamp import torch -from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin +from fp8_utils import evaluate_model, get_training_utilities +from msamp.common.dtype import Dtypes from msamp.fsdp import FP8FullyShardedDataParallel from msamp.optim import FSDPAdamW -from msamp.common.dtype import Dtypes -from fp8_utils import evaluate_model, get_training_utilities +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers.models.bert import BertLayer from accelerate import Accelerator -from accelerate.state import AcceleratorState +from accelerate import FullyShardedDataParallelPlugin as FSDPPlugin from accelerate.utils import FP8RecipeKwargs, set_seed -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from transformers.models.bert import BertLayer -from functools import partial MODEL_NAME = "bert-base-cased" @@ -45,10 +45,7 @@ def train_baseline(opt_level="O2"): accelerator = Accelerator() device = accelerator.device model, optimizer = msamp.initialize( - model, optimizer, - opt_level=opt_level, - weight_qtype=Dtypes.kfloat8_e4m3, - use_fsdp=True + model, optimizer, opt_level=opt_level, weight_qtype=Dtypes.kfloat8_e4m3, use_fsdp=True ) model = FP8FullyShardedDataParallel( @@ -60,7 +57,7 @@ def train_baseline(opt_level="O2"): backward_prefetch=None, forward_prefetch=False, limit_all_gathers=True, - device_id=device + device_id=device, ) optimizer = FSDPAdamW(optimizer) @@ -78,8 +75,8 @@ def train_baseline(opt_level="O2"): trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = ( - accelerator.free_memory(model, optimizer, train_dataloader, eval_dataloader, lr_scheduler) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.free_memory( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) assert ( trained_model_results["accuracy"] > base_model_results["accuracy"] @@ -118,8 +115,8 @@ def train_integration(opt_level="O2"): trained_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator) - model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = ( - accelerator.free_memory(model, optimizer, train_dataloader, eval_dataloader, lr_scheduler) + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.free_memory( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) assert ( trained_model_results["accuracy"] > base_model_results["accuracy"] diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index fc14ee792b1..0abceb642b0 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -104,7 +104,7 @@ save_fsdp_optimizer, wait_for_everyone, ) -from .utils.constants import FSDP_PYTORCH_VERSION, PROFILE_PATTERN_NAME +from .utils.constants import PROFILE_PATTERN_NAME from .utils.modeling import get_state_dict_offloaded_model from .utils.other import is_compiled_module @@ -310,8 +310,8 @@ def __init__( # if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance( # fsdp_plugin, FullyShardedDataParallelPlugin # ): - # if is_torch_version("<", FSDP_PYTORCH_VERSION): - # raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") + # if is_torch_version("<", FSDP_PYTORCH_VERSION): + # raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") if fsdp_plugin is None: # init from env variables fsdp_plugin = ( @@ -507,8 +507,11 @@ def __init__( elif self.state.mixed_precision == "fp8": # We always enable `native_amp` for FP8 self.native_amp = True - # MS-AMP requires grad scaler however - if self.fp8_backend == "MSAMP" and self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED): + # MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP: + if self.fp8_backend == "MSAMP" and self.distributed_type not in ( + DistributedType.FSDP, + DistributedType.DEEPSPEED, + ): self.scaler = torch.cuda.amp.GradScaler() # Start of internal step tracking @@ -1336,6 +1339,7 @@ def prepare(self, *args, device_placement=None): # We need to convert the underlying optimizer to FSDPAdamW *after* FSDP wrapping result = list(result) from msamp.optim import FSDPAdamW + for i, obj in enumerate(result): if isinstance(obj, AcceleratedOptimizer): result[i].optimizer = FSDPAdamW(optimizer=obj.optimizer) @@ -1636,6 +1640,13 @@ def _prepare_te(self, *args): def _prepare_deepspeed(self, *args): import deepspeed + ds_initialize = deepspeed.initialize + if self.fp8_backend == "MSAMP": + # MS-AMP requires DeepSpeed patches + from msamp import deepspeed as msamp_deepspeed + + ds_initialize = msamp_deepspeed.initialize + deepspeed_plugin = self.state.deepspeed_plugin is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) @@ -1824,7 +1835,7 @@ def _prepare_deepspeed(self, *args): if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES: kwargs["lr_scheduler"] = scheduler - engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) + engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs) if optimizer is not None: optimizer = DeepSpeedOptimizerWrapper(optimizer) if scheduler is not None: @@ -2037,8 +2048,10 @@ def _prepare_msamp(self, *args, device_placement): # NOTE: MS-AMP fsdp relies on it's own MP policy, we must drop the users self.state.fsdp_plugin.mixed_precision_policy = None from msamp.common.dtype import Dtypes + model, optimizer = msamp.initialize( - model, optimizer, + model, + optimizer, opt_level=self.fp8_recipe_handler.opt_level, use_fsdp=self.distributed_type == DistributedType.FSDP, weight_qtype=Dtypes.kfloat8_e4m3, @@ -3598,4 +3611,6 @@ def fp8_backend(self): "Returns the configured backend for training in FP8" if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None: return self.fp8_recipe_handler.backend + elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: + return "MSAMP" return None diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 919b7fadc2b..53d55ab4e70 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -972,6 +972,16 @@ class DeepSpeedPlugin: " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..." }, ) + enable_msamp: bool = field( + default=None, + metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."}, + ) + msamp_opt_level: str = field( + default=None, + metadata={ + "help": "Optimization level for MS-AMP. Only applicable if `enable_msamp` is True. Should be one of ['O1', 'O2', 'O3']." + }, + ) def __post_init__(self): from .deepspeed import HfDeepSpeedConfig @@ -1006,6 +1016,12 @@ def __post_init__(self): os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true" ) + if self.enable_msamp is None: + self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP" + + if self.msamp_opt_level is None: + self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1") + if self.hf_ds_config is None: self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") if ( @@ -1075,6 +1091,14 @@ def __post_init__(self): if self.zero3_init_flag and not self.hf_ds_config.is_zero3(): warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.") self.zero3_init_flag = False + if self.enable_msamp: + if self.zero_stage == 3: + raise NotImplementedError( + "MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead." + ) + if self.msamp_opt_level not in ["O1", "O2", "O3"]: + raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1', 'O2', 'O3'].") + self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level} def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs): mismatches = [] if mismatches is None else mismatches @@ -1144,6 +1168,10 @@ def set_mixed_precision(self, mixed_precision): if "bf16" not in ds_config: ds_config["bf16"] = {"enabled": True} + if mixed_precision == "fp8" and self.enable_msamp: + if "msamp" not in ds_config: + ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level} + if mixed_precision != "no": diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true": From 854e44a05c741b9bd60409e6a9712d57251fe691 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 16 Aug 2024 15:38:36 -0400 Subject: [PATCH 13/14] Add msamp init args to deal with weight qtype defaults --- src/accelerate/accelerator.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0abceb642b0..4659a4c7452 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2048,14 +2048,11 @@ def _prepare_msamp(self, *args, device_placement): # NOTE: MS-AMP fsdp relies on it's own MP policy, we must drop the users self.state.fsdp_plugin.mixed_precision_policy = None from msamp.common.dtype import Dtypes - - model, optimizer = msamp.initialize( - model, - optimizer, - opt_level=self.fp8_recipe_handler.opt_level, - use_fsdp=self.distributed_type == DistributedType.FSDP, - weight_qtype=Dtypes.kfloat8_e4m3, - ) + msamp_init_args = dict(model=model, optimizer=optimizer, opt_level=self.fp8_recipe_handler.opt_level) + if self.distributed_type == DistributedType.FSDP: + msamp_init_args["use_fsdp"] = True + msamp_init_args["weight_qtype"] = Dtypes.kfloat8_e4m3 + model, optimizer = msamp.initialize(**msamp_init_args) for i in range(len(result)): if isinstance(result[i], torch.nn.Module): From e5f6d570d253373e049cf71c33c6b20f66808aef Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Fri, 16 Aug 2024 18:12:33 -0400 Subject: [PATCH 14/14] Format --- src/accelerate/accelerator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 4659a4c7452..b8d9f1f1b07 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2048,6 +2048,7 @@ def _prepare_msamp(self, *args, device_placement): # NOTE: MS-AMP fsdp relies on it's own MP policy, we must drop the users self.state.fsdp_plugin.mixed_precision_policy = None from msamp.common.dtype import Dtypes + msamp_init_args = dict(model=model, optimizer=optimizer, opt_level=self.fp8_recipe_handler.opt_level) if self.distributed_type == DistributedType.FSDP: msamp_init_args["use_fsdp"] = True