Challenges in Training Traditional GANs (Generative Adversarial Networks)

Marcus Jenkins avatar

Traditional GANs consist of two networks, the discriminator and the generator.

The generator receives an input vector of random values that are drawn from the normal distribution. The discriminator is a binary classifier that identifies whether the image is a real image from the dataset, or a fake from the generator. It does this by assigning a high confidence (value close to 1) to the real images and a low confidence (value close to 0) to fake images.

For each training step, the generator creates a batch of images, and the discriminator outputs a confidence score to indicate how likely the images are real. The generator is then updated to increase the discriminator’s confidence for its images.

Likewise, in the same training step, the discriminator receives a batch of real images from the dataset and a batch of fake images created by the generator. The discriminator is updated to assign a higher confidence to the real images and a lower confidence to the fake images.

In this sense, the discriminator and the generator are competing against each other. The generator is trying to “fool” the discriminator, while the discriminator is trying not to be fooled.

However, in practise, traditional GANs are notoriously unstable to train. Initially, the discriminator often outperforms the generator, and, as training progresses, the generator improves and begins to challenge the discriminator. But if one network, either the generator or the discriminator, performs too well and the other cannot catch up, issues such as vanishing gradients or mode collapse can occur, hindering the overall training process.

1. Vanishing Discriminator Gradient

If the discriminator is allowed to strongly out-perform the generator, or the generator out-performs the discriminator too quickly, the gradient used to optimise the parameters in the generator will vanish (approach 0). This is attributed to the sigmoid function, σ(x), which converts the output of the discriminator network to a confidence score, bounded between 0 and 1.

The definition of the sigmoid function with respect to the input x.

But why is this an issue? Well, let’s visualise the sigmoid function:

A graph showing the output of the sigmoid function across the values negative 10 to positive 10.

We can see that, as x decreases below around -5, σ(x) approaches 0, and x as increases above around 5, σ(x) approaches 1. Further changes in x beyond these values result in negligible change in σ(x). In other words, saturation occurs.

This means that if the discriminator becomes too good at identifying the generated images as fake, the feedback it provides to the generator to improve becomes less useful. Similarly, if the generator is able to out-perform the discriminator too quickly, the feedback to the discriminator uses to learn to reduce its confidence in fake images becomes less meaningful also.

More specifically, since σ(x) changes very little with respect to the change in the discriminator’s output, x, when σ(x) is close to 0 or 1, the gradient of the output with respect to the loss becomes vanishingly small—this is known as the vanishing gradient problem.

Let’s try to understand this using an example where the discriminator is strongly out-performing the generator.

1.1. Demonstration

Let’s say our generator produces an image with the pixels x1, x2 … xn, which are fed to a simple discriminator network, D(x):

A discriminator network consisting of a simple neural network that outputs h(x), which is fed to a sigmoid activation function ŷ = sigmoid(h(x)).

The discriminator is dominating, and so produces a highly negative value of h(x) = -10, giving σ(h(x)) = 0.00004 as the confidence of the image being real.

The generator’s loss for this image will be:

The loss function of the generator, denoted as L subscript G, equals log(1 - ŷ).

Let’s calculate the gradient ∂LG/∂h(x). To do so we work backwards from ŷ using back-propagation via the chain rule:

The gradient of the loss function L subscript G with respect to h(x) is calculated using the chain rule.

First we’ll calculate ∂LG/∂σ(h(x)) using the rule dloge(a)/da = 1/(1 – a):

The derivative of L subscript G with respect to sigmoid(h(x)) is 1 / (1 - sigmoid(h(x)))

Then ∂σ(h(x))/∂h(x) using the rule dσ(a)/da = σ(a)(1 – σ(a)):

The derivative of sigmoid(h(x)) with respect to h(x) is sigmoid(h(x))(1 - sigmoid(h(x))).

Therefore:

The final calculation of the gradient of the loss function L subscript G with respect to h(x) is σ(h(x)).

This means that, whatever σ(h(x)) is, the gradient of the generator’s loss function with respect to the discriminator network’s output, ∂LG/∂h(x), will be the same. This is 0.00004 in this instance, which is very small.

To understand the effect on the generator’s learning we need to visualise the full gradient flow from LG to the generator, G(z).

A simple generator network, G(z), demonstrated as a neural network, accepts input noise z and the output is attached to the discriminator network, D(z).
Highly simplified example of a GAN network.

To update the generator, ∂LG/∂θGi, where θGi is the i-th parameter in the generator, must be calculated. This is calculated using back-propagation from ∂LG/∂h(x) via the chain rule:

The gradient of L subscript G with respect to theta subscript Gi is calculated using chain rule, and so all terms are multiplied by the gradient of the loss function L subscript G with respect to h(x), which is 0.00004.

As we can see, a small ∂LG/∂h(x) causes the gradient of the loss with respect to each parameter in the generator, ∂LG/∂θGi, to be small also, due to multiplication by ∂LG/∂h(x). This means the generator won’t improve, since updates to the generator’s parameters via gradient descent will be minimal:

Each parameter in G(z), denoted as theta subscript Gi, is updated proportional to the gradient of L subscript G with respect to theta subscript Gi.

2. Mode Collapse

If the generator improves too quickly, or the discriminator is too weak, mode collapse can occur. Mode collapse is characterised by a break down in the diversity of images produced by the generator. This can be a limited set of images, or, In the worst case, the same image, regardless of the input noise vector, z.

This happens because the generator essentially overfits to the discriminator. The discriminator learns a set of features that it uses to distinguish real images from fake ones. If the discriminator cannot keep up, the generator can learn to produce images that replicate a limited set of features that consistently fool the discriminator. Let’s imagine that there is an image X* = G(z) that the discriminator will assign the highest confidence to, thus provides the minimum loss in LG.

The loss function of the generator, denoted as L subscript G, equals log(1 - ŷ).

If the discriminator fails to compete with the generator, the generator will gradually converge to output images that mimic optimal image X* for the full distribution of z.

Likewise, because the discriminator is assigning a high likelihood of the generated images being real, the generator’s loss will be very small. Therefore the generator will halt exploration and settle on the same set of images, or singular image.

3. An Uncertain Discriminator

If the discriminator is too weak, it may lose its ability to distinguish between real and fake images with confidence. This is characterised by a confidence that hovers around 0.5, with a lacking disparity between real and fake images. As a result, the discriminator fails to provide adequate feedback, akin to when the discriminator is dominating (Section 1).

As discussed in Section 1, the generator relies on the gradient ∂LG/∂θGi to improve itself. ∂LG/∂θGi describes the change in LG as each parameter, θGi, changes. As we’ve know, LG is:

The loss function of the generator, denoted as L subscript G, equals log(1 - ŷ).

If the discriminator always outputs values close to 0.5, however, no matter how we change θGi, the loss will always remain around -0.69:

The loss L subscript G is -0.693 when the discriminator outputs a confidence of 0.5.

Let’s try to understand this in a bit more depth with a demonstration.

3.1. Demonstration

Imagine our generator produces an image that is given a confidence of 0.498, and so LG will be -0.68912. θG1 is then bumped up by 0.1. This gives an increased confidence of 0.4982, and so LG decreases to -0.68955. The gradient LG with respect to θG1 is:

As we can see, this gradient is vanishingly small, and so the generator will not be able to learn effectively.

4. Final Words

In this article, we explored three key issues in training traditional GANs. Since then, research has primarily focussed on addressing these issues to provide more stable training processes. One notable advancement is the Wasserstein GAN (WGAN), which I intend to cover in a future article.

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *