diff --git a/flame/models/parallelize_fla.py b/flame/models/parallelize_fla.py index 37178af..141fe59 100644 --- a/flame/models/parallelize_fla.py +++ b/flame/models/parallelize_fla.py @@ -399,9 +399,6 @@ def apply_compile(model: nn.Module): lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True) model.register_module(lm_head_key, lm_head) - logger.info("Compiling the entire model with torch.compile") - model = torch.compile(model) - def apply_fsdp( model: nn.Module,