Skip to content

Mithil-hub/Optimizing-Multimodal-Diffusion-Transformers-with-MoE-Enhanced-Stable-Diffusion-3

Repository files navigation

SD1.5 MoE vs Dense for InstructPix2Pix

Train and compare MoE (4 experts, top-2) vs Dense baseline on InstructPix2Pix image editing.


Files

Training Scripts

File Description
train_sd15_moe_pix2pix.py MoE training (4 experts, top-2)
train_sd15_dense_pix2pix.py Dense baseline (no MoE)

SLURM Jobs

File Description
train_sd15_moe_pix2pix.slurm MoE training (2 GPUs, 8 hrs)
train_sd15_dense_pix2pix.slurm Dense training (2 GPUs, 8 hrs)
resume_sd15_moe.slurm Resume training from checkpoint
test_quick.slurm Quick test (no dataset needed)
test_full.slurm Full test with dataset

Inference & Testing

File Description
inference_sd15.py Inference (auto-detects MoE/Dense)
test_sd15_training.py Test script for sanity checks

Quick Start

1. Copy Files to Cluster

mkdir -p ~/moe-minimal
cd ~/moe-minimal
# Upload all .py and .slurm files

2. Run Quick Test (No Dataset)

sbatch test_quick.slurm

Expected output:

✓ MoE application test passed!
✓ Forward pass test passed!
✓ Training step test passed!
✓ Checkpoint save/load test passed!
ALL QUICK TESTS PASSED! ✓

3. Run Full Test (With Dataset)

sbatch test_full.slurm

4. Train MoE Model

sbatch train_sd15_moe_pix2pix.slurm

5. Train Dense Baseline

sbatch train_sd15_dense_pix2pix.slurm

6. Resume After Time Runs Out

# Find checkpoint
ls /scratch/khasti/sd15-moe-pix2pix/run-*/latest_checkpoint.pt

# Resume
sbatch resume_sd15_moe.slurm /scratch/khasti/sd15-moe-pix2pix/run-xxx/latest_checkpoint.pt

7. Inference

srun --time=0:30:00 -G a100:1 --pty bash -c '
source ~/miniconda3/bin/activate sd-moe && \
cd ~/moe-minimal && \
python inference_sd15.py \
    --checkpoint /scratch/khasti/sd15-moe-pix2pix/run-xxx/best_checkpoint.pt \
    --input_image test_input.jpg \
    --prompt "make it sunset" \
    --output test_output.jpg'

Architecture Comparison

MoE Dense
FFN blocks 4 experts, top-2 gating Standard FFN
Params ~4x FFN params Baseline
FLOPs/token ~2x FFN (top-2 active) Baseline
Training script train_sd15_moe_pix2pix.py train_sd15_dense_pix2pix.py

Configuration

Training Config

Setting Value
Base model SD 1.5
Dataset /scratch/khasti/datasets/instructpix2pix
Resolution 256
Batch size 4 per GPU
GPUs 2x A100
Time 8 hours per job

MoE Config

Setting Value
Experts 4
Top-K 2
Aux loss weight 0.01

Checkpoints

Training saves:

  • latest_checkpoint.pt — every epoch (for resuming)
  • best_checkpoint.pt — best validation loss (for inference)
  • checkpoint_step_XXX.pt — periodic saves

Checkpoint contains:

{
    'unet_state_dict': ...,      # Model weights
    'optimizer_state_dict': ..., # Optimizer state
    'scheduler_state_dict': ..., # LR scheduler
    'epoch': ...,
    'global_step': ...,
    'best_loss': ...,
    'model_type': 'moe' or 'dense'
}

Metrics to Compare

Run both MoE and Dense training, then compare:

Metric How to measure
Training loss Check best_loss in checkpoint
Training time Wall clock time from SLURM
GPU memory nvidia-smi during training
Inference quality Visual comparison of outputs

Troubleshooting

Dataset not found

ls /scratch/khasti/datasets/instructpix2pix/

Out of memory

Reduce batch size in SLURM script:

BATCH_SIZE=2

NCCL timeout

Add to SLURM script:

export NCCL_TIMEOUT=1800

Resume fails

Make sure checkpoint path is correct:

ls /scratch/khasti/sd15-*/run-*/latest_checkpoint.pt

About

Mixture-of-Experts vs Dense baseline for InstructPix2Pix image editing using Stable Diffusion 1.5

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors