Train and compare MoE (4 experts, top-2) vs Dense baseline on InstructPix2Pix image editing.
| File | Description |
|---|---|
train_sd15_moe_pix2pix.py |
MoE training (4 experts, top-2) |
train_sd15_dense_pix2pix.py |
Dense baseline (no MoE) |
| File | Description |
|---|---|
train_sd15_moe_pix2pix.slurm |
MoE training (2 GPUs, 8 hrs) |
train_sd15_dense_pix2pix.slurm |
Dense training (2 GPUs, 8 hrs) |
resume_sd15_moe.slurm |
Resume training from checkpoint |
test_quick.slurm |
Quick test (no dataset needed) |
test_full.slurm |
Full test with dataset |
| File | Description |
|---|---|
inference_sd15.py |
Inference (auto-detects MoE/Dense) |
test_sd15_training.py |
Test script for sanity checks |
mkdir -p ~/moe-minimal
cd ~/moe-minimal
# Upload all .py and .slurm filessbatch test_quick.slurmExpected output:
✓ MoE application test passed!
✓ Forward pass test passed!
✓ Training step test passed!
✓ Checkpoint save/load test passed!
ALL QUICK TESTS PASSED! ✓
sbatch test_full.slurmsbatch train_sd15_moe_pix2pix.slurmsbatch train_sd15_dense_pix2pix.slurm# Find checkpoint
ls /scratch/khasti/sd15-moe-pix2pix/run-*/latest_checkpoint.pt
# Resume
sbatch resume_sd15_moe.slurm /scratch/khasti/sd15-moe-pix2pix/run-xxx/latest_checkpoint.ptsrun --time=0:30:00 -G a100:1 --pty bash -c '
source ~/miniconda3/bin/activate sd-moe && \
cd ~/moe-minimal && \
python inference_sd15.py \
--checkpoint /scratch/khasti/sd15-moe-pix2pix/run-xxx/best_checkpoint.pt \
--input_image test_input.jpg \
--prompt "make it sunset" \
--output test_output.jpg'| MoE | Dense | |
|---|---|---|
| FFN blocks | 4 experts, top-2 gating | Standard FFN |
| Params | ~4x FFN params | Baseline |
| FLOPs/token | ~2x FFN (top-2 active) | Baseline |
| Training script | train_sd15_moe_pix2pix.py |
train_sd15_dense_pix2pix.py |
| Setting | Value |
|---|---|
| Base model | SD 1.5 |
| Dataset | /scratch/khasti/datasets/instructpix2pix |
| Resolution | 256 |
| Batch size | 4 per GPU |
| GPUs | 2x A100 |
| Time | 8 hours per job |
| Setting | Value |
|---|---|
| Experts | 4 |
| Top-K | 2 |
| Aux loss weight | 0.01 |
Training saves:
latest_checkpoint.pt— every epoch (for resuming)best_checkpoint.pt— best validation loss (for inference)checkpoint_step_XXX.pt— periodic saves
Checkpoint contains:
{
'unet_state_dict': ..., # Model weights
'optimizer_state_dict': ..., # Optimizer state
'scheduler_state_dict': ..., # LR scheduler
'epoch': ...,
'global_step': ...,
'best_loss': ...,
'model_type': 'moe' or 'dense'
}Run both MoE and Dense training, then compare:
| Metric | How to measure |
|---|---|
| Training loss | Check best_loss in checkpoint |
| Training time | Wall clock time from SLURM |
| GPU memory | nvidia-smi during training |
| Inference quality | Visual comparison of outputs |
ls /scratch/khasti/datasets/instructpix2pix/Reduce batch size in SLURM script:
BATCH_SIZE=2Add to SLURM script:
export NCCL_TIMEOUT=1800Make sure checkpoint path is correct:
ls /scratch/khasti/sd15-*/run-*/latest_checkpoint.pt