diff --git a/.github/workflows/unit-tests-recipes.yml b/.github/workflows/unit-tests-recipes.yml index 4385f40feb..70203f3b8a 100644 --- a/.github/workflows/unit-tests-recipes.yml +++ b/.github/workflows/unit-tests-recipes.yml @@ -155,6 +155,9 @@ jobs: with: sparse-checkout: "${{ matrix.recipe.dir }}" sparse-checkout-cone-mode: false + - name: Include symlink targets for esm2_peft_te + if: ${{ matrix.recipe.dir == 'bionemo-recipes/recipes/esm2_peft_te' }} + run: git -c safe.directory=/__w/bionemo-framework/bionemo-framework sparse-checkout add bionemo-recipes/recipes/esm2_native_te - name: Cache Hugging Face models uses: actions/cache@v4 diff --git a/bionemo-recipes/models/esm2/pyproject.toml b/bionemo-recipes/models/esm2/pyproject.toml index 6fe871a015..dfb6d3fc87 100644 --- a/bionemo-recipes/models/esm2/pyproject.toml +++ b/bionemo-recipes/models/esm2/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "jinja2", "megatron-fsdp", "omegaconf", - "peft", + "peft @ git+https://github.com/balvisio/peft.git@dev/ba/support-te-lora", "pytest", "torch", "torchao!=0.14.0", diff --git a/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py b/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py index 00fdf23128..1fd14ee767 100644 --- a/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py @@ -679,3 +679,81 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class NVConvNetHead(nn.Module): + """Convolution based head for token classification.""" + + def __init__(self, config: NVEsmConfig): + """Initialize the NVConvNetHead.""" + super().__init__() + self.conv_head = torch.nn.Sequential( + torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3), + torch.nn.ReLU(), + torch.nn.Dropout(config.hidden_dropout_prob), + torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3), + ) + + def forward(self, features, **kwargs): + """Forward pass for the convolutional token classification head.""" + return self.conv_head(features).transpose(1, 2) + + +class NVEsmForConvTokenClassification(NVEsmPreTrainedModel): + """Adds a convolutional classification head to the model.""" + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.classifier = NVConvNetHead(config) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if outputs[0].dim() == 3: + sequence_output = outputs[0] + else: + sequence_output = outputs[0].unsqueeze(0) + + sequence_output = sequence_output.transpose(1, 2) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/models/esm2/tests/test_peft.py b/bionemo-recipes/models/esm2/tests/test_peft.py index b991a3dc4c..8caadc04aa 100644 --- a/bionemo-recipes/models/esm2/tests/test_peft.py +++ b/bionemo-recipes/models/esm2/tests/test_peft.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import peft -import pytest import torch from esm.modeling_esm_te import NVEsmForMaskedLM @@ -58,7 +59,6 @@ def test_lora_model_forward_pass(te_model_checkpoint, input_data): assert outputs.loss is not None -@pytest.mark.xfail(reason="BIONEMO-3136: LoRA model initializes with warnings because of TE layers.") def test_lora_model_raises_no_warnings(te_model_checkpoint): model = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) @@ -71,13 +71,14 @@ def test_lora_model_raises_no_warnings(te_model_checkpoint): bias="none", ) - with pytest.warns(UserWarning) as record: + with warnings.catch_warnings(record=True) as record: + # Cause all warnings to be triggered (default behavior may ignore some) + warnings.simplefilter("always") peft.get_peft_model(model, peft_config) assert len(record) == 0 -@pytest.mark.xfail(reason="BIONEMO-3136: LoRA model initialization fails with target_modules because of TE layers.") def test_lora_model_with_target_modules(te_model_checkpoint): model = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index 00fdf23128..1fd14ee767 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -679,3 +679,81 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class NVConvNetHead(nn.Module): + """Convolution based head for token classification.""" + + def __init__(self, config: NVEsmConfig): + """Initialize the NVConvNetHead.""" + super().__init__() + self.conv_head = torch.nn.Sequential( + torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3), + torch.nn.ReLU(), + torch.nn.Dropout(config.hidden_dropout_prob), + torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3), + ) + + def forward(self, features, **kwargs): + """Forward pass for the convolutional token classification head.""" + return self.conv_head(features).transpose(1, 2) + + +class NVEsmForConvTokenClassification(NVEsmPreTrainedModel): + """Adds a convolutional classification head to the model.""" + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.classifier = NVConvNetHead(config) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if outputs[0].dim() == 3: + sequence_output = outputs[0] + else: + sequence_output = outputs[0].unsqueeze(0) + + sequence_output = sequence_output.transpose(1, 2) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/checkpoint.py b/bionemo-recipes/recipes/esm2_native_te/checkpoint.py index a168eeeb89..ce909aaaec 100644 --- a/bionemo-recipes/recipes/esm2_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/esm2_native_te/checkpoint.py @@ -194,6 +194,15 @@ def save_final_model_ddp( underlying_model: transformers.PreTrainedModel = model.module if hasattr(model, "module") else model # type: ignore os.makedirs(save_directory, exist_ok=True) + # If we are saving a PEFT model we also save the base_model config. + # This allows for an streamlined reload of the PEFT model without having to manually reconstruct the config of + # the base_model. + # For example: + # >>> config = AutoConfig.from_pretrained() + # >>> base_model = AutoModelForTokenClassification.from_pretrained(, config=config) + # >>> peft_model = PeftModel.from_pretrained(base_model, ) + if hasattr(underlying_model, "peft_config"): + underlying_model.config.save_pretrained(save_directory) underlying_model.save_pretrained(save_directory, state_dict=underlying_model.state_dict(), safe_serialization=True) logger.info(f"Saved final DDP model to {save_directory}") diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index 00fdf23128..1fd14ee767 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -679,3 +679,81 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class NVConvNetHead(nn.Module): + """Convolution based head for token classification.""" + + def __init__(self, config: NVEsmConfig): + """Initialize the NVConvNetHead.""" + super().__init__() + self.conv_head = torch.nn.Sequential( + torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3), + torch.nn.ReLU(), + torch.nn.Dropout(config.hidden_dropout_prob), + torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3), + ) + + def forward(self, features, **kwargs): + """Forward pass for the convolutional token classification head.""" + return self.conv_head(features).transpose(1, 2) + + +class NVEsmForConvTokenClassification(NVEsmPreTrainedModel): + """Adds a convolutional classification head to the model.""" + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.classifier = NVConvNetHead(config) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if outputs[0].dim() == 3: + sequence_output = outputs[0] + else: + sequence_output = outputs[0].unsqueeze(0) + + sequence_output = sequence_output.transpose(1, 2) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_peft_te/Dockerfile b/bionemo-recipes/recipes/esm2_peft_te/Dockerfile new file mode 100644 index 0000000000..4b1b4a7743 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/Dockerfile @@ -0,0 +1,12 @@ +FROM nvcr.io/nvidia/pytorch:25.12-py3 + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=esm2_peft_te/requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +WORKDIR /workspace/bionemo-recipes/recipes/esm2_peft_te +COPY esm2_peft_te/ /workspace/bionemo-recipes/recipes/esm2_peft_te +COPY esm2_native_te/checkpoint.py /workspace/bionemo-recipes/recipes/esm2_native_te/checkpoint.py +COPY esm2_native_te/collator.py /workspace/bionemo-recipes/recipes/esm2_native_te/collator.py +COPY esm2_native_te/distributed_config.py /workspace/bionemo-recipes/recipes/esm2_native_te/distributed_config.py +COPY esm2_native_te/scheduler.py /workspace/bionemo-recipes/recipes/esm2_native_te/scheduler.py diff --git a/bionemo-recipes/recipes/esm2_peft_te/README.md b/bionemo-recipes/recipes/esm2_peft_te/README.md index 98bc76396b..e43fb2c60c 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/README.md +++ b/bionemo-recipes/recipes/esm2_peft_te/README.md @@ -2,6 +2,106 @@ This folder demonstrates how to fine-tune a TransformerEngine-accelerated ESM-2 model using PEFT. -Note: This recipe is a work in progress, and currently only demonstrates basic support for LoRA fine-tuning and -TransformerEngine layers. See `bionemo-recipes/models/esm2/tests/test_peft.py` for additional information and known -limitations. +## Commands to Launch LoRA Fine-tuning + +To run single-process training on one GPU, run: + +```bash +python train_lora_ddp.py +``` + +To run multi-process training locally on 2+ GPUs, run (e.g. 2 GPUs): + +```bash +torchrun --nproc_per_node=2 train_lora_ddp.py +``` + +## Sequence Packing (THD input format) + +Sequence packing is handled via the [`DataCollatorWithFlattening`](https://huggingface.co/docs/transformers/v4.47.1/main_classes/data_collator#transformers.DataCollatorWithFlattening) collator from the HuggingFace transformers library that provides input arguments (e.g. +`cu_seq_lens_q`) needed for padding-free attention. To enable sequence packing, set `use_sequence_packing=true` +in the hydra configuration. + +```bash +python train_lora_ddp.py --config-name L0_sanity use_sequence_packing=true +``` + +## Running Inference + +To do inference, use `infer.py`. By default, it uses the `hydra_config/L0_sanity_infer.yaml` config, which points to a +previously LoRA fine-tuned checkpoint fine-tuned on `data/porter6_train_dataset_55k.parquet`. It also defaults to running +inference on the FASTA file at `data/input_infer.fasta` (see `hydra_config/defaults_infer.yaml`). + +To run inference with the defaults: + +```bash +python infer.py +``` + +You can override the most common settings from the command line: + +- **`input_file`**: FASTA input (default: `data/input_infer.fasta`) +- **`output_file`**: Where to write predictions (CSV). If `null`, results print to stdout (default: `preds.csv`) +- **`model_tag`**: Base ESM-2 HF model to load (default: `nvidia/esm2_t6_8M_UR50D`) +- **`base_model_config_dir`**: Directory containing the fine-tuned model config +- **`peft_model_config_dir`**: Directory containing the LoRA adapter weights/config (defaults to `base_model_config_dir`) + +Examples: + +```bash +# Run on a different FASTA file and write a CSV +python infer.py input_file=/path/to/inputs.fasta output_file=preds.csv + +# Point to your own LoRA fine-tuned checkpoint directory +python infer.py base_model_config_dir=/path/to/my_peft_checkpoint peft_model_config_dir=/path/to/my_peft_checkpoint +``` + +## Datasets + +This recipe includes small and medium-sized datasets in `data/` so you can get started quickly without downloading +anything. + +- **Quick sanity dataset (used for CI and smoke tests)**: `data/peft_sanity_dataset.parquet` is a **5,000-sample subset** + of the Hugging Face dataset + [`lamm-mit/protein_secondary_structure_from_PDB`](https://huggingface.co/datasets/lamm-mit/protein_secondary_structure_from_PDB). + It is intended for fast local iteration and is also used by the recipe's CI tests. + +- **Porter6 paper datasets**: + + - `data/porter6_train_dataset_55k.parquet`: training set. + - `data/porter6_val_dataset_2024_692.parquet`: 2024 benchmark validation set. + + These originate from the Porter6 secondary-structure prediction work. For details on how they were built, see the + [Porter6 paper](https://pmc.ncbi.nlm.nih.gov/articles/PMC11719765/). The original dataset files were downloaded from + the [Porter6 repository](https://github.com/WafaAlanazi/Porter6) and then converted to Parquet format for use in this + recipe. + +### Installing Dependencies + +The easiest way to get started with this recipe is to use the provided Dockerfile, which uses the latest NVIDIA PyTorch +base image to provide optimized versions of PyTorch and TransformerEngine. To build the container, run: + +```bash +docker build -f Dockerfile -t esm2_peft_te .. +``` + +This uses `..` as the build context so the `esm2_native_te` symlink targets are available during the build. +The Dockerfile only copies the four shared Python files from `esm2_native_te` into the image. + +To run the container, run: + +```bash +docker run -it --gpus all --network host --ipc=host --rm -v ${PWD}:/workspace/bionemo esm2_peft_te /bin/bash +``` + +## Developer Guide + +### Running tests + +To run tests locally, run `recipes_local_test.py` from the repository root with the recipe directory as an argument. + +```bash +./ci/scripts/recipes_local_test.py bionemo-recipes/recipes/esm2_peft_te/ +``` + +For more information see [here](../esm2_native_te/README.md). diff --git a/bionemo-recipes/recipes/esm2_peft_te/checkpoint.py b/bionemo-recipes/recipes/esm2_peft_te/checkpoint.py new file mode 120000 index 0000000000..96b01bd2d3 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/checkpoint.py @@ -0,0 +1 @@ +../esm2_native_te/checkpoint.py \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_peft_te/collator.py b/bionemo-recipes/recipes/esm2_peft_te/collator.py new file mode 120000 index 0000000000..37423926a0 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/collator.py @@ -0,0 +1 @@ +../esm2_native_te/collator.py \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_peft_te/data/input_infer.fasta b/bionemo-recipes/recipes/esm2_peft_te/data/input_infer.fasta new file mode 100644 index 0000000000..c1244acb1e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/data/input_infer.fasta @@ -0,0 +1,4 @@ +>2KL1A +MNEAKGVYVMSVLPNMPAAGRLEAGDRIAAIDGQPINTSEQIVSYVREKQAGDRVRVTFIRDRKQHEAELVLKPFPHHPNQIGLGVT +>2QYUA +ATSSPSSPADWAKKLTDAVLRQKAGETLTAADRDFSNADFRNITFSKILPPSFMERDGDIIKGFNFSNSKFTYSDISHLHFDECRFTYSTLSDVVCSNTKFSNSDMNEVFLQYSITTQQQPSFIDTTLKNTLIRHKANLSGVILNEPDNSSPPSVSGGGNFIRLGDIWLQMPLLWTENAVDGFLNHEHNNGKSILMTIDSLPDKYSQEKVQAMEDLVKSLRGGRLTEACIRPVESSLVSVLAHPPYTQSALISEWLGPVQERFFAHQCQTYNDVPLPAPDTYYQQRILPVLLDSFDRNSAAMTTHSGLFNQVILHCMTGVDCTDGTRQKAAALYEQYLAHPAVSPHIHNGLFGNYDGSPDWTTRAADNFLLLSSQDSDTAMMLSTDTLLTMLNPTPDTAWDNFYLLRAGENVSTAQISPVELFRHDFPVFLAAFNQQATQRRFGELIDIILSTEEHGELNQQFLAATNQKHSTVKLIDDASVSRLATIFDPLLPEGKLSPAHYQHILSAYHLTDATPQKQAETLFCLSTAFARYSSSAIFGTEHDSPPALRGYAEALMQKAWELSPAIFPSSEQFTEWSDRFHGLHGAFTCTSVVADSMQRHARKYFPSVLSSILPLAWA diff --git a/bionemo-recipes/recipes/esm2_peft_te/peft_sanity_dataset.parquet b/bionemo-recipes/recipes/esm2_peft_te/data/peft_sanity_dataset.parquet similarity index 100% rename from bionemo-recipes/recipes/esm2_peft_te/peft_sanity_dataset.parquet rename to bionemo-recipes/recipes/esm2_peft_te/data/peft_sanity_dataset.parquet diff --git a/bionemo-recipes/recipes/esm2_peft_te/data/porter6_train_dataset_55k.parquet b/bionemo-recipes/recipes/esm2_peft_te/data/porter6_train_dataset_55k.parquet new file mode 100644 index 0000000000..d7b6817369 Binary files /dev/null and b/bionemo-recipes/recipes/esm2_peft_te/data/porter6_train_dataset_55k.parquet differ diff --git a/bionemo-recipes/recipes/esm2_peft_te/data/porter6_val_dataset_2024_692.parquet b/bionemo-recipes/recipes/esm2_peft_te/data/porter6_val_dataset_2024_692.parquet new file mode 100644 index 0000000000..879141ed00 Binary files /dev/null and b/bionemo-recipes/recipes/esm2_peft_te/data/porter6_val_dataset_2024_692.parquet differ diff --git a/bionemo-recipes/recipes/esm2_peft_te/dataset.py b/bionemo-recipes/recipes/esm2_peft_te/dataset.py new file mode 100644 index 0000000000..e2e5901367 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/dataset.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datasets +import datasets.distributed +from datasets import IterableDataset, load_dataset +from torch.utils.data import DataLoader, DistributedSampler +from transformers import ( + AutoTokenizer, + DataCollatorForTokenClassification, + DataCollatorWithFlattening, +) + +from collator import TokenPackingDataset +from distributed_config import DistributedConfig +from utils import SS3_LABEL2ID, SS8_LABEL2ID + + +def create_dataloader( + distributed_config: DistributedConfig, + use_sequence_packing: bool, + tokenizer_name: str, + micro_batch_size: int, + val_micro_batch_size: int, + num_workers: int, + max_seq_length: int, + stride: int, + seed: int, + ss3_classification: bool, + load_dataset_kwargs: dict, +) -> tuple[DataLoader, DataLoader | None, IterableDataset | DistributedSampler]: + """Create a dataloader for the secondary structure dataset.""" + dataset_or_dataset_dict = load_dataset(**load_dataset_kwargs) + + if isinstance(dataset_or_dataset_dict, dict): + train_dataset = dataset_or_dataset_dict.get("train") + assert train_dataset, "'train' split must be specified." + val_dataset = dataset_or_dataset_dict.get("validation") + else: + train_dataset = dataset_or_dataset_dict + val_dataset = None + + print( + f"Loading dataset: path: '{load_dataset_kwargs['path']}' | data_files: '{load_dataset_kwargs['data_files']}'." + ) + + perform_validation = val_dataset is not None + + if isinstance(train_dataset, IterableDataset): + train_dataset = datasets.distributed.split_dataset_by_node( + train_dataset, + rank=distributed_config.rank, + world_size=distributed_config.world_size, + ) + train_dataset = train_dataset.shuffle(seed=seed, buffer_size=10_000) + + if perform_validation: + val_dataset = datasets.distributed.split_dataset_by_node( + val_dataset, + rank=distributed_config.rank, + world_size=distributed_config.world_size, + ) + + if ss3_classification: + ss_token_map = SS3_LABEL2ID + else: + ss_token_map = SS8_LABEL2ID + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenize_args = { + "max_length": max_seq_length, + "truncation": True, + "stride": stride, + "return_overflowing_tokens": True, + "return_offsets_mapping": True, + } + + def tokenize(example): + """Tokenize both the input protein sequence and the secondary structure labels.""" + result = tokenizer(example["Sequence"], **tokenize_args) + + # While we can use the rust-based tokenizer for the protein sequence, we manually encode the secondary structure + # labels. Our goal is to return a list of integer labels with the same shape as the input_ids. + labels = [] + for batch_idx in range(len(result["input_ids"])): + sequence_labels = [] + + # This array maps the possibly-chunked result["input_ids"] to the original sequence. Because of + # `return_overflowing_tokens`, each input sequence may be split into multiple input rows. + offsets = result["offset_mapping"][batch_idx] + + # This gets the original secondary structure sequence for the current chunk. + ss_sequence = example["Secondary_structure"][result["overflow_to_sample_mapping"][batch_idx]] + + for offset_start, offset_end in offsets: + if offset_start == offset_end: + sequence_labels.append(-100) # Start and end of the sequence tokens can be ignored. + elif offset_end == offset_start + 1: # All tokens are single-character. + ss_char = ss_sequence[offset_start] + ss_label_value = ss_token_map[ss_char] # Encode the secondary structure character + sequence_labels.append(ss_label_value) + else: + raise ValueError(f"Invalid offset: {offset_start} {offset_end}") + + labels.append(sequence_labels) + + return {"input_ids": result["input_ids"], "labels": labels} + + train_tokenized_dataset = train_dataset.map( + tokenize, + batched=True, + remove_columns=[col for col in train_dataset.features if col not in ["input_ids", "labels"]], + ) + + if isinstance(train_tokenized_dataset, IterableDataset): + train_sampler = None + else: + train_sampler = DistributedSampler( + train_tokenized_dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=seed, + ) + + if use_sequence_packing: + assert isinstance(train_tokenized_dataset, datasets.IterableDataset), ( + "THD token packing requires a streaming dataset." + ) + collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) + train_tokenized_dataset = TokenPackingDataset( + train_tokenized_dataset, max_tokens_per_batch=micro_batch_size * max_seq_length + ) + batch_size = None # The TokenPackingDataset will handle the batching. + else: + collator = DataCollatorForTokenClassification( + tokenizer=tokenizer, padding="max_length", max_length=max_seq_length + ) + batch_size = micro_batch_size + + train_dataloader = DataLoader( + train_tokenized_dataset, + sampler=train_sampler, + batch_size=batch_size, + collate_fn=collator, + num_workers=num_workers, + pin_memory=True, + ) + + if perform_validation: + val_tokenized_dataset = val_dataset.map( + tokenize, + batched=True, + remove_columns=[col for col in val_dataset.features if col not in ["input_ids", "labels"]], + ) + + if isinstance(val_tokenized_dataset, IterableDataset): + val_sampler = None + else: + val_sampler = DistributedSampler( + val_tokenized_dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=seed, + ) + + if use_sequence_packing: + assert isinstance(val_tokenized_dataset, datasets.IterableDataset), ( + "THD token packing requires a streaming dataset." + ) + collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) + val_tokenized_dataset = TokenPackingDataset( + val_tokenized_dataset, max_tokens_per_batch=micro_batch_size * max_seq_length + ) + val_batch_size = None # The TokenPackingDataset will handle the batching. + else: + collator = DataCollatorForTokenClassification( + tokenizer=tokenizer, padding="max_length", max_length=max_seq_length + ) + val_batch_size = val_micro_batch_size + + val_dataloader = DataLoader( + val_tokenized_dataset, + sampler=val_sampler, + batch_size=val_batch_size, + collate_fn=collator, + num_workers=num_workers, + pin_memory=True, + ) + else: + val_dataloader = None + + return train_dataloader, val_dataloader, train_tokenized_dataset if train_sampler is None else train_sampler diff --git a/bionemo-recipes/recipes/esm2_peft_te/distributed_config.py b/bionemo-recipes/recipes/esm2_peft_te/distributed_config.py new file mode 120000 index 0000000000..f1ecc736cc --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/distributed_config.py @@ -0,0 +1 @@ +../esm2_native_te/distributed_config.py \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index 00fdf23128..1fd14ee767 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -679,3 +679,81 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class NVConvNetHead(nn.Module): + """Convolution based head for token classification.""" + + def __init__(self, config: NVEsmConfig): + """Initialize the NVConvNetHead.""" + super().__init__() + self.conv_head = torch.nn.Sequential( + torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3), + torch.nn.ReLU(), + torch.nn.Dropout(config.hidden_dropout_prob), + torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3), + ) + + def forward(self, features, **kwargs): + """Forward pass for the convolutional token classification head.""" + return self.conv_head(features).transpose(1, 2) + + +class NVEsmForConvTokenClassification(NVEsmPreTrainedModel): + """Adds a convolutional classification head to the model.""" + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.classifier = NVConvNetHead(config) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if outputs[0].dim() == 3: + sequence_output = outputs[0] + else: + sequence_output = outputs[0].unsqueeze(0) + + sequence_output = sequence_output.transpose(1, 2) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/README.md b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/README.md new file mode 100644 index 0000000000..5acd90933f --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/README.md @@ -0,0 +1,200 @@ +--- +base_model: nvidia/esm2_t6_8M_UR50D +library_name: peft +tags: + - base_model:adapter:nvidia/esm2_t6_8M_UR50D + - lora + - transformers +--- + +# Model Card for Model ID + + + +## Model Details + +### Model Description + + + +- **Developed by:** [More Information Needed] +- **Funded by \[optional\]:** [More Information Needed] +- **Shared by \[optional\]:** [More Information Needed] +- **Model type:** [More Information Needed] +- **Language(s) (NLP):** [More Information Needed] +- **License:** [More Information Needed] +- **Finetuned from model \[optional\]:** [More Information Needed] + +### Model Sources [optional] + + + +- **Repository:** [More Information Needed] +- **Paper \[optional\]:** [More Information Needed] +- **Demo \[optional\]:** [More Information Needed] + +## Uses + + + +### Direct Use + + + +[More Information Needed] + +### Downstream Use [optional] + + + +[More Information Needed] + +### Out-of-Scope Use + + + +[More Information Needed] + +## Bias, Risks, and Limitations + + + +[More Information Needed] + +### Recommendations + + + +Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations. + +## How to Get Started with the Model + +Use the code below to get started with the model. + +[More Information Needed] + +## Training Details + +### Training Data + + + +[More Information Needed] + +### Training Procedure + + + +#### Preprocessing [optional] + +[More Information Needed] + +#### Training Hyperparameters + +- **Training regime:** [More Information Needed] + +#### Speeds, Sizes, Times [optional] + + + +[More Information Needed] + +## Evaluation + + + +### Testing Data, Factors & Metrics + +#### Testing Data + + + +[More Information Needed] + +#### Factors + + + +[More Information Needed] + +#### Metrics + + + +[More Information Needed] + +### Results + +[More Information Needed] + +#### Summary + +## Model Examination [optional] + + + +[More Information Needed] + +## Environmental Impact + + + +Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). + +- **Hardware Type:** [More Information Needed] +- **Hours used:** [More Information Needed] +- **Cloud Provider:** [More Information Needed] +- **Compute Region:** [More Information Needed] +- **Carbon Emitted:** [More Information Needed] + +## Technical Specifications [optional] + +### Model Architecture and Objective + +[More Information Needed] + +### Compute Infrastructure + +[More Information Needed] + +#### Hardware + +[More Information Needed] + +#### Software + +[More Information Needed] + +## Citation [optional] + + + +**BibTeX:** + +[More Information Needed] + +**APA:** + +[More Information Needed] + +## Glossary [optional] + + + +[More Information Needed] + +## More Information [optional] + +[More Information Needed] + +## Model Card Authors [optional] + +[More Information Needed] + +## Model Card Contact + +[More Information Needed] + +### Framework versions + +- PEFT 0.18.1.dev0 diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/adapter_config.json b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/adapter_config.json new file mode 100644 index 0000000000..383fae6c25 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/adapter_config.json @@ -0,0 +1,44 @@ +{ + "alora_invocation_tokens": null, + "alpha_pattern": {}, + "arrow_config": null, + "auto_mapping": null, + "base_model_name_or_path": "nvidia/esm2_t6_8M_UR50D", + "bias": "none", + "corda_config": null, + "ensure_weight_tying": false, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 16, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": [ + "classifier", + "score" + ], + "peft_type": "LORA", + "peft_version": "0.18.1.dev0@UNKNOWN", + "qalora_group_size": 16, + "r": 8, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "layernorm_qkv" + ], + "target_parameters": null, + "task_type": "TOKEN_CLS", + "trainable_token_indices": null, + "use_bdlora": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/adapter_model.safetensors b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/adapter_model.safetensors new file mode 100644 index 0000000000..a223c8da50 Binary files /dev/null and b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/adapter_model.safetensors differ diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/config.json b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/config.json new file mode 100644 index 0000000000..167cfcd8ed --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/config.json @@ -0,0 +1,60 @@ +{ + "architectures": [ + "NVEsmForMaskedLM" + ], + "attention_probs_dropout_prob": 0.0, + "attn_input_format": "bshd", + "attn_mask_type": "padding", + "auto_map": { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification" + }, + "classifier_dropout": null, + "dtype": "bfloat16", + "emb_layer_norm_before": false, + "encoder_activation": "gelu", + "esmfold_config": null, + "fuse_qkv_params": true, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 320, + "id2label": { + "0": "H", + "1": "E", + "2": "C" + }, + "initializer_range": 0.02, + "intermediate_size": 1280, + "is_folding_model": false, + "label2id": { + "B": 1, + "C": 2, + "E": 1, + "G": 0, + "H": 0, + "I": 0, + "L": 2, + "S": 2, + "T": 2, + "~": 2 + }, + "layer_norm_eps": 1e-05, + "mask_token_id": 32, + "max_position_embeddings": 1026, + "max_seq_length": null, + "micro_batch_size": null, + "model_type": "nv_esm", + "num_attention_heads": 20, + "num_hidden_layers": 6, + "pad_token_id": 1, + "padded_vocab_size": 64, + "position_embedding_type": "rotary", + "qkv_weight_interleaved": true, + "token_dropout": true, + "transformers_version": "4.57.5", + "use_cache": true, + "vocab_list": null, + "vocab_size": 33 +} diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/esm_nv.py new file mode 100644 index 0000000000..1fd14ee767 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/esm_nv.py @@ -0,0 +1,759 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from typing import ClassVar, Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs + + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention. This controls + whether the dimensions of the intermediate hidden states is 'batch first' + ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, + `b` batch size, `h` the number of heads, `d` head size. Note that these + formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + for layer_module in self.layers: + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. + These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + # Keep the last layers of the network in higher precision to avoid numerical instability. + # Please see recipes/fp8_analysis/README.md for more details. + with transformer_engine.pytorch.fp8_autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + + if not using_thd: + # BSHD token dropout correction + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + else: + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged( + is_masked, offsets=kwargs["cu_seq_lens_q"] + ).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class NVConvNetHead(nn.Module): + """Convolution based head for token classification.""" + + def __init__(self, config: NVEsmConfig): + """Initialize the NVConvNetHead.""" + super().__init__() + self.conv_head = torch.nn.Sequential( + torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3), + torch.nn.ReLU(), + torch.nn.Dropout(config.hidden_dropout_prob), + torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3), + ) + + def forward(self, features, **kwargs): + """Forward pass for the convolutional token classification head.""" + return self.conv_head(features).transpose(1, 2) + + +class NVEsmForConvTokenClassification(NVEsmPreTrainedModel): + """Adds a convolutional classification head to the model.""" + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.classifier = NVConvNetHead(config) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + if outputs[0].dim() == 3: + sequence_output = outputs[0] + else: + sequence_output = outputs[0].unsqueeze(0) + + sequence_output = sequence_output.transpose(1, 2) + + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity.yaml new file mode 100644 index 0000000000..4a92da0d9e --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity.yaml @@ -0,0 +1,55 @@ +defaults: + - defaults + - _self_ + +# Training config +model_tag: ./example_8m_checkpoint # E.g., nvidia/esm2_t6_8M_UR50D or facebook/esm2_t6_8M_UR50D +use_pretrained: false +num_train_steps: 250 + +validation_interval: 50 + +# We want this on in CI/CD to validate that the script runs successfully with torch.compile. +use_torch_compile: true + +use_sequence_packing: false + +dataset: + tokenizer_name: ${model_tag} + micro_batch_size: 8 + val_micro_batch_size: 128 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + ss3_classification: true + load_dataset_kwargs: + path: "parquet" + split: "train" + data_files: "data/peft_sanity_dataset.parquet" + +# WandB config +wandb_init_args: + name: "esm2_lora_example_8M_sanity" + project: "esm2_lora_sanity" + mode: "offline" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 100 + +checkpoint: + ckpt_dir: null + save_final_model: false + +logger: + frequency: 1 + +lora: + r: 8 + alpha: 16 + target_modules: + - "layernorm_qkv" + # For 'facebook/esm2*' use: + # - "query" + # - "key" + # - "value" diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity_infer.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity_infer.yaml new file mode 100644 index 0000000000..adff8470f1 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L0_sanity_infer.yaml @@ -0,0 +1,14 @@ +defaults: +- defaults_infer +- _self_ + +model_tag: "nvidia/esm2_t6_8M_UR50D" +base_model_config_dir: "example_nv_esm2_t6_8M_UR50D_peft_checkpoint" # pragma: allowlist secret + +output_file: preds.csv + +inference: + batch_size: 4 # tune based on GPU memory + max_seq_length: 1024 + stride: 16 + infer_overflowing_aas: true diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_fb_15B.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_fb_15B.yaml new file mode 100644 index 0000000000..1c0bbfbda2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_fb_15B.yaml @@ -0,0 +1,57 @@ +defaults: + - defaults + - _self_ + +# Training config +model_tag: facebook/esm2_t48_15B_UR50D # E.g., nvidia/esm2_t6_8M_UR50D or facebook/esm2_t6_8M_UR50D +use_pretrained: true +num_train_steps: 1000 + +validation_interval: 20 + +use_torch_compile: true + +use_sequence_packing: true + +dataset: + # The NVIDIA tokenizer is identical to the facebook/esm* and supports generating multiple samples from sequences + # longer than 1024. + tokenizer_name: nvidia/esm2_t48_15B_UR50D + micro_batch_size: 8 + val_micro_batch_size: 128 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + ss3_classification: true + load_dataset_kwargs: + path: "parquet" + split: null + data_files: + train: "data/porter6_train_dataset_55k.parquet" + validation: "data/porter6_val_dataset_2024_692.parquet" + +# WandB config +wandb_init_args: + name: "esm2_t48_15B_UR50D_lora" + project: "esm2_lora" + mode: "online" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 50 + num_training_steps: 1_000 + +checkpoint: + ckpt_dir: "checkpoints/facebook_esm2_t48_15B_UR50D" # pragma: allowlist secret + save_final_model: false + +logger: + frequency: 1 + +lora: + r: 8 + alpha: 16 + target_modules: + - "query" + - "key" + - "value" diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_nv_15B.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_nv_15B.yaml new file mode 100644 index 0000000000..e0a413b3d2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/L1_nv_15B.yaml @@ -0,0 +1,57 @@ +defaults: + - defaults + - _self_ + +# Training config +model_tag: nvidia/esm2_t48_15B_UR50D # E.g., nvidia/esm2_t6_8M_UR50D or facebook/esm2_t6_8M_UR50D +use_pretrained: true +num_train_steps: 1000 + +validation_interval: 20 + +use_torch_compile: true + +use_sequence_packing: true + +dataset: + tokenizer_name: ${model_tag} + micro_batch_size: 8 + val_micro_batch_size: 128 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + ss3_classification: true + load_dataset_kwargs: + path: "parquet" + split: null + data_files: + train: "data/porter6_train_dataset_55k.parquet" + validation: "data/porter6_val_dataset_2024_692.parquet" + +# WandB config +wandb_init_args: + name: "esm2_t48_15B_UR50D_lora" + project: "esm2_lora" + mode: "online" + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 50 + num_training_steps: 1_000 + +checkpoint: + ckpt_dir: "checkpoints/nvidia_esm2_t48_15B_UR50D" # pragma: allowlist secret + save_final_model: true + +logger: + frequency: 1 + +lora: + r: 8 + alpha: 16 + target_modules: + - "layernorm_qkv" + # For 'facebook/esm2*' use: + # - "query" + # - "key" + # - "value" diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults.yaml new file mode 100644 index 0000000000..4dc314e59c --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults.yaml @@ -0,0 +1,55 @@ +# Training config +model_tag: ??? # E.g., nvidia/esm2_t6_8M_UR50D, facebook/esm2_t6_8M_UR50D, or a local path (e.g ./example_8m_checkpoint) +num_train_steps: ??? +validation_interval: 50 + +# Whether to wrap the model in torch.compile. Note, this is currently not supported with mfsdp (BIONEMO-2977). +# We leave this off by default since we don't see much of a performance improvement with TE layers. +use_torch_compile: false + +use_sequence_packing: false + +dataset: + tokenizer_name: ${model_tag} + micro_batch_size: ??? + val_micro_batch_size: 64 + num_workers: 1 + max_seq_length: 1024 + stride: 16 + seed: 42 + ss3_classification: true + load_dataset_kwargs: + path: "nvidia/esm2_uniref_pretraining_data" + split: "train" + streaming: True + +# WandB config +wandb_init_args: + name: ??? + project: null + +# Optimizer config +adamw_kwargs: + lr: 4e-4 + fused: true + betas: [0.9, 0.98] + eps: 1e-8 + weight_decay: 0.01 + +# Learning rate scheduler config +lr_scheduler_kwargs: + num_warmup_steps: 2_000 + num_training_steps: 500_000 + +# Checkpoint config +checkpoint: + ckpt_dir: ??? + save_final_model: true + +logger: + frequency: 100 + +lora: + r: 8 + alpha: 16 + target_modules: ??? diff --git a/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults_infer.yaml b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults_infer.yaml new file mode 100644 index 0000000000..fb036f36e7 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/hydra_config/defaults_infer.yaml @@ -0,0 +1,12 @@ +model_tag: ??? +base_model_config_dir: ??? +peft_model_config_dir: ${base_model_config_dir} + +input_file: data/input_infer.fasta +output_file: null + +inference: + batch_size: 4 # tune based on GPU memory + max_seq_length: 1024 + stride: 16 + infer_overflowing_aas: true diff --git a/bionemo-recipes/recipes/esm2_peft_te/infer.py b/bionemo-recipes/recipes/esm2_peft_te/infer.py new file mode 100644 index 0000000000..1732bd87c6 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/infer.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig +from peft import PeftModel +from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer + +from utils import format_output_rows, load_input, write_output + + +def _batched_inference( + model, + tokenizer, + records, + batch_size: int, + max_seq_length: int, + stride: int, + infer_overflowing_aas: bool, + device: str = "cuda", +) -> tuple[list[str], list[int]]: + id2label = model.config.id2label + + predictions = [] + sequences_to_sample_mapping = [] + + for i in range(0, len(records), batch_size): + batch = records[i : i + batch_size] + sequences = [r["sequence"] for r in batch] + + inputs = tokenizer( + sequences, + max_length=max_seq_length, + truncation=True, + stride=stride, + return_overflowing_tokens=infer_overflowing_aas, + return_tensors="pt", + padding=True, + ) + + overflow_map = inputs.pop("overflow_to_sample_mapping") + num_samples = len(inputs["input_ids"]) + + # inner batching over tokenizer outputs + for j in range(0, num_samples, batch_size): + sub_inputs = {k: v[j : j + batch_size].to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**sub_inputs) + + preds = outputs.logits.argmax(dim=-1) + + for k, (pred, input_ids) in enumerate(zip(preds, sub_inputs["input_ids"])): + length = (input_ids != tokenizer.pad_token_id).sum().item() + labels = "".join(id2label[i.item()] for i in pred[:length]) + + predictions.append(labels) + + # map back to original record index + original_idx = i + overflow_map[j + k].item() + sequences_to_sample_mapping.append(original_idx) + + return predictions, sequences_to_sample_mapping + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity_infer", version_base="1.2") +def main(args: DictConfig): + """Infer using a PEFT ESM-2 model. + + This script can be run once ESM2 has been PEFT fine-tuned and adapaters have + been checkpointed. For reference, an example has been provided in the './checkpoints' directory. + """ + # Ideally we would like to load the PEFT model directly by doing: + # >>> model = AutoPeftModelForTokenClassification.from_pretrained("", trust_remote_code=True) + # + # However, the from_pretrained() function has a positional argument named 'config' which prevent us from passing a + # a different model config to the base_model. Thus, we first build the base model and then we load the PEFT adapters. + + # Load the custom config + config = AutoConfig.from_pretrained(args.base_model_config_dir, trust_remote_code=True) + + # Load base model with the custom config + base_model = AutoModelForTokenClassification.from_pretrained( + args.model_tag, # original model tag + config=config, + trust_remote_code=True, + ) + + # Load PEFT adapters on top + peft_model = PeftModel.from_pretrained(base_model, args.peft_model_config_dir) + peft_model = peft_model.to("cuda").eval() + + tokenizer = AutoTokenizer.from_pretrained("nvidia/esm2_t48_15B_UR50D") + + records = load_input(Path(args.input_file)) + + predictions, sequences_to_sample_mapping = _batched_inference( + peft_model, + tokenizer, + records, + **args.inference, + ) + + if args.output_file: + write_output(records, predictions, sequences_to_sample_mapping, Path(args.output_file)) + + header, rows = format_output_rows(records, predictions, sequences_to_sample_mapping) + + print("---------------") + print("\t".join(header)) + for row in rows: + print("\t".join(row)) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_peft_te/perf_logger.py b/bionemo-recipes/recipes/esm2_peft_te/perf_logger.py new file mode 100644 index 0000000000..7828a4938a --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/perf_logger.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import time + +import torch +import torchmetrics +import torchmetrics.text +import wandb +from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm +from transformers.modeling_outputs import MaskedLMOutput + +from distributed_config import DistributedConfig + + +logger = logging.getLogger(__name__) + + +class PerfLogger: + """Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training. + + Args: + dist_config: The distributed configuration. + args: The arguments. + + Attributes: + min_loss: The minimum loss seen so far. + """ + + def __init__(self, dist_config: DistributedConfig, args: DictConfig): + """Initialize the logger.""" + self._dist_config = dist_config + self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) + + self.min_loss = float("inf") + + self.logging_frequency = args.logger.frequency + # Track whether to collect memory stats (disabled by default for max performance) + + metrics_dict = { + "train/loss": torchmetrics.MeanMetric(), + "train/grad_norm": torchmetrics.MeanMetric(), + "train/learning_rate": torchmetrics.MeanMetric(), + "train/num_tokens_per_gpu": torchmetrics.MeanMetric(), + "train/total_num_tokens": torchmetrics.SumMetric(), + "train/num_unpadded_tokens": torchmetrics.MeanMetric(), + "train/step_time": torchmetrics.MeanMetric(), + "train/tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/unpadded_tokens_per_second_per_gpu": torchmetrics.MeanMetric(), + "train/total_unpadded_tokens_per_batch": torchmetrics.SumMetric(), + "train/perplexity": torchmetrics.text.Perplexity(ignore_index=-100), + "train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(), + "train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(), + "val/loss": torchmetrics.MeanMetric(), + "val/accuracy": torchmetrics.MeanMetric(), + } + + self.metrics = torchmetrics.MetricCollection(metrics_dict) + # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. + self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.previous_step_time = time.perf_counter() + self.train_start_time = None + self.train_end_time = None + + if self._dist_config.is_main_process(): + # Log the entire args object to wandb for experiment tracking and reproducibility. + wandb.init(**args.wandb_init_args, config=self._run_config) + self._progress_bar = tqdm(total=args.num_train_steps, file=sys.stderr, desc="Training") + + def log_train_end_time(self): + """Log when the train loop for a batch ends.""" + self.train_end_time = time.perf_counter() + return + + def log_train_start_time(self): + """Log when the train loop for a batch starts.""" + self.train_start_time = time.perf_counter() + return + + def log_step( + self, + step: int, + batch: dict[str, torch.Tensor], + outputs: MaskedLMOutput, + grad_norm: float, + lr: float, + val_loss: float | None = None, + val_acc: float | None = None, + ): + """Log a step to the logger and wandb. + + Args: + step: The step number. + batch: The batch of data for the step. + outputs: The outputs of the step. + grad_norm: The gradient norm of the step. + lr: The learning rate of the step. + val_loss: The validation loss of the step (if calculated) + val_acc: The validation accuracy of the step (if calculated) + """ + num_tokens = batch["input_ids"].numel() + # 1 is the padding token for ESM-2. + num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() + + self.min_loss = min(self.min_loss, outputs.loss.item()) + step_time = self.train_end_time - self.train_start_time + + self.metrics["train/loss"].update(outputs.loss) + self.metrics["train/num_tokens_per_gpu"].update(num_tokens) + self.metrics["train/total_num_tokens"].update(num_tokens) + self.metrics["train/num_unpadded_tokens"].update(num_unpadded_tokens) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens / self.logging_frequency) + + if val_loss is not None: + self.metrics["val/loss"].update(val_loss) + self.metrics["val/accuracy"].update(val_acc) + + # Handle sequence packing for torchmetrics calculation. + if outputs.logits.dim() < 3: + outputs.logits = outputs.logits.unsqueeze(0) + + self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) + + if step % self.logging_frequency == 0 and step > 0: + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + + metrics = self.metrics.compute() + self.metrics.reset() + metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) + + if self._dist_config.is_main_process(): + train_metrics = {k: v for k, v in metrics.items() if k.startswith("train/")} + val_metrics = {k: v for k, v in metrics.items() if k.startswith("val/")} + + wandb.log(train_metrics, step=step) + + if val_loss is not None and len(val_metrics) > 0: + wandb.log(val_metrics, step=step) + + self._progress_bar.update(self.logging_frequency) + self._progress_bar.set_postfix({"loss": outputs.loss.item()}) + + if self._dist_config.local_rank == 0: + logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + + def finish(self): + """Finish the logger and close the progress bar.""" + if not self._dist_config.is_main_process(): + return + + wandb.finish() + self._progress_bar.close() diff --git a/bionemo-recipes/recipes/esm2_peft_te/requirements.txt b/bionemo-recipes/recipes/esm2_peft_te/requirements.txt index 8810ca2003..2fcdb0e4e6 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_peft_te/requirements.txt @@ -1,7 +1,9 @@ datasets -peft +hydra-core +peft @ git+https://github.com/balvisio/peft.git@dev/ba/support-te-lora torch torchao!=0.14.0 +torchmetrics tqdm transformer_engine[pytorch] transformers diff --git a/bionemo-recipes/recipes/esm2_peft_te/scheduler.py b/bionemo-recipes/recipes/esm2_peft_te/scheduler.py new file mode 120000 index 0000000000..02e1991a43 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/scheduler.py @@ -0,0 +1 @@ +../esm2_native_te/scheduler.py \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py b/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py index e9d2af6ae7..c7270c5721 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/conftest.py @@ -16,6 +16,14 @@ import sys from pathlib import Path +import pytest + sys.path.append(Path(__file__).parent.parent.as_posix()) sys.path.append(Path(__file__).parent.as_posix()) + + +@pytest.fixture +def recipe_path() -> Path: + """Return the root directory of the recipe.""" + return Path(__file__).parent.parent diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.csv b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.csv new file mode 100644 index 0000000000..016b6ebfea --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.csv @@ -0,0 +1,2 @@ +Sequence,Secondary_structure +LAGVSERTIDPKQNFYMHWC,HIGEBST~HIGEBST~HIGE diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.parquet b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.parquet new file mode 100644 index 0000000000..51661a8b18 Binary files /dev/null and b/bionemo-recipes/recipes/esm2_peft_te/tests/peft_test_dataset.parquet differ diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/test_infer.py b/bionemo-recipes/recipes/esm2_peft_te/tests/test_infer.py new file mode 100644 index 0000000000..f46d4c008a --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/test_infer.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import sys +from pathlib import Path + +import pytest +import torch + + +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") + + +def run_infer_cmd(cmd, recipe_path): + """Run an inference command and check for errors.""" + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + cwd=str(recipe_path), + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}") + + +@requires_cuda +def test_infer_runs(recipe_path): + """Test that the infer script runs with default config.""" + output_path = Path(recipe_path) / "preds.csv" + + run_infer_cmd( + [ + sys.executable, + "infer.py", + "--config-name", + "L0_sanity_infer", + ], + recipe_path, + ) + + if output_path.exists(): + output_path.unlink() diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py index eb782e2552..eb0d4960b7 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora.py @@ -14,21 +14,62 @@ # limitations under the License. import torch -from transformers import BatchEncoding +from hydra import compose, initialize_config_dir -from train_lora import create_dataloader, train_lora +from train_lora_ddp import main as main_ddp -def test_create_dataloader(): - dataloader = create_dataloader(use_sanity_dataset=True) - for batch in dataloader: - assert isinstance(batch, BatchEncoding) - assert isinstance(batch["input_ids"], torch.Tensor) - assert isinstance(batch["labels"], torch.Tensor) - break +def test_sanity_convergence_ddp(tmp_path, recipe_path): + """Test that the main function can be invoked wrapping the model in DDP.""" + # Run the training script with Hydra configuration overrides + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + ], + ) -def test_train_lora(): - dataloader = create_dataloader(use_sanity_dataset=True) - loss = train_lora(dataloader) - assert loss < 1.5 + final_loss = main_ddp(sanity_config) + assert final_loss < 3.0, f"Final loss {final_loss} is too high" + + +def test_sanity_convergence_ddp_non_streaming_dataset(tmp_path, recipe_path): + """Test that the training script works with a non-streaming dataset.""" + + # Run the training script with Hydra configuration overrides + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "dataset.load_dataset_kwargs.streaming=False", + ], + ) + + final_loss = main_ddp(sanity_config) + assert final_loss < 3.0, f"Final loss {final_loss} is too high" + + +def test_sanity_ddp_thd(tmp_path, monkeypatch, recipe_path): + if torch.cuda.get_device_capability() == (12, 0): + # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default, + # but it's missing this THD implementation. + monkeypatch.setenv("NVTE_FUSED_ATTN", "0") + + # For DDP, we only check that the script can run successfully with THD, not convergence. + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "use_sequence_packing=true", + "num_train_steps=4", + ], + ) + + main_ddp(sanity_config) diff --git a/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora_two_gpus.py b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora_two_gpus.py new file mode 100644 index 0000000000..c822eeefa5 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/tests/test_train_lora_two_gpus.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These tests don't check convergence, they just check that the training script runs successfully on multiple GPUs. + +import subprocess + +import pytest +import torch + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def run_train_cmd(cmd, recipe_path): + """Run a training command and check for errors.""" + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + cwd=str(recipe_path), + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}") + + +@requires_multi_gpu +def test_multi_gpu_train_te_ddp(tmp_path, recipe_path): + # Run 'accelerate launch train.py' as a subprocess + run_train_cmd( + [ + "torchrun", + "--nproc_per_node", + "2", + "--standalone", + "train_lora_ddp.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + ], + recipe_path, + ) diff --git a/bionemo-recipes/recipes/esm2_peft_te/train_lora.py b/bionemo-recipes/recipes/esm2_peft_te/train_lora.py deleted file mode 100644 index 3e3e2b03db..0000000000 --- a/bionemo-recipes/recipes/esm2_peft_te/train_lora.py +++ /dev/null @@ -1,158 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Demonstration of LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT. - -Still needs: - - [ ] Hydra config management. - - [ ] THD / sequence packing. - - [ ] DDP / Multi-node training. - - [ ] FP8 tests. - - [ ] Perf / wandb logging. -""" - -import peft -import torch -from datasets import load_dataset -from tqdm import tqdm -from transformers import ( - AutoConfig, - AutoModelForTokenClassification, - AutoTokenizer, - DataCollatorForTokenClassification, -) - - -def create_dataloader(use_sanity_dataset: bool = False) -> torch.utils.data.DataLoader: - """Create a dataloader for the secondary structure dataset.""" - if use_sanity_dataset: - # 5000 row sanity dataset. - ss_dataset = load_dataset("parquet", data_files="peft_sanity_dataset.parquet", split="train", streaming=True) - else: - # Full-scale source dataset. - ss_dataset = load_dataset("lamm-mit/protein_secondary_structure_from_PDB", split="train", streaming=True) - - ss_token_map = {"H": 0, "E": 1, "I": 2, "S": 3, "T": 4, "C": 5, "B": 6, "G": 7, "~": -100} - - tokenizer = AutoTokenizer.from_pretrained("example_8m_checkpoint") - tokenize_args = { - "max_length": 128, - "truncation": True, - "stride": 16, # TODO: figure this out later - "return_overflowing_tokens": True, - "return_offsets_mapping": True, - } - - def tokenize(example): - """Tokenize both the input protein sequence and the secondary structure labels.""" - result = tokenizer(example["Sequence"], **tokenize_args) - - # While we can use the rust-based tokenizer for the protein sequence, we manually encode the secondary structure - # labels. Our goal is to return a list of integer labels with the same shape as the input_ids. - labels = [] - for batch_idx in range(len(result["input_ids"])): - sequence_labels = [] - - # This array maps the possibly-chunked result["input_ids"] to the original sequence. Because of - # `return_overflowing_tokens`, each input sequence may be split into multiple input rows. - offsets = result["offset_mapping"][batch_idx] - - # This gets the original secondary structure sequence for the current chunk. - ss_sequence = example["Secondary_structure"][result["overflow_to_sample_mapping"][batch_idx]] - - for offset_start, offset_end in offsets: - if offset_start == offset_end: - sequence_labels.append(-100) # Start and end of the sequence tokens can be ignored. - elif offset_end == offset_start + 1: # All tokens are single-character. - ss_char = ss_sequence[offset_start] - ss_label_value = ss_token_map[ss_char] # Encode the secondary structure character - sequence_labels.append(ss_label_value) - else: - raise ValueError(f"Invalid offset: {offset_start} {offset_end}") - - labels.append(sequence_labels) - - return {"input_ids": result["input_ids"], "labels": labels} - - tokenized_dataset = ss_dataset.map( - tokenize, - batched=True, - remove_columns=[col for col in ss_dataset.features if col not in ["input_ids", "labels"]], - ) - - collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding="max_length", max_length=1024) - dataloader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=16, collate_fn=collator) - - return dataloader - - -def train_lora(dataloader: torch.utils.data.DataLoader) -> float: - """Training loop for LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT. - - Args: - dataloader: DataLoader for the secondary structure dataset. - - Returns: - Final loss value. - """ - # For testing, we don't want to depend on loading pre-trained weights. - config = AutoConfig.from_pretrained("example_8m_checkpoint", trust_remote_code=True) - config.num_labels = 8 - model = AutoModelForTokenClassification.from_config(config, trust_remote_code=True) - - # Alternatively, we'd want to load an actual pre-trained checkpoint. - # model = AutoModelForTokenClassification.from_pretrained( - # "example_8m_checkpoint", num_labels=8, trust_remote_code=True, dtype="bfloat16" - # ) - - peft_config = peft.LoraConfig( - task_type=peft.TaskType.TOKEN_CLS, - inference_mode=False, - r=16, - lora_alpha=16, - # target_modules=["layernorm_qkv"], # TODO: figure out if this could work? - target_parameters=["layernorm_qkv.weight"], - bias="none", - ) - - peft_model = peft.get_peft_model(model, peft_config) - peft_model.to("cuda", dtype=torch.bfloat16) - - # Create optimizer. - optimizer = torch.optim.AdamW(peft_model.parameters(), lr=1e-3, weight_decay=0.01) - - # Training loop. - step = 0 - with tqdm(dataloader, desc="Training") as progress_bar: - for batch in progress_bar: - batch = {k: v.to("cuda") for k, v in batch.items()} # noqa PLW2901 - output = peft_model(**batch) - loss = output.loss - loss.backward() - progress_bar.set_postfix({"loss": loss.item()}) - - # Step optimizer. - optimizer.step() - optimizer.zero_grad() - - step += 1 - if step >= 100: - return loss.item() - - -if __name__ == "__main__": - dataloader = create_dataloader(use_sanity_dataset=True) - train_lora(dataloader) diff --git a/bionemo-recipes/recipes/esm2_peft_te/train_lora_convnet.py b/bionemo-recipes/recipes/esm2_peft_te/train_lora_convnet.py new file mode 100644 index 0000000000..7168ad00de --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/train_lora_convnet.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Demonstration of LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT.""" + +import logging +from pathlib import Path + +import hydra +import peft +import torch +from modeling_esm_te import NVEsmForConvTokenClassification +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from transformers import ( + AutoConfig, +) + +from checkpoint import save_final_model_ddp +from dataset import create_dataloader +from distributed_config import DistributedConfig +from perf_logger import PerfLogger +from scheduler import get_linear_schedule_with_warmup +from utils import ( + SS3_ID2LABEL, + SS3_LABEL2ID, + SS8_ID2LABEL, + SS8_LABEL2ID, + compute_accuracy, + get_parameter_names_with_lora, +) + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float: + """Training loop for LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT. + + Args: + args: Configuration arguments from hydra. + + Returns: + Final loss value. + """ + # Initialize the distributed configuration, including creating the distributed process group. + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + train_dataloader, val_dataloader, train_dataset_or_sampler = create_dataloader( + distributed_config=dist_config, + use_sequence_packing=args.use_sequence_packing, + **OmegaConf.to_container(args.dataset, resolve=True), + ) + + perform_validation = val_dataloader is not None + + # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2 + # and MFSDP. + device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("ddp",)) + + config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True) + if args.use_sequence_packing: + config.attn_input_format = "thd" + + if args.dataset["ss3_classification"]: + config.id2label = SS3_ID2LABEL + config.label2id = SS3_LABEL2ID + else: + config.id2label = SS8_ID2LABEL + config.label2id = SS8_LABEL2ID + + if args.use_pretrained: + model = NVEsmForConvTokenClassification.from_pretrained( + args.model_tag, config=config, trust_remote_code=True, dtype="bfloat16" + ) + else: + model = NVEsmForConvTokenClassification.from_config(config, trust_remote_code=True) + + peft_config = peft.LoraConfig( + task_type=peft.TaskType.TOKEN_CLS, + inference_mode=False, + r=args.lora.r, + lora_alpha=args.lora.alpha, + target_modules=list(args.lora.target_modules), + bias="none", + ) + + peft_model = peft.get_peft_model(model, peft_config) + peft_model.to(device=device, dtype=torch.bfloat16) + + print("----- PEFT Model --------") + peft_model.print_trainable_parameters() + + # Create optimizer. + decay_parameters = get_parameter_names_with_lora(peft_model) + optimizer_grouped_parameters = [ + { + "params": [p for n, p in peft_model.named_parameters() if (n in decay_parameters and p.requires_grad)], + "weight_decay": args.adamw_kwargs.weight_decay, + }, + { + "params": [p for n, p in peft_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **args.adamw_kwargs) + + scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + peft_model = torch.nn.parallel.DistributedDataParallel( + peft_model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["ddp"], + find_unused_parameters=True, + ) + + if args.use_torch_compile: + # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. + peft_model = torch.compile(peft_model) + + perf_logger = PerfLogger(dist_config, args) + + # Training loop. + step = 0 + epoch = 0 + while step < args.num_train_steps: + for batch in train_dataloader: + perf_logger.log_train_start_time() + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901 + + output = peft_model(**batch) + loss = output.loss + loss.backward() + + # Compute and clip gradient norms. + total_norm = torch.nn.utils.clip_grad_norm_(peft_model.parameters(), max_norm=1.0).item() + + # Step optimizer. + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + step += 1 + + perf_logger.log_train_end_time() + # Validation + avg_val_loss = None + avg_val_acc = None + if perform_validation and step % args.validation_interval == 0: + peft_model.eval() + val_loss_total = 0.0 + val_correct_total = 0 + val_tokens_total = 0 + val_steps = 0 + with torch.no_grad(): + for val_batch in val_dataloader: + val_batch = { # noqa: PLW2901 + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in val_batch.items() + } + val_output = peft_model(**val_batch) + + # Loss + val_loss_total += val_output.loss.item() + + # Accuracy + logits = val_output.logits + labels = val_batch["labels"] + correct, total = compute_accuracy(logits, labels) + val_correct_total += correct + val_tokens_total += total + + val_steps += 1 + + avg_val_loss = val_loss_total / val_steps + avg_val_acc = val_correct_total / val_tokens_total if val_tokens_total > 0 else 0.0 + print(f"\nStep: {step}: Validation Loss = {avg_val_loss:.4f}, Accuracy: {avg_val_acc:.4f}\n") + peft_model.train() + + perf_logger.log_step( + step=step, + batch=batch, + outputs=output, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + val_loss=avg_val_loss, + val_acc=avg_val_acc, + ) + + if step >= args.num_train_steps: + break + + # Dataloader exhausted, incrementing epoch + epoch += 1 + train_dataset_or_sampler.set_epoch(epoch) + + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None + + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_ddp( + model=peft_model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_peft_te/train_lora_ddp.py b/bionemo-recipes/recipes/esm2_peft_te/train_lora_ddp.py new file mode 100644 index 0000000000..62fe36ab05 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/train_lora_ddp.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Demonstration of LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT using DDP.""" + +import logging +from pathlib import Path + +import hydra +import peft +import torch +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, +) + +from checkpoint import save_final_model_ddp +from dataset import create_dataloader +from distributed_config import DistributedConfig +from perf_logger import PerfLogger +from scheduler import get_linear_schedule_with_warmup +from utils import ( + SS3_ID2LABEL, + SS3_LABEL2ID, + SS8_ID2LABEL, + SS8_LABEL2ID, + compute_accuracy, + get_parameter_names_with_lora, +) + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") +def main(args: DictConfig) -> float: + """Training loop for LoRA fine-tuning of ESM-2 with Transformer Engine and PEFT. + + Args: + args: Configuration arguments from hydra. + + Returns: + Final loss value. + """ + # Initialize the distributed configuration, including creating the distributed process group. + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + train_dataloader, val_dataloader, train_dataset_or_sampler = create_dataloader( + distributed_config=dist_config, + use_sequence_packing=args.use_sequence_packing, + **OmegaConf.to_container(args.dataset, resolve=True), + ) + + perform_validation = val_dataloader is not None + + # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2 + # and MFSDP. + device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("ddp",)) + + # For testing, we don't want to depend on loading pre-trained weights. + config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True) + if args.use_sequence_packing: + config.attn_input_format = "thd" + + if args.dataset["ss3_classification"]: + config.id2label = SS3_ID2LABEL + config.label2id = SS3_LABEL2ID + else: + config.id2label = SS8_ID2LABEL + config.label2id = SS8_LABEL2ID + + if args.use_pretrained: + model = AutoModelForTokenClassification.from_pretrained( + args.model_tag, config=config, trust_remote_code=True, dtype="bfloat16" + ) + else: + model = AutoModelForTokenClassification.from_config(config, trust_remote_code=True) + + peft_config = peft.LoraConfig( + task_type=peft.TaskType.TOKEN_CLS, + inference_mode=False, + r=args.lora.r, + lora_alpha=args.lora.alpha, + target_modules=list(args.lora.target_modules), + bias="none", + ) + + peft_model = peft.get_peft_model(model, peft_config) + peft_model.to(device=device, dtype=torch.bfloat16) + + print("----- PEFT Model --------") + peft_model.print_trainable_parameters() + + # Create optimizer. + decay_parameters = get_parameter_names_with_lora(peft_model) + optimizer_grouped_parameters = [ + { + "params": [p for n, p in peft_model.named_parameters() if (n in decay_parameters and p.requires_grad)], + "weight_decay": args.adamw_kwargs.weight_decay, + }, + { + "params": [p for n, p in peft_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **args.adamw_kwargs) + + scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + peft_model = torch.nn.parallel.DistributedDataParallel( + peft_model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["ddp"], + find_unused_parameters=True, + ) + + if args.use_torch_compile: + # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. + peft_model = torch.compile(peft_model) + + perf_logger = PerfLogger(dist_config, args) + + # Training loop. + step = 0 + epoch = 0 + while step < args.num_train_steps: + for batch in train_dataloader: + perf_logger.log_train_start_time() + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901 + + output = peft_model(**batch) + loss = output.loss + loss.backward() + + # Compute and clip gradient norms. + total_norm = torch.nn.utils.clip_grad_norm_(peft_model.parameters(), max_norm=1.0).item() + + # Step optimizer. + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + step += 1 + + perf_logger.log_train_end_time() + # Validation + avg_val_loss = None + avg_val_acc = None + if perform_validation and step % args.validation_interval == 0: + peft_model.eval() + val_loss_total = 0.0 + val_correct_total = 0 + val_tokens_total = 0 + val_steps = 0 + with torch.no_grad(): + for val_batch in val_dataloader: + val_batch = { # noqa: PLW2901 + k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in val_batch.items() + } + val_output = peft_model(**val_batch) + + # Loss + val_loss_total += val_output.loss.item() + + # Accuracy + logits = val_output.logits + labels = val_batch["labels"] + correct, total = compute_accuracy(logits, labels) + val_correct_total += correct + val_tokens_total += total + + val_steps += 1 + + avg_val_loss = val_loss_total / val_steps + avg_val_acc = val_correct_total / val_tokens_total if val_tokens_total > 0 else 0.0 + print(f"\nStep: {step}: Validation Loss = {avg_val_loss:.4f}, Accuracy: {avg_val_acc:.4f}\n") + peft_model.train() + + perf_logger.log_step( + step=step, + batch=batch, + outputs=output, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + val_loss=avg_val_loss, + val_acc=avg_val_acc, + ) + + if step >= args.num_train_steps: + break + + # Dataloader exhausted, incrementing epoch + epoch += 1 + train_dataset_or_sampler.set_epoch(epoch) + + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None + + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_ddp( + model=peft_model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/esm2_peft_te/utils.py b/bionemo-recipes/recipes/esm2_peft_te/utils.py new file mode 100644 index 0000000000..7296659100 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_peft_te/utils.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +from collections import defaultdict +from pathlib import Path + +import torch +from transformers.trainer_pt_utils import get_parameter_names + + +SS3_ID2LABEL = {0: "H", 1: "E", 2: "C"} + +SS3_LABEL2ID = { + "H": 0, + "I": 0, + "G": 0, + "E": 1, + "B": 1, + "S": 2, + "T": 2, + "~": 2, + "C": 2, + "L": 2, +} # '~' denotes coil / unstructured + +SS8_ID2LABEL = {0: "H", 1: "I", 2: "G", 3: "E", 4: "B", 5: "S", 6: "T", 7: "C"} + +SS8_LABEL2ID = { + "H": 0, + "I": 1, + "G": 2, + "E": 3, + "B": 4, + "S": 5, + "T": 6, + "~": 7, + "C": 7, + "L": 7, +} # '~' denotes coil / unstructured + + +def compute_accuracy(preds, labels, ignore_index=-100) -> tuple[int, int]: + """Calculate the accuracy.""" + preds_labels = torch.argmax(preds, dim=-1) + mask = labels != ignore_index + correct = (preds_labels == labels) & mask + + return correct.sum().item(), mask.sum().item() + + +def get_parameter_names_with_lora(model): + """Get layers with non-zero weight decay. + + This function reuses the Transformers' library function + to list all the layers that should have weight decay. + """ + forbidden_name_patterns = [ + r"bias", + r"layernorm", + r"rmsnorm", + r"(?:^|\.)norm(?:$|\.)", + r"_norm(?:$|\.)", + r"\.lora_[AB]\.", + ] + + decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm], forbidden_name_patterns) + + return decay_parameters + + +def load_fasta(path: Path) -> list[dict]: + """Read FASTA file and return input sequences.""" + records = [] + seq, pdb_id = [], None + + with open(path) as f: + for raw_line in f: + line = raw_line.strip() + if line.startswith(">"): + if seq: + records.append({"pdb_id": pdb_id, "sequence": "".join(seq)}) + pdb_id = line[1:] or None + seq = [] + else: + seq.append(line) + + if seq: + records.append({"pdb_id": pdb_id, "sequence": "".join(seq)}) + + return records + + +def load_csv(path: Path) -> list[dict]: + """Read input CSV file for inference. + + It is assumed that the input CSV file contains: + - Optional column named 'pdb_id' of the sequence. + - Aminoacid sequence. + """ + with open(path) as f: + reader = csv.DictReader(f) + has_pdb_id = "pdb_id" in reader.fieldnames + + return [ + { + "pdb_id": row["pdb_id"] if has_pdb_id else None, + "sequence": row["sequence"], + } + for row in reader + ] + + +def load_input(path: Path) -> list[dict]: + """Read the input sequences from FASTA or CSV file.""" + suffix = path.suffix.lower() + + if suffix == ".csv": + return load_csv(path) + elif suffix in {".fa", ".fasta", ".faa"}: + return load_fasta(path) + else: + raise ValueError(f"Unsupported input format: {suffix}") + + +def format_output_rows(records, predictions, sequences_to_sample_mapping): + """Format the output into CSV-type lines. + + Returns: + header: list[str] + rows: list[tuple[str, str]] + """ + has_pdb_id = any(r.get("pdb_id") for r in records) + header = ["pdb_id", "prediction"] if has_pdb_id else ["id", "prediction"] + + counts = defaultdict(int) + rows = [] + + for pred, orig_idx in zip(predictions, sequences_to_sample_mapping): + counts[orig_idx] += 1 + suffix = counts[orig_idx] + + base = records[orig_idx]["pdb_id"] if has_pdb_id else str(orig_idx) + + out_id = base if suffix == 1 else f"{base}_{suffix}" + rows.append((out_id, pred)) + + return header, rows + + +def write_output(records, predictions, sequences_to_sample_mapping: list[int], output_path: Path): + """Write the predictions to an output file.""" + header, rows = format_output_rows(records, predictions, sequences_to_sample_mapping) + + with open(output_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(header) + writer.writerows(rows) diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 5d391d5fb0..00f350d8b5 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -33,6 +33,7 @@ "bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py": [ "bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py", "bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py", + "bionemo-recipes/recipes/esm2_peft_te/example_nv_esm2_t6_8M_UR50D_peft_checkpoint/esm_nv.py", "bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py", ], "bionemo-recipes/models/esm2/src/esm/collator.py": [