Home Artificial Intelligence Simple Intro to Conditional GANs with TorchFusion and PyTorch

Simple Intro to Conditional GANs with TorchFusion and PyTorch


Humans are very good at recognizing things and also creating new things. For so long, we have worked on teaching computers to emulate human ability to recognize things but the ability to create new things eluded artificial intelligence systems for long. That was until 2014 when Ian Goodfellow invented Generative Adversarial Networks. In this post, we shall go through a basic overview of Generative Adversarial Networks and we shall use them to generate images of specific digits.

Overview of Generative Adversarial Networks

Imagine you are an artist trying to draw a very realistic picture of Obama that will fool a judge into thinking the picture is a real picture. The first time you do this, the judge easily detects your picture is fake, then you try again and again until the judge is fooled into thinking the picture is real. Generative Adversarial Networks works this way, it consists of two models,
A Generator that draws images and a Discriminator that attempts to distinguish between real images and the images drawn by the discriminator.
In a sense, both are competing with each other, the generator is trained to fool the discriminator while the discriminator is trained to properly tell apart which images are real and which are generated. In the end, the generator will become so perfect that the discriminator will not be able to tell apart between real and generated images.
Below are samples created by a GAN Generator.
GANs are of two general classes, Unconditional GANs that randomly generates any class of images and Conditional GANs that generates specific classes. In this tutorial, we shall be using the conditional gans as they allow us to specify what we want to generate.

Tools Setup

Tranining GANs is usually complicated, but thanks to Torchfusion, a research framework built on PyTorch, the process will be super simple and very straightforward.

Install Torchfusion via PyPi

pip3 install torchfusion

Install PyTorch

If you don’t have torchfusion already installed, head over to pytorch.org for the latest install binaries of PyTorch.
Now you are fully setup!
Next, import a couple of classes.
Define the generator network and the discriminator
In the above, we specify the resolution of the images to be generated as 1 x 32 x 32.
Setup the optimizers for both Generator and Discriminator models
Now we need to load a dataset which we shall try to draw samples from. In this case, we shall be using MNIST.
Below we create a Learner, torchfusion has various learners that are highly specialized for different purposes.
And now, we can call the train function to train the two models
By specifying the saveoutputsinterval to be 500, the learner will display sample generated outputs after every 500 batch iterations.
Here is the full code
After just 20 epochs of training, this generates the image below:
Now to the most exciting part, using your trained model, you can easily generate new images of specific digits.
In the code below, we generated a new image of digit 6, you can specify any digit between 0 — 9
Generative Adversarial Networks are an exciting field of research, torchfusion makes it very simple with well optimized implementations of the best GAN algorithms.
TorchFusion is developed and maintained by I and Moses Olafenwa, The AI Commons Team, as part of our efforts to democratize Artificial Intelligence and make it accessible to every single person and organization on the planet.
Official Repo of Torchfusion is https://github.com/johnolafenwa/TorchFusion
Tutorials and documentation for TorchFusion is available from https://torchfusion.readthedocs.io
You can always reach to me on twitter via @johnolafenwa

This content was originally published here.