Create a Rust implementation of Stable Diffusion that demonstrates text-to-image generation using only ndarray for tensor operations. Focus on forward inference only.
- Text Prompt → CLIP Encoder → Text Embedding
- Noise + Text Embedding → Diffusion Process (reverse) → Latent Image
- Latent Image → VAE Decoder → Final Image
-
ndarray- Tensor operations -
ndarray-linalg- Linear algebra (matrix operations) -
rand/rand_distr- Random sampling for diffusion -
serde+serde_json- Weight deserialization -
image- Output image generation (png/jpg) -
tokenizers- Text tokenization for CLIP -
ndarray-stats- Statistical operations -
half- BF16 (bfloat16) support for reduced precision compute -
memmap2- Memory-mapped file I/O -
safetensors- SafeTensors format support -
reqwest+tokio- Async HTTP for downloads -
indicatif- Progress bars for downloads
-
src/main.rs- Entry point with CLI commands -
src/types.rs- Type definitions and constants -
src/weights.rs- Weight loading infrastructure -
src/clip.rs- Text encoder module (stub) -
src/diffusion.rs- Diffusion module (stub) -
src/vae.rs- VAE decoder module (stub) -
src/utils.rs- Utility functions (stub)
src/
├── main.rs - Entry point, CLI demo
├── weights.rs - Weight loading and parsing
├── clip.rs - Text encoder (CLIP model)
├── diffusion.rs - Diffusion process and inference
├── vae.rs - VAE decoder
├── utils.rs - Common utilities (tensor operations, normalization)
└── types.rs - Type definitions and constants
- Infrastructure for downloading from Hugging Face Hub
- CLI command:
cargo run -- download - Error handling with helpful messages
- Progress bars for downloads
- Actually working download (requires HF_TOKEN authentication)
- Documentation in WEIGHTS.md
- Store locally in
./weights/directory
- Precision: Use BF16 (bfloat16) as default
- BF16 type alias in types.rs
- Documentation of benefits (50% memory savings, better stability)
- Memory-Mapped File Loading (Complete)
-
memmap2dependency added - Memory mapping implementation in load_from_safetensors()
- Lazy loading support
- ~80% memory savings documentation
- docs/memory_layout.md with detailed explanation
-
- ArrayView for Zero-Copy Access
- WeightMatrix<'a> = ArrayView<'a, bf16, IxDyn> type alias
- Documentation of memory layout
- Example code in examples/mmap_arrayview.rs
- Safety guarantees documented
- Structure: WeightStore struct created
- Load actual weights and validate shapes
- CLIP encoder: 197 tensors, 4 key tensors validated
- UNet denoiser: 686 tensors, architecture verified
- VAE decoder: 248 tensors, output projection verified
Model Structure:
- Total Parameters: 123.06 Million (469.44 MB in F32)
- Precision: F32 (float32 for weights, matching safetensors storage)
- Vocabulary Size: 49,408 tokens
- Embedding Dimension: 768
- Sequence Length: 77 (max tokens)
- Transformer Layers: 12 (indexed 0-11)
- Attention Heads: 12 heads × 64 dims/head = 768
- MLP Expansion: 4× (768 → 3072 → 768)
Architecture Layers:
-
Input: Text string or token IDs (1-77 tokens)
- Pad/truncate to exactly 77 tokens
-
Token Embedding [49408, 768]
- Lookup layer: token_id → 768-dim vector
- Maps vocabulary to embedding space
-
Positional Embedding [77, 768]
- Learned position embeddings for each position
- Added element-wise to token embeddings
-
Transformer Blocks × 12 (layers 0-11) Each block contains:
- Layer Norm 1 [768]
- Self-Attention (multi-head)
- Query proj [768, 768]
- Key proj [768, 768]
- Value proj [768, 768]
- Output proj [768, 768]
- Layer Norm 2 [768]
- MLP Feed-Forward
- Linear 1 (expand): [3072, 768]
- GELU activation
- Linear 2 (project): [768, 3072]
-
Final Layer Norm [768]
- Normalizes output of last transformer block
-
Output: (77, 768) text embedding
- 77 position embeddings × 768 dimensions
- Used as conditioning for diffusion model
Data Flow Example:
Input text: "a beautiful sunset over the ocean"
↓
Tokenizer: "a" → 320, "beautiful" → 1283, ...
↓
Token IDs (77 tokens, padded): [320, 1283, ..., 0, 0, ..., 0]
↓
Token Embedding lookup: (77,) → (77, 768)
↓
+ Position Embedding: (77, 768) + (77, 768) → (77, 768)
↓
Transformer Block × 12:
LayerNorm → MultiHeadAttention → Add & Norm → MLP → Add & Norm
(77, 768) → (77, 768) for each layer
↓
Final LayerNorm: (77, 768) → (77, 768)
↓
Output: Text conditioning (77, 768)
- Ready for use in diffusion model's cross-attention
Key Implementation Checkpoints:
- ✓ Tensor shapes verified from actual weights
- ✓ Attention mechanism: 12 heads, QKV projections
- ✓ MLP structure: 4× expansion ratio
- ✓ Normalization: Pre-norm architecture (norm before operations)
- Ready to implement in Phase 3.2
Full CLIP Encoder Implementation:
-
ClipEncoderstruct with all weight matrices:- Token embeddings [49408, 768]
- Position embeddings [77, 768]
- 12 TransformerLayer blocks
- Final layer normalization weights/biases
-
TransformerLayerstruct with:- Pre-norm architecture (norm before operations)
- Multi-head self-attention (12 heads, 64 dims each)
- MLP with GELU activation (768 → 3072 → 768)
- Residual connections
- Complete forward pass implementation:
- Tokenization: Text → 77 tokens (simplified word-based for now)
- Token embedding lookup: Token IDs → (77, 768)
- Position embedding addition: Add positional information
- Transformer block application: 12 layers of attention + MLP
- Final layer normalization: Normalize output
- Output: (77, 768) text embeddings ready for diffusion
- Matrix operations:
- Custom matrix multiplication (no CBLAS dependency)
matmul(): Standard A @ B multiplicationmatmul_transpose(): A @ B^T for weight matrices- Efficient tensor operations without external linear algebra libs
- Helper functions:
- GELU activation:
0.5 * (1 + tanh(√(2/π)(x + 0.044715*x³))) - Softmax: Numerically stable (subtract max before exp)
- Layer norm:
(x - mean) / √(variance + eps) * γ + β - Tokenize: Mock tokenization (placeholder for tokenizers crate)
- Tensor loading: SafeTensors deserialization
- GELU activation:
- Test command:
cargo run --release -- clip-test - Validation:
- Loads actual weights (470 MB CLIP model)
- Generates (77, 768) embeddings
- Tested with sample prompts:
- "a cat on a beach" → (77, 768)
- "a beautiful sunset over the ocean" → (77, 768)
- "dog" → (77, 768)
- Output ranges reasonable: [-27.98, 32.89]
Next step: Integrate actual tokenizers crate for production-quality tokenization
Created comprehensive guide: UNDERSTANDING_DIFFUSION.md
Forward Process Explained:
- Progressive noise addition: x_t = √(ᾱ_t) * x_0 + √(1 - ᾱ_t) * ε
- 1000-step noise schedule from clean image to pure noise
- Mathematical foundation for understanding reverse inference
Noise Schedules:
- Linear schedule: β_t = β_min + (β_max - β_min) * t / 1000
- Simple, used in DDPM paper
- β_min = 0.0001, β_max = 0.02
- Cosine schedule: ᾱ_t = (cos(π * t / 2000))²
- Smoother transitions, better perceptual quality
- Used by Stable Diffusion
Key Variables:
- α_t: How much original signal to keep at step t
- β_t: How much new noise to add at step t
- ᾱ_t: Cumulative product of α values (α_1 * α_2 * ... * α_t)
- σ_t²: Posterior variance for reverse sampling
Understanding forward diffusion enables understanding reverse inference:
Forward (Noising):
x_0 (clean image) → add noise → x_1 → add noise → ... → x_1000 (pure noise)
Reverse (Denoising):
x_1000 (pure noise) → predict & remove noise → x_999 → ... → x_0 (clean image)
UNet learns: "Given noisy image at step t, what noise was added?"
CLIP Embedding → Noise Schedule → Inference Pipeline:
- User provides text prompt
- CLIP encoder converts to (77, 768) conditioning vector
- Noise schedule pre-computed for 1000 timesteps
- UNet uses both to iteratively denoise latent
- VAE decoder converts latent to RGB image
Mathematical Foundation for Phase 5:
- Denoising formula: x_{t-1} = (1/√α_t) * (x_t - (β_t/√(1-ᾱ_t)) * ε_pred) + σ_t * z
- Classifier-free guidance: ε_guided = ε_uncond + scale * (ε_text - ε_uncond)
- Text conditioning: UNet receives CLIP embedding at each step
See UNDERSTANDING_DIFFUSION.md for:
- Complete pipeline visualization
- Multi-head attention mechanics
- MLP feed-forward explanation
- Noise schedule mathematics
- Reverse process algorithm
- Phase dependencies and data flow
NoiseSchedule Struct Implemented:
- Linear schedule: β from 0.0001 to 0.02 (DDPM paper)
- Cosine schedule: Smoother transitions (Stable Diffusion)
- Pre-computed for 1000 timesteps
- Arrays: betas, alphas, alphas_cumprod (ᾱ_t)
- Derived values: √(ᾱ_t), √(1 - ᾱ_t), posterior_variance (σ_t²)
Formulas Implemented:
- α_t = 1 - β_t (signal retention)
- ᾱ_t = ∏(α_i) for i=1..t (cumulative product)
- σ_t² = (1 - ᾱ_{t-1}) / (1 - ᾱ_t) * β_t (posterior variance)
Test Command:
cargo run --release -- noise-test- Displays schedules at key timesteps (1, 10, 100, 500, 750, 999)
- Validates mathematical correctness
Example Output:
Linear Schedule Step 500: β=0.010060, α=0.989940, ᾱ=0.077797
Cosine Schedule Step 500: β=0.026632, α=0.973368, ᾱ=0.000056
UNetDenoiser Implementation:
- TimestepEmbedding: Sinusoidal positional encoding (128 dims)
- ResidualBlock: Feature transformation with time integration
- CrossAttentionBlock: Multi-head attention for text conditioning
- UNetDenoiser struct: Main architecture coordinator
- load_from_file(): Weight loading interface
- predict_noise(): Full forward pass skeleton with shape validation
Components Implemented:
-
Timestep Embedding (sinusoidal encoding)
- Formula: sin(t / 10000^(2i/d)) and cos(...)
- Output: (1280,) vector capturing time at multiple scales
- Pre-computed for all 1000 timesteps
-
Residual Blocks (feature transformation)
- Structure: Conv → GroupNorm → SiLU → Add time → Conv → Residual
- Pre-norm architecture for gradient flow
- Time embedding integration via broadcast
-
Cross-Attention Blocks (text conditioning)
- Query from latent features (4096, 320)
- Key/Value from text embedding (77, 768)
- Multi-head attention (8 heads) for semantic alignment
- Returns conditioned features (4096, 320)
-
UNet Architecture
- 4 residual blocks in downsampling
- 3 cross-attention blocks
- Bottleneck with attention
- Upsampling with skip connections
- Structure matches Stable Diffusion v1.5
Tensor Organization:
- 686 total tensors expected
- Timestep embedding: ~128 tensors
- Residual + attention: ~450 tensors
- Output layers: ~8 tensors
File Size Validation:
- Expected: ~3.4 GB
- Checks file exists and validates size before loading
Next Steps for Full UNet:
- Parse 686 tensors from safetensors file
- Implement 2D convolution operations
- Implement group normalization
- Connect weight tensors to forward pass
- Test with actual CLIP embeddings
DiffusionPipeline Implementation:
-
new(): Linear schedule variant -
with_cosine_schedule(): Better quality variant -
sample(): Main sampling algorithm- Takes initial noise, text embedding, num_steps
- Iterates from t=1000 down to 0
- Calls UNet for noise prediction
- Applies denoise_step() formula
-
denoise_step(): Denoising formula- x_{t-1} = (1/√α_t) * (x_t - (β_t/√(1-ᾱ_t)) * ε_pred) + σ_t * z
- Adds stochastic noise for variety (optional)
- Gaussian noise generation for posterior variance
Algorithm (DDPM/DDIM):
1. Start with x_1000 ~ N(0, 1)
2. For t = 1000 down to 1:
a. Predict noise: ε_pred = UNet(x_t, t, text_embedding)
b. Denoise: x_{t-1} = (1/√α_t) * (x_t - (β_t/√(1-ᾱ_t)) * ε_pred) + σ_t * z
c. If t > 1: Add noise z ~ N(0, posterior_variance[t])
3. Output: x_0 (clean latent)
To complete Phase 5:
- Load UNet weights from safetensors file (686 tensors)
- Implement UNet forward pass with all components
- Integrate timestep embedding (sinusoidal + MLP)
- Add cross-attention for text conditioning
- Test with CLIP embeddings from Phase 3
- Optional: Implement classifier-free guidance
- Input: Latent representation (4, 64, 64)
- Output: RGB image (3, 512, 512) - scales up 8x
- Components:
- Upsampling blocks (nearest neighbor or transpose conv)
- Residual blocks with convolutions
- Final projection to 3 channels
- Output activation: tanh or sigmoid, scale to [0, 1]
- Implement basic convolution operations using ndarray
- Or use simplified upsampling: nearest-neighbor, bilinear, or simple conv
- Normalization layers: similar to layer norm for image tensors
- Load VAE decoder weights and apply sequentially
- Ensure output is normalized to [0, 1]
- Convert to u8 for image format
- Handle color space (likely already RGB)
- Help message
-
cargo run -- downloadcommand structure -
cargo run -- testcommand structure -
cargo run -- generatecommand (not implemented)
cargo run --release -- --prompt "a cat on a beach" --steps 50 --output out.png
- Parse CLI arguments: prompt, num_diffusion_steps, seed, output_path
- Load weights (CLIP, UNet, VAE) from disk
- Tokenize and encode text with CLIP → embedding (1, 77, 768)
- Initialize latent noise: random (1, 4, 64, 64)
- Run diffusion inference loop with CLIP embedding for conditioning
- Apply VAE decoder to get RGB image
- Save image as PNG/JPG
- Print timing information
- Primary: Use BF16 precision design
- CLI flag to switch between BF16 and FP32:
--precision bf16|fp32 - Seed parameter for reproducibility
- Diffusion step parameter (10-50 steps)
- Model size selection (full vs distilled)
- Weight Loading: Verify weights load correctly with expected shapes
- CLIP Encoding: Test text embedding output and verify reasonableness
- Diffusion Sampling: Verify noise schedule computation
- Image Generation: Generate test image and verify output shape
- End-to-End: Full pipeline produces valid PNG image
- Print tensor shapes at each stage
- Visualize intermediate latents (save as images)
- Compare outputs with known implementations
- Test with simple prompts first
- Matrix Operations: ndarray doesn't have all deep-learning optimizations (no GPU)
- Mitigation: Focus on correctness first, optimize later
- Convolutions: May need to implement manually or use simplified version
- Simpler approach: Use 1x1 convs or fully connected layers
- Attention Computation: Can be memory-intensive
- Mitigation: Process in chunks or use lower precision
- Large Weights File: May need to manage memory carefully
- Mitigation: Stream weights if needed, use safetensors format
- Multi-threading with
rayonfor parallel operations - Consider using
ndarray-linalgfor optimized matrix ops - Profile and identify bottlenecks
- Optional: Create simplified model for faster demo
- Keep layers and operations pure functions where possible
- Create helper functions for common ops (matrix mul, activation, norm)
- Use type aliases for clarity:
type Tensor = Array<bf16, IxDyn>;(or generic over precision) - Document expected tensor shapes in function signatures
Status: Framework defined, not started Priority: High (enables future optimization) Effort: Medium (2-3 days) Benefits: Zero-copy tensor construction, full memory control, reduce dependencies
Implement safetensors format parsing by hand instead of using external crate:
- Why: Educational, optimizable, and reduces dependencies
- Format Analysis:
- Header: 8 bytes (little-endian u64) containing header size
- Metadata: JSON describing tensor names, shapes, dtype, offsets
- Data: Raw tensor bytes in specified order
- Simple format makes hand-parsing feasible
- Implementation:
// Parse safetensors format fn parse_safetensors(mmap: &[u8]) -> Result<HashMap<String, Tensor>> { let header_size = u64::from_le_bytes(mmap[0..8].try_into()?); let header_json = std::str::from_utf8(&mmap[8..8+header_size as usize])?; let metadata: SafeTensorsHeader = serde_json::from_str(header_json)?; // Map tensor data from mmap using offsets }
- Benefits:
- Full control over memory layout and tensor buffer management
- Can optimize for specific access patterns (sequential vs random access)
- Reduce dependencies and binary size
- Better integration with ndarray (direct buffer wrapping without copies)
- Safer: validate format at compile time
- Testing: Compare output with official safetensors crate on sample files
- Future: Enable zero-copy tensor construction directly from mmap
Status: Framework defined, not started Priority: Medium (for CPU performance) Effort: Medium (3-4 days) Benefits: 5-10x speedup on capable hardware
Optimize matrix operations to use Intel AVX-512 with native BF16 support:
- Why: AVX-512-BF16 provides native hardware acceleration for bfloat16 operations
- Implementation:
- Create
unsafeSIMD blocks targetingtarget_feature = "avx512bf16" - Implement hand-optimized matrix multiplication kernels
- Use
packed_simdcrate or inline assembly for AVX-512 operations - Implement cache-friendly tile-based matrix multiplication
- Create
- Expected Benefit: 5-10x speedup on capable hardware compared to generic ndarray ops
- Compatibility: Add runtime CPU feature detection; fallback to generic ndarray on unsupported CPUs
- Testing: Provide benchmark suite (
--bench) comparing AVX-512 vs generic implementations
Status: Framework defined, not started Priority: Medium (for GPU acceleration) Effort: High (5-7 days) Benefits: 10-50x speedup on discrete GPUs, WebAssembly support
Port compute-heavy operations to WebGPU for cross-platform GPU acceleration:
- Why: WebGPU enables deployment on web and provides portable GPU compute
- Architecture:
- Create optional
wgpubackend module alongside ndarray - Implement key kernels in WGSL (WebGPU Shading Language):
- Matrix multiplication (gemm)
- Softmax for attention
- Convolution operations
- Element-wise operations (GELU, layer norm)
- Maintain ndarray backend for CPU fallback
- Create optional
- Implementation Path:
- Start with most expensive ops: UNet forward pass and attention
- Implement buffer management and GPU memory pooling
- Create abstraction layer to swap between GPU/CPU backends
- Use
wgpucrate for WebGPU + Vulkan/Metal/DX12 support
- Expected Benefit: 10-50x speedup on discrete GPUs
- Deployment: Can be compiled to WebAssembly and run in browser
- Testing: Verify GPU outputs match CPU results (account for precision differences)
- Use feature flags in Cargo.toml:
[features] avx512-bf16 = [] # Enable AVX-512-BF16 optimizations wgpu-backend = ["wgpu"] # Enable WebGPU/GPU compute
- Compile with:
cargo build --release --features "avx512-bf16"orcargo build --target wasm32-unknown-unknown --features "wgpu-backend" - Add benchmark:
cargo benchto measure performance of different backends