This project implements a class-conditional diffusion model (DDPM) trained on the CIFAR-10 dataset, together with an interactive Streamlit chat interface powered by Google Gemini.
Users can request images of CIFAR-10 classes in natural language, and the system generates them using a trained diffusion model.
- Conditional DDPM trained on CIFAR-10 (32×32 RGB images)
- U-Net with attention and sinusoidal time embeddings
- EMA (Exponential Moving Average) for stable sampling
- Streamlit chat UI
- Gemini function calling to trigger image generation
- Supports all 10 CIFAR-10 classes
Supported classes:
airplane, automobile, bird, cat, deer,
dog, frog, horse, ship, truckA trained model checkpoint is available and can be downloaded from Kaggle:
Conditional Diffusion Pretrained Model
After downloading, create a folder named models in the project root (if it doesn't exist) and place the checkpoint there:
mkdir models
mv checkpoint.pt models/This ensures that the Streamlit app (main.py) can correctly load the model for inference.-
conditional-diffusion/
├── main.py # Streamlit app + Gemini integration
└── model.py # Diffusion model, UNet, EMA, training utilities-
Training (offline)
- A conditional DDPM is trained on CIFAR-10.
- A U-Net predicts the noise
$\epsilon$ given(x_t, t, class_label). - EMA weights are maintained for higher-quality sampling.
- A checkpoint (
checkpoint.pt) is saved.
-
Inference (online)
- The Streamlit app loads the EMA model from a checkpoint.
- User chats with a Gemini-powered assistant.
- If the user requests a valid CIFAR-10 class, Gemini calls:
generate_cifar_image(label="cat")
- The diffusion model generates and displays the image.
git clone https://github.com/b14ucky/conditional-diffusion.git
cd conditional-diffusionpython -m venv venv
source venv/bin/activate # Linux / macOS
venv\Scripts\activate # Windowspip install torch torchvision streamlit matplotlib tqdm python-dotenv google-genaiMake sure your PyTorch installation matches your CUDA setup if using GPU.
Create a .env file in the project root:
API_KEY=your_google_gemini_api_keyThis key is required for Gemini chat and function calling. The API key can be obtained here.
streamlit run main.pyThen open your browser at:
http://localhost:8501Examples of valid prompts:
- "Generate a cat"
- "Show me a ship"
- "I want an image of an airplane"
Examples of invalid prompts:
- "Generate a dragon"
- "Make a 4K portrait of a person"
If the requested object is not part of CIFAR-10, the assistant will politely refuse.
- U-Net with:
- Residual blocks
- Group normalization
- Multi-head self-attention at selected resolutions
- Sinusoidal time embeddings
- Class embeddings injected into residual blocks
-
DDPM with linear
$\beta$ schedule - 1000 diffusion steps
- Reverse process implemented manually
- EMA decay: 0.9999
- EMA weights used for sampling
You can train the model yourself using ModelTrainer in model.py.
Example:
from model import ModelTrainer
trainer = ModelTrainer(
batch_size=64,
time_steps=1000,
lr=2e-5,
)
trainer.train(
n_epochs=75,
checkpoint_output_path="checkpoint.pt",
)Training CIFAR-10 diffusion models is compute-intensive. A GPU is strongly recommended.
UNet– Conditional U-Net backboneResBlock– Residual blocks with label conditioningAttention– Multi-head self-attentionSinusoidalEmbeddings– Time-step embeddingsDDPMScheduler– Noise scheduleEMA– Exponential Moving Average wrapperLabelEncoder– Maps class names → label tensorsModelTrainer– Training loop and checkpointing
- Image resolution fixed to 32×32
- Only CIFAR-10 classes supported
- Sampling is relatively slow (pure PyTorch DDPM)
- Not intended for photorealistic generation
This project is licensed under the MIT License.