diff --git a/crates/claude-core/src/config.rs b/crates/claude-core/src/config.rs index 5843913..cdb4613 100644 --- a/crates/claude-core/src/config.rs +++ b/crates/claude-core/src/config.rs @@ -1,42 +1,42 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelConfig { - /// Dimension of the token embeddings (and internal transformer states). - pub n_embd: i64, - /// Number of attention heads. - pub n_head: i64, - /// Number of transformer layers. - pub n_layer: i64, - /// Size of the vocabulary. - pub vocab_size: i64, - /// Maximum context window size (max sequence length). - pub max_seq_len: i64, - /// Dropout probability (applied to attention and residual connections). - pub dropout: f64, - /// RMSNorm epsilon value (for numerical stability). - pub layer_norm_epsilon: f64, - /// Whether to use bias in linear layers (typically false in modern LLMs like Llama/PaLM). - pub use_bias: bool, -} - -impl Default for ModelConfig { - fn default() -> Self { - Self { - n_embd: 768, // GPT-2 Small equivalent - n_head: 12, - n_layer: 12, - vocab_size: 50257, - max_seq_len: 1024, - dropout: 0.0, - layer_norm_epsilon: 1e-5, - use_bias: false, - } - } -} - -impl ModelConfig { - pub fn head_size(&self) -> i64 { - self.n_embd / self.n_head - } -} +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelConfig { + /// Dimension of the token embeddings (and internal transformer states). + pub n_embd: i64, + /// Number of attention heads. + pub n_head: i64, + /// Number of transformer layers. + pub n_layer: i64, + /// Size of the vocabulary. + pub vocab_size: i64, + /// Maximum context window size (max sequence length). + pub max_seq_len: i64, + /// Dropout probability (applied to attention and residual connections). + pub dropout: f64, + /// RMSNorm epsilon value (for numerical stability). + pub layer_norm_epsilon: f64, + /// Whether to use bias in linear layers (typically false in modern LLMs like Llama/PaLM). + pub use_bias: bool, +} + +impl Default for ModelConfig { + fn default() -> Self { + Self { + n_embd: 768, // GPT-2 Small equivalent + n_head: 12, + n_layer: 12, + vocab_size: 50257, + max_seq_len: 2048, + dropout: 0.0, + layer_norm_epsilon: 1e-5, + use_bias: false, + } + } +} + +impl ModelConfig { + pub fn head_size(&self) -> i64 { + self.n_embd / self.n_head + } +} diff --git a/crates/claude-tui/src/main.rs b/crates/claude-tui/src/main.rs index 7c7fd53..65f600c 100644 --- a/crates/claude-tui/src/main.rs +++ b/crates/claude-tui/src/main.rs @@ -1,217 +1,234 @@ -use anyhow::Result; -use crossterm::{ - event::{DisableMouseCapture, EnableMouseCapture, Event, EventStream, KeyCode}, - execute, - terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, -}; -use futures::StreamExt; -use ratatui::{ - backend::CrosstermBackend, - Terminal, -}; -use std::{error::Error, io, time::Duration}; -use tokio::sync::mpsc; -use std::sync::Arc; - -// Local crate imports -use claude_core::{ClaudeTransformer, ModelConfig}; -use inference::{Generator, SamplingParams}; -use tokenizer::{BPE, Vocab}; -use tch::{nn, Device}; -use tui_input::backend::crossterm::EventHandler; - -mod app; -mod ui; - -use app::{App, Message, Sender}; - -#[derive(Debug)] -enum Action { - Tick, - TokenGenerated(String), - GenerationFinished, -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - // 0. Initialize Model & Tokenizer - let device = Device::cuda_if_available(); - println!("Using device: {:?}", device); - - let vocab_path = "data/vocab.json"; - let checkpoint_dir = std::path::Path::new("checkpoints"); - - // Load Tokenizer - let tokenizer = if std::path::Path::new(vocab_path).exists() { - println!("Loading tokenizer from {}", vocab_path); - Arc::new(BPE::load(vocab_path)?) - } else { - println!("Warning: Tokenizer vocab not found at {}. Using minimal fallback.", vocab_path); - let mut vocab = Vocab::new(); - vocab.insert(" ".to_string(), 32); - for i in 65..123 { - let c = i as u8 as char; - vocab.insert(c.to_string(), i as u32); - } - vocab.insert("".to_string(), 0); - Arc::new(BPE::new(vocab, std::collections::HashMap::new())) - }; - - // Load Model - let model = if checkpoint_dir.exists() && checkpoint_dir.join("config.json").exists() { - Arc::new(inference::load_model(checkpoint_dir, device)?) - } else { - println!("Warning: No trained model found in {:?}. Initializing random model.", checkpoint_dir); - let config = ModelConfig { - n_embd: 128, - n_head: 4, - n_layer: 4, - vocab_size: tokenizer.vocab.len() as i64, - max_seq_len: 512, - dropout: 0.1, - use_bias: true, - layer_norm_epsilon: 1e-5, - }; - let vs = nn::VarStore::new(device); - Arc::new(ClaudeTransformer::new(&vs.root(), &config)) - }; - - // 1. Setup terminal (raw mode, alternate screen) - enable_raw_mode()?; - let mut stdout = io::stdout(); - execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?; - let backend = CrosstermBackend::new(stdout); - let mut terminal = Terminal::new(backend)?; - - // 2. Setup channels and app - let (tx, mut rx) = mpsc::channel(32); - let mut app = App::new(); - - // 3. Event Loop - let mut reader = EventStream::new(); - let tick_rate = Duration::from_millis(100); - let tx_tick = tx.clone(); - - // Tick task - tokio::spawn(async move { - loop { - if tx_tick.send(Action::Tick).await.is_err() { - break; - } - tokio::time::sleep(tick_rate).await; - } - }); - - // Run loop - let res = run_app(&mut terminal, &mut app, &mut reader, tx, &mut rx, model, tokenizer, device).await; - - // 4. Cleanup - disable_raw_mode()?; - execute!( - terminal.backend_mut(), - LeaveAlternateScreen, - DisableMouseCapture - )?; - terminal.show_cursor()?; - - if let Err(err) = res { - println!("{:?}", err) - } - - Ok(()) -} - -async fn run_app( - terminal: &mut Terminal>, - app: &mut App, - reader: &mut EventStream, - tx: mpsc::Sender, - rx: &mut mpsc::Receiver, - model: Arc, - tokenizer: Arc, - device: Device, -) -> io::Result<()> { - loop { - // Draw - terminal.draw(|f| ui::draw(f, app))?; - - // Handle events - tokio::select! { - // Priority: Internal Actions (Ticks, Responses) - Some(action) = rx.recv() => { - match action { - Action::Tick => {} - Action::TokenGenerated(token_text) => { - app.append_token(&token_text); - } - Action::GenerationFinished => { - app.is_loading = false; - } - } - } - // User Input - Some(Ok(event)) = reader.next() => { - match event { - Event::Key(key) => { - if key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL) { - return Ok(()); - } - - match key.code { - KeyCode::Enter => { - let text: String = app.input.value().into(); - if !text.trim().is_empty() { - app.messages.push(Message { - sender: Sender::User, - content: text.clone(), - }); - app.input.reset(); - app.is_loading = true; - - let tx_action = tx.clone(); - let model = Arc::clone(&model); - let tokenizer = Arc::clone(&tokenizer); - let prompt = text.clone(); - - tokio::spawn(async move { - let mut generator = Generator::new(Arc::clone(&model), device); - let params = SamplingParams::default(); - - // 1. Tokenize prompt - let input_ids: Vec = tokenizer.encode(&prompt).iter().map(|&id| id as i64).collect(); - - // 2. Setup internal stream channel - let (token_tx, mut token_rx) = mpsc::channel(100); - - // 3. Start generation in a blocking-safe way if necessary or just await - // Since we are already in an async spawn, we can run generate_stream - let tokenizer_clone = Arc::clone(&tokenizer); - let tx_action_clone = tx_action.clone(); - - tokio::spawn(async move { - let _ = generator.generate_stream(&input_ids, 50, ¶ms, token_tx); - }); - - while let Some(token_id) = token_rx.recv().await { - let text = tokenizer_clone.decode(&[token_id as u32]); - let _ = tx_action_clone.send(Action::TokenGenerated(text)).await; - } - - let _ = tx_action.send(Action::GenerationFinished).await; - }); - } - } - KeyCode::Esc => { - app.input.reset(); - } - _ => { - app.input.handle_event(&Event::Key(key)); - } - } - } - _ => {} - } - } - } - } -} +use anyhow::Result; +use crossterm::{ + event::{DisableMouseCapture, EnableMouseCapture, Event, EventStream, KeyCode}, + execute, + terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, +}; +use futures::StreamExt; +use ratatui::{backend::CrosstermBackend, Terminal}; +use std::sync::Arc; +use std::{error::Error, io, time::Duration}; +use tokio::sync::mpsc; + +// Local crate imports +use claude_core::{ClaudeTransformer, ModelConfig}; +use inference::{Generator, SamplingParams}; +use tch::{nn, Device}; +use tokenizer::{Vocab, BPE}; +use tui_input::backend::crossterm::EventHandler; + +mod app; +mod ui; + +use app::{App, Message, Sender}; + +#[derive(Debug)] +enum Action { + Tick, + TokenGenerated(String), + GenerationFinished, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // 0. Initialize Model & Tokenizer + let device = Device::cuda_if_available(); + println!("Using device: {:?}", device); + + let vocab_path = "data/vocab.json"; + let checkpoint_dir = std::path::Path::new("checkpoints"); + + // Load Tokenizer + let tokenizer = if std::path::Path::new(vocab_path).exists() { + println!("Loading tokenizer from {}", vocab_path); + Arc::new(BPE::load(vocab_path)?) + } else { + println!( + "Warning: Tokenizer vocab not found at {}. Using minimal fallback.", + vocab_path + ); + let mut vocab = Vocab::new(); + vocab.insert(" ".to_string(), 32); + for i in 65..123 { + let c = i as u8 as char; + vocab.insert(c.to_string(), i as u32); + } + vocab.insert("".to_string(), 0); + Arc::new(BPE::new(vocab, std::collections::HashMap::new())) + }; + + // Load Model + let model = if checkpoint_dir.exists() && checkpoint_dir.join("config.json").exists() { + Arc::new(inference::load_model(checkpoint_dir, device)?) + } else { + println!( + "Warning: No trained model found in {:?}. Initializing random model.", + checkpoint_dir + ); + let config = ModelConfig { + n_embd: 128, + n_head: 4, + n_layer: 4, + vocab_size: tokenizer.vocab.len() as i64, + max_seq_len: 2048, + dropout: 0.1, + use_bias: true, + layer_norm_epsilon: 1e-5, + }; + let vs = nn::VarStore::new(device); + Arc::new(ClaudeTransformer::new(&vs.root(), &config)) + }; + + // 1. Setup terminal (raw mode, alternate screen) + enable_raw_mode()?; + let mut stdout = io::stdout(); + execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?; + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::new(backend)?; + + // 2. Setup channels and app + let (tx, mut rx) = mpsc::channel(32); + let mut app = App::new(); + + // 3. Event Loop + let mut reader = EventStream::new(); + let tick_rate = Duration::from_millis(100); + let tx_tick = tx.clone(); + + // Tick task + tokio::spawn(async move { + loop { + if tx_tick.send(Action::Tick).await.is_err() { + break; + } + tokio::time::sleep(tick_rate).await; + } + }); + + // Run loop + let res = run_app( + &mut terminal, + &mut app, + &mut reader, + tx, + &mut rx, + model, + tokenizer, + device, + ) + .await; + + // 4. Cleanup + disable_raw_mode()?; + execute!( + terminal.backend_mut(), + LeaveAlternateScreen, + DisableMouseCapture + )?; + terminal.show_cursor()?; + + if let Err(err) = res { + println!("{:?}", err) + } + + Ok(()) +} + +async fn run_app( + terminal: &mut Terminal>, + app: &mut App, + reader: &mut EventStream, + tx: mpsc::Sender, + rx: &mut mpsc::Receiver, + model: Arc, + tokenizer: Arc, + device: Device, +) -> io::Result<()> { + loop { + // Draw + terminal.draw(|f| ui::draw(f, app))?; + + // Handle events + tokio::select! { + // Priority: Internal Actions (Ticks, Responses) + Some(action) = rx.recv() => { + match action { + Action::Tick => {} + Action::TokenGenerated(token_text) => { + app.append_token(&token_text); + } + Action::GenerationFinished => { + app.is_loading = false; + } + } + } + // User Input + Some(Ok(event)) = reader.next() => { + match event { + Event::Key(key) => { + if key.code == KeyCode::Char('c') && key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL) { + return Ok(()); + } + + match key.code { + KeyCode::Enter => { + let text: String = app.input.value().into(); + if !text.trim().is_empty() { + app.messages.push(Message { + sender: Sender::User, + content: text.clone(), + }); + app.input.reset(); + app.is_loading = true; + + let tx_action = tx.clone(); + let model = Arc::clone(&model); + let tokenizer = Arc::clone(&tokenizer); + let prompt = text.clone(); + + tokio::spawn(async move { + let mut generator = Generator::new(Arc::clone(&model), device); + let params = SamplingParams::default(); + + // 1. Tokenize prompt + let input_ids: Vec = tokenizer.encode(&prompt).iter().map(|&id| id as i64).collect(); + + // 2. Setup internal stream channel + let (token_tx, mut token_rx) = mpsc::channel(100); + + // 3. Start generation in a blocking-safe way if necessary or just await + // Since we are already in an async spawn, we can run generate_stream + let tokenizer_clone = Arc::clone(&tokenizer); + let tx_action_clone = tx_action.clone(); + + tokio::spawn(async move { + let _ = generator.generate_stream(&input_ids, 50, ¶ms, token_tx); + }); + + while let Some(token_id) = token_rx.recv().await { + let raw_text = tokenizer_clone.decode(&[token_id as u32]); + let text: String = raw_text + .chars() + .map(|c| if c.is_control() && c != '\n' && c != '\t' { '�' } else { c }) + .collect(); + let _ = tx_action_clone.send(Action::TokenGenerated(text)).await; + } + + let _ = tx_action.send(Action::GenerationFinished).await; + }); + } + } + KeyCode::Esc => { + app.input.reset(); + } + _ => { + app.input.handle_event(&Event::Key(key)); + } + } + } + _ => {} + } + } + } + } +} diff --git a/crates/inference/src/generator.rs b/crates/inference/src/generator.rs index b9237c0..468f695 100644 --- a/crates/inference/src/generator.rs +++ b/crates/inference/src/generator.rs @@ -1,73 +1,89 @@ -use tch::{Tensor, Device, IndexOp}; -use claude_core::ClaudeTransformer; -use crate::sampling::{Sampler, SamplingParams}; - -use std::sync::Arc; - -pub struct Generator { - model: Arc, - device: Device, -} - -impl Generator { - pub fn new(model: Arc, device: Device) -> Self { - Self { model, device } - } - - pub fn generate_stream( - &mut self, - prompt_ids: &[i64], - max_new_tokens: usize, - params: &SamplingParams, - tx: tokio::sync::mpsc::Sender, - ) -> anyhow::Result<()> { - let mut tokens = prompt_ids.to_vec(); - - // Initialize KV Caches for each layer - let mut caches: Vec = (0..self.model.config.n_layer) - .map(|_| claude_core::kv_cache::KVCache::new( - self.model.config.max_seq_len as usize, - self.model.config.n_head, - self.model.config.n_embd / self.model.config.n_head, - self.device, - tch::Kind::Float - )) - .collect(); - - // 1. Prefill - let input_tensor = Tensor::from_slice(&tokens).view([1, tokens.len() as i64]).to(self.device); - let logits = self.model.forward(&input_tensor, Some(&mut caches)); - - // Sample first new token - let next_token_logits = logits.i((0, -1, ..)); - let mut next_token = Sampler::sample(&next_token_logits, params, &tokens)?; - - // Yield first token - let _ = tx.blocking_send(next_token); - tokens.push(next_token); - - // 2. Decode Loop - for _ in 0..max_new_tokens { - let input_tensor = Tensor::from_slice(&[next_token]).view([1, 1]).to(self.device); - let logits = self.model.forward(&input_tensor, Some(&mut caches)); - - let next_token_logits = logits.i((0, -1, ..)); - next_token = Sampler::sample(&next_token_logits, params, &tokens)?; - - // Yield token - if tx.blocking_send(next_token).is_err() { - break; // Receiver dropped - } - tokens.push(next_token); - - if tokens.len() >= self.model.config.max_seq_len as usize { - break; - } - } - - Ok(()) - } -} - -unsafe impl Send for Generator {} - +use crate::sampling::{Sampler, SamplingParams}; +use claude_core::ClaudeTransformer; +use tch::{Device, IndexOp, Tensor}; + +use std::sync::Arc; + +pub struct Generator { + model: Arc, + device: Device, +} + +impl Generator { + pub fn new(model: Arc, device: Device) -> Self { + Self { model, device } + } + + pub fn generate_stream( + &mut self, + prompt_ids: &[i64], + max_new_tokens: usize, + params: &SamplingParams, + tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + if prompt_ids.is_empty() { + return Ok(()); + } + + let max_seq_len = self.model.config.max_seq_len as usize; + let mut tokens = if prompt_ids.len() > max_seq_len { + prompt_ids[prompt_ids.len() - max_seq_len..].to_vec() + } else { + prompt_ids.to_vec() + }; + + // Initialize KV Caches for each layer + let mut caches: Vec = (0..self.model.config.n_layer) + .map(|_| { + claude_core::kv_cache::KVCache::new( + self.model.config.max_seq_len as usize, + self.model.config.n_head, + self.model.config.n_embd / self.model.config.n_head, + self.device, + tch::Kind::Float, + ) + }) + .collect(); + + // 1. Prefill + let input_tensor = Tensor::from_slice(&tokens) + .view([1, tokens.len() as i64]) + .to(self.device); + let logits = self.model.forward(&input_tensor, Some(&mut caches)); + + // Sample first new token + let next_token_logits = logits.i((0, -1, ..)); + let mut next_token = Sampler::sample(&next_token_logits, params, &tokens)?; + + // Yield first token + if tx.try_send(next_token).is_err() { + return Ok(()); + } + tokens.push(next_token); + + // 2. Decode Loop + for _ in 0..max_new_tokens { + let input_tensor = Tensor::from_slice(&[next_token]) + .view([1, 1]) + .to(self.device); + let logits = self.model.forward(&input_tensor, Some(&mut caches)); + + let next_token_logits = logits.i((0, -1, ..)); + next_token = Sampler::sample(&next_token_logits, params, &tokens)?; + + // Yield token + if tx.try_send(next_token).is_err() { + break; // Receiver dropped or channel is full + } + tokens.push(next_token); + + if tokens.len() >= max_seq_len { + break; + } + } + + Ok(()) + } +} + +unsafe impl Send for Generator {} diff --git a/crates/inference/src/main.rs b/crates/inference/src/main.rs index 9c1484d..6731596 100644 --- a/crates/inference/src/main.rs +++ b/crates/inference/src/main.rs @@ -119,7 +119,7 @@ async fn main() -> anyhow::Result<()> { n_head: 4, n_layer: 4, vocab_size: tokenizer.vocab.len() as i64, - max_seq_len: 512, + max_seq_len: 2048, dropout: 0.0, use_bias: true, layer_norm_epsilon: 1e-5,