Transformers with Linear Attention enable fast and parallel training. Moreover, they can be formulated as Recurrent Neural Networks (RNNs), for efficient linear-time inference. While extensively evaluated in causal sequence modeling, they have yet to be extended to the bi-directional setting. We introduce the LION framework, establishing new theoretical foundations for Linear Transformers in bi-directional sequence modeling. LION constructs a bi-directional RNN equivalent to full Linear Attention. This extends the benefits of Linear Transformers: parallel training and efficient inference into the bi-directional setting.
Existing memory-efficient bi-directional models employ more than 2x the training time of a Transformer. Our Linear Attention framework benefits from memory-efficient inference while maintaining the Transformer training speed.
| Task | 🦁-🔥 | 🦁-D | 🦁-S | Hydra | Vim |
|---|---|---|---|---|---|
| Vision | |||||
| MLM | N.A |
Using LION, we cast three Linear Transformers to their bi-directional form:
- LION-️🔥, the bi-directional variant corresponding to LinearTransformer.
- LION-D, extending RetNet.
- LION-S, a Linear Transformer with a stable selective mask inspired by the selectivity of SSMs like Mamba🐍.
By replacing the attention block with LION (-️🔥, -D, -S), we achieve performance on bi-directional tasks that is comparable to Transformers and State-Space Models (SSMs) while improving training speed.
This repository provides the code for the LION model, covering image classification and masked language modeling (MLM). Our image classification setup is adapted from DeiT, and the MLM implementation builds on M2-BERT.
Setup: Please follow the instructions from the DeiT library to configure the environment.
Within the Image Classification folder, you’ll find models_lion.py, which contains the implementations of LION-🔥, LION-D, and LION-S in three formats: attention, recurrent and chunk-based. We also introduce specialized “curves.py” for processing image patches in LION-D and LION-S, enhancing spatial representation as discussed in our paper with notation LION-D/S♮.
Below is an example of how to run LION-D for image classification from scratch, followed by a command that demonstrates LION-S♮ training using “curves” and altered patch orders:
# Example 1: Train LION-D from scratch
python -m torch.distributed.launch --nproc_per_node=4 --use_env main_lion.py \
--model lion_base_patch16_224 \
--batch-size 256 \
--data-path /datapath \
--output_dir /outputpath# Example 2: Train LION-S (or LION-D) with curves and patch-order changes
python -m torch.distributed.launch --nproc_per_node=4 --use_env main_lion.py \
--model lion_base_patch16_224 \
--batch-size 256 \
--data-path /datapath \
--output_dir /outputpath \
--mask_type Selective \
--order S \
--format AttentionInside models_lion, there are 3 sizes defined as:
- LION in base scale (86M) with an image size of 224, called
lion_base_patch16_224 - LION in small scale (22M) with an image size of 224, called
lion_small_patch16_224 - LION in tiny scale (5M) with an image size of 224, called
lion_tiny_patch16_224
Below are some of the key arguments you can customize when training LION-based models:
-
pos_embEnables fixed positional embeddings (as in ViT) (defaultFalse).- To set True:
--pos_emb
- To set True:
-
cls_tokUses an independent classification token if set toTrue; otherwise, classification is based on the average pooling of all tokens (defaultFalse).- To set True:
--cls_tok
- To set True:
-
mask_typeDefines how masking or gating is applied. Supported options includeLit,Decay, andSelectivewhich correspond to LION-🔥, LION-D, and LION-S respectively.- Example usage:
--mask_type Decay
- Example usage:
-
orderSpecifies the order in which image patches are processed. Options include:Normal(default order)S(special ordering)- Example usage:
--order S
-
formatControls the internal representation of the sequence. Valid options are:Attention(standard attention-like format)RNN(recurrent-like format)Chunk(chunk-based approach)- Example usage:
--format Attention
-
chunk_sizeAn integer that sets the size of chunks when using chunk-based processing.- Example usage:
--chunk_size 64
- Example usage:
By combining these arguments, you can experiment with different positional embeddings, classification tokens, patch orders, and masking mechanisms to adapt the LION model to your specific tasks and preferences.
Notes:
- Choose any desired size (e.g.,
lion_base_patch16_224,lion_small_patch16_224orlion_tiny_patch16_224). - By changing the
mask_type, get different LION variants (e.g., LION-🔥, LION-D or LION-S). - Determine the internal representation format with
format(e.g.,Attentionfor training,RNNorChunkfor inference). - Adjust
nproc_per_node,batch-size,data-path, andoutput_diraccording to your hardware setup and dataset location. - The additional flags (
order,pos_emb,cls_tok) control the specific training variations (e.g., changing patch-order “S,”, adding positional embeddings and using a classification token). - As our codebase extends DeiT, you can easily distill RegNET into a LION model by following the same distillation commands used for DeiT—just swap in the LION model name. This ensures you can leverage the established DeiT distillation process without additional modifications.
Below are the results on Image Classification with ImageNet-1K for LION models vs benchmarks.
| Model | #Param | ImageNet Top-1 Acc. | Train. time |
|---|---|---|---|
| 86M | |||
| 86M | |||
| 104M | |||
| 98M | |||
| 86M | |||
| 86M | |||
| 86M | |||
| 86M | |||
| 86M |
To install the dependencies, run pip install -r requirements.txt.
In order to download the C4 dataset, run src/convert_dataset.py.
Select the according config.yaml file from the yamls/pretrain folder. Make sure to modify the data_local path to match the location of the C4 dataset on your machine.
To pretrain the model, run main.py with the desired configuration file. For example, to pretrain a Lion-Lit-Large model, run
composer main.py yamls/pretrain/lion-lit-large.yamlSelect the according config.yaml file from the yamls/finetune-glue folder. Make sure to modify the starting_checkpoint_load_path to match the location of the checkpoint of the pretraining run you want to finetune.
To finetune the model, run glue.py with the desired configuration file. For example, to finetune a Lion-Lit-Large model on GLUE, run
python3 glue.py yamls/finetune-glue/lion-lit-large.yamlBelow are the results on MLM with C4 Dataset for LION models vs benchmarks.
| Model | MLM Acc. | GLUE | Train. time |
|---|---|---|---|
| BERT | |||
| Hydra | |||
| LION-🔥 | |||
| LION-D | |||
| LION-S |
