diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index e1848f46a3..982ffd88df 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -237,6 +237,9 @@ def __call__(self, features, return_tensors=None): batch["input_ids"], special_tokens_mask=special_tokens_mask ) + if self.pad_to_multiple_of is not None and self.pad_sequences_to_be_divisible_by is not None: + raise ValueError("pad_to_multiple_of and pad_sequences_to_be_divisible_by cannot be used together") + if self.pad_to_multiple_of is not None: batch = self._pad_batch_to_multiple_of(batch) diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index fa8ad45599..1b246f99ae 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -206,12 +206,14 @@ def forward( "cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs." ) + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. with torch.autocast(device_type="cuda", enabled=False): if self.config.position_embedding_type == "rotary": if self.config.attn_input_format == "bshd": te_rope_emb = self.rotary_embeddings(max_seq_len=hidden_states.shape[1]) elif self.config.attn_input_format == "thd": + # This is correct. te_rope_emb = self.rotary_embeddings( max_seq_len=kwargs["cu_seq_lens_q_padded"][-1] if "cu_seq_lens_q_padded" in kwargs @@ -222,7 +224,7 @@ def forward( for layer_module in self.layers: if kwargs.get("output_hidden_states", False): all_hidden_states = (*all_hidden_states, hidden_states) - + hidden_states = layer_module( hidden_states, attention_mask, @@ -235,7 +237,6 @@ def forward( max_seqlen_kv=kwargs.get("max_length_k", None), pad_between_seqs=kwargs.get("pad_between_seqs", None), ) - hidden_states = self.emb_layer_norm_after(hidden_states) if kwargs.get("output_hidden_states", False): diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index 3bc3635ba7..f8ce5bbeda 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -22,6 +22,7 @@ dataset: path: "nvidia/esm2_uniref_pretraining_data" split: "train" streaming: True + pad_sequences_to_be_divisible_by: null # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 293eafe84d..275b96c03a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -68,7 +68,7 @@ def main(args: DictConfig) -> float | None: ) # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) + config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, token_dropout=False, dtype=torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -130,7 +130,10 @@ def main(args: DictConfig) -> float | None: step = start_step while step < args.num_train_steps: for batch in train_dataloader: + # import pdb; pdb.set_trace() batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + batch["pad_between_seqs"] = True + # Forward pass with mixed precision. with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):