Hi @nicolas-dufour, thanks for releasing this awesome repo with detailed readme and configs! I'm trying to reproduce RIN on CIFAR-10 with your code and got similar inception score of 10.184 (vs original 10.3). Manual inspection of the generated images show decent quality. However, the FID of 3.367 is worse than the reported 1.81 in the paper. I saw your discussion with RIN authors where you confirmed a reproduction in the end. What did I miss? Thanks.
Training command: python cad/train.py overrides=cifar10_rin
Evaluation command: python cad/test.py overrides=cifar10_rin computer.devices=1 logger.offline=False logger.project=RIN
Results:
Inception: 10.184
FID: 3.367
Accuracy: 0.953
Precision: 0.675
Recall: 0.643
Density: 0.956
Coverage: 0.829
Link to wandb trace: https://wandb.ai/kevinxli/RIN/runs/msgj3d04
Link to wandb evaluation results: https://wandb.ai/kevinxli/RIN/runs/u8v6hpyt
Auto generated config:
model:
optimizer:
optim:
_target_: utils.optimizers.Lamb
lr: 0.003
betas:
- 0.9
- 0.999
weight_decay: 0.01
exclude_ln_and_biases_from_weight_decay: true
lr_scheduler:
_partial_: true
_target_: utils.lr_scheduler.WarmupCosineDecayLR
warmup_steps: 10000
total_steps: ${trainer.max_steps}
rate: 0.8
network:
_target_: cad.models.networks.rin.RINClassCond
data_size: ${data.data_resolution}
data_dim: 256
num_input_channels: 3
num_latents: 128
latents_dim: 512
label_dim: ${data.label_dim}
num_cond_tokens: ${data.num_cond_tokens}
num_processing_layers: 2
num_blocks: 3
path_size: 2
read_write_heads: 8
compute_heads: 16
latent_mlp_multiplier: 4
data_mlp_multiplier: 2
rw_dropout: 0.0
compute_dropout: 0.1
rw_stochastic_depth: 0
compute_stochastic_depth: 0.1
time_scaling: 1000.0
noise_embedding_type: positional
data_positional_embedding_type: learned
weight_init: xavier_uniform
bias_init: zeros
use_cond_token: true
use_biases: true
concat_cond_token_to_latents: true
use_cond_rin_block: false
use_16_bits_layer_norm: false
train_noise_scheduler:
_target_: cad.models.schedulers.SigmoidScheduler
start: -3
end: 3
tau: 0.9
clip_min: 1.0e-09
inference_noise_scheduler:
_target_: cad.models.schedulers.CosineSchedulerSimple
ns: 0.0002
ds: 0.00025
preconditioning:
_target_: cad.models.preconditioning.DDPMPrecond
num_latents: ${model.network.num_latents}
latents_dim: ${model.network.latents_dim}
data_preprocessing:
_target_: cad.models.preprocessing.PrecomputedPreconditioning
input_key: image
output_key_root: x_0
cond_preprocessing:
_target_: cad.models.preprocessing.PrecomputedPreconditioning
input_key: label
output_key_root: label
drop_labels: false
postprocessing:
_partial_: true
_target_: utils.image_processing.remap_image_torch
loss:
_partial_: true
_target_: cad.models.losses.DDPMLoss
self_cond_rate: 0.9
cond_drop_rate: 0.0
conditioning_key: ${model.cond_preprocessing.output_key_root}
resample_by_coherence: false
sample_random_when_drop: false
val_sampler:
_partial_: true
_target_: cad.models.samplers.ddim.ddim_sampler
num_steps: 250
cfg_rate: ${model.cfg_rate}
test_sampler:
_partial_: true
_target_: cad.models.samplers.ddpm.ddpm_sampler
num_steps: 1000
cfg_rate: ${model.cfg_rate}
uncond_conditioning:
_target_: cad.utils.misc.dummy_value_loader
value: 0.0
vae_embedding_name_mean: null
return_image: true
name: RIN
ema_decay: 0.9999
start_ema_step: 0
cfg_rate: 0.0
channel_wise_normalisation: false
computer:
devices: 1
num_workers: 10
progress_bar_refresh_rate: 2
sync_batchnorm: false
accelerator: gpu
precision: 16-mixed
strategy: auto
num_nodes: 1
eval_gpu_type: v100
data:
train_aug:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.RandomHorizontalFlip
p: 0.5
- _target_: torchvision.transforms.Normalize
mean: 0.5
std: 0.5
val_aug:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
mean: 0.5
std: 0.5
name: CIFAR-10
type: class_conditional
img_resolution: 32
data_resolution: 32
label_dim: 10
num_cond_tokens: 1
full_batch_size: 256
in_channels: 3
out_channels: 3
train_instance:
_partial_: true
_target_: torchvision.datasets.CIFAR10
root: ${data_dir}
download: true
train: true
transform: ${data.train_aug}
target_transform: ${data.target_transform}
val_instance:
_partial_: true
_target_: torchvision.datasets.CIFAR10
root: ${data_dir}
download: true
train: false
transform: ${data.val_aug}
target_transform: ${data.target_transform}
target_transform:
_target_: utils.one_hot_transform.OneHotTransform
num_classes: ${data.label_dim}
collate_fn:
_target_: data.datamodule.collate_to_dict
keys:
- image
- label
train_dataset: ${data.train_instance}
val_dataset: ${data.val_instance}
datamodule:
_target_: data.datamodule.ImageDataModule
train_dataset: ${data.train_dataset}
val_dataset: ${data.val_dataset}
full_batch_size: ${data.full_batch_size}
collate_fn: ${data.collate_fn}
num_workers: ${computer.num_workers}
num_nodes: ${computer.num_nodes}
num_devices: ${computer.devices}
trainer:
_target_: pytorch_lightning.Trainer
max_steps: 150000
val_check_interval: 5000
check_val_every_n_epoch: null
devices: ${computer.devices}
accelerator: ${computer.accelerator}
strategy: ${computer.strategy}
log_every_n_steps: 1
num_nodes: ${computer.num_nodes}
precision: ${computer.precision}
logger:
_target_: pytorch_lightning.loggers.WandbLogger
save_dir: ${root_dir}/cad/wandb
name: ${experiment_name}
project: RIN
log_model: false
offline: false
checkpoints:
_target_: callbacks.checkpoint_and_validate.ModelCheckpointValidate
gpu_type: ${computer.eval_gpu_type}
validate_when_not_on_cluster: false
validate_when_on_cluster: false
eval_set: train
validate_conditional: true
validate_unconditional: false
validate_per_class_metrics: true
shape:
- ${model.network.num_input_channels}
- ${data.data_resolution}
- ${data.data_resolution}
num_classes: ${data.label_dim}
dataset_name: ${data.name}
dirpath: ${root_dir}/cad/checkpoints/${experiment_name}
filename: step_{step}
monitor: val/loss_ema
save_last: true
save_top_k: -1
enable_version_counter: false
every_n_train_steps: 10000
auto_insert_metric_name: false
progress_bar:
_target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: ${computer.progress_bar_refresh_rate}
data_dir: ${root_dir}/cad/datasets
root_dir: ${hydra:runtime.cwd}
experiment_name_suffix: base
experiment_name: ${data.name}_${model.name}_${experiment_name_suffix}
Hi @nicolas-dufour, thanks for releasing this awesome repo with detailed readme and configs! I'm trying to reproduce RIN on CIFAR-10 with your code and got similar inception score of 10.184 (vs original 10.3). Manual inspection of the generated images show decent quality. However, the FID of 3.367 is worse than the reported 1.81 in the paper. I saw your discussion with RIN authors where you confirmed a reproduction in the end. What did I miss? Thanks.
Training command:
python cad/train.py overrides=cifar10_rinEvaluation command:
python cad/test.py overrides=cifar10_rin computer.devices=1 logger.offline=False logger.project=RINResults:
Inception: 10.184
FID: 3.367
Accuracy: 0.953
Precision: 0.675
Recall: 0.643
Density: 0.956
Coverage: 0.829
Link to wandb trace: https://wandb.ai/kevinxli/RIN/runs/msgj3d04
Link to wandb evaluation results: https://wandb.ai/kevinxli/RIN/runs/u8v6hpyt
Auto generated config: