code for Structured Word Embedding for Low Memory Neural Network Language Model
The code repo for basis embedding to reduce model size and memory consumption This repo is built based on the pytorch/examples repo on github
basis embedding related arguments:
--basis<0>: number of basis to decompose the embedding matrix, 0 is normal mode--num_clusters: number of clusters for all the vocabulary--load_input_embedding: path of pre-trained embedding matrix for input embedding--load_output_embedding: path of pre-trained embedding matrix for output embedding
misc options:
-cor--config: the path for configuration file, it will override arguments parser's default values and be overrided by command line options--train: train or just evaluation existing model--dict <None>: use vocabulary file if specified, otherwise use the words in train.txt
python main.py -c config/default.conf # train a cross-entropy baseline
python main.py -c config/ptb_basis_tied.conf # basis embedding inited via tied embedding on ptbDuring training, if a keyboard interrupt (Ctrl-C) is received, training is stopped and the current model is evaluted against the test dataset.
The main.py script accepts the following arguments:
optional arguments:
-h, --help show this help message and exit
-c, --config PATH preset configurations to load
--data DATA location of the data corpus
--model MODEL type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)
--emsize EMSIZE size of word embeddings
--nhid NHID humber of hidden units per layer
--nlayers NLAYERS number of layers
--lr LR initial learning rate
--clip CLIP gradient clipping
--epochs EPOCHS upper epoch limit
--batch-size N batch size
--dropout DROPOUT dropout applied to layers (0 = no dropout)
--tied tie the word embedding and softmax weights
--seed SEED random seed
--cuda use CUDA
--log-interval N report interval
--save SAVE path to save the final model
... more from previous basis embedding related parametersThe basis decoder now uses a Triton kernel on CUDA devices for the codebook
decode step in the output BasisLinear module. The dense centroid projection
still uses torch.bmm; Triton replaces the expensive gather/sum decode and its
backward scatter. CPU and non-Triton environments fall back to the original
PyTorch implementation.
Benchmark command:
python benchmarks/benchmark_basis_vs_linear.py \
--tokens 128 \
--hidden-sizes 2048 4096 \
--vocab-sizes 50000 100000 150000 200000 \
--num-basis 8 \
--clusters 384 \
--dtype float16Measured on an NVIDIA H20 with fp16 tensors:
| Hidden | Vocab | Full linear fwd (ms) | Basis linear fwd (ms) | Fwd speedup | Full linear fwd+bwd (ms) | Basis linear fwd+bwd (ms) | Total speedup |
|---|---|---|---|---|---|---|---|
| 2048 | 50k | 0.206 | 0.071 | 2.90x | 0.620 | 0.351 | 1.77x |
| 2048 | 100k | 0.384 | 0.087 | 4.39x | 1.191 | 0.441 | 2.70x |
| 2048 | 150k | 0.571 | 0.128 | 4.47x | 1.760 | 0.452 | 3.89x |
| 2048 | 200k | 0.757 | 0.160 | 4.73x | 2.326 | 0.599 | 3.88x |
| 4096 | 50k | 0.399 | 0.068 | 5.83x | 1.192 | 0.353 | 3.37x |
| 4096 | 100k | 0.759 | 0.088 | 8.62x | 2.310 | 0.350 | 6.60x |
| 4096 | 150k | 1.125 | 0.126 | 8.91x | 3.426 | 0.442 | 7.76x |
| 4096 | 200k | 1.503 | 0.161 | 9.34x | 4.567 | 0.563 | 8.11x |
Profiler summary for hidden=4096, vocab=100k:
- Full linear spent about 2.29 ms of CUDA time, dominated by the large GEMMs
for
x @ W.Tand the fullV x Hweight gradient. - Basis linear spent about 0.27 ms of CUDA time: roughly 74 us in the Triton
decode forward kernel, 155 us in the Triton decode backward scatter kernel,
and 13 us in the bias-gradient kernel. The centroid
bmmwork is small compared with the dense full-linear GEMMs.
End-to-end PTB smoke test:
python main.py -c /tmp/ptb_full_test.conf
python main.py -c /tmp/ptb_basis_test.confBoth runs used emsize=200, nhid=200, nlayers=2, batch size 20, seed
1111, and two training epochs on the full PTB train split. The full-linear
baseline trained both epochs with the dense output layer. The basis-linear run
trained one epoch with the dense output layer, then enabled an 8-basis,
384-cluster output decoder for epoch 2.
| Decoder | Epoch 1 valid PPL | Epoch 2 valid PPL | Test PPL |
|---|---|---|---|
| Full linear | 209.21 | 176.28 | 170.86 |
| Basis linear output | 211.59 | 183.45 | 177.29 |
The short basis run is within about 3.8% test PPL of the full-linear baseline. This is a smoke test rather than a tuned PTB result; longer training and hyperparameter tuning should be used for final quality numbers.
- main.py: the entry file, it parses the parameters, defines models and feeds the data to model
- model.py: define the input embedding and LSTM layer
- basis_loss.py: It contains a basis linear module, taking inputs from LSTM hidden state and outputing loss value.
- basis/: core part of the basis embedding module
- utils.py: do product quantization for pre-trained embedding
- data.py: data pre-processing
- .th/.th.decoder: the pre-trained embedding matrix
- .conf: sample configuration files