diff --git a/.github/workflows/unit-tests-recipes.yml b/.github/workflows/unit-tests-recipes.yml index 4385f40feb..70203f3b8a 100644 --- a/.github/workflows/unit-tests-recipes.yml +++ b/.github/workflows/unit-tests-recipes.yml @@ -155,6 +155,9 @@ jobs: with: sparse-checkout: "${{ matrix.recipe.dir }}" sparse-checkout-cone-mode: false + - name: Include symlink targets for esm2_peft_te + if: ${{ matrix.recipe.dir == 'bionemo-recipes/recipes/esm2_peft_te' }} + run: git -c safe.directory=/__w/bionemo-framework/bionemo-framework sparse-checkout add bionemo-recipes/recipes/esm2_native_te - name: Cache Hugging Face models uses: actions/cache@v4 diff --git a/bionemo-recipes/models/esm2/pyproject.toml b/bionemo-recipes/models/esm2/pyproject.toml index ac8430b0d9..f68a7c8906 100644 --- a/bionemo-recipes/models/esm2/pyproject.toml +++ b/bionemo-recipes/models/esm2/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "jinja2", "megatron-fsdp", "omegaconf", - "peft", + "peft @ git+https://github.com/balvisio/peft.git@support-te-lora", "pytest", "torch", "torchao!=0.14.0", diff --git a/bionemo-recipes/models/esm2/tests/test_peft.py b/bionemo-recipes/models/esm2/tests/test_peft.py index b991a3dc4c..8caadc04aa 100644 --- a/bionemo-recipes/models/esm2/tests/test_peft.py +++ b/bionemo-recipes/models/esm2/tests/test_peft.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import peft -import pytest import torch from esm.modeling_esm_te import NVEsmForMaskedLM @@ -58,7 +59,6 @@ def test_lora_model_forward_pass(te_model_checkpoint, input_data): assert outputs.loss is not None -@pytest.mark.xfail(reason="BIONEMO-3136: LoRA model initializes with warnings because of TE layers.") def test_lora_model_raises_no_warnings(te_model_checkpoint): model = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) @@ -71,13 +71,14 @@ def test_lora_model_raises_no_warnings(te_model_checkpoint): bias="none", ) - with pytest.warns(UserWarning) as record: + with warnings.catch_warnings(record=True) as record: + # Cause all warnings to be triggered (default behavior may ignore some) + warnings.simplefilter("always") peft.get_peft_model(model, peft_config) assert len(record) == 0 -@pytest.mark.xfail(reason="BIONEMO-3136: LoRA model initialization fails with target_modules because of TE layers.") def test_lora_model_with_target_modules(te_model_checkpoint): model = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) diff --git a/bionemo-recipes/recipes/esm2_native_te/checkpoint.py b/bionemo-recipes/recipes/esm2_native_te/checkpoint.py index a168eeeb89..ce909aaaec 100644 --- a/bionemo-recipes/recipes/esm2_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/esm2_native_te/checkpoint.py @@ -194,6 +194,15 @@ def save_final_model_ddp( underlying_model: transformers.PreTrainedModel = model.module if hasattr(model, "module") else model # type: ignore os.makedirs(save_directory, exist_ok=True) + # If we are saving a PEFT model we also save the base_model config. + # This allows for an streamlined reload of the PEFT model without having to manually reconstruct the config of + # the base_model. + # For example: + # >>> config = AutoConfig.from_pretrained() + # >>> base_model = AutoModelForTokenClassification.from_pretrained(, config=config) + # >>> peft_model = PeftModel.from_pretrained(base_model, ) + if hasattr(underlying_model, "peft_config"): + underlying_model.config.save_pretrained(save_directory) underlying_model.save_pretrained(save_directory, state_dict=underlying_model.state_dict(), safe_serialization=True) logger.info(f"Saved final DDP model to {save_directory}") diff --git a/bionemo-recipes/recipes/esm2_peft_te/Dockerfile b/bionemo-recipes/recipes/esm2_peft_te/Dockerfile new file mode 100644 index 0000000000..03eadca61e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/Dockerfile @@ -0,0 +1,9 @@ +# syntax=docker/dockerfile:1.4 +FROM nvcr.io/nvidia/pytorch:26.01-py3 + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +WORKDIR /workspace/bionemo +COPY . . diff --git a/bionemo-recipes/recipes/esm2_peft_te/README.md b/bionemo-recipes/recipes/esm2_peft_te/README.md index 60435aff4a..1e39450bcc 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/README.md +++ b/bionemo-recipes/recipes/esm2_peft_te/README.md @@ -2,6 +2,127 @@ This folder demonstrates how to fine-tune a TransformerEngine-accelerated ESM-2 model using PEFT. -Note: This recipe is a work in progress, and currently only demonstrates basic support for LoRA fine-tuning and -TransformerEngine layers. Refer to `bionemo-recipes/models/esm2/tests/test_peft.py` for additional information and known -limitations. +## Prerequisite: Download Porter6 datasets + +To download the curated Porter6 datasets used by this recipe, run: + +``` +python data/prepare_porter6_dataset.py +``` + +This script downloads and prepares Parquet files under `data/`: + +- `data/porter6_train_dataset_55k.parquet`: training dataset used for LoRA fine-tuning examples. +- `data/porter6_val_dataset_2024_692.parquet`: validation/benchmark split used for evaluation. + +These files are used by the default Hydra configs in this recipe. For dataset provenance and additional options, see +the [Datasets](#datasets) section below. + +## Commands to Launch LoRA Fine-tuning + +To run single-process training on one GPU, run: + +```bash +python train_lora_ddp.py +``` + +To run multi-process training locally on 2+ GPUs, run (e.g. 2 GPUs): + +```bash +torchrun --nproc_per_node=2 train_lora_ddp.py +``` + +## Sequence Packing (THD input format) + +Sequence packing is handled via the [`DataCollatorWithFlattening`](https://huggingface.co/docs/transformers/v4.47.1/main_classes/data_collator#transformers.DataCollatorWithFlattening) collator from the HuggingFace transformers library that provides input arguments (e.g. +`cu_seq_lens_q`) needed for padding-free attention. To enable sequence packing, set `use_sequence_packing=true` +in the hydra configuration. + +```bash +python train_lora_ddp.py --config-name L0_sanity use_sequence_packing=true +``` + +## Running Inference + +Use `infer.py` for inference. By default it uses `hydra_config/L0_sanity_infer.yaml` and reads sequences from +`data/input_infer.fasta` (see `hydra_config/defaults_infer.yaml`). + +Inference requires a LoRA fine-tuned checkpoint directory from training. A typical workflow is: + +1. Pick a training config (for example `hydra_config/L0_sanity.yaml`) and set `checkpoint.ckpt_dir` (for example, + `nv_esm2_t6_8M_UR50D_peft_checkpoint`. The final model will be saved in `nv_esm2_t6_8M_UR50D_peft_checkpoint/train_ddp/final_model`). +2. Run training: + `python train_lora_ddp.py --config-dir hydra_config --config-name L0_sanity` +3. In your inference config (for example `hydra_config/L0_sanity_infer.yaml`), set `base_model_config_dir` to the same + `/train_ddp/final_model` from step 1. +4. Run inference: + +```bash +python infer.py +``` + +You can override the most common settings from the command line: + +- **`input_file`**: FASTA input (default: `data/input_infer.fasta`) +- **`output_file`**: Where to write predictions (CSV). If `null`, results print to stdout (default: `preds.csv`) +- **`model_tag`**: Base ESM-2 HF model to load (default: `nvidia/esm2_t6_8M_UR50D`) +- **`base_model_config_dir`**: Directory containing the fine-tuned model config +- **`peft_model_config_dir`**: Directory containing the LoRA adapter weights/config (defaults to `base_model_config_dir`) + +Examples: + +```bash +# Run on a different FASTA file and write a CSV +python infer.py input_file=/path/to/inputs.fasta output_file=preds.csv + +# Point to your own LoRA fine-tuned checkpoint directory +python infer.py base_model_config_dir=/path/to/my_peft_checkpoint peft_model_config_dir=/path/to/my_peft_checkpoint +``` + +## Datasets + +This recipe includes small and medium-sized datasets in `data/` so you can get started quickly without downloading +anything. + +- **Quick sanity dataset (used for CI and smoke tests)**: `data/peft_sanity_dataset.parquet` is a **5,000-sample subset** + of the Hugging Face dataset + [`lamm-mit/protein_secondary_structure_from_PDB`](https://huggingface.co/datasets/lamm-mit/protein_secondary_structure_from_PDB). + It is intended for fast local iteration and is also used by the recipe's CI tests. + +- **Porter6 paper datasets**: + + - `data/porter6_train_dataset_55k.parquet`: training set. + - `data/porter6_val_dataset_2024_692.parquet`: 2024 benchmark validation set. + + These originate from the Porter6 secondary-structure prediction work. Run + `python data/prepare_porter6_dataset.py` to download the source files from the + [Porter6 repository](https://github.com/WafaAlanazi/Porter6), verify checksums, and convert them to the Parquet files + above. For details on the dataset construction, see the + [Porter6 paper](https://pmc.ncbi.nlm.nih.gov/articles/PMC11719765/). + +### Installing Dependencies + +The easiest way to get started with this recipe is to use the provided Dockerfile, which uses the latest NVIDIA PyTorch +base image to provide optimized versions of PyTorch and TransformerEngine. To build the container, run: + +```bash +docker build -f Dockerfile -t esm2_peft_te . +``` + +To run the container, run: + +```bash +docker run -it --gpus all --network host --ipc=host --rm -v ${PWD}:/workspace/bionemo esm2_peft_te /bin/bash +``` + +## Developer Guide + +### Running tests + +To run tests locally, run `recipes_local_test.py` from the repository root with the recipe directory as an argument. + +```bash +./ci/scripts/recipes_local_test.py bionemo-recipes/recipes/esm2_peft_te/ +``` + +For more information see [here](../esm2_native_te/README.md). diff --git a/bionemo-recipes/recipes/esm2_peft_te/checkpoint.py b/bionemo-recipes/recipes/esm2_peft_te/checkpoint.py new file mode 100644 index 0000000000..ce909aaaec --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/checkpoint.py @@ -0,0 +1,618 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from typing import NamedTuple + +import torch +import transformers +from safetensors.torch import save_file +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + get_state_dict, + set_state_dict, +) +from torch.distributed.checkpoint.state_dict_loader import load as dcp_load +from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save +from torch.distributed.checkpoint.state_dict_saver import save as dcp_save +from torch.distributed.checkpoint.stateful import Stateful +from torchdata.stateful_dataloader import StatefulDataLoader + +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) +_ckpt_futures: dict = {} + + +class CheckpointOutput(NamedTuple): + """Output of checkpoint loading.""" + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + dataloader: StatefulDataLoader | None + step: int + epoch: int + + +# ============================================================================ +# Helper functions +# ============================================================================ + + +def get_latest_checkpoint(ckpt_path: str | os.PathLike) -> tuple[Path | None, int]: + """Get the latest checkpoint path and step number. + + Returns: + Tuple of (checkpoint path, step number). + If no checkpoint files are found, returns (None, 0). + """ + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + return None, 0 + + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + + if not checkpoints: + return None, 0 + + latest = max(checkpoints, key=lambda x: int(Path(x).stem.split("_")[1])) + step = int(Path(latest).stem.split("_")[1]) + return latest, step + + +def should_save_checkpoint(step: int, save_every_n_steps: int) -> bool: + """Determine if a checkpoint should be saved.""" + if save_every_n_steps > 0 and step % save_every_n_steps == 0 and step > 0: + return True + return False + + +def prune_checkpoints(ckpt_path: str | os.PathLike, max_checkpoints: int) -> None: + """Prune checkpoints to keep only the latest `max_checkpoints` checkpoints.""" + ckpt_path = Path(ckpt_path) + checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")] + checkpoints.sort(key=lambda x: int(Path(x).stem.split("_")[1])) + if len(checkpoints) > max_checkpoints: + for checkpoint in checkpoints[:-max_checkpoints]: + logger.info(f"Pruning checkpoint {checkpoint}") + if checkpoint.is_dir(): + shutil.rmtree(checkpoint) + else: + os.remove(checkpoint) + + +# ============================================================================ +# DDP Checkpointing +# ============================================================================ + + +def load_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, +) -> CheckpointOutput: + """Load DDP checkpoint.""" + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + + if not checkpoint_path: + logger.info("No DDP checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + checkpoint = torch.load( + checkpoint_path / "checkpoint.pt", + map_location=f"cuda:{dist_config.local_rank}", + weights_only=True, + ) + + model.load_state_dict(checkpoint["model"], strict=False) + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) + step = checkpoint["step"] + epoch = checkpoint["epoch"] + + if dist_config.is_main_process(): + logger.info(f"Loaded DDP checkpoint from step {step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, step + 1, epoch) + + +def save_checkpoint_ddp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + max_checkpoints: int | None = None, +) -> None: + """Saves the Dataloader state and the DDP checkpoint.""" + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + # Dataloader checkpointing needs to happen on all ranks, while DDP model checkpointing only needs to happen on the + # main process. + save_dataloader(dataloader, checkpoint_path, dist_config) + + if not dist_config.is_main_process(): + return + + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "epoch": epoch, + }, + checkpoint_path / "checkpoint.pt", + ) + + logger.info(f"Saved DDP checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_ddp( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for DDP - only on main process.""" + if not dist_config.is_main_process(): + return + + # Unwrap model if wrapped + underlying_model: transformers.PreTrainedModel = model.module if hasattr(model, "module") else model # type: ignore + + os.makedirs(save_directory, exist_ok=True) + # If we are saving a PEFT model we also save the base_model config. + # This allows for an streamlined reload of the PEFT model without having to manually reconstruct the config of + # the base_model. + # For example: + # >>> config = AutoConfig.from_pretrained() + # >>> base_model = AutoModelForTokenClassification.from_pretrained(, config=config) + # >>> peft_model = PeftModel.from_pretrained(base_model, ) + if hasattr(underlying_model, "peft_config"): + underlying_model.config.save_pretrained(save_directory) + underlying_model.save_pretrained(save_directory, state_dict=underlying_model.state_dict(), safe_serialization=True) + logger.info(f"Saved final DDP model to {save_directory}") + + +# ============================================================================ +# mFSDP Checkpointing +# ============================================================================ + + +def load_checkpoint_mfsdp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, +) -> CheckpointOutput: + """Load mFSDP distributed checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The directory containing checkpoints. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + + Returns: + Tuple of (model, optimizer, scheduler, step). + """ + checkpoint_path, step = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No mFSDP checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + ckpt_state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "metadata": { + "step": step, # Initialize with current step from filename + "epoch": 0, # Initialize with default epoch + }, + } + torch.distributed.checkpoint.load(state_dict=ckpt_state_dict, checkpoint_id=checkpoint_path) + + model.load_state_dict(ckpt_state_dict["model"], strict=False) + optimizer.load_state_dict(ckpt_state_dict["optimizer"]) + scheduler.load_state_dict(ckpt_state_dict["scheduler"]) + dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) + + step = ckpt_state_dict["metadata"]["step"] + epoch = ckpt_state_dict["metadata"]["epoch"] + + # Ensure all ranks have completed loading before proceeding + torch.distributed.barrier() + + logger.info(f"Loaded mFSDP checkpoint from step {step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, step + 1, epoch) + + +def save_checkpoint_mfsdp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + epoch: int = 0, + max_checkpoints: int | None = None, +) -> None: + """Save mFSDP distributed checkpoint. + + Args: + model: The model to save. + optimizer: The optimizer to save. + scheduler: The LR scheduler to save. + ckpt_path: The directory to save the checkpoint. + step: The step number to save the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to save. + epoch: The epoch number to save the checkpoint. + max_checkpoints: The maximum number of checkpoints to keep. + """ + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + # Save dataloader state, if provided. + save_dataloader(dataloader, checkpoint_path, dist_config) + + # Save model, optimizer, scheduler state, and metadata + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "metadata": { + "step": step, + "epoch": epoch, + }, + } + + torch.distributed.checkpoint.save(state_dict, checkpoint_id=checkpoint_path) + + if dist_config.is_main_process(): + logger.info(f"Saved mFSDP checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_mfsdp( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for mFSDP - requires parameter gathering on all ranks.""" + from megatron_fsdp.uneven_dtensor import gather_uneven_dtensor_to_full_tensor + + if dist_config.is_main_process(): + logger.info("Starting mFSDP parameter gathering...") + + # Parameter gathering must happen on ALL processes + unsharded_state_dict = { + # Gather all parameters to CPU, and remove the "module." prefix from the Megatron-FSDP class wrapper. + k.removeprefix("module."): gather_uneven_dtensor_to_full_tensor( + v, target_device=torch.device("cpu") + ).to_local() + if isinstance(v, torch.distributed.tensor.DTensor) + else v + for k, v in model.state_dict().items() + } + + # Only main process saves the model + if not dist_config.is_main_process(): + return + + os.makedirs(save_directory, exist_ok=True) + model.module.save_pretrained(save_directory, state_dict=unsharded_state_dict, safe_serialization=True) + logger.info(f"Saved final mFSDP model to {save_directory}") + + +# ============================================================================ +# FSDP2 Checkpointing +# ============================================================================ + + +@dataclass +class AppState(Stateful): + """AppState for FSDP2 checkpoint. + + Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html + """ + + model: torch.nn.Module + optimizer: torch.optim.Optimizer + scheduler: torch.optim.lr_scheduler.LRScheduler + step: int = 0 + epoch: int = 0 + state_dict_options: StateDictOptions = field( + default_factory=lambda: StateDictOptions( + full_state_dict=False, + cpu_offload=True, + ) + ) + + def state_dict(self): + """Get the state dict for the model, optimizer, scheduler, and step.""" + model_state_dict, optimizer_state_dict = get_state_dict( + self.model, self.optimizer, options=self.state_dict_options + ) + return { + "model": model_state_dict, + "optim": optimizer_state_dict, + "scheduler": self.scheduler.state_dict(), + "step": self.step, + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict: dict): + """Load the state dict for the model, optimizer, scheduler, and step.""" + set_state_dict( + self.model, + self.optimizer, + model_state_dict=state_dict["model"], + optim_state_dict=state_dict["optim"], + options=self.state_dict_options, + ) + self.scheduler.load_state_dict(state_dict["scheduler"]) + self.step = state_dict["step"] + self.epoch = state_dict["epoch"] + + +def load_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, +) -> CheckpointOutput: + """Load FSDP2 checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The directory containing checkpoints. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + process_group: The process group to use for checkpointing. + """ + checkpoint_path, _ = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No FSDP2 checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + app_state = AppState( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + state_dict = {"app": app_state} + dcp_load(state_dict, checkpoint_id=checkpoint_path, process_group=process_group) + + if dataloader is not None: + load_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + + logger.info(f"Loaded distributed FSDP2 checkpoint from step {app_state.step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, app_state.step + 1, app_state.epoch) + + +def save_checkpoint_fsdp2( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + process_group: torch.distributed.ProcessGroup | None = None, + max_checkpoints: int | None = None, + async_save: bool = False, +) -> None: + """Save FSDP2 checkpoint. + + Args: + model: The model to save. + optimizer: The optimizer to save. + scheduler: The LR scheduler to save. + ckpt_path: The directory to save the checkpoint. + step: The step number to save the checkpoint. + epoch: The epoch number to save the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to save. + process_group: The process group to use for checkpointing. + max_checkpoints: The maximum number of checkpoints to keep. + async_save: Whether to save the checkpoint asynchronously. + """ + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + if dataloader is not None: + save_dataloader( + dataloader=dataloader, + ckpt_path=checkpoint_path, + dist_config=dist_config, + ) + logger.info(f"Saved FSDP2 dataloader to {ckpt_path}") + + # If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time. + if async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: + _ckpt_futures["fsdp2"].result() + + # Clear GPU cache before checkpointing to free up fragmented memory. + gc.collect() + torch.cuda.empty_cache() + torch.distributed.barrier(group=process_group) + + state_dict = {"app": AppState(model=model, optimizer=optimizer, scheduler=scheduler, step=step, epoch=epoch)} + ckpt_save_func = dcp_async_save if async_save else dcp_save + _ckpt_futures["fsdp2"] = ckpt_save_func(state_dict, checkpoint_id=checkpoint_path, process_group=process_group) + + if dist_config.is_main_process(): + logger.info(f"Saved distributed FSDP2 checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_fsdp2( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for FSDP2 - gather on all ranks, save on main.""" + # ALL ranks must participate in gathering + model_state_dict = get_model_state_dict( + model=model, + options=StateDictOptions( + full_state_dict=True, + cpu_offload=True, + ), + ) + + # Only main process saves + if not dist_config.is_main_process(): + return + + os.makedirs(save_directory, exist_ok=True) + + # Save just the weights using safetensors + save_file(model_state_dict, os.path.join(save_directory, "model.safetensors")) + + # Save the config + underlying_model = model.module if hasattr(model, "module") else model + if hasattr(underlying_model, "config"): + underlying_model.config.save_pretrained(save_directory) + + logger.info(f"Saved final FSDP2 model to {save_directory} (weights + config only)") + + +# ============================================================================ +# Dataloader Checkpointing +# ============================================================================ + + +def save_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +): + """Save the dataloader state to a file. + + For resuming training with long epochs, we save the dataloader state as part of the checkpoint to allow for resuming + from the exact same step. Here we save the dataloader state based on global rank. Note, the total number of ranks + and dataloader num_workers should match for resuming training. + + Args: + dataloader: The dataloader to save the state of. + ckpt_path: The path to save the dataloader state to. + dist_config: The distributed configuration. + """ + if dataloader is None: + return + + ckpt_path = Path(ckpt_path) + ckpt_path.mkdir(parents=True, exist_ok=True) + dataloader_path = ckpt_path / f"dataloader_rank_{dist_config.rank}.pt" + + dataloader_state = dataloader.state_dict() + dataloader_state["num_workers"] = dataloader.num_workers + dataloader_state["num_ranks"] = dist_config.world_size + torch.save(dataloader_state, dataloader_path) + if dist_config.is_main_process(): + logger.info(f"Saved dataloader state to {dataloader_path}") + + +def load_dataloader( + dataloader: StatefulDataLoader | None, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, +) -> StatefulDataLoader | None: + """Load the dataloader state from a file. + + Here we load the dataloader state based on global rank. + + Args: + dataloader: The dataloader to load the state of. + ckpt_path: The path to load the dataloader state from. + dist_config: The distributed configuration. + """ + if dataloader is None: + return dataloader + + dataloader_path = Path(ckpt_path) / f"dataloader_rank_{dist_config.rank}.pt" + if not dataloader_path.exists(): + logger.warning( + f"No dataloader checkpoint found for rank {dist_config.rank}, starting dataloader from scratch." + ) + return dataloader + + dataloader_state = torch.load(dataloader_path) + + if ( + dataloader.num_workers != dataloader_state["num_workers"] + or dist_config.world_size != dataloader_state["num_ranks"] + ): + logger.warning( + f"Dataloader num_workers mismatch: {dataloader.num_workers} != {dataloader_state['num_workers']} or " + f"num_ranks mismatch: {dist_config.world_size} != {dataloader_state['num_ranks']}, " + "starting dataloader from scratch." + ) + return dataloader + + dataloader.load_state_dict(dataloader_state) + if dist_config.is_main_process(): + logger.info(f"Loaded dataloader state from {dataloader_path}") + + return dataloader diff --git a/bionemo-recipes/recipes/esm2_peft_te/collator.py b/bionemo-recipes/recipes/esm2_peft_te/collator.py new file mode 100644 index 0000000000..43158d3f9b --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/collator.py @@ -0,0 +1,876 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data collator for THD input format tests. + +This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, TypedDict + +import datasets +import torch +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp +from transformers import DataCollator, DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +@dataclass +class DataCollatorWithFlattening: + """Data collator that wraps a DataCollatorForLanguageModeling and flattens inputs for flash-attention. + + This collator enables efficient training on batches containing variable-length sequences, by first flattening + (packing) multiple input sequences into a single contiguous tensor without padding between sequences. Then, it + applies masked language modeling (MLM) masking using the provided DataCollatorForLanguageModeling instance. + + The collator also generates metadata required for Flash Attention or context-parallel attention: + - `cu_seq_lens_q` and `cu_seq_lens_k` tensors, denoting cumulative sequence lengths so that sequence boundaries + within the packed tensor are known during attention computation. + + Optionally, the collator can: + - Pad the total number of tokens in the batch to be divisible by `pad_to_multiple_of` (by appending a mock + sequence). + - Pad each individual sequence to be divisible by `pad_sequences_to_be_divisible_by` if provided. + + Only PyTorch tensors (`return_tensors="pt"`) are supported. + + Args: + collator (DataCollatorForLanguageModeling): The collator to use for MLM masking. This is a captive + collator and should be constructed externally and passed in. + return_position_ids (bool): Whether to return position ids (default False). + pad_to_multiple_of (int, optional): If set, pads the total sequence length to be divisible by this number. + pad_sequences_to_be_divisible_by (int, optional): If set, each individual sequence is padded to this value. + separator_id (int, optional): A label to insert between sequences, typically should be -100 for causal LM. + + Example: + >>> from transformers import AutoTokenizer, DataCollatorForLanguageModeling + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + >>> mlm_collator = DataCollatorForLanguageModeling(tokenizer) + >>> flat_collator = DataCollatorWithFlattening( + ... collator=mlm_collator, + ... pad_to_multiple_of=8, + ... ) + >>> + >>> # Input: variable length protein sequences + >>> sequences = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... {"input_ids": [0, 12, 13, 2]}, # 4 tokens + ... ] # Total: 15 tokens + >>> batch = flat_collator(sequences) + >>> print(batch['input_ids'].shape) # torch.Size([1, 16]) + >>> print(batch['labels'].shape) # torch.Size([1, 16]) + >>> print(batch['cu_seq_lens_q']) # tensor([0, 5, 11, 15, 16], dtype=torch.int32) + + Note: + The output is a THD-format (Total, Height, Depth) batch, where all input sequences are packed without + inter-sequence padding. Sequence boundaries are preserved using `cu_seq_lens_q`/`cu_seq_lens_k`, enabling + Flash Attention or context-parallelism without traditional attention masks. + """ + + collator: DataCollatorForLanguageModeling + return_position_ids: bool = False + pad_to_multiple_of: int | None = None + pad_sequences_to_be_divisible_by: int | None = None + separator_id: int | None = None + + def __post_init__(self): + """Ensure padding options are not used together.""" + if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None: + raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together") + + def __call__(self, features, return_tensors=None): + """Process a batch of variable-length sequences for Flash Attention with MLM. + + This method performs the following steps: + 1. Flattens multiple sequences into a single packed tensor with Flash Attention metadata + 2. Applies MLM masking to the flattened sequence while preserving special tokens + 3. Optionally pads to a multiple of a specified number for hardware optimization + + Args: + features (List[Dict[str, List[int]]]): List of tokenized sequences, each containing + 'input_ids' and optionally 'attention_mask'. Example: + [ + {"input_ids": [0, 5, 6, 7, 2]}, # Protein sequence 1 + {"input_ids": [0, 8, 9, 10, 11, 2]}, # Protein sequence 2 + {"input_ids": [0, 12, 13, 2]} # Protein sequence 3 + ] + return_tensors (str, optional): Format for returned tensors. Only "pt" (PyTorch) + is supported. Defaults to None (uses collator default). + + Returns: + Dict[str, torch.Tensor]: Batch dictionary containing: + - input_ids (torch.Tensor): Flattened and MLM-masked token sequences. + Shape: [1, total_tokens] where total_tokens = sum of all sequence lengths + (plus padding if pad_to_multiple_of is specified). + - labels (torch.Tensor): MLM labels with -100 for non-masked tokens and + original token IDs for masked positions. Same shape as input_ids. + - cu_seq_lens_q (torch.IntTensor): Cumulative sequence lengths for queries. + Shape: [num_sequences + 1] or [num_sequences + 2] if padding is added. + Example: [0, 5, 11, 15] or [0, 5, 11, 15, 16] with padding. + - cu_seq_lens_k (torch.IntTensor): Cumulative sequence lengths for keys. + Same as cu_seq_lens_q for self-attention. + - max_length_q (int): Maximum sequence length in the batch. + - max_length_k (int): Same as max_length_q for self-attention. + - attention_mask (torch.Tensor): Attention mask with 1s for actual tokens + and 0s for padding tokens (if any). + + Raises: + NotImplementedError: If return_tensors is not "pt". + + Example: + >>> # Input features + >>> features = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... ] + >>> + >>> batch = collator(features) + >>> + >>> # Output shapes and values + >>> batch['input_ids'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['labels'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['cu_seq_lens_q'] # tensor([0, 5, 11], dtype=torch.int32) or larger + + Note: + The output is in THD (Total, Height, Depth) format with batch_size=1 and + sequence_length=total_tokens, optimized for Flash Attention's variable-length + sequence processing capabilities. When pad_to_multiple_of is used, an additional + mock sequence is appended to reach the desired total length. + """ + # Perform the masking with the BSHD collator. + bshd_batch = self.collator(features) + + # Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values. + packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids) + + # Get the masked input_ids and labels from the BSHD batch. + masked_input_ids = bshd_batch["input_ids"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + masked_labels = bshd_batch["labels"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + + if self.separator_id is not None: + masked_labels[:, packed_batch["cu_seq_lens_q"][1:-1]] = self.separator_id + + # Update the packed batch with the masked input_ids and labels. + packed_batch["input_ids"] = masked_input_ids + packed_batch["labels"] = masked_labels + + if self.pad_to_multiple_of is not None: + packed_batch = self._pad_batch_to_multiple_of(packed_batch) + + elif self.pad_sequences_to_be_divisible_by is not None: + packed_batch = self._pad_sequences_to_be_divisible_by(packed_batch) + + return packed_batch + + def _pad_batch_to_multiple_of(self, batch): + """Add a mock sequence to make the total number of tokens divisible by pad_to_multiple_of.""" + # Ensure token_pad is an integer, defaulting to 1 if pad_token_id is None or invalid + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_to_multiple_of is not None, "pad_to_multiple_of must be set" + + return _pt_pad_to_multiple_of( + batch, + self.pad_to_multiple_of, + token_pad=pad_token_id, + label_pad=-100, + ) + + def _pad_sequences_to_be_divisible_by(self, batch): + """Pad individual sequences using cu_seq_lens_*_padded for context parallelism.""" + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_sequences_to_be_divisible_by is not None, "pad_sequences_to_be_divisible_by must be set" + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + batch["input_ids"], + batch["labels"], + batch["cu_seq_lens_q"], + self.pad_sequences_to_be_divisible_by, + padding_token_id=pad_token_id, + padding_label_id=-100, + ) + + batch["input_ids"] = input_ids_padded.unsqueeze(0) + batch["labels"] = labels_padded.unsqueeze(0) + batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + batch["pad_between_seqs"] = True + return batch + + +@dataclass +class TokenPackingDataset(torch.utils.data.IterableDataset): + """Dataset that uses sequence packing to construct batches with variable length up to a maximum number of tokens.""" + + dataset: datasets.IterableDataset + """Dataset to pack.""" + max_tokens_per_batch: int + """Maximum number of tokens per batch.""" + drop_last: bool = True + """Whether to drop the last batch if it's less than max_length.""" + split_samples: bool = False + """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + + def __iter__(self): + """Yield batches of samples, each with a variable number of tokens up to the maximum length. + + When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting + the final sample if needed. The remaining tokens from the split sample start the next batch. + + Returns: + A generator of batches of samples, each with a variable number of tokens up to the maximum length. + """ + samples = [] + current_length = 0 + for sample in iter(self.dataset): + current_length += len(sample["input_ids"]) + if current_length == self.max_tokens_per_batch: + yield [*samples, sample] + samples = [] + current_length = 0 + + elif current_length > self.max_tokens_per_batch: + if not self.split_samples: + # If we are not splitting samples, we can just yield the current batch (before this sample) and + # start a new one. + yield samples + samples = [sample] + + else: + # Calculate how many tokens are already in the batch + tokens_in_batch = current_length - len(sample["input_ids"]) + # Calculate how many tokens we can fit from this sample + tokens_available = self.max_tokens_per_batch - tokens_in_batch + first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) + yield [*samples, first_part] + samples = [remaining_part] + + current_length = len(samples[0]["input_ids"]) + else: + samples.append(sample) + + if not self.drop_last and samples: + yield samples + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset.""" + self.dataset.set_epoch(epoch) + + +@dataclass +class DataCollatorForContextParallel: + """A collator that is aware of context parallelism. + + For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split + into shards for each context parallelism rank. + + The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the + appropriate GPUs. + + Note: + When used with the ContextParallelDataLoaderWrapper and both context parallelism and tensor parallelism are + used, the collator inspects the ordering of the mesh dimensions to determine the layout of the flattened batch. + + If "cp" comes before "tp" in the mesh dimension names (CP row-major), the flattened batch will be: + [(cp0, tp0), (cp0, tp1), ..., (cp1, tp0), (cp1, tp1), ...] + + If "tp" comes before "cp" (TP row-major), the flattened batch will be: + [(tp0, cp0), (tp0, cp1), ..., (tp1, cp0), (tp1, cp1), ...] + + Args: + collator: The collator to use for the batch. + device_mesh: The device mesh with named dimensions. Must contain either a "cp" dimension for context parallelism + and/or a "tp" dimension for tensor parallelism. + qkv_format: The format of the query-key-value (QKV) tensor. + is_causal_lm: Whether the collator is for a causal language model. If True, the labels will be shifted before + being split into CP shards, and will be returned in the `shift_labels` field. + + """ + + collator: DataCollator + device_mesh: torch.distributed.device_mesh.DeviceMesh + qkv_format: str = "thd" + is_causal_lm: bool = False + + # Derived fields, initialized in __post_init__. + cp_world_size: int = field(init=False) + tp_world_size: int | None = field(init=False) + _is_cp_row_major: bool = field(init=False) + + def __post_init__(self): + """Initialize the cp_world_size, tp_world_size, and _is_cp_row_major fields based on the device mesh.""" + dim_names = self.device_mesh.mesh_dim_names + if dim_names is None: + raise ValueError("device_mesh must have mesh_dim_names") + + self.cp_world_size = self.device_mesh.size(dim_names.index("cp")) if "cp" in dim_names else 1 + self.tp_world_size = self.device_mesh.size(dim_names.index("tp")) if "tp" in dim_names else None + + # Determine whether CP is the row (outer) dimension of the 2D mesh. + # When flattened, the row-major dimension's index changes slowest. + # If "cp" comes before "tp" in mesh_dim_names, CP is the row dimension. + if "cp" in dim_names and "tp" in dim_names: + self._is_cp_row_major = dim_names.index("cp") < dim_names.index("tp") + else: + self._is_cp_row_major = True + + def __call__(self, features) -> list[dict[str, Any]]: + """Process batches of data and create shards for each context parallelism rank. + + Args: + features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'. + + Returns: + A list of dictionaries, each containing a shard of the batch for a given context parallelism rank. + """ + batch = self.collator(features) + + if self.is_causal_lm: + labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) + batch["labels"] = labels[..., 1:].contiguous() + + combined_batch = [] + for cp_rank in range(self.cp_world_size): + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format. + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format=self.qkv_format, + cp_rank=cp_rank, + cp_world_size=self.cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + if self.is_causal_lm: + batch_shard["shift_labels"] = labels_sharded + batch_shard["labels"] = None + else: + batch_shard["labels"] = labels_sharded + # Now determine the max length of the sequence. + if self.qkv_format == "thd": + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + max_length = seqlens_q.max().item() + elif self.qkv_format == "bshd": + max_length = batch["input_ids"].shape[1] + # For BSHD context parallelism, we can't handle padding, so we remove the attention mask. + del batch_shard["attention_mask"] + else: + raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") + + batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64) + combined_batch.append(batch_shard) + + if self.tp_world_size is not None: + # Replicate each CP shard for TP ranks. The ordering depends on which dimension forms the rows in the + # flattened mesh. + if self._is_cp_row_major: + # Flattened mesh: [(cp0,tp0), (cp0,tp1), (cp1,tp0), (cp1,tp1)] + # Output: [cp0, cp0, cp1, cp1] + combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)] + else: + # Flattened mesh: [(tp0,cp0), (tp0,cp1), (tp1,cp0), (tp1,cp1)] + # Output: [cp0, cp1, cp0, cp1] + combined_batch = [ + combined_batch[cp_rank] for _ in range(self.tp_world_size) for cp_rank in range(self.cp_world_size) + ] + + return combined_batch + + +class ContextParallelDataLoaderWrapper: + """A dataloader that is aware of context and tensor parallelism.""" + + def __init__( + self, + dataloader: torch.utils.data.DataLoader | None, + cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh, + ): + """A dataloader wrapper that distributes the data across the context and tensor parallelism groups. + + This class materializes a single dataloader for each data parallel mesh rank, and splits / replicates the data + from this dataloader across the context and tensor parallelism groups. + + Args: + dataloader: The dataloader to use. + cp_tp_mesh: The context parallel mesh, or a flattened, combined context parallel and tensor parallel mesh. + If a flattened mesh is provided, the cp / tp dimensions should be in the order they appeared in the + mesh_dim_names as passed to DataCollatorForContextParallel. + """ + if cp_tp_mesh.get_local_rank() == 0: + assert dataloader is not None, "dataloader must be provided on rank 0" + self.dataloader = dataloader + + else: + assert dataloader is None, "Dataloader on non-rank 0 will not be used" + + self.cp_tp_rank = cp_tp_mesh.get_local_rank() + self.cp_tp_group = cp_tp_mesh.get_group() + self.num_cp_tp_ranks = cp_tp_mesh.size() + self._iterator = None + + logger.debug( + "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", + torch.distributed.get_rank() if torch.distributed.is_initialized() else "", + self.cp_tp_rank, + ) + + def __iter__(self): + """Make the dataloader iterable.""" + if self.cp_tp_rank == 0: + self._iterator = iter(self.dataloader) # < --- collator output. + return self + + def __next__(self): + """Get the batch from the dataloader for the current CP rank.""" + batch = self._send_data_to_cp_tp_ranks() + return batch + + def _send_data_to_cp_tp_ranks(self): + """Send data to all the CP/TP ranks. + + This function will get the batch from the dataloader on CP rank 0, and then determine + the shards for all the different CP group members. + combined_batch = [, , ..., ] + Then it will scatter the shards to the different CP group members. + The shards are then combined into a single batch and returned to the caller + for the current CP rank. + + If tensor parallelism is also being used, the combined batch will look like: + combined_batch = [, , ..., , , ...] + where there are cp_world_size shards, and each shard is replicated tp_world_size times. The ordering of the + shards depends on which dimension forms the rows in the flattened mesh. + + Scalability: + Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they + do not grow linearly with CP size. + + Args: + None + + Returns: + batch: The batch for the current CP/TP rank. + + """ + try: + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + except StopIteration as ex: + # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so + # that the dataloader can be restarted. + combined_batch = [ex] * self.num_cp_tp_ranks + + batch_on_this_rank = _scatter_batch_to_cp_tp_ranks(combined_batch, self.cp_tp_group) + + if isinstance(batch_on_this_rank, StopIteration): + raise batch_on_this_rank + + return batch_on_this_rank + + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_tp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + + +def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: + """Split a sample dictionary at a specified number of tokens. + + This function splits a sample into two parts: the first part contains exactly `num_tokens` tokens, + and the second part contains the remaining tokens. All fields that are sequences (input_ids, attention_mask, + token_type_ids, labels, etc.) are split accordingly. + + Args: + sample: Dictionary containing sample data with fields like input_ids, attention_mask, token_type_ids, labels, etc. + num_tokens: Number of tokens to include in the first part of the split. + + Returns: + A tuple of two dictionaries: (first_part, remaining_part), where: + - first_part contains the first `num_tokens` tokens from each sequence field + - remaining_part contains the remaining tokens from each sequence field + + Example: + >>> sample = { + ... "input_ids": [0, 5, 6, 7, 8, 9, 2], + ... "attention_mask": [1, 1, 1, 1, 1, 1, 1], + ... "labels": [0, 5, 6, 7, 8, 9, 2] + ... } + >>> first, remaining = split_sample_by_num_tokens(sample, 3) + >>> first["input_ids"] # [0, 5, 6] + >>> remaining["input_ids"] # [7, 8, 9, 2] + """ + sample_length = len(sample["input_ids"]) + if num_tokens >= sample_length: + raise ValueError( + f"num_tokens ({num_tokens}) must be less than sample length ({sample_length}) to split the sample" + ) + if num_tokens <= 0: + raise ValueError(f"num_tokens ({num_tokens}) must be positive") + + first_part = {} + remaining_part = {} + + # Fields that should be split by tokens (sequence fields) + sequence_fields = ["input_ids", "attention_mask", "token_type_ids", "token_type", "labels"] + + for key, value in sample.items(): + if key in sequence_fields: + # Handle both list and tensor inputs + if isinstance(value, torch.Tensor): + first_part[key] = value[:num_tokens].clone() + remaining_part[key] = value[num_tokens:].clone() + elif isinstance(value, list): + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + else: + # For other types, try to slice if possible + try: + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + except (TypeError, IndexError): + # If slicing doesn't work, copy the value to both parts + # This handles fields that shouldn't be split (like metadata) + first_part[key] = value + remaining_part[key] = value + else: + # For non-sequence fields, copy to both parts + # This handles metadata fields that shouldn't be split + first_part[key] = value + remaining_part[key] = value + + return first_part, remaining_part + + +def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + is_labels_provided = "labels" in features[0] + sample_lengths = [len(sample["input_ids"]) for sample in features] + + batch = {} + batch["max_length_q"] = batch["max_length_k"] = max(sample_lengths) + batch["input_ids"] = torch.tensor( + [[token for sample in features for token in sample["input_ids"]]], dtype=torch.int64 + ) + if is_labels_provided: + batch["labels"] = torch.tensor( + [[label for sample in features for label in sample["labels"]]], dtype=torch.int64 + ) + cu_seq_lens = torch.zeros(len(features) + 1, dtype=torch.int32) + cu_seq_lens[1:] = torch.cumsum(torch.tensor(sample_lengths), dim=0, dtype=torch.int32) + batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens + if "attention_mask" in features[0]: + batch["attention_mask"] = torch.tensor( + [[v for sample in features for v in sample["attention_mask"]]], dtype=torch.int64 + ) + if return_position_ids: + batch["position_ids"] = torch.hstack( + [torch.arange(sample_len, dtype=torch.int64) for sample_len in sample_lengths] + ).unsqueeze(0) + + return batch + + +def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token_pad: int, label_pad: int): + """Pad a batch to a multiple of pad_to_multiple_of. + + Appends a mock sequence to the end of the batch with the given token_pad and label_pad to make the total number of + tokens divisible by pad_to_multiple_of. + + Args: + batch: Input batch, possibly containing labels and/or cu_seq_lens / max_length keys. + pad_to_multiple_of: Multiple to pad to. + token_pad: Token to pad with. + label_pad: Label to pad with. + + Returns: + Batch dictionary with padded input_ids, labels, cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k. + """ + # Number of tokens we need to pad to make the total number of tokens divisible by pad_to_multiple_of + remainder = -batch["input_ids"].numel() % pad_to_multiple_of + + if remainder == 0: + return batch + + batch["input_ids"] = torch.cat( + [batch["input_ids"], torch.full((1, remainder), token_pad, dtype=batch["input_ids"].dtype)], dim=1 + ) + + if "labels" in batch: + batch["labels"] = torch.cat( + [batch["labels"], torch.full((1, remainder), label_pad, dtype=batch["labels"].dtype)], dim=1 + ) + + if "cu_seq_lens_q" in batch: + batch["cu_seq_lens_q"] = torch.cat( + [ + batch["cu_seq_lens_q"], + torch.tensor([batch["cu_seq_lens_q"][-1] + remainder], dtype=batch["cu_seq_lens_q"].dtype), + ], + dim=0, + ) + batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] + + if "max_length_q" in batch: + batch["max_length_q"] = max(batch["max_length_q"], remainder) + batch["max_length_k"] = batch["max_length_q"] + + if "attention_mask" in batch: + batch["attention_mask"] = torch.cat( + [batch["attention_mask"], torch.zeros((1, remainder), dtype=batch["attention_mask"].dtype)], dim=1 + ) + + if "position_ids" in batch: + batch["position_ids"] = torch.cat( + [batch["position_ids"], torch.arange(remainder, dtype=batch["position_ids"].dtype).unsqueeze(0)], dim=1 + ) + + return batch + + +# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 +# we can replace this with the one in TransformerEngine. +def _split_batch_by_cp_rank( + cu_seqlens_padded: torch.Tensor | None, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup | None = None, + qvk_format: str = "thd", + cp_rank: int | None = None, + cp_world_size: int | None = None, +): + """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format. + + This function is intended for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths for THD format, + and with padded sequences for BSHD format. + + Args: + cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format. + input_ids_padded: Input IDs. + labels_padded: Labels. + cp_group: Context parallel group. + qvk_format: Format of the input data ("thd" or "bshd"). + cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank. + cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it + were executing on that rank without querying `torch.distributed.get_rank`. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + if cp_world_size is None or cp_world_size <= 1: + # No splitting needed + return input_ids_padded, labels_padded + + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_world_size): + raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.") + + if qvk_format == "thd": + if cu_seqlens_padded is None: + raise ValueError("cu_seqlens_padded is required for THD format") + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_world_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError("Make sure the inputs are in THD format and padded correctly.") + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + # Process each tensor directly + input_ids_padded = process_tensor(input_ids_padded) + labels_padded = process_tensor(labels_padded) + + elif qvk_format == "bshd": + # BSHD format: [batch, seq_len, ...] + # Split along sequence dimension (dim=1) + # Each sequence is split into 2*cp_world_size chunks + # Each rank gets chunks at positions: [cp_rank, 2*cp_world_size - cp_rank - 1] + + def process_tensor_bshd(val): + if val is None: + return val + + if val.ndim < 2: + raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D") + + seq_len = val.shape[1] + + # Calculate chunk size + total_chunks = 2 * cp_world_size + chunk_size = seq_len // total_chunks + + if chunk_size == 0: + raise ValueError( + f"Sequence length {seq_len} must be divisible by {total_chunks} " + f"(2 * cp_world_size) for BSHD context parallelism" + ) + + # Determine which chunks this rank should get + # Rank 0 gets chunks [0, total_chunks-1] + # Rank 1 gets chunks [1, total_chunks-2] + # Rank k gets chunks [k, total_chunks-k-1] + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + + # Collect slices for this rank + rank_slices = [] + for chunk_idx in chunk_indices: + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + rank_slices.append(torch.arange(start_idx, end_idx, device=val.device)) + + # Concatenate indices for all chunks this rank should get + indices = torch.cat(rank_slices) + + # Select along sequence dimension (dim=1) + return val.index_select(1, indices) + + input_ids_padded = process_tensor_bshd(input_ids_padded) + labels_padded = process_tensor_bshd(labels_padded) + + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded + + +class BatchType(TypedDict): + """The fields in the batch dictionary fo THD context parallel.""" + + input_ids: torch.Tensor + labels: torch.Tensor | None + shift_labels: torch.Tensor | None + cu_seq_lens_q: torch.Tensor + cu_seq_lens_k: torch.Tensor + cu_seq_lens_q_padded: torch.Tensor + cu_seq_lens_k_padded: torch.Tensor + max_length_q: int + max_length_k: int + pad_between_seqs: bool + + +def _scatter_batch_to_cp_tp_ranks( + all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None +) -> BatchType | StopIteration: + """Scatter a batch to all the CP ranks. + + Args: + all_batches (list[BatchType] | list[StopIteration]): A list of already-sharded batches to scatter to the CP/TP + ranks. + cp_tp_group (torch.distributed.ProcessGroup | None): The process group to scatter the batches to. + + Returns: + BatchType | StopIteration: The batch on this rank. + """ + scatter_object_output_list = [None] + # Note: This does not provide an async_op handle. Thus its blocking. + torch.distributed.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=all_batches, + group=cp_tp_group, + group_src=0, + ) + return scatter_object_output_list[0] diff --git a/bionemo-recipes/recipes/esm2_peft_te/data/input_infer.fasta b/bionemo-recipes/recipes/esm2_peft_te/data/input_infer.fasta new file mode 100644 index 0000000000..c1244acb1e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/data/input_infer.fasta @@ -0,0 +1,4 @@ +>2KL1A +MNEAKGVYVMSVLPNMPAAGRLEAGDRIAAIDGQPINTSEQIVSYVREKQAGDRVRVTFIRDRKQHEAELVLKPFPHHPNQIGLGVT +>2QYUA +ATSSPSSPADWAKKLTDAVLRQKAGETLTAADRDFSNADFRNITFSKILPPSFMERDGDIIKGFNFSNSKFTYSDISHLHFDECRFTYSTLSDVVCSNTKFSNSDMNEVFLQYSITTQQQPSFIDTTLKNTLIRHKANLSGVILNEPDNSSPPSVSGGGNFIRLGDIWLQMPLLWTENAVDGFLNHEHNNGKSILMTIDSLPDKYSQEKVQAMEDLVKSLRGGRLTEACIRPVESSLVSVLAHPPYTQSALISEWLGPVQERFFAHQCQTYNDVPLPAPDTYYQQRILPVLLDSFDRNSAAMTTHSGLFNQVILHCMTGVDCTDGTRQKAAALYEQYLAHPAVSPHIHNGLFGNYDGSPDWTTRAADNFLLLSSQDSDTAMMLSTDTLLTMLNPTPDTAWDNFYLLRAGENVSTAQISPVELFRHDFPVFLAAFNQQATQRRFGELIDIILSTEEHGELNQQFLAATNQKHSTVKLIDDASVSRLATIFDPLLPEGKLSPAHYQHILSAYHLTDATPQKQAETLFCLSTAFARYSSSAIFGTEHDSPPALRGYAEALMQKAWELSPAIFPSSEQFTEWSDRFHGLHGAFTCTSVVADSMQRHARKYFPSVLSSILPLAWA diff --git a/bionemo-recipes/recipes/esm2_peft_te/peft_sanity_dataset.parquet b/bionemo-recipes/recipes/esm2_peft_te/data/peft_sanity_dataset.parquet similarity index 100% rename from bionemo-recipes/recipes/esm2_peft_te/peft_sanity_dataset.parquet rename to bionemo-recipes/recipes/esm2_peft_te/data/peft_sanity_dataset.parquet diff --git a/bionemo-recipes/recipes/esm2_peft_te/data/prepare_porter6_dataset.py b/bionemo-recipes/recipes/esm2_peft_te/data/prepare_porter6_dataset.py new file mode 100644 index 0000000000..9e1a6e4fbb --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/data/prepare_porter6_dataset.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import shutil +import tempfile +import zipfile +from pathlib import Path +from urllib.request import urlretrieve + +import pandas as pd + + +PORTER6_ZIP_URL = "https://github.com/WafaAlanazi/Porter6/raw/main/SS%20datasets.zip" +SCRIPT_DATA_DIR = Path(__file__).resolve().parent + +DATASET_FILES = { + "dataset_train55k_80%.txt": { + "output": "porter6_train_dataset_55k.parquet", + "sha256": "4b1c011d8cea0b892743053eb4234db80344b8d9c90243f19b4637781ce8922b", # pragma: allowlist secret + }, + "2024Testset_692.adataset": { + "output": "porter6_val_dataset_2024_692.parquet", + "sha256": "b4a1b69f2003a66a62eb106aded784f9938fc734e876458223459fd9a10f1ca2", # pragma: allowlist secret + }, +} + + +def parse_input_file(path): + """Parse a Porter6-formatted secondary-structure file into row dictionaries.""" + records = [] + + with open(path, "r") as f: + lines = [line.strip() for line in f if line.strip()] + + i = 0 + while i < len(lines): + pdb_id = lines[i] + _ = lines[i + 1] # length line, not strictly needed + seq_line = lines[i + 2] + ss_line = lines[i + 3] + + sequence = seq_line.replace(" ", "") + secondary_structure = ss_line.replace(" ", "") + secondary_structure = secondary_structure.replace(".", "~") + + records.append( + { + "PDB_ID": pdb_id, + "Sequence": sequence, + "Secondary_structure": secondary_structure, + } + ) + + i += 4 + + return records + + +def compute_sha256(file_path): + """Compute SHA256 checksum for a file.""" + digest = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + digest.update(chunk) + return digest.hexdigest() + + +def main(): + """Download Porter6 datasets and write train/validation parquet files.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + zip_path = tmp_path / "SS_datasets.zip" + + print(f"Downloading Porter6 datasets from: {PORTER6_ZIP_URL}") + urlretrieve(PORTER6_ZIP_URL, zip_path) + + extract_dir = tmp_path / "extracted" + extract_dir.mkdir(parents=True, exist_ok=True) + + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(extract_dir) + + # Flatten extracted files so we can match expected names regardless of zip structure. + extracted_by_name = {} + for extracted_file in extract_dir.rglob("*"): + if extracted_file.is_file(): + extracted_by_name[extracted_file.name] = extracted_file + + for input_name, file_config in DATASET_FILES.items(): + if input_name not in extracted_by_name: + available = ", ".join(sorted(extracted_by_name)) + raise FileNotFoundError( + f"Expected file '{input_name}' was not found in downloaded zip. Available files: {available}" + ) + + source_file = extracted_by_name[input_name] + working_input_path = tmp_path / input_name + shutil.copy2(source_file, working_input_path) + + actual_sha256 = compute_sha256(working_input_path) + expected_sha256 = file_config["sha256"] + if actual_sha256 != expected_sha256: + raise ValueError( + f"SHA256 mismatch for '{input_name}': expected {expected_sha256}, got {actual_sha256}" + ) + + records = parse_input_file(working_input_path) + df = pd.DataFrame(records) + + output_path = SCRIPT_DATA_DIR / file_config["output"] + df.to_parquet(output_path, index=False) + + print(f"Converted {input_name}: {len(df)} records") + print(f"Wrote: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_peft_te/dataset.py b/bionemo-recipes/recipes/esm2_peft_te/dataset.py new file mode 100644 index 0000000000..0e3b777eda --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/dataset.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +from collections import defaultdict +from pathlib import Path + +import datasets +import datasets.distributed +import torch +from datasets import IterableDataset, load_dataset +from torch.utils.data import DataLoader, DistributedSampler +from transformers import ( + AutoTokenizer, + DataCollatorForTokenClassification, + DataCollatorWithFlattening, +) +from transformers.trainer_pt_utils import get_parameter_names + +from collator import TokenPackingDataset +from distributed_config import DistributedConfig + + +SS3_ID2LABEL = {0: "H", 1: "E", 2: "C"} + +SS3_LABEL2ID = { + "H": 0, + "I": 0, + "G": 0, + "E": 1, + "B": 1, + "S": 2, + "T": 2, + "~": 2, + "C": 2, + "L": 2, +} # '~' denotes coil / unstructured + +SS8_ID2LABEL = {0: "H", 1: "I", 2: "G", 3: "E", 4: "B", 5: "S", 6: "T", 7: "C"} + +SS8_LABEL2ID = { + "H": 0, + "I": 1, + "G": 2, + "E": 3, + "B": 4, + "S": 5, + "T": 6, + "~": 7, + "C": 7, + "L": 7, +} # '~' denotes coil / unstructured + + +def create_dataloader( + distributed_config: DistributedConfig, + use_sequence_packing: bool, + tokenizer_name: str, + micro_batch_size: int, + val_micro_batch_size: int, + num_workers: int, + max_seq_length: int, + stride: int, + seed: int, + ss3_classification: bool, + load_dataset_kwargs: dict, +) -> tuple[DataLoader, DataLoader | None, IterableDataset | DistributedSampler]: + """Create a dataloader for the secondary structure dataset.""" + dataset_or_dataset_dict = load_dataset(**load_dataset_kwargs) + + if isinstance(dataset_or_dataset_dict, dict): + train_dataset = dataset_or_dataset_dict.get("train") + assert train_dataset, "'train' split must be specified." + val_dataset = dataset_or_dataset_dict.get("validation") + else: + train_dataset = dataset_or_dataset_dict + val_dataset = None + + print( + f"Loading dataset: path: '{load_dataset_kwargs['path']}' | data_files: '{load_dataset_kwargs['data_files']}'." + ) + + perform_validation = val_dataset is not None + + if isinstance(train_dataset, IterableDataset): + train_dataset = datasets.distributed.split_dataset_by_node( + train_dataset, + rank=distributed_config.rank, + world_size=distributed_config.world_size, + ) + train_dataset = train_dataset.shuffle(seed=seed, buffer_size=10_000) + + if perform_validation: + val_dataset = datasets.distributed.split_dataset_by_node( + val_dataset, + rank=distributed_config.rank, + world_size=distributed_config.world_size, + ) + + if ss3_classification: + ss_token_map = SS3_LABEL2ID + else: + ss_token_map = SS8_LABEL2ID + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenize_args = { + "max_length": max_seq_length, + "truncation": True, + "stride": stride, + "return_overflowing_tokens": True, + "return_offsets_mapping": True, + } + + def tokenize(example): + """Tokenize both the input protein sequence and the secondary structure labels.""" + result = tokenizer(example["Sequence"], **tokenize_args) + + # While we can use the rust-based tokenizer for the protein sequence, we manually encode the secondary structure + # labels. Our goal is to return a list of integer labels with the same shape as the input_ids. + labels = [] + for batch_idx in range(len(result["input_ids"])): + sequence_labels = [] + + # This array maps the possibly-chunked result["input_ids"] to the original sequence. Because of + # `return_overflowing_tokens`, each input sequence may be split into multiple input rows. + offsets = result["offset_mapping"][batch_idx] + + # This gets the original secondary structure sequence for the current chunk. + ss_sequence = example["Secondary_structure"][result["overflow_to_sample_mapping"][batch_idx]] + + for offset_start, offset_end in offsets: + if offset_start == offset_end: + sequence_labels.append(-100) # Start and end of the sequence tokens can be ignored. + elif offset_end == offset_start + 1: # All tokens are single-character. + ss_char = ss_sequence[offset_start] + ss_label_value = ss_token_map[ss_char] # Encode the secondary structure character + sequence_labels.append(ss_label_value) + else: + raise ValueError(f"Invalid offset: {offset_start} {offset_end}") + + labels.append(sequence_labels) + + return {"input_ids": result["input_ids"], "labels": labels} + + train_tokenized_dataset = train_dataset.map( + tokenize, + batched=True, + remove_columns=[col for col in train_dataset.features if col not in ["input_ids", "labels"]], + ) + + if isinstance(train_tokenized_dataset, IterableDataset): + train_sampler = None + else: + train_sampler = DistributedSampler( + train_tokenized_dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=seed, + ) + + if use_sequence_packing: + assert isinstance(train_tokenized_dataset, datasets.IterableDataset), ( + "THD token packing requires a streaming dataset." + ) + collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) + train_tokenized_dataset = TokenPackingDataset( + train_tokenized_dataset, max_tokens_per_batch=micro_batch_size * max_seq_length + ) + batch_size = None # The TokenPackingDataset will handle the batching. + else: + collator = DataCollatorForTokenClassification( + tokenizer=tokenizer, padding="max_length", max_length=max_seq_length + ) + batch_size = micro_batch_size + + train_dataloader = DataLoader( + train_tokenized_dataset, + sampler=train_sampler, + batch_size=batch_size, + collate_fn=collator, + num_workers=num_workers, + pin_memory=True, + ) + + if perform_validation: + val_tokenized_dataset = val_dataset.map( + tokenize, + batched=True, + remove_columns=[col for col in val_dataset.features if col not in ["input_ids", "labels"]], + ) + + if isinstance(val_tokenized_dataset, IterableDataset): + val_sampler = None + else: + val_sampler = DistributedSampler( + val_tokenized_dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=seed, + ) + + if use_sequence_packing: + assert isinstance(val_tokenized_dataset, datasets.IterableDataset), ( + "THD token packing requires a streaming dataset." + ) + collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) + val_tokenized_dataset = TokenPackingDataset( + val_tokenized_dataset, max_tokens_per_batch=micro_batch_size * max_seq_length + ) + val_batch_size = None # The TokenPackingDataset will handle the batching. + else: + collator = DataCollatorForTokenClassification( + tokenizer=tokenizer, padding="max_length", max_length=max_seq_length + ) + val_batch_size = val_micro_batch_size + + val_dataloader = DataLoader( + val_tokenized_dataset, + sampler=val_sampler, + batch_size=val_batch_size, + collate_fn=collator, + num_workers=num_workers, + pin_memory=True, + ) + else: + val_dataloader = None + + return train_dataloader, val_dataloader, train_tokenized_dataset if train_sampler is None else train_sampler + + +def compute_accuracy(preds, labels, ignore_index=-100) -> tuple[int, int]: + """Calculate the accuracy.""" + preds_labels = torch.argmax(preds, dim=-1) + mask = labels != ignore_index + correct = (preds_labels == labels) & mask + + return correct.sum().item(), mask.sum().item() + + +def get_parameter_names_with_lora(model): + """Get layers with non-zero weight decay. + + This function reuses the Transformers' library function + to list all the layers that should have weight decay. + """ + forbidden_name_patterns = [ + r"bias", + r"layernorm", + r"rmsnorm", + r"(?:^|\.)norm(?:$|\.)", + r"_norm(?:$|\.)", + r"\.lora_[AB]\.", + ] + + decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], forbidden_name_patterns) + + return decay_parameters + + +def load_fasta(path: Path) -> list[dict]: + """Read FASTA file and return input sequences.""" + records = [] + seq, pdb_id = [], None + + with open(path) as f: + for raw_line in f: + line = raw_line.strip() + if line.startswith(">"): + if seq: + records.append({"pdb_id": pdb_id, "sequence": "".join(seq)}) + pdb_id = line[1:] or None + seq = [] + else: + seq.append(line) + + if seq: + records.append({"pdb_id": pdb_id, "sequence": "".join(seq)}) + + return records + + +def load_csv(path: Path) -> list[dict]: + """Read input CSV file for inference. + + It is assumed that the input CSV file contains: + - Optional column named 'pdb_id' of the sequence. + - Aminoacid sequence. + """ + with open(path) as f: + reader = csv.DictReader(f) + has_pdb_id = "pdb_id" in reader.fieldnames + + return [ + { + "pdb_id": row["pdb_id"] if has_pdb_id else None, + "sequence": row["sequence"], + } + for row in reader + ] + + +def load_input(path: Path) -> list[dict]: + """Read the input sequences from FASTA or CSV file.""" + suffix = path.suffix.lower() + + if suffix == ".csv": + return load_csv(path) + elif suffix in {".fa", ".fasta", ".faa"}: + return load_fasta(path) + else: + raise ValueError(f"Unsupported input format: {suffix}") + + +def format_output_rows(records, predictions, sequences_to_sample_mapping): + """Format the output into CSV-type lines. + + Returns: + header: list[str] + rows: list[tuple[str, str]] + """ + has_pdb_id = any(r.get("pdb_id") for r in records) + header = ["pdb_id", "prediction"] if has_pdb_id else ["id", "prediction"] + + counts = defaultdict(int) + rows = [] + + for pred, orig_idx in zip(predictions, sequences_to_sample_mapping): + counts[orig_idx] += 1 + suffix = counts[orig_idx] + + base = records[orig_idx]["pdb_id"] if has_pdb_id else str(orig_idx) + + out_id = base if suffix == 1 else f"{base}_{suffix}" + rows.append((out_id, pred)) + + return header, rows + + +def write_output(records, predictions, sequences_to_sample_mapping: list[int], output_path: Path): + """Write the predictions to an output file.""" + header, rows = format_output_rows(records, predictions, sequences_to_sample_mapping) + + with open(output_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(header) + writer.writerows(rows) diff --git a/bionemo-recipes/recipes/esm2_peft_te/distributed_config.py b/bionemo-recipes/recipes/esm2_peft_te/distributed_config.py new file mode 100644 index 0000000000..271a5ffcfc --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/distributed_config.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from dataclasses import dataclass, field + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class DistributedConfig: + """Class to track distributed ranks and handle basic distributed training setup. + + If torch distributed environment variables are not set, we set them to default values for single-process training. + + Attributes: + rank: The rank of the process. + local_rank: The local rank of the process. + world_size: The total number of processes. + """ + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + _master_addr: str = field(default_factory=lambda: os.environ.setdefault("MASTER_ADDR", "localhost")) + _master_port: str = field(default_factory=lambda: os.environ.setdefault("MASTER_PORT", "12355")) + + def is_main_process(self) -> bool: + """This is the global rank 0 process, to be used for wandb logging, etc.""" + return self.rank == 0 diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity.yaml new file mode 100644 index 0000000000..4a92da0d9e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity.yaml @@ -0,0 +1,55 @@ +defaults: + - defaults + - _self_ + +# Training config +model_tag: ./example_8m_checkpoint # E.g., nvidia/esm2_t6_8M_UR50D or facebook/esm2_t6_8M_UR50D +use_pretrained: false +num_train_steps: 250 + +validation_interval: 50 + +# We want this on in CI/CD to validate that the script runs successfully with torch.compile. +use_torch_compile: true + +use_sequence_packing: false + +dataset: + tokenizer_name: ${model_tag} + micro_batch_size: 8 + val_micro_batch_size: 128 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + ss3_classification: true + load_dataset_kwargs: + path: "parquet" + split: "train" + data_files: "data/peft_sanity_dataset.parquet" + +# WandB config +wandb_init_args: + name: "esm2_lora_example_8M_sanity" + project: "esm2_lora_sanity" + mode: "offline" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 100 + +checkpoint: + ckpt_dir: null + save_final_model: false + +logger: + frequency: 1 + +lora: + r: 8 + alpha: 16 + target_modules: + - "layernorm_qkv" + # For 'facebook/esm2*' use: + # - "query" + # - "key" + # - "value" diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity_infer.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity_infer.yaml new file mode 100644 index 0000000000..4c69ee449e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity_infer.yaml @@ -0,0 +1,14 @@ +defaults: +- defaults_infer +- _self_ + +model_tag: "nvidia/esm2_t6_8M_UR50D" +base_model_config_dir: "/train_ddp/final_model" # pragma: allowlist secret + +output_file: preds.csv + +inference: + batch_size: 4 # tune based on GPU memory + max_seq_length: 1024 + stride: 16 + infer_overflowing_aas: true diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_fb_15B.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_fb_15B.yaml new file mode 100644 index 0000000000..1c0bbfbda2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_fb_15B.yaml @@ -0,0 +1,57 @@ +defaults: + - defaults + - _self_ + +# Training config +model_tag: facebook/esm2_t48_15B_UR50D # E.g., nvidia/esm2_t6_8M_UR50D or facebook/esm2_t6_8M_UR50D +use_pretrained: true +num_train_steps: 1000 + +validation_interval: 20 + +use_torch_compile: true + +use_sequence_packing: true + +dataset: + # The NVIDIA tokenizer is identical to the facebook/esm* and supports generating multiple samples from sequences + # longer than 1024. + tokenizer_name: nvidia/esm2_t48_15B_UR50D + micro_batch_size: 8 + val_micro_batch_size: 128 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + ss3_classification: true + load_dataset_kwargs: + path: "parquet" + split: null + data_files: + train: "data/porter6_train_dataset_55k.parquet" + validation: "data/porter6_val_dataset_2024_692.parquet" + +# WandB config +wandb_init_args: + name: "esm2_t48_15B_UR50D_lora" + project: "esm2_lora" + mode: "online" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 50 + num_training_steps: 1_000 + +checkpoint: + ckpt_dir: "checkpoints/facebook_esm2_t48_15B_UR50D" # pragma: allowlist secret + save_final_model: false + +logger: + frequency: 1 + +lora: + r: 8 + alpha: 16 + target_modules: + - "query" + - "key" + - "value" diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_nv_15B.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_nv_15B.yaml new file mode 100644 index 0000000000..e0a413b3d2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_nv_15B.yaml @@ -0,0 +1,57 @@ +defaults: + - defaults + - _self_ + +# Training config +model_tag: nvidia/esm2_t48_15B_UR50D # E.g., nvidia/esm2_t6_8M_UR50D or facebook/esm2_t6_8M_UR50D +use_pretrained: true +num_train_steps: 1000 + +validation_interval: 20 + +use_torch_compile: true + +use_sequence_packing: true + +dataset: + tokenizer_name: ${model_tag} + micro_batch_size: 8 + val_micro_batch_size: 128 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + ss3_classification: true + load_dataset_kwargs: + path: "parquet" + split: null + data_files: + train: "data/porter6_train_dataset_55k.parquet" + validation: "data/porter6_val_dataset_2024_692.parquet" + +# WandB config +wandb_init_args: + name: "esm2_t48_15B_UR50D_lora" + project: "esm2_lora" + mode: "online" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 50 + num_training_steps: 1_000 + +checkpoint: + ckpt_dir: "checkpoints/nvidia_esm2_t48_15B_UR50D" # pragma: allowlist secret + save_final_model: true + +logger: + frequency: 1 + +lora: + r: 8 + alpha: 16 + target_modules: + - "layernorm_qkv" + # For 'facebook/esm2*' use: + # - "query" + # - "key" + # - "value" diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults.yaml new file mode 100644 index 0000000000..4dc314e59c --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults.yaml @@ -0,0 +1,55 @@ +# Training config +model_tag: ??? # E.g., nvidia/esm2_t6_8M_UR50D, facebook/esm2_t6_8M_UR50D, or a local path (e.g ./example_8m_checkpoint) +num_train_steps: ??? +validation_interval: 50 + +# Whether to wrap the model in torch.compile. Note, this is currently not supported with mfsdp (BIONEMO-2977). +# We leave this off by default since we don't see much of a performance improvement with TE layers. +use_torch_compile: false + +use_sequence_packing: false + +dataset: + tokenizer_name: ${model_tag} + micro_batch_size: ??? + val_micro_batch_size: 64 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + seed: 42 + ss3_classification: true + load_dataset_kwargs: + path: "nvidia/esm2_uniref_pretraining_data" + split: "train" + streaming: True + +# WandB config +wandb_init_args: + name: ??? + project: null + +# Optimizer config +adamw_kwargs: + lr: 4e-4 + fused: true + betas: [0.9, 0.98] + eps: 1e-8 + weight_decay: 0.01 + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 2_000 + num_training_steps: 500_000 + +# Checkpoint config +checkpoint: + ckpt_dir: ??? + save_final_model: true + +logger: + frequency: 100 + +lora: + r: 8 + alpha: 16 + target_modules: ??? diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults_infer.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults_infer.yaml new file mode 100644 index 0000000000..fb036f36e7 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults_infer.yaml @@ -0,0 +1,12 @@ +model_tag: ??? +base_model_config_dir: ??? +peft_model_config_dir: ${base_model_config_dir} + +input_file: data/input_infer.fasta +output_file: null + +inference: + batch_size: 4 # tune based on GPU memory + max_seq_length: 1024 + stride: 16 + infer_overflowing_aas: true diff --git a/bionemo-recipes/recipes/esm2_peft_te/infer.py b/bionemo-recipes/recipes/esm2_peft_te/infer.py new file mode 100644 index 0000000000..d6b970fe9c --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/infer.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig +from peft import PeftModel +from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer + +from dataset import format_output_rows, load_input, write_output + + +def _batched_inference( + model, + tokenizer, + records, + batch_size: int, + max_seq_length: int, + stride: int, + infer_overflowing_aas: bool, + device: str = "cuda", +) -> tuple[list[str], list[int]]: + id2label = model.config.id2label + + predictions = [] + sequences_to_sample_mapping = [] + + for i in range(0, len(records), batch_size): + batch = records[i : i + batch_size] + sequences = [r["sequence"] for r in batch] + + inputs = tokenizer( + sequences, + max_length=max_seq_length, + truncation=True, + stride=stride, + return_overflowing_tokens=infer_overflowing_aas, + return_tensors="pt", + padding=True, + ) + + num_samples = len(inputs["input_ids"]) + overflow_map = inputs.pop("overflow_to_sample_mapping", torch.arange(num_samples)) + + # inner batching over tokenizer outputs + for j in range(0, num_samples, batch_size): + sub_inputs = {k: v[j : j + batch_size].to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**sub_inputs) + + preds = outputs.logits.argmax(dim=-1) + + for k, (pred, input_ids) in enumerate(zip(preds, sub_inputs["input_ids"])): + length = (input_ids != tokenizer.pad_token_id).sum().item() + labels = "".join(id2label[i.item()] for i in pred[:length]) + + predictions.append(labels) + + # map back to original record index + original_idx = i + overflow_map[j + k].item() + sequences_to_sample_mapping.append(original_idx) + + return predictions, sequences_to_sample_mapping + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity_infer", version_base="1.2") +def main(args: DictConfig): + """Infer using a PEFT ESM-2 model. + + This script can be run once ESM2 has been PEFT fine-tuned and adapters have + been checkpointed. For reference, an example has been provided in the './checkpoints' directory. + """ + # Ideally we would like to load the PEFT model directly by doing: + # >>> model = AutoPeftModelForTokenClassification.from_pretrained("", trust_remote_code=True) + # + # However, the from_pretrained() function has a positional argument named 'config' which prevent us from passing a + # a different model config to the base_model. Thus, we first build the base model and then we load the PEFT adapters. + + # Load the custom config + config = AutoConfig.from_pretrained(args.base_model_config_dir, trust_remote_code=True) + + # For recipe simplicity, we only support the attention input format to BSHD. + config.attn_input_format = "bshd" + + # Load base model with the custom config + base_model = AutoModelForTokenClassification.from_pretrained( + args.model_tag, # original model tag + config=config, + trust_remote_code=True, + ) + + # Load PEFT adapters on top + peft_model = PeftModel.from_pretrained(base_model, args.peft_model_config_dir) + peft_model = peft_model.to("cuda").eval() + + tokenizer = AutoTokenizer.from_pretrained("nvidia/esm2_t48_15B_UR50D") + + records = load_input(Path(args.input_file)) + + predictions, sequences_to_sample_mapping = _batched_inference( + peft_model, + tokenizer, + records, + **args.inference, + ) + + if args.output_file: + write_output(records, predictions, sequences_to_sample_mapping, Path(args.output_file)) + + header, rows = format_output_rows(records, predictions, sequences_to_sample_mapping) + + print("---------------") + print("\t".join(header)) + for row in rows: + print("\t".join(row)) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_peft_te/perf_logger.py b/bionemo-recipes/recipes/esm2_peft_te/perf_logger.py new file mode 100644 index 0000000000..7828a4938a --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/perf_logger.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import time + +import torch +import torchmetrics +import torchmetrics.text +import wandb +from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm +from transformers.modeling_outputs import MaskedLMOutput + +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) + + +class PerfLogger: + """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. + + Args: + dist_config: The distributed configuration. + args: The arguments. + + Attributes: + min_loss: The minimum loss seen so far. + """ + + def __init__(self, dist_config: DistributedConfig, args: DictConfig): + """Initialize the logger.""" + self._dist_config = dist_config + self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) + + self.min_loss = float("inf") + + self.logging_frequency = args.logger.frequency + # Track whether to collect memory stats (disabled by default for max performance) + + metrics_dict = { + "train/loss": torchmetrics.MeanMetric(), + "train/grad_norm": torchmetrics.MeanMetric(), + "train/learning_rate": torchmetrics.MeanMetric(), + "train/num_tokens_per_gpu": torchmetrics.MeanMetric(), + "train/total_num_tokens": torchmetrics.SumMetric(), + "train/num_unpadded_tokens": torchmetrics.MeanMetric(), + "train/step_time": torchmetrics.MeanMetric(), + "train/tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/unpadded_tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/total_unpadded_tokens_per_batch": torchmetrics.SumMetric(), + "train/perplexity": torchmetrics.text.Perplexity(ignore_index=-100), + "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), + "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), + "val/loss": torchmetrics.MeanMetric(), + "val/accuracy": torchmetrics.MeanMetric(), + } + + self.metrics = torchmetrics.MetricCollection(metrics_dict) + # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. + self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.previous_step_time = time.perf_counter() + self.train_start_time = None + self.train_end_time = None + + if self._dist_config.is_main_process(): + # Log the entire args object to wandb for experiment tracking and reproducibility. + wandb.init(**args.wandb_init_args, config=self._run_config) + self._progress_bar = tqdm(total=args.num_train_steps, file=sys.stderr, desc="Training") + + def log_train_end_time(self): + """Log when the train loop for a batch ends.""" + self.train_end_time = time.perf_counter() + return + + def log_train_start_time(self): + """Log when the train loop for a batch starts.""" + self.train_start_time = time.perf_counter() + return + + def log_step( + self, + step: int, + batch: dict[str, torch.Tensor], + outputs: MaskedLMOutput, + grad_norm: float, + lr: float, + val_loss: float | None = None, + val_acc: float | None = None, + ): + """Log a step to the logger and wandb. + + Args: + step: The step number. + batch: The batch of data for the step. + outputs: The outputs of the step. + grad_norm: The gradient norm of the step. + lr: The learning rate of the step. + val_loss: The validation loss of the step (if calculated) + val_acc: The validation accuracy of the step (if calculated) + """ + num_tokens = batch["input_ids"].numel() + # 1 is the padding token for ESM-2. + num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() + + self.min_loss = min(self.min_loss, outputs.loss.item()) + step_time = self.train_end_time - self.train_start_time + + self.metrics["train/loss"].update(outputs.loss) + self.metrics["train/num_tokens_per_gpu"].update(num_tokens) + self.metrics["train/total_num_tokens"].update(num_tokens) + self.metrics["train/num_unpadded_tokens"].update(num_unpadded_tokens) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens / self.logging_frequency) + + if val_loss is not None: + self.metrics["val/loss"].update(val_loss) + self.metrics["val/accuracy"].update(val_acc) + + # Handle sequence packing for torchmetrics calculation. + if outputs.logits.dim() < 3: + outputs.logits = outputs.logits.unsqueeze(0) + + self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) + + if step % self.logging_frequency == 0 and step > 0: + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + + metrics = self.metrics.compute() + self.metrics.reset() + metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) + + if self._dist_config.is_main_process(): + train_metrics = {k: v for k, v in metrics.items() if k.startswith("train/")} + val_metrics = {k: v for k, v in metrics.items() if k.startswith("val/")} + + wandb.log(train_metrics, step=step) + + if val_loss is not None and len(val_metrics) > 0: + wandb.log(val_metrics, step=step) + + self._progress_bar.update(self.logging_frequency) + self._progress_bar.set_postfix({"loss": outputs.loss.item()}) + + if self._dist_config.local_rank == 0: + logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + + def finish(self): + """Finish the logger and close the progress bar.""" + if not self._dist_config.is_main_process(): + return + + wandb.finish() + self._progress_bar.close() diff --git a/bionemo-recipes/recipes/esm2_peft_te/requirements.txt b/bionemo-recipes/recipes/esm2_peft_te/requirements.txt index 8810ca2003..10aa35c823 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_peft_te/requirements.txt @@ -1,7 +1,9 @@ datasets -peft +hydra-core +peft @ git+https://github.com/balvisio/peft.git@support-te-lora torch torchao!=0.14.0 +torchmetrics tqdm transformer_engine[pytorch] transformers diff --git a/bionemo-recipes/recipes/esm2_peft_te/scheduler.py b/bionemo-recipes/recipes/esm2_peft_te/scheduler.py new file mode 100644 index 0000000000..9f9da8da91 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/scheduler.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.optim.lr_scheduler import LambdaLR + + +def get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=2_000, + num_training_steps=500_000, + last_epoch=-1, +): + """Linear warmup and decay scheduler for ESM-2 pretraining. + + The description from Lin 2022 is: The learning rate is warmed up over the first 2,000 steps + to a peak value of 4e-4 (1.6e-4 for the 15B parameter model), and then linearly decayed to + one tenth of its peak value over the 90% of training duration. We've found internally that a + longer warmup helps convergence for larger models (3B+) with bf16 precision. + """ + decay_steps = int(num_training_steps * 0.9) + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + # Warmup phase: linearly increase learning rate + return float(current_step) / float(max(1, num_warmup_steps)) + # Decay phase: linearly decay to one tenth of peak over 90% of training + elif current_step > decay_steps: + return 0.1 # one tenth of peak learning rate after decay period + else: + # Linear decay from 1.0 to 0.1 over decay_steps-num_warmup_steps + return 1.0 - 0.9 * (current_step - num_warmup_steps) / float(max(1, decay_steps - num_warmup_steps)) + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py b/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py index e9d2af6ae7..c7270c5721 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py @@ -16,6 +16,14 @@ import sys from pathlib import Path +import pytest + sys.path.append(Path(__file__).parent.parent.as_posix()) sys.path.append(Path(__file__).parent.as_posix()) + + +@pytest.fixture +def recipe_path() -> Path: + """Return the root directory of the recipe.""" + return Path(__file__).parent.parent diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.csv b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.csv new file mode 100644 index 0000000000..016b6ebfea --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.csv @@ -0,0 +1,2 @@ +Sequence,Secondary_structure +LAGVSERTIDPKQNFYMHWC,HIGEBST~HIGEBST~HIGE diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.parquet b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.parquet new file mode 100644 index 0000000000..51661a8b18 Binary files /dev/null and b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.parquet differ diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/test_infer.py b/bionemo-recipes/recipes/esm2_peft_te/tests/test_infer.py new file mode 100644 index 0000000000..d413501d41 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/test_infer.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import peft +import pytest +import torch +from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer + +from dataset import SS3_ID2LABEL, SS3_LABEL2ID +from infer import _batched_inference + + +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") + + +@pytest.fixture() +def peft_model(recipe_path): + """Build a real 8M NV-ESM2 model with LoRA adapters (random weights, no checkpoint needed).""" + model_path = str(recipe_path / "example_8m_checkpoint") + + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + config.attn_input_format = "bshd" + config.id2label = SS3_ID2LABEL + config.label2id = SS3_LABEL2ID + + base_model = AutoModelForTokenClassification.from_config(config, trust_remote_code=True) + + lora_config = peft.LoraConfig( + task_type=peft.TaskType.TOKEN_CLS, + inference_mode=True, + r=8, + lora_alpha=16, + target_modules=["layernorm_qkv"], + bias="none", + ) + + model = peft.get_peft_model(base_model, lora_config) + model.to(device="cuda", dtype=torch.bfloat16) + model.eval() + return model + + +@pytest.fixture() +def tokenizer(recipe_path): + """Load the ESM-2 tokenizer from the local example checkpoint.""" + return AutoTokenizer.from_pretrained(str(recipe_path / "example_8m_checkpoint")) + + +@requires_cuda +def test_batched_inference_returns_predictions(peft_model, tokenizer): + """Test that _batched_inference produces one prediction per input record.""" + records = [ + {"sequence": "MNEAKGVY"}, + {"sequence": "ATSSPSSPADWAKKL"}, + ] + + predictions, mapping = _batched_inference( + model=peft_model, + tokenizer=tokenizer, + records=records, + batch_size=4, + max_seq_length=1024, + stride=16, + infer_overflowing_aas=False, + ) + + assert len(predictions) == len(records) + assert len(mapping) == len(records) + + # Each prediction string must only contain valid SS3 label characters + valid_labels = set(SS3_ID2LABEL.values()) + for pred in predictions: + assert len(pred) > 0 + assert all(c in valid_labels for c in pred), f"Unexpected character in prediction: {pred}" + + # Mapping indices should cover all input records + assert sorted(mapping) == list(range(len(records))) + + +@requires_cuda +def test_batched_inference_with_overflow(peft_model, tokenizer): + """Test that long sequences are split into overlapping chunks via overflow.""" + long_seq = "MNEAKGVY" * 20 # 160 amino acids + + records = [{"sequence": long_seq}] + + predictions, mapping = _batched_inference( + model=peft_model, + tokenizer=tokenizer, + records=records, + batch_size=2, + max_seq_length=32, # short window to force multiple chunks + stride=8, + infer_overflowing_aas=True, + ) + + # With overflow enabled and a short window, we expect multiple chunks + assert len(predictions) > 1, "Expected multiple chunks for a long sequence" + assert all(idx == 0 for idx in mapping), "All chunks should map back to the single input record" + + +@requires_cuda +def test_batched_inference_single_record(peft_model, tokenizer): + """Test _batched_inference with a single short sequence.""" + records = [{"sequence": "ACDE"}] + + predictions, mapping = _batched_inference( + model=peft_model, + tokenizer=tokenizer, + records=records, + batch_size=1, + max_seq_length=1024, + stride=16, + infer_overflowing_aas=False, + ) + + assert len(predictions) == 1 + assert mapping == [0] + + +@requires_cuda +def test_batched_inference_prediction_length(peft_model, tokenizer): + """Test that each prediction's length equals the number of non-pad tokens. + + The ESM-2 tokenizer prepends and appends , so the prediction + string length should be len(sequence) + 2 for sequences shorter than + max_seq_length. + """ + seq = "MNEAKGVY" + records = [{"sequence": seq}] + + predictions, _ = _batched_inference( + model=peft_model, + tokenizer=tokenizer, + records=records, + batch_size=1, + max_seq_length=1024, + stride=16, + infer_overflowing_aas=False, + ) + + # +2 for and special tokens + expected_length = len(seq) + 2 + assert len(predictions[0]) == expected_length, ( + f"Prediction length {len(predictions[0])} != expected {expected_length}" + ) diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py index eb782e2552..eb0d4960b7 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py @@ -14,21 +14,62 @@ # limitations under the License. import torch -from transformers import BatchEncoding +from hydra import compose, initialize_config_dir -from train_lora import create_dataloader, train_lora +from train_lora_ddp import main as main_ddp -def test_create_dataloader(): - dataloader = create_dataloader(use_sanity_dataset=True) - for batch in dataloader: - assert isinstance(batch, BatchEncoding) - assert isinstance(batch["input_ids"], torch.Tensor) - assert isinstance(batch["labels"], torch.Tensor) - break +def test_sanity_convergence_ddp(tmp_path, recipe_path): + """Test that the main function can be invoked wrapping the model in DDP.""" + # Run the training script with Hydra configuration overrides + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + ], + ) -def test_train_lora(): - dataloader = create_dataloader(use_sanity_dataset=True) - loss = train_lora(dataloader) - assert loss < 1.5 + final_loss = main_ddp(sanity_config) + assert final_loss < 3.0, f"Final loss {final_loss} is too high" + + +def test_sanity_convergence_ddp_non_streaming_dataset(tmp_path, recipe_path): + """Test that the training script works with a non-streaming dataset.""" + + # Run the training script with Hydra configuration overrides + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "dataset.load_dataset_kwargs.streaming=False", + ], + ) + + final_loss = main_ddp(sanity_config) + assert final_loss < 3.0, f"Final loss {final_loss} is too high" + + +def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path): + if torch.cuda.get_device_capability() == (12, 0): + # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default, + # but it's missing this THD implementation. + monkeypatch.setenv("NVTE_FUSED_ATTN", "0") + + # For DDP, we only check that the script can run successfully with THD, not convergence. + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "use_sequence_packing=true", + "num_train_steps=4", + ], + ) + + main_ddp(sanity_config) diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora_two_gpus.py b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora_two_gpus.py new file mode 100644 index 0000000000..c822eeefa5 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora_two_gpus.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These tests don't check convergence, they just check that the training script runs successfully on multiple GPUs. + +import subprocess + +import pytest +import torch + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def run_train_cmd(cmd, recipe_path): + """Run a training command and check for errors.""" + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + cwd=str(recipe_path), + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}") + + +@requires_multi_gpu +def test_multi_gpu_train_te_ddp(tmp_path, recipe_path): + # Run 'accelerate launch train.py' as a subprocess + run_train_cmd( + [ + "torchrun", + "--nproc_per_node", + "2", + "--standalone", + "train_lora_ddp.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + ], + recipe_path, + ) diff --git a/bionemo-recipes/recipes/esm2_peft_te/train_lora.py b/bionemo-recipes/recipes/esm2_peft_te/train_lora.py deleted file mode 100644 index 3e3e2b03db..0000000000 --- a/bionemo-recipes/recipes/esm2_peft_te/train_lora.py +++ /dev/null @@ -1,158 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Demonstration of LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT. - -Still needs: - - [ ] Hydra config management. - - [ ] THD / sequence packing. - - [ ] DDP / Multi-node training. - - [ ] FP8 tests. - - [ ] Perf / wandb logging. -""" - -import peft -import torch -from datasets import load_dataset -from tqdm import tqdm -from transformers import ( - AutoConfig, - AutoModelForTokenClassification, - AutoTokenizer, - DataCollatorForTokenClassification, -) - - -def create_dataloader(use_sanity_dataset: bool = False) -> torch.utils.data.DataLoader: - """Create a dataloader for the secondary structure dataset.""" - if use_sanity_dataset: - # 5000 row sanity dataset. - ss_dataset = load_dataset("parquet", data_files="peft_sanity_dataset.parquet", split="train", streaming=True) - else: - # Full-scale source dataset. - ss_dataset = load_dataset("lamm-mit/protein_secondary_structure_from_PDB", split="train", streaming=True) - - ss_token_map = {"H": 0, "E": 1, "I": 2, "S": 3, "T": 4, "C": 5, "B": 6, "G": 7, "~": -100} - - tokenizer = AutoTokenizer.from_pretrained("example_8m_checkpoint") - tokenize_args = { - "max_length": 128, - "truncation": True, - "stride": 16, # TODO: figure this out later - "return_overflowing_tokens": True, - "return_offsets_mapping": True, - } - - def tokenize(example): - """Tokenize both the input protein sequence and the secondary structure labels.""" - result = tokenizer(example["Sequence"], **tokenize_args) - - # While we can use the rust-based tokenizer for the protein sequence, we manually encode the secondary structure - # labels. Our goal is to return a list of integer labels with the same shape as the input_ids. - labels = [] - for batch_idx in range(len(result["input_ids"])): - sequence_labels = [] - - # This array maps the possibly-chunked result["input_ids"] to the original sequence. Because of - # `return_overflowing_tokens`, each input sequence may be split into multiple input rows. - offsets = result["offset_mapping"][batch_idx] - - # This gets the original secondary structure sequence for the current chunk. - ss_sequence = example["Secondary_structure"][result["overflow_to_sample_mapping"][batch_idx]] - - for offset_start, offset_end in offsets: - if offset_start == offset_end: - sequence_labels.append(-100) # Start and end of the sequence tokens can be ignored. - elif offset_end == offset_start + 1: # All tokens are single-character. - ss_char = ss_sequence[offset_start] - ss_label_value = ss_token_map[ss_char] # Encode the secondary structure character - sequence_labels.append(ss_label_value) - else: - raise ValueError(f"Invalid offset: {offset_start} {offset_end}") - - labels.append(sequence_labels) - - return {"input_ids": result["input_ids"], "labels": labels} - - tokenized_dataset = ss_dataset.map( - tokenize, - batched=True, - remove_columns=[col for col in ss_dataset.features if col not in ["input_ids", "labels"]], - ) - - collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding="max_length", max_length=1024) - dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=16, collate_fn=collator) - - return dataloader - - -def train_lora(dataloader: torch.utils.data.DataLoader) -> float: - """Training loop for LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT. - - Args: - dataloader: DataLoader for the secondary structure dataset. - - Returns: - Final loss value. - """ - # For testing, we don't want to depend on loading pre-trained weights. - config = AutoConfig.from_pretrained("example_8m_checkpoint", trust_remote_code=True) - config.num_labels = 8 - model = AutoModelForTokenClassification.from_config(config, trust_remote_code=True) - - # Alternatively, we'd want to load an actual pre-trained checkpoint. - # model = AutoModelForTokenClassification.from_pretrained( - # "example_8m_checkpoint", num_labels=8, trust_remote_code=True, dtype="bfloat16" - # ) - - peft_config = peft.LoraConfig( - task_type=peft.TaskType.TOKEN_CLS, - inference_mode=False, - r=16, - lora_alpha=16, - # target_modules=["layernorm_qkv"], # TODO: figure out if this could work? - target_parameters=["layernorm_qkv.weight"], - bias="none", - ) - - peft_model = peft.get_peft_model(model, peft_config) - peft_model.to("cuda", dtype=torch.bfloat16) - - # Create optimizer. - optimizer = torch.optim.AdamW(peft_model.parameters(), lr=1e-3, weight_decay=0.01) - - # Training loop. - step = 0 - with tqdm(dataloader, desc="Training") as progress_bar: - for batch in progress_bar: - batch = {k: v.to("cuda") for k, v in batch.items()} # noqa PLW2901 - output = peft_model(**batch) - loss = output.loss - loss.backward() - progress_bar.set_postfix({"loss": loss.item()}) - - # Step optimizer. - optimizer.step() - optimizer.zero_grad() - - step += 1 - if step >= 100: - return loss.item() - - -if __name__ == "__main__": - dataloader = create_dataloader(use_sanity_dataset=True) - train_lora(dataloader) diff --git a/bionemo-recipes/recipes/esm2_peft_te/train_lora_ddp.py b/bionemo-recipes/recipes/esm2_peft_te/train_lora_ddp.py new file mode 100644 index 0000000000..bf689a5440 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/train_lora_ddp.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Demonstration of LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT using DDP.""" + +import logging +from pathlib import Path + +import hydra +import peft +import torch +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, +) + +from checkpoint import save_final_model_ddp +from dataset import ( + SS3_ID2LABEL, + SS3_LABEL2ID, + SS8_ID2LABEL, + SS8_LABEL2ID, + compute_accuracy, + create_dataloader, + get_parameter_names_with_lora, +) +from distributed_config import DistributedConfig +from perf_logger import PerfLogger +from scheduler import get_linear_schedule_with_warmup + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float: + """Training loop for LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT. + + Args: + args: Configuration arguments from hydra. + + Returns: + Final loss value. + """ + # Initialize the distributed configuration, including creating the distributed process group. + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + train_dataloader, val_dataloader, train_dataset_or_sampler = create_dataloader( + distributed_config=dist_config, + use_sequence_packing=args.use_sequence_packing, + **OmegaConf.to_container(args.dataset, resolve=True), + ) + + perform_validation = val_dataloader is not None + + # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2 + # and MFSDP. + device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("ddp",)) + + # For testing, we don't want to depend on loading pre-trained weights. + config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True) + if args.use_sequence_packing: + config.attn_input_format = "thd" + + if args.dataset["ss3_classification"]: + config.id2label = SS3_ID2LABEL + config.label2id = SS3_LABEL2ID + else: + config.id2label = SS8_ID2LABEL + config.label2id = SS8_LABEL2ID + + if args.use_pretrained: + model = AutoModelForTokenClassification.from_pretrained( + args.model_tag, config=config, trust_remote_code=True, dtype="bfloat16" + ) + else: + model = AutoModelForTokenClassification.from_config(config, trust_remote_code=True) + + peft_config = peft.LoraConfig( + task_type=peft.TaskType.TOKEN_CLS, + inference_mode=False, + r=args.lora.r, + lora_alpha=args.lora.alpha, + target_modules=list(args.lora.target_modules), + bias="none", + ) + + peft_model = peft.get_peft_model(model, peft_config) + peft_model.to(device=device, dtype=torch.bfloat16) + + print("----- PEFT Model --------") + peft_model.print_trainable_parameters() + + # Create optimizer. + decay_parameters = get_parameter_names_with_lora(peft_model) + optimizer_grouped_parameters = [ + { + "params": [p for n, p in peft_model.named_parameters() if (n in decay_parameters and p.requires_grad)], + "weight_decay": args.adamw_kwargs.weight_decay, + }, + { + "params": [p for n, p in peft_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **args.adamw_kwargs) + + scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + peft_model = torch.nn.parallel.DistributedDataParallel( + peft_model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["ddp"], + find_unused_parameters=True, + ) + + if args.use_torch_compile: + # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. + peft_model = torch.compile(peft_model) + + perf_logger = PerfLogger(dist_config, args) + + # Training loop. + step = 0 + epoch = 0 + while step < args.num_train_steps: + for batch in train_dataloader: + perf_logger.log_train_start_time() + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901 + + output = peft_model(**batch) + loss = output.loss + loss.backward() + + # Compute and clip gradient norms. + total_norm = torch.nn.utils.clip_grad_norm_(peft_model.parameters(), max_norm=1.0).item() + + # Step optimizer. + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + step += 1 + + perf_logger.log_train_end_time() + # Validation + avg_val_loss = None + avg_val_acc = None + if perform_validation and step % args.validation_interval == 0: + peft_model.eval() + val_loss_total = 0.0 + val_correct_total = 0 + val_tokens_total = 0 + val_steps = 0 + with torch.no_grad(): + for val_batch in val_dataloader: + val_batch = { # noqa: PLW2901 + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in val_batch.items() + } + val_output = peft_model(**val_batch) + + # Loss + val_loss_total += val_output.loss.item() + + # Accuracy + logits = val_output.logits + labels = val_batch["labels"] + correct, total = compute_accuracy(logits, labels) + val_correct_total += correct + val_tokens_total += total + + val_steps += 1 + + avg_val_loss = val_loss_total / val_steps + avg_val_acc = val_correct_total / val_tokens_total if val_tokens_total > 0 else 0.0 + print(f"\nStep: {step}: Validation Loss = {avg_val_loss:.4f}, Accuracy: {avg_val_acc:.4f}\n") + peft_model.train() + + perf_logger.log_step( + step=step, + batch=batch, + outputs=output, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + val_loss=avg_val_loss, + val_acc=avg_val_acc, + ) + + if step >= args.num_train_steps: + break + + # Dataloader exhausted, incrementing epoch + epoch += 1 + train_dataset_or_sampler.set_epoch(epoch) + + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None + + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_ddp( + model=peft_model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main() diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index a4281dc3eb..ed6ae1ac10 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -57,6 +57,18 @@ "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common", ], + "bionemo-recipes/recipes/esm2_native_te/collator.py": [ + "bionemo-recipes/recipes/esm2_peft_te/collator.py", + ], + "bionemo-recipes/recipes/esm2_native_te/checkpoint.py": [ + "bionemo-recipes/recipes/esm2_peft_te/checkpoint.py", + ], + "bionemo-recipes/recipes/esm2_native_te/distributed_config.py": [ + "bionemo-recipes/recipes/esm2_peft_te/distributed_config.py", + ], + "bionemo-recipes/recipes/esm2_native_te/scheduler.py": [ + "bionemo-recipes/recipes/esm2_peft_te/scheduler.py", + ], }