Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/unit-tests-recipes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ jobs:
with:
sparse-checkout: "${{ matrix.recipe.dir }}"
sparse-checkout-cone-mode: false
- name: Include symlink targets for esm2_peft_te
if: ${{ matrix.recipe.dir == 'bionemo-recipes/recipes/esm2_peft_te' }}
run: git -c safe.directory=/__w/bionemo-framework/bionemo-framework sparse-checkout add bionemo-recipes/recipes/esm2_native_te

- name: Cache Hugging Face models
uses: actions/cache@v4
Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/models/esm2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"jinja2",
"megatron-fsdp",
"omegaconf",
"peft",
"peft @ git+https://github.com/balvisio/peft.git@dev/ba/support-te-lora",
"pytest",
"torch",
"torchao!=0.14.0",
Expand Down
78 changes: 78 additions & 0 deletions bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,81 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class NVConvNetHead(nn.Module):
"""Convolution based head for token classification."""

def __init__(self, config: NVEsmConfig):
"""Initialize the NVConvNetHead."""
super().__init__()
self.conv_head = torch.nn.Sequential(
torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3),
torch.nn.ReLU(),
torch.nn.Dropout(config.hidden_dropout_prob),
torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3),
)

def forward(self, features, **kwargs):
"""Forward pass for the convolutional token classification head."""
return self.conv_head(features).transpose(1, 2)


class NVEsmForConvTokenClassification(NVEsmPreTrainedModel):
"""Adds a convolutional classification head to the model."""

def __init__(self, config):
"""Initialize NVEsmForTokenClassification."""
super().__init__(config)
self.num_labels = config.num_labels

self.esm = NVEsmModel(config, add_pooling_layer=False)
self.classifier = NVConvNetHead(config)

self.init_weights()
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> TokenClassifierOutput:
"""Forward pass for the token classification head.

labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)

if outputs[0].dim() == 3:
sequence_output = outputs[0]
else:
sequence_output = outputs[0].unsqueeze(0)

sequence_output = sequence_output.transpose(1, 2)

logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()

labels = labels.to(logits.device)
loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1))

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
9 changes: 5 additions & 4 deletions bionemo-recipes/models/esm2/tests/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import peft
import pytest
import torch

from esm.modeling_esm_te import NVEsmForMaskedLM
Expand Down Expand Up @@ -58,7 +59,6 @@ def test_lora_model_forward_pass(te_model_checkpoint, input_data):
assert outputs.loss is not None


@pytest.mark.xfail(reason="BIONEMO-3136: LoRA model initializes with warnings because of TE layers.")
def test_lora_model_raises_no_warnings(te_model_checkpoint):
model = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)

Expand All @@ -71,13 +71,14 @@ def test_lora_model_raises_no_warnings(te_model_checkpoint):
bias="none",
)

with pytest.warns(UserWarning) as record:
with warnings.catch_warnings(record=True) as record:
# Cause all warnings to be triggered (default behavior may ignore some)
warnings.simplefilter("always")
peft.get_peft_model(model, peft_config)

assert len(record) == 0


@pytest.mark.xfail(reason="BIONEMO-3136: LoRA model initialization fails with target_modules because of TE layers.")
def test_lora_model_with_target_modules(te_model_checkpoint):
model = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,81 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class NVConvNetHead(nn.Module):
"""Convolution based head for token classification."""

def __init__(self, config: NVEsmConfig):
"""Initialize the NVConvNetHead."""
super().__init__()
self.conv_head = torch.nn.Sequential(
torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3),
torch.nn.ReLU(),
torch.nn.Dropout(config.hidden_dropout_prob),
torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3),
)

def forward(self, features, **kwargs):
"""Forward pass for the convolutional token classification head."""
return self.conv_head(features).transpose(1, 2)


class NVEsmForConvTokenClassification(NVEsmPreTrainedModel):
"""Adds a convolutional classification head to the model."""

def __init__(self, config):
"""Initialize NVEsmForTokenClassification."""
super().__init__(config)
self.num_labels = config.num_labels

self.esm = NVEsmModel(config, add_pooling_layer=False)
self.classifier = NVConvNetHead(config)

self.init_weights()
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> TokenClassifierOutput:
"""Forward pass for the token classification head.

labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)

if outputs[0].dim() == 3:
sequence_output = outputs[0]
else:
sequence_output = outputs[0].unsqueeze(0)

sequence_output = sequence_output.transpose(1, 2)

logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()

labels = labels.to(logits.device)
loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1))

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
9 changes: 9 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ def save_final_model_ddp(
underlying_model: transformers.PreTrainedModel = model.module if hasattr(model, "module") else model # type: ignore

os.makedirs(save_directory, exist_ok=True)
# If we are saving a PEFT model we also save the base_model config.
# This allows for an streamlined reload of the PEFT model without having to manually reconstruct the config of
# the base_model.
# For example:
# >>> config = AutoConfig.from_pretrained(<save_directory>)
# >>> base_model = AutoModelForTokenClassification.from_pretrained(<model.tag>, config=config)
# >>> peft_model = PeftModel.from_pretrained(base_model, <save_directory>)
if hasattr(underlying_model, "peft_config"):
underlying_model.config.save_pretrained(save_directory)
underlying_model.save_pretrained(save_directory, state_dict=underlying_model.state_dict(), safe_serialization=True)
logger.info(f"Saved final DDP model to {save_directory}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,81 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class NVConvNetHead(nn.Module):
"""Convolution based head for token classification."""

def __init__(self, config: NVEsmConfig):
"""Initialize the NVConvNetHead."""
super().__init__()
self.conv_head = torch.nn.Sequential(
torch.nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=7, padding=3),
torch.nn.ReLU(),
torch.nn.Dropout(config.hidden_dropout_prob),
torch.nn.Conv1d(config.hidden_size // 2, config.num_labels, kernel_size=7, padding=3),
)

def forward(self, features, **kwargs):
"""Forward pass for the convolutional token classification head."""
return self.conv_head(features).transpose(1, 2)


class NVEsmForConvTokenClassification(NVEsmPreTrainedModel):
"""Adds a convolutional classification head to the model."""

def __init__(self, config):
"""Initialize NVEsmForTokenClassification."""
super().__init__(config)
self.num_labels = config.num_labels

self.esm = NVEsmModel(config, add_pooling_layer=False)
self.classifier = NVConvNetHead(config)

self.init_weights()
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> TokenClassifierOutput:
"""Forward pass for the token classification head.

labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)

if outputs[0].dim() == 3:
sequence_output = outputs[0]
else:
sequence_output = outputs[0].unsqueeze(0)

sequence_output = sequence_output.transpose(1, 2)

logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()

labels = labels.to(logits.device)
loss = loss_fct(logits.reshape(-1, self.num_labels), labels.view(-1))

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
12 changes: 12 additions & 0 deletions bionemo-recipes/recipes/esm2_peft_te/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM nvcr.io/nvidia/pytorch:25.12-py3

RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=esm2_peft_te/requirements.txt,target=/requirements.txt \
PIP_CONSTRAINT= pip install -r /requirements.txt

WORKDIR /workspace/bionemo-recipes/recipes/esm2_peft_te
COPY esm2_peft_te/ /workspace/bionemo-recipes/recipes/esm2_peft_te
COPY esm2_native_te/checkpoint.py /workspace/bionemo-recipes/recipes/esm2_native_te/checkpoint.py
COPY esm2_native_te/collator.py /workspace/bionemo-recipes/recipes/esm2_native_te/collator.py
COPY esm2_native_te/distributed_config.py /workspace/bionemo-recipes/recipes/esm2_native_te/distributed_config.py
COPY esm2_native_te/scheduler.py /workspace/bionemo-recipes/recipes/esm2_native_te/scheduler.py
Loading