pub struct UNet2DConditionModelConfig {
pub center_input_sample: bool = false,
pub flip_sin_to_cos: bool = true,
pub freq_shift: f64 = 0.0,
pub layers_per_block: usize = 2,
pub downsample_padding: usize = 1,
pub mid_block_scale_factor: f64 = 1.0,
pub norm_num_groups: usize = 32,
pub norm_eps: f64 = 1e-5,
pub cross_attention_dim: usize = 1280, // ⚠️ NOT 768!
pub sliced_attention_size: Option<usize> = None,
pub use_linear_projection: bool = false,
}Down Blocks (with downsampling):
├─ Block 0: out_channels=320, use_cross_attn=Some(1), attention_head_dim=8
├─ Block 1: out_channels=640, use_cross_attn=Some(1), attention_head_dim=8
├─ Block 2: out_channels=1280, use_cross_attn=Some(1), attention_head_dim=8
└─ Block 3: out_channels=1280, use_cross_attn=None, attention_head_dim=8 (no attn!)
Mid Block:
└─ Single block with 1280 channels and cross-attention
Up Blocks (with upsampling):
├─ Block 0: out_channels=1280, use_cross_attn=Some(1), attention_head_dim=8
├─ Block 1: out_channels=640, use_cross_attn=Some(1), attention_head_dim=8
└─ Block 2: out_channels=320, use_cross_attn=Some(1), attention_head_dim=8
-
Time Embedding
- Uses
Timestepsprojection (sinusoidal encoding) - Then
TimestepEmbedding(linear layers) - Output dimension: 1280 (matches mid-block)
- Uses
-
Conv Layers
- Input: Conv2d with 4 → 320 channels
- Output: Conv2d with 1280 → 4 channels
- Kernel size: 3×3
-
Group Normalization
norm_num_groups: 32norm_eps: 1e-5(default)- Applied before attention and after convolutions
-
Cross-Attention
- Query dim: same as block channels (320/640/1280)
- Context dim: 1280 ←
⚠️ KEY DIFFERENCE! - Heads: 8
- Head dim: channels / 8
- ✅ Convolutional layers (not fully connected!)
- ✅ Multiple downsampling/upsampling blocks
- ✅ Skip connections between down/up blocks
- ✅ Cross-attention for text conditioning
- ✅ Residual connections
- ✅ Group normalization (not layer norm)
-
Cross-Attention Dimension Mismatch
- We hardcoded:
context_dim: 768 - Should be:
context_dim: 1280← FROM CLIP ENCODER OUTPUT - This means CLIP encoder output needs projection from 768 → 1280!
- We hardcoded:
-
No Actual Convolutions
- We use
ResidualBlock::forward()which returns input unchanged - Should have actual 2D convolutions with proper channels
- We use
-
No Real Skip Connections
- No connection between corresponding down/up blocks
- Skip connections are critical for information flow
-
Incorrect Timestep Embedding
- We create (1280,) embeddings
- But don't broadcast properly to spatial dimensions
-
No Proper Normalization
- We don't use group norm (32 groups)
- We don't apply it correctly before/after operations
-
Cross-Attention Output Wrong
- We return shape
(query_dim,)per timestep - Should maintain spatial dimensions for residuals
- We return shape
Priority 1: Fix Architecture
- Implement 2D convolution operations using ndarray
- Add group normalization
- Fix cross-attention dimensions
- Add skip connections
Priority 2: Load Weights
- Parse 686 UNet tensors from safetensors
- Map weights to conv layers, norms, attention
- Test weight shapes match architecture
Priority 3: Full Forward Pass
- Integrate all components
- Test against Candle reference outputs
| Dimension | Value | Notes |
|---|---|---|
| Batch | 1 | Single image |
| Latent Channels | 4 | Input/output latent channels |
| Latent H/W | 64×64 | 8x compression from 512×512 |
| Text Context Length | 77 | Fixed CLIP token length |
| Text Embed Dim (CLIP) | 768 | CLIP encoder output |
| Cross-Attn Dim | 1280 | ← Needs projection! |
| Position | Channels | Heads | Head Dim |
|---|---|---|---|
| Entry | 320 | 8 | 40 |
| Mid | 1280 | 8 | 160 |
| Exit | 320 | 8 | 40 |
- Verify CLIP output is (77, 768)
- Check if CLIP needs projection to (77, 1280)
- Confirm 686 UNet tensors load correctly
- Test conv2d operations on sample data
- Compare intermediate activations with Candle
- Verify skip connections work
- Test group norm values