Skip to content

Commit a405c0d

Browse files
committed
Add minimal training example with Darwin's Origin of Species
Creates examples/simple.py demonstrating core ScratchGPT usage: auto-downloads text data, trains small model with CharTokenizer, shows text generation, uses temp dirs for clean execution
1 parent 1823476 commit a405c0d

1 file changed

Lines changed: 184 additions & 0 deletions

File tree

examples/simple.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple example showing minimal usage of ScratchGPT to train on Darwin's "On the Origin of Species"
4+
5+
This script demonstrates:
6+
1. Downloading training data from Project Gutenberg
7+
2. Setting up a basic configuration
8+
3. Training a small transformer model
9+
4. Basic text generation
10+
11+
Usage:
12+
python simple.py
13+
"""
14+
15+
import subprocess
16+
import sys
17+
import tempfile
18+
from pathlib import Path
19+
20+
import torch
21+
from torch.optim import AdamW
22+
23+
# Import ScratchGPT components
24+
from scratchgpt import (
25+
CharTokenizer,
26+
FileDataSource,
27+
ScratchGPTArchitecture,
28+
ScratchGPTConfig,
29+
ScratchGPTTraining,
30+
Trainer,
31+
TransformerLanguageModel,
32+
)
33+
34+
35+
def download_darwin_text(data_file: Path) -> None:
36+
"""Download Darwin's 'On the Origin of Species' if not already present."""
37+
if data_file.exists():
38+
print(f"✅ Data file already exists: {data_file}")
39+
return
40+
41+
print("📥 Downloading 'On the Origin of Species' by Charles Darwin...")
42+
url = "https://www.gutenberg.org/files/1228/1228-0.txt"
43+
44+
try:
45+
# Use curl to download the file
46+
_ = subprocess.run(
47+
["curl", "-s", url, "-o", str(data_file)],
48+
check=True,
49+
capture_output=True,
50+
text=True
51+
)
52+
print(f"✅ Downloaded data to: {data_file}")
53+
except subprocess.CalledProcessError as e:
54+
print(f"❌ Failed to download data: {e}")
55+
print("Please install curl or manually download the file from:")
56+
print(url)
57+
sys.exit(1)
58+
except FileNotFoundError:
59+
print("❌ curl not found. Please install curl or manually download:")
60+
print(f" curl -s {url} > {data_file}")
61+
sys.exit(1)
62+
63+
64+
def create_simple_config() -> ScratchGPTConfig:
65+
"""Create a minimal configuration suitable for quick training."""
66+
# Small architecture for quick training on CPU/small GPU
67+
architecture = ScratchGPTArchitecture(
68+
block_size=128, # Smaller context window
69+
embedding_size=256, # Smaller embeddings
70+
num_heads=8, # Fewer attention heads
71+
num_blocks=4, # Fewer transformer blocks
72+
# vocab_size will be set based on the tokenizer
73+
)
74+
75+
# Training config optimized for quick results
76+
training = ScratchGPTTraining(
77+
max_epochs=20, # Fewer epochs for quick demo
78+
learning_rate=3e-4, # Standard learning rate
79+
batch_size=32, # Reasonable batch size
80+
dropout_rate=0.1, # Light dropout
81+
random_seed=42, # Reproducible results
82+
)
83+
84+
return ScratchGPTConfig(
85+
architecture=architecture,
86+
training=training
87+
)
88+
89+
90+
def prepare_text_for_tokenizer(data_file: Path) -> str:
91+
"""Read the text file for tokenization."""
92+
print(f"Reading text from: {data_file}")
93+
94+
with open(data_file, 'r', encoding='utf-8') as f:
95+
text = f.read()
96+
97+
print(f"Text length: {len(text):,} characters")
98+
return text
99+
100+
101+
def main():
102+
print("ScratchGPT Simple Training Example")
103+
print("=" * 50)
104+
105+
# Use temporary directory that auto-cleans when done
106+
with tempfile.TemporaryDirectory() as tmp_dir:
107+
tmp_path = Path(tmp_dir)
108+
data_file = tmp_path / "darwin_origin_species.txt"
109+
experiment_dir = tmp_path / "darwin_experiment"
110+
111+
# Step 1: Download data
112+
download_darwin_text(data_file)
113+
114+
# Step 2: Prepare text and create tokenizer
115+
text = prepare_text_for_tokenizer(data_file)
116+
print("Creating character-level tokenizer...")
117+
tokenizer = CharTokenizer(text=text)
118+
print(f"Vocabulary size: {tokenizer.vocab_size}")
119+
120+
# Alternative: Use a pre-trained tokenizer like GPT-2
121+
# This requires: pip install 'scratchgpt[hf-tokenizers]'
122+
#
123+
# from scratchgpt import HuggingFaceTokenizer
124+
# tokenizer = HuggingFaceTokenizer.from_hub("gpt2")
125+
# print(f"Vocabulary size: {tokenizer.vocab_size}") # ~50,257 tokens
126+
#
127+
# Trade-offs:
128+
# - CharTokenizer: Small vocab (~100 chars), learns from scratch, simple
129+
# - GPT-2 Tokenizer: Large vocab (~50K tokens), pre-trained, better text quality
130+
# - GPT-2 tokenizer will likely generate more coherent text but requires more memory
131+
132+
# Step 3: Create configuration
133+
config = create_simple_config()
134+
config.architecture.vocab_size = tokenizer.vocab_size
135+
print(f"Model configuration: {config.architecture.embedding_size}D embeddings, "
136+
f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads")
137+
138+
# Step 4: Setup model and training
139+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140+
print(f"Using device: {device}")
141+
142+
model = TransformerLanguageModel(config)
143+
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
144+
145+
optimizer = AdamW(model.parameters(), lr=config.training.learning_rate)
146+
data_source = FileDataSource(data_file)
147+
148+
# Step 5: Create trainer and start training
149+
trainer = Trainer(
150+
model=model,
151+
config=config.training,
152+
optimizer=optimizer,
153+
experiment_path=experiment_dir,
154+
device=device
155+
)
156+
157+
print("\nStarting training...")
158+
trainer.train(data=data_source, tokenizer=tokenizer)
159+
160+
# Step 6: Simple text generation demo
161+
print("\nTesting text generation:")
162+
model.eval()
163+
164+
test_prompts = [
165+
"Natural selection",
166+
"The origin of species",
167+
"Darwin observed"
168+
]
169+
170+
for prompt in test_prompts:
171+
print(f"\nPrompt: '{prompt}'")
172+
context = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
173+
174+
with torch.no_grad():
175+
generated = model.generate(context, max_new_tokens=100)
176+
result = tokenizer.decode(generated[0].tolist())
177+
print(f"Generated: {result}")
178+
179+
print(f"\nTraining complete! All temporary files automatically cleaned up.")
180+
print("Run the script again to start fresh.")
181+
182+
183+
if __name__ == "__main__":
184+
main()

0 commit comments

Comments
 (0)