Skip to content

alexandre-martel/RNN-JAX-implementation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RNN-JAX from scratch

Logo JAX

Implementation of a discrete RNN for text generation, designed just with JAX.

The model is trained on Montaigne's Essays to capture the structures of Old French, in order to try to do carachter prediction.

We need a lot of epochs for traning (500.000 at least) but with Jax and especially jit and lax.scan, it is very fast.

Structure du Projet

  • data/essais_montaigne.txt : Training data
  • src/model.py : Sampling and RNN maths
  • src/train.py : Optimisation
  • src/utils.py : Token processor
  • main.py : Training and weights saving
  • inference.py : Inference

Installation

pip install -r requirement.txt

About

Implementation of RNNs with JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages