from IPython.display import Image
Following along with Chollet’s example in his Deep Learning book, he builds a GAN over a bunch of frog images from CIFAR. It’s a great, succinct bit of code, so I’m not going to copy-paste it all here. Instead, I’ll comment on the things that stand out as non-obvious and important section by section
Building the Generator
Everything starts off by defining the
generator_inputrelative to some
latent_dim– the dimension of our latent space that we’re going to use to generate noise. He defines this to be 32, I’ve seen it as high as 100-200 for larger image applications
The first hidden layer is a large Dense layer that is comprised of
(smaller height, smaller width, lots of filters). This learns to generate pixel-level representations from our latent space. We reshape this into an actual image before doing any convolution.
Each layer of our Generator (and later with the Discriminator) leverages a
ReLU()as an activation function. This is because sparsity (many, many zeros for negative values that
ReLU()provides) is an undesirable trait in a GAN– less non-zero weights means less levers the Generator can pull to fool the Discriminator.
This one threw me for a loop.
The basic idea of the
Conv2DTranspose() layer is for upsampling our small, random images into something the same size as our real data,
Particularly, this YouTube video provided a nice visual representation of what the reverse of a Convolution operation accomplishes.
Moreover, this repo has some great graphics to cement the intuition for Transpose Convolution– particularly when it comes to understanding what “stride” means in this context.
strides=2 means that you insert a blank space between each pixel in the original image. This (with
padding=same) is how we achieve upsampling from
One more note: Chollet has a great nugget of wisdom with respect to
strides for GANs
In generated images, it is common to see “checkerboard artifacts” caused by unequal coverage of the pixel space in the generator. To fix this, we use a kernel size that is divisible by the stride size, whenever we use a strided Conv2DTranpose or Conv2D in both the generator and discriminator.
Building the Discriminator
As with the Discriminator, we avoid sparsity by utilizing
LeakyReLU() layers. We also introduce randomness to the learning by tossing a healthy amount of
Dropout() before the final
Dense() layer, used to do the
Building the GAN
One tricky thing to note is the order that we compile our
The general pseudocode of training is:
- Train the Discriminator by itself
- Pass those gradients back to the Generator
- Train the Generator
So you’ll notice that when constructing the Networks that we use to construct the GAN, he compiles the Discriminator, but not the Generator.
This is because the
model.compile() step freezes the Network definition (particularly all of the
layer.trainable values) as-is.
Only after the Discriminator is frozen do we turn around and say
discriminator.trainable = False. Then we build the GAN by piping the noise into the Generator and the Generator into the Discriminator. And after we specify that the overall model takes noise from the latent space and outputs the prediction on an image do we compile the network in its entirety.
This compiled object will have both models connected and talking to one another via gradients, but only the
Generator parameters will be trainable.
He does a lot of data prep here, particularly regarding:
- Loading the dataset
- Handling batch size
That’s all pretty straight-forward and not specific to GANs
Train the Discriminator
Before we can train our discrimator to sort real from fake images, we need some fake images, so we:
- Sample random noise from the latent space
- Use our Generator to make fake images via
Where we’ll diverge (and most example I’ve seen do…) is not concatenating the real and fake images into one list. This allows us to better-monitor the loss and accuracy for both Networks over both groups.
Then we’ll leverage the same “multiply by
0.05 * normal_noise” trick that he uses. This dampens our Discriminator from learning TOO quickly.
Then we train the Discriminator with
discriminator.train_on_batch() (which only goes through one training step)
Train the Generator
Now, we’re going to generate all fake images, lie and say they’re real, and send them over to the Discriminator.
gan.train_on_batch(), passing in latent noise and a bunch of True’s. Because the Generator is hooked into the Discriminator, all of the gradients from prediction error flow backward and inform the Generator. Recall that every single image sent over was fake, but we said it was true, let’s consider both examples:
- The Discriminator says “This is real”: It matches the label we sent over, reinforcing this kind of image being sent in the future
- The Discriminator says “This is fake”: Whatever the Generator thought it’d learned about the underling data distribution
pr_data, was clearly not representative of how we get our images. Try something else.
In either case, this is a win-win for our Generator… assuming that there’s some healthy mix of fooling and not-foolilng the Discriminator.
And therein lies the rub.