1+ #!/usr/bin/env python3
2+ """
3+ Simple example showing minimal usage of ScratchGPT to train on Darwin's "On the Origin of Species"
4+
5+ This script demonstrates:
6+ 1. Downloading training data from Project Gutenberg
7+ 2. Setting up a basic configuration
8+ 3. Training a small transformer model
9+ 4. Basic text generation
10+
11+ Usage:
12+ python simple.py
13+ """
14+
15+ import subprocess
16+ import sys
17+ import tempfile
18+ from pathlib import Path
19+
20+ import torch
21+ from torch .optim import AdamW
22+
23+ # Import ScratchGPT components
24+ from scratchgpt import (
25+ CharTokenizer ,
26+ FileDataSource ,
27+ ScratchGPTArchitecture ,
28+ ScratchGPTConfig ,
29+ ScratchGPTTraining ,
30+ Trainer ,
31+ TransformerLanguageModel ,
32+ )
33+
34+
35+ def download_darwin_text (data_file : Path ) -> None :
36+ """Download Darwin's 'On the Origin of Species' if not already present."""
37+ if data_file .exists ():
38+ print (f"✅ Data file already exists: { data_file } " )
39+ return
40+
41+ print ("📥 Downloading 'On the Origin of Species' by Charles Darwin..." )
42+ url = "https://www.gutenberg.org/files/1228/1228-0.txt"
43+
44+ try :
45+ # Use curl to download the file
46+ _ = subprocess .run (
47+ ["curl" , "-s" , url , "-o" , str (data_file )],
48+ check = True ,
49+ capture_output = True ,
50+ text = True
51+ )
52+ print (f"✅ Downloaded data to: { data_file } " )
53+ except subprocess .CalledProcessError as e :
54+ print (f"❌ Failed to download data: { e } " )
55+ print ("Please install curl or manually download the file from:" )
56+ print (url )
57+ sys .exit (1 )
58+ except FileNotFoundError :
59+ print ("❌ curl not found. Please install curl or manually download:" )
60+ print (f" curl -s { url } > { data_file } " )
61+ sys .exit (1 )
62+
63+
64+ def create_simple_config () -> ScratchGPTConfig :
65+ """Create a minimal configuration suitable for quick training."""
66+ # Small architecture for quick training on CPU/small GPU
67+ architecture = ScratchGPTArchitecture (
68+ block_size = 128 , # Smaller context window
69+ embedding_size = 256 , # Smaller embeddings
70+ num_heads = 8 , # Fewer attention heads
71+ num_blocks = 4 , # Fewer transformer blocks
72+ # vocab_size will be set based on the tokenizer
73+ )
74+
75+ # Training config optimized for quick results
76+ training = ScratchGPTTraining (
77+ max_epochs = 20 , # Fewer epochs for quick demo
78+ learning_rate = 3e-4 , # Standard learning rate
79+ batch_size = 32 , # Reasonable batch size
80+ dropout_rate = 0.1 , # Light dropout
81+ random_seed = 42 , # Reproducible results
82+ )
83+
84+ return ScratchGPTConfig (
85+ architecture = architecture ,
86+ training = training
87+ )
88+
89+
90+ def prepare_text_for_tokenizer (data_file : Path ) -> str :
91+ """Read the text file for tokenization."""
92+ print (f"Reading text from: { data_file } " )
93+
94+ with open (data_file , 'r' , encoding = 'utf-8' ) as f :
95+ text = f .read ()
96+
97+ print (f"Text length: { len (text ):,} characters" )
98+ return text
99+
100+
101+ def main ():
102+ print ("ScratchGPT Simple Training Example" )
103+ print ("=" * 50 )
104+
105+ # Use temporary directory that auto-cleans when done
106+ with tempfile .TemporaryDirectory () as tmp_dir :
107+ tmp_path = Path (tmp_dir )
108+ data_file = tmp_path / "darwin_origin_species.txt"
109+ experiment_dir = tmp_path / "darwin_experiment"
110+
111+ # Step 1: Download data
112+ download_darwin_text (data_file )
113+
114+ # Step 2: Prepare text and create tokenizer
115+ text = prepare_text_for_tokenizer (data_file )
116+ print ("Creating character-level tokenizer..." )
117+ tokenizer = CharTokenizer (text = text )
118+ print (f"Vocabulary size: { tokenizer .vocab_size } " )
119+
120+ # Alternative: Use a pre-trained tokenizer like GPT-2
121+ # This requires: pip install 'scratchgpt[hf-tokenizers]'
122+ #
123+ # from scratchgpt import HuggingFaceTokenizer
124+ # tokenizer = HuggingFaceTokenizer.from_hub("gpt2")
125+ # print(f"Vocabulary size: {tokenizer.vocab_size}") # ~50,257 tokens
126+ #
127+ # Trade-offs:
128+ # - CharTokenizer: Small vocab (~100 chars), learns from scratch, simple
129+ # - GPT-2 Tokenizer: Large vocab (~50K tokens), pre-trained, better text quality
130+ # - GPT-2 tokenizer will likely generate more coherent text but requires more memory
131+
132+ # Step 3: Create configuration
133+ config = create_simple_config ()
134+ config .architecture .vocab_size = tokenizer .vocab_size
135+ print (f"Model configuration: { config .architecture .embedding_size } D embeddings, "
136+ f"{ config .architecture .num_blocks } blocks, { config .architecture .num_heads } heads" )
137+
138+ # Step 4: Setup model and training
139+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
140+ print (f"Using device: { device } " )
141+
142+ model = TransformerLanguageModel (config )
143+ print (f"Model parameters: { sum (p .numel () for p in model .parameters ()):,} " )
144+
145+ optimizer = AdamW (model .parameters (), lr = config .training .learning_rate )
146+ data_source = FileDataSource (data_file )
147+
148+ # Step 5: Create trainer and start training
149+ trainer = Trainer (
150+ model = model ,
151+ config = config .training ,
152+ optimizer = optimizer ,
153+ experiment_path = experiment_dir ,
154+ device = device
155+ )
156+
157+ print ("\n Starting training..." )
158+ trainer .train (data = data_source , tokenizer = tokenizer )
159+
160+ # Step 6: Simple text generation demo
161+ print ("\n Testing text generation:" )
162+ model .eval ()
163+
164+ test_prompts = [
165+ "Natural selection" ,
166+ "The origin of species" ,
167+ "Darwin observed"
168+ ]
169+
170+ for prompt in test_prompts :
171+ print (f"\n Prompt: '{ prompt } '" )
172+ context = torch .tensor (tokenizer .encode (prompt )).unsqueeze (0 ).to (device )
173+
174+ with torch .no_grad ():
175+ generated = model .generate (context , max_new_tokens = 100 )
176+ result = tokenizer .decode (generated [0 ].tolist ())
177+ print (f"Generated: { result } " )
178+
179+ print (f"\n Training complete! All temporary files automatically cleaned up." )
180+ print ("Run the script again to start fresh." )
181+
182+
183+ if __name__ == "__main__" :
184+ main ()
0 commit comments