MexMa is a micro‑LLM designed to train and run on modest GPUs (e.g., GTX 1650 Ti, 4 GB VRAM) and small cloud GPUs. It includes simple training/inference scripts, a basic config system, and scaffolding for SFT, distillation, and RAG.
- ~24M params, decoder‑only (RMSNorm, RoPE), context 512 (CLI override supported), vocab 16k.
- Train on 300M clean tokens; then do a short SFT pass for instruction following.
configs/— YAML configs (seeconfigs/mexma-24m.yaml)src/mexma/— package (model, train loop, data utils, tokenizer)scripts/— CLI entry points (train, generate, data prep, RAG, SFT helpers)data/— raw/processed text (created locally)tokenizer/— tokenizer artifacts (tokenizer.json)outputs/— checkpoints and logsdocs/— checklist, playbook, model card
- Distillation/SFT:
scripts/build_distillation.py,scripts/format_sft.py - RAG:
scripts/rag_index.py,scripts/rag_generate.py - Tokenizer/data:
scripts/train_tokenizer.py,scripts/prepare_data.py,scripts/count_tokens.py - Inference/export:
scripts/generate.py,scripts/export_model_safetensors.py - Optional API:
scripts/api_server.py,scripts/generate_key.py
Install
pip install -U pip
pip install -r requirements.txt
pip install -e .Prepare data (example)
# put raw .txt into data/raw/ or sample a slice (FineWeb example)
python scripts/sample_fineweb.py --name CC-MAIN-2025-26 --tokens 200000000 --out data/raw/fineweb_200m.txt
python scripts/prepare_data.py --src data/raw --out dataTrain tokenizer (16k byte‑level BPE)
python scripts/train_tokenizer.py --corpus data/train --out tokenizer --vocab 16000Train / resume
# resume if a checkpoint exists; else omit --resume
python scripts/train_lm.py --config configs/mexma-24m.yaml --tokenizer tokenizer/tokenizer.json \
--resume outputs/mexma-24m/step-XXXXX.pt --max_steps 300000Generate
python scripts/generate.py --ckpt outputs/mexma-24m/step-XXXXX.pt \
--tokenizer tokenizer/tokenizer.json --prompt "In one sentence, what is MexMa?"# filter/format teacher outputs → SFT JSONL → flat txt for trainer
python scripts/build_distillation.py --input data/distill/raw_teacher.jsonl --out data/sft/distilled.jsonl
python scripts/format_sft.py --qa_jsonl data/sft/distilled.jsonl --out data/sft/chat.jsonl
python - <<'PY'
import json, pathlib
out=pathlib.Path('data/raw'); out.mkdir(parents=True, exist_ok=True)
with open('data/sft/chat.jsonl','r',encoding='utf-8') as f, open(out/'sft.txt','w',encoding='utf-8') as w:
for ln in f:
m=json.loads(ln)['messages']; u=[x['content'] for x in m if x['role']=='user'][-1]; a=[x['content'] for x in m if x['role']=='assistant'][-1]
w.write(f"User: {u}\nAssistant: {a}\n\n")
PY
# short SFT run
python scripts/train_lm.py --config configs/mexma-24m.yaml --tokenizer tokenizer/tokenizer.json \
--resume outputs/mexma-24m/step-XXXXX.pt --train_glob "data/raw/sft.txt" --val_glob "data/raw/sft.txt" \
--seq_len 256 --max_steps 20000python scripts/rag_index.py --glob "data/train/**/*.txt" --out outputs/rag/index.json
python scripts/rag_generate.py --ckpt outputs/mexma-24m/step-XXXXX.pt \
--tokenizer tokenizer/tokenizer.json --index outputs/rag/index.json --prompt "Summarize MexMa."- Checkpoint cadence in config:
checkpoint_every,eval_every(e.g., 1000). - Tokens processed ≈ steps × batch_size × seq_len.
- Keep tokenizer fixed once pretraining starts.
- Windows native is fine; WSL2/Ubuntu simplifies CUDA.
- PyTorch 2.x; AMP/SDPA used by default; gradient checkpointing available.
- Base training + inference working; scaffolding for SFT, distillation, RAG included.