 # Neural network simplified part 3 : learning

16 August 2020

What differentiate Artificial Intelligence (AI) from other computer programs is their ability to learn; their behavior is not solely decided by programmers but also their own experience. This is why AI can outsmart us in many tasks. In this article, I will try to show you how neural networks (NNs) learn.

# What it means to train a neural network (NN)

Let’s remind ourselves of the first article where I presented a simple NN to classify mails as spam or not spam. These kinds of NNs take in some numbers and output others. Here, the input represents a mail in some way (style, spelling, grammar, …) and the output represents how much the network thinks this mail is spam.

In the first article, I also explained how NNs make decisions. To summarize, neurons are linked and these linked have an associated weight. The weight of each link decides how strongly the information flows from one neuron to the next. Hence, the weights completely define how the network behaves.

To train a NN is to find an arrangement of weight, where some neurons are more strongly linked than others, such that the network makes good decisions. Note that many different arrangements of weight can lead to the same performance due to the internal symmetry of NN.

# Supervised learning

Supervised learning is one way that we can train NNs. It means we use labelled data to train our network. In this case, imagine we have 10000 mails manually labelled by someone as spam or not spam. We will present this data to the network, so it can learn to differentiate the two kinds of mail.

In practice, we usually cut the data in multiple sets. A training set is used to train the network and a test set is used to test its performance. Indeed, since the network is trained on the training set, it knows the data and will perform better on it. Thus, keeping a separate test set ensures the network is tested on data it has never seen. This gives a better measure of the performance of the network.

# The influence of weights on the error

What we want is to reduce the overall error the network makes. Here, this means reducing the misclassification of mails. Imagine a machine that takes in the NN current weights and outputs the error for a given training set; this machine is called the cost function. Hence, what we want is to minimize the cost function by choosing a better set of weights.

The weights are a set of numbers. Let’s first imagine there is only one. At any time we can increase or decrease this weight and this will affect the error on the training set. Hence, there exists for this weight at least one value that leads to the lowest error (marked as a green X on the graph below). In practice however, we do not have one weight but often millions but the idea is the same : there is an error curve and a point at which it is lowest and the goal is to find it.

While in theory we can draw the cost function, in reality we can only know the error for tested weight configurations. This means we do not know what the cost function looks like. Hence, we do not know right away what the best weight configuration is. Thus, we need training protocols to look for weight configuration that leads to good performance.

# Training protocols

A naive approach is to test many weight configurations randomly chosen and pick the best but this is painfully slow and inefficient. A better approach yet is gradient descent. I won’t get into the details but here the magic of mathematics gives us a nice tool : gradient. Gradients can give us the direction of local descent for a given function. Thus, there is no need to look in a few directions. We now have a nice way to reduce the error by following this gradient, it’s fast to compute and much more efficient. If we follow the gradient, we are guaranteed to arrive at a local minima of the cost function.