Simple Generative Adversarial Network (GANs) with Keras 01 Feb 2018

image

This is a simple example to illustrate the basic idea behind Generatives Adversarial Networks (GANs).

I tried to eliminate all the bells and whistles from the most common implementation I found online, so that its easier to focus on the main idea.

It follows that the results are not optimised and are not as nice as they could be using more complex layers and architecture.

For the TL;DR people the repo is here and the full code is at the end of the page.

The goal is to train a NN to generate new images, that are indistinguishable from the ones contained in the database (MNIST in this case) without making identical copies. The NN learns what is the “essence” of the image and is then able to create it starting from a a random array of numbers.

The main idea is to have two separate Neural Networks, a generator and a discriminator, locked in a competition with eachother.

The generator create new imagesas similar as possible to the pictures in the database. The discriminator try to understand if they are original pictures or synthetic ones. So how does it work? Lets start from the building blocks.

This is the generator:

def generator(self):

model = Sequential()
        model.add(Dense(256, input_shape=(100,)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.WIDTH * self.HEIGHT * self.CHANNELS, activation='tanh'))

model.add(Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS)))
        
        return model

As you can see this is as simple as possible, 3 fully connected layers, with batch normalisation.

This is the discriminator:

def discriminator(self):

    model = Sequential()
    model.add(Flatten(input_shape=self.SHAPE))
    model.add(Dense((self.WIDTH * self.HEIGHT * self.CHANNELS), input_shape=self.SHAPE))
        
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense((self.WIDTH * self.HEIGHT * self.CHANNELS)/2))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    return model

The discriminator is as simple as the generator.

I first create instances of the generator and the discriminator:

def __init__(self, width = 28, height= 28, channels = 1):

    self.WIDTH = width
    self.HEIGHT = height
    self.CHANNELS = channels

    self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)

    self.OPTIMIZER = Adam(lr=0.0002, decay=8e-9)

    self.noise_gen = np.random.normal(0,1,(100,))

    self.G = self.generator()
    self.G.compile(loss='binary_crossentropy', \
        optimizer=self.OPTIMIZER)

    self.D = self.discriminator()
    self.D.compile(loss='binary_crossentropy', \
        optimizer=self.OPTIMIZER, metrics=['accuracy'] 

Now the tricky part to understand. Lets create the adversarial Model (stacked_G_D in the code), this is just a generator followed by a discriminator. Notice that the weights of the discriminator have been frozen, so when I train this model, the generator layers will be unaffected and just relay upwards the gradient:

def stacked_G_D(self):
    self.D.trainable = False ## this freezes the weight, keep  
                             ## reading to understand why this is
                             ## necessary

    model = Sequential()
    model.add(self.G)
    model.add(self.D)

    return model

### I need to compile the stacked model, so that a new instance of
### the discriminator is created with the frozen parameters. 

self.stacked_G_D = self.stacked_G_D()
self.stacked_G_D.compile(loss='binary_crossentropy', \
 optimizer=self.OPTIMIZER)

Now lets start the training:

def train(self, X_train, epochs=20000, batch = 32, save_interval = 200):

  for cnt in range(epochs):

    ## train discriminator
    
    random_index =  np.random.randint(0, len(X_train) - batch/2)
    
    legit_images = X_train[random_index : random_index + batch/2].reshape(batch/2, self.WIDTH, self.HEIGHT, self.CHANNELS)

    gen_noise = np.random.normal(0, 1, (batch/2,100))
    syntetic_images = self.G.predict(gen_noise)

    x_combined_batch = np.concatenate((legit_images,
syntetic_images))

    y_combined_batch = np.concatenate((np.ones((batch/2, 1)), np.zeros((batch/2, 1))))

    d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)

This is normal batch training of the discriminator.

Now lets train the generator through the adversarial model:

# train generator

     noise = np.random.normal(0, 1, (batch,100))
     y_mislabled = np.ones((batch, 1))

     g_loss = self.stacked_G_D.train_on_batch(noise, y_mislabled)

Notice the key point, I dont train he generator directly. I train it indirectly through the adversarial model.

I pass noise to the adversarial model and mislabel everything as if they were images taken from the database, when they will be generated by the generator.

The discriminator, which is trained beforehand on the real images, will fail to label the synthetic images as real, and the error committed will account for an increasingly high error, as calculated by the error function.

So here is where back propagation kicks in. Because the parameter of the discriminator are frozen, in this instance, back propagation will not affect them. Instead it will affect the parameters of the generator.

So minimising the error function of the adversarial model really means make the generated images as similar as possible to something, the discriminator will recognise as real. This is quite ingenious in my book !

At the end of the training phase, you want the loss value of the adversarial model to be small, and the error of the discriminator to be the highest possible, which means it is no longer able to tell the difference.

At the end of my training phase, my discriminator loss was around 0.73. Considering that we are feeding it 50% real and 50% synthetic, it means it was sometimes not able to recognise the fake images. Which is quite a good result, considering this example is absolutely not optimising for the result. To know the exact percentage I could have added a precision metrics at compile time, note to self..

Notice its possible to get a lot better results using more complex architectures for the generator and discriminator.

Here the full code:

import sys

import numpy as np

from IPython.core.debugger import Tracer

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam, RMSprop

import matplotlib.pyplot as plt
plt.switch_backend('agg')


class GAN(object):
def __init__(self, width = 28, height= 28, channels = 1):

self.WIDTH = width
        self.HEIGHT = height
        self.CHANNELS = channels

self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)

self.OPTIMIZER = Adam(lr=0.0002, decay=8e-9)

self.noise_gen = np.random.normal(0,1,(100,))

self.G = self.generator()
        self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

self.D = self.discriminator()
        self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'] )

self.stacked_G_D = self.stacked_G_D()
        
        self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)

def generator(self):

model = Sequential()
        model.add(Dense(256, input_shape=(100,)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.WIDTH * self.HEIGHT * self.CHANNELS, activation='tanh'))

model.add(Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS)))
        
        return model

def discriminator(self):

model = Sequential()
        model.add(Flatten(input_shape=self.SHAPE))
        model.add(Dense((self.WIDTH * self.HEIGHT * self.CHANNELS), input_shape=self.SHAPE))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense((self.WIDTH * self.HEIGHT * self.CHANNELS)/2))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

return model

def stacked_G_D(self):
        self.D.trainable = False

model = Sequential()
        model.add(self.G)
        model.add(self.D)

return model

def train(self, X_train, epochs=20000, batch = 32, save_interval = 200):

for cnt in range(epochs):

## train discriminator
            random_index =  np.random.randint(0, len(X_train) - batch/2)
            legit_images = X_train[random_index : random_index + batch/2].reshape(batch/2, self.WIDTH, self.HEIGHT, self.CHANNELS)

gen_noise = np.random.normal(0, 1, (batch/2,100))
            syntetic_images = self.G.predict(gen_noise)

x_combined_batch = np.concatenate((legit_images, syntetic_images))
            y_combined_batch = np.concatenate((np.ones((batch/2, 1)), np.zeros((batch/2, 1))))

d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)

# train generator

noise = np.random.normal(0, 1, (batch,100))
            y_mislabled = np.ones((batch, 1))

g_loss = self.stacked_G_D.train_on_batch(noise, y_mislabled)

print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))
            
            if cnt % save_interval == 0 : 
                self.plot_images(save2file=True, step=cnt)

def plot_images(self, save2file=False,  samples=16, step=0):
        filename = "./images/mnist_%d.png" % step
        noise = np.random.normal(0, 1, (samples,100))

images = self.G.predict(noise)
        
        plt.figure(figsize=(10,10))
    
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [ self.HEIGHT, self.WIDTH ])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()

if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()

if __name__ == '__main__':
    (X_train, _), (_, _) = mnist.load_data()

# Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

gan = GAN()
gan.train(X_train)
-