Skip to content

Trouble reproducing CIFAR-10 RIN results #3

@AlienKevin

Description

@AlienKevin

Hi @nicolas-dufour, thanks for releasing this awesome repo with detailed readme and configs! I'm trying to reproduce RIN on CIFAR-10 with your code and got similar inception score of 10.184 (vs original 10.3). Manual inspection of the generated images show decent quality. However, the FID of 3.367 is worse than the reported 1.81 in the paper. I saw your discussion with RIN authors where you confirmed a reproduction in the end. What did I miss? Thanks.

Training command: python cad/train.py overrides=cifar10_rin

Evaluation command: python cad/test.py overrides=cifar10_rin computer.devices=1 logger.offline=False logger.project=RIN

Results:
Inception: 10.184
FID: 3.367
Accuracy: 0.953
Precision: 0.675
Recall: 0.643
Density: 0.956
Coverage: 0.829

Link to wandb trace: https://wandb.ai/kevinxli/RIN/runs/msgj3d04
Link to wandb evaluation results: https://wandb.ai/kevinxli/RIN/runs/u8v6hpyt
Auto generated config:

model:
  optimizer:
    optim:
      _target_: utils.optimizers.Lamb
      lr: 0.003
      betas:
      - 0.9
      - 0.999
      weight_decay: 0.01
    exclude_ln_and_biases_from_weight_decay: true
  lr_scheduler:
    _partial_: true
    _target_: utils.lr_scheduler.WarmupCosineDecayLR
    warmup_steps: 10000
    total_steps: ${trainer.max_steps}
    rate: 0.8
  network:
    _target_: cad.models.networks.rin.RINClassCond
    data_size: ${data.data_resolution}
    data_dim: 256
    num_input_channels: 3
    num_latents: 128
    latents_dim: 512
    label_dim: ${data.label_dim}
    num_cond_tokens: ${data.num_cond_tokens}
    num_processing_layers: 2
    num_blocks: 3
    path_size: 2
    read_write_heads: 8
    compute_heads: 16
    latent_mlp_multiplier: 4
    data_mlp_multiplier: 2
    rw_dropout: 0.0
    compute_dropout: 0.1
    rw_stochastic_depth: 0
    compute_stochastic_depth: 0.1
    time_scaling: 1000.0
    noise_embedding_type: positional
    data_positional_embedding_type: learned
    weight_init: xavier_uniform
    bias_init: zeros
    use_cond_token: true
    use_biases: true
    concat_cond_token_to_latents: true
    use_cond_rin_block: false
    use_16_bits_layer_norm: false
  train_noise_scheduler:
    _target_: cad.models.schedulers.SigmoidScheduler
    start: -3
    end: 3
    tau: 0.9
    clip_min: 1.0e-09
  inference_noise_scheduler:
    _target_: cad.models.schedulers.CosineSchedulerSimple
    ns: 0.0002
    ds: 0.00025
  preconditioning:
    _target_: cad.models.preconditioning.DDPMPrecond
    num_latents: ${model.network.num_latents}
    latents_dim: ${model.network.latents_dim}
  data_preprocessing:
    _target_: cad.models.preprocessing.PrecomputedPreconditioning
    input_key: image
    output_key_root: x_0
  cond_preprocessing:
    _target_: cad.models.preprocessing.PrecomputedPreconditioning
    input_key: label
    output_key_root: label
    drop_labels: false
  postprocessing:
    _partial_: true
    _target_: utils.image_processing.remap_image_torch
  loss:
    _partial_: true
    _target_: cad.models.losses.DDPMLoss
    self_cond_rate: 0.9
    cond_drop_rate: 0.0
    conditioning_key: ${model.cond_preprocessing.output_key_root}
    resample_by_coherence: false
    sample_random_when_drop: false
  val_sampler:
    _partial_: true
    _target_: cad.models.samplers.ddim.ddim_sampler
    num_steps: 250
    cfg_rate: ${model.cfg_rate}
  test_sampler:
    _partial_: true
    _target_: cad.models.samplers.ddpm.ddpm_sampler
    num_steps: 1000
    cfg_rate: ${model.cfg_rate}
  uncond_conditioning:
    _target_: cad.utils.misc.dummy_value_loader
    value: 0.0
  vae_embedding_name_mean: null
  return_image: true
  name: RIN
  ema_decay: 0.9999
  start_ema_step: 0
  cfg_rate: 0.0
  channel_wise_normalisation: false
computer:
  devices: 1
  num_workers: 10
  progress_bar_refresh_rate: 2
  sync_batchnorm: false
  accelerator: gpu
  precision: 16-mixed
  strategy: auto
  num_nodes: 1
  eval_gpu_type: v100
data:
  train_aug:
    _target_: torchvision.transforms.Compose
    transforms:
    - _target_: torchvision.transforms.ToTensor
    - _target_: torchvision.transforms.RandomHorizontalFlip
      p: 0.5
    - _target_: torchvision.transforms.Normalize
      mean: 0.5
      std: 0.5
  val_aug:
    _target_: torchvision.transforms.Compose
    transforms:
    - _target_: torchvision.transforms.ToTensor
    - _target_: torchvision.transforms.Normalize
      mean: 0.5
      std: 0.5
  name: CIFAR-10
  type: class_conditional
  img_resolution: 32
  data_resolution: 32
  label_dim: 10
  num_cond_tokens: 1
  full_batch_size: 256
  in_channels: 3
  out_channels: 3
  train_instance:
    _partial_: true
    _target_: torchvision.datasets.CIFAR10
    root: ${data_dir}
    download: true
    train: true
    transform: ${data.train_aug}
    target_transform: ${data.target_transform}
  val_instance:
    _partial_: true
    _target_: torchvision.datasets.CIFAR10
    root: ${data_dir}
    download: true
    train: false
    transform: ${data.val_aug}
    target_transform: ${data.target_transform}
  target_transform:
    _target_: utils.one_hot_transform.OneHotTransform
    num_classes: ${data.label_dim}
  collate_fn:
    _target_: data.datamodule.collate_to_dict
    keys:
    - image
    - label
  train_dataset: ${data.train_instance}
  val_dataset: ${data.val_instance}
  datamodule:
    _target_: data.datamodule.ImageDataModule
    train_dataset: ${data.train_dataset}
    val_dataset: ${data.val_dataset}
    full_batch_size: ${data.full_batch_size}
    collate_fn: ${data.collate_fn}
    num_workers: ${computer.num_workers}
    num_nodes: ${computer.num_nodes}
    num_devices: ${computer.devices}
trainer:
  _target_: pytorch_lightning.Trainer
  max_steps: 150000
  val_check_interval: 5000
  check_val_every_n_epoch: null
  devices: ${computer.devices}
  accelerator: ${computer.accelerator}
  strategy: ${computer.strategy}
  log_every_n_steps: 1
  num_nodes: ${computer.num_nodes}
  precision: ${computer.precision}
logger:
  _target_: pytorch_lightning.loggers.WandbLogger
  save_dir: ${root_dir}/cad/wandb
  name: ${experiment_name}
  project: RIN
  log_model: false
  offline: false
checkpoints:
  _target_: callbacks.checkpoint_and_validate.ModelCheckpointValidate
  gpu_type: ${computer.eval_gpu_type}
  validate_when_not_on_cluster: false
  validate_when_on_cluster: false
  eval_set: train
  validate_conditional: true
  validate_unconditional: false
  validate_per_class_metrics: true
  shape:
  - ${model.network.num_input_channels}
  - ${data.data_resolution}
  - ${data.data_resolution}
  num_classes: ${data.label_dim}
  dataset_name: ${data.name}
  dirpath: ${root_dir}/cad/checkpoints/${experiment_name}
  filename: step_{step}
  monitor: val/loss_ema
  save_last: true
  save_top_k: -1
  enable_version_counter: false
  every_n_train_steps: 10000
  auto_insert_metric_name: false
progress_bar:
  _target_: pytorch_lightning.callbacks.TQDMProgressBar
  refresh_rate: ${computer.progress_bar_refresh_rate}
data_dir: ${root_dir}/cad/datasets
root_dir: ${hydra:runtime.cwd}
experiment_name_suffix: base
experiment_name: ${data.name}_${model.name}_${experiment_name_suffix}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions