Skip to content

Enabling JAX as backend for the GAN training step #8

@mbarbetti

Description

@mbarbetti

Starting from the v0.2.0 release PIDGAN is compatible with the new multi-backend Keras 3.

Keras 3 is a full rewrite of Keras that enables you to run your Keras workflows on top of either JAX, TensorFlow, or PyTorch, and that unlocks brand new large-scale model training and deployment capabilities.

At the moment, training GAN models is only possible by using the TensorFlow backend. For example, if we look at lines 173-183 of the Keras3-based GAN class, we have

def train_step(self, *args, **kwargs):
  if keras.backend.backend() == "tensorflow":
    return self._tf_train_step(*args, **kwargs)
  elif keras.backend.backend() == "torch":
     raise NotImplementedError("`train_step()` not implemented for the PyTorch backend")
  elif keras.backend.backend() == "jax":
     raise NotImplementedError("`train_step()` not implemented for the Jax backend")

The goal of this issue is to implement the train_step() also for the JAX backend. In addition to the "plain" training step, also the Lipschitz regularization functions should be adapted to rely on the JAX backend.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requesthelp wantedExtra attention is neededpythonPull requests that update Python code

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions