Research pipeline for evaluating synthetic data generation via fine-tuned diffusion models on class-imbalanced image classification. Compares ResNet and Vision Transformer (ViT) performance across three scenarios:
- Imbalanced datasets (real data only)
- Traditional oversampling (duplicated minority samples)
- Synthetic augmentation (real + diffusion-generated minority samples)
9-stage experimental pipeline:
- Dataset Preparation: AdImageNet, CIFAR-FS, and plant pathology datasets
- Imbalance Induction: Systematic minority class sample removal
- Diffusion Fine-tuning: Fine-tune models on minority samples
- Synthetic Generation: Generate minority samples to match majority counts
- Dataset Construction: Create three dataset versions
- Classifier Training: Train ResNet and ViT on each version
- Evaluation: Comprehensive metrics and confusion matrices
- Analysis: Performance comparison and quality assessment
- Documentation: Reproducible results with detailed logging
- Python 3.8+, GPU (recommended, auto-detects CPU fallback), 8GB+ RAM, 20GB+ disk space
git clone <repository-url>
cd pavic-augmentation
pip install -r requirements.txtpython main.py# Specific datasets
python main.py --datasets ad-imagenet cifar-fs
# Skip stages for testing
python main.py --skip-diffusion
python main.py --skip-training --skip-evaluation
# Custom experiment
python main.py --experiment-name "my_experiment" --seed 123--experiment-name: Custom experiment name--datasets: Choose datasets (ad-imagenet,cifar-fs,plant-pathology)--skip-diffusion: Skip diffusion fine-tuning and synthetic generation--skip-training: Skip classifier training--skip-evaluation: Skip model evaluation--seed: Random seed (default: 42)
Edit src/config.py for key parameters:
datasets: Dataset selectionminority_classes: Classes to make minorityimbalance_ratio: Fraction of minority samples to keepdiffusion_model: Base diffusion modelnum_epochs: Training epochslearning_rate: Learning rate
- AdImageNet: 9,003 advertisement images (auto-downloaded)
- CIFAR-FS: Place FC100 pickle files in
data/raw/cifar-fs/ - Plant Pathology: Generated locally for testing
Generated in results/experiment_name/:
- Performance metrics (accuracy, F1, confusion matrices)
- Comparison visualizations and tables
- Class distribution analysis
- Synthetic image quality assessment
- RTX 4070: Automatically detected and optimized (12GB VRAM)
- Other GPUs: 8GB+ VRAM recommended (auto-detects)
- CPU: Automatically adjusts batch size and epochs
# Quick check if your RTX 4070 is detected
python quick_gpu_check.py
# Full GPU setup (if needed)
python check_gpu.py- GPU not detected: Run
python check_gpu.pyto install CUDA PyTorch - Import errors:
pip install --upgrade diffusers transformers accelerate - Memory issues: Automatically optimized per hardware
- Testing: Use
--skip-diffusionfor faster runs