Implementation of a discrete RNN for text generation, designed just with JAX.
The model is trained on Montaigne's Essays to capture the structures of Old French, in order to try to do carachter prediction.
We need a lot of epochs for traning (500.000 at least) but with Jax and especially jit and lax.scan, it is very fast.
data/essais_montaigne.txt: Training datasrc/model.py: Sampling and RNN mathssrc/train.py: Optimisationsrc/utils.py: Token processormain.py: Training and weights savinginference.py: Inference
pip install -r requirement.txt