Problem

Consider a simple MLP that takes in combined 3 character embeddings as an input and we predicts a new character.

# A simple MLP
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) 
b1 = torch.randn(n_hidden,                        generator=g) 
W2 = torch.randn((n_hidden, vocab_size),          generator=g) 
b2 = torch.randn(vocab_size,                      generator=g)

# BatchNorm parameters
bngain = torch.ones((1, n_hidden))
bnbias = torch.zeros((1, n_hidden))
bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1, n_hidden))

parameters = [C, W1, W2, b2]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
    p.requires_grad = True

# same optimization as last time
max_steps = 200000
batch_size = 32
lossi = []

for i in range(max_steps):
  
    # minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

    # forward pass
    emb = C[Xb] # embed the characters into vectors
    embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
    # Linear layer
    hpreact = embcat @ W1 + b1 # hidden layer pre-activation
    # Non-linearity
    h = torch.tanh(hpreact) # hidden layer
    logits = h @ W2 + b2 # output layer
    loss = F.cross_entropy(logits, Yb) # loss function

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    lr = 0.1 if i < 100000 else 0.01 # step learning rate decay
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    if i % 10000 == 0: # print every once in a while
        print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())
  

When we train this simple MLP, the output loss over 200,000 iterations

     0/ 200000: 27.8817
  10000/ 200000: 2.5633
  20000/ 200000: 2.6522
  30000/ 200000: 2.8065
  40000/ 200000: 2.1546
  50000/ 200000: 2.7555
  60000/ 200000: 2.4661
  70000/ 200000: 2.0084
  80000/ 200000: 2.3762
  90000/ 200000: 2.2308
 100000/ 200000: 2.0540
 110000/ 200000: 2.3655
 120000/ 200000: 1.8583
 130000/ 200000: 2.4840
 140000/ 200000: 2.4164
 150000/ 200000: 2.1783
 160000/ 200000: 2.0387
 170000/ 200000: 1.8343
 180000/ 200000: 2.1532
 190000/ 200000: 1.9804

We can see the the loss in the first iteration is 27.8817 and loss after that iteration has drastically decreased. There is a significant gap in loss between those two iterations. The problem here is that the initial loss is just too big. We can also prove it. Initially we would want to assign equal probability to all the characters, because we don’t know which character comes next, and so on. The likelihood that a character will appear next in a equally likely scenario is 1/27. So when we calculate our negative log likelihood (loss function) we get.

- torch.tensor(1/27.0).log()
>> tensor(3.2958)

which should be the approximate loss initially, but in our case we have loss of 27.8817, which means our NN is wasting computation just because greater loss in the initially.

Why is our loss too big initially?

To find out, let’s look at our weights that shape our logits, which is just before calculating our loss.

logits = h @ W2 + b2 # output layer

let’s take a look at the distribution of our weights, at this point (just before calculating loss).

plt.hist(W2.flatten().detach(), bins= 50)
plt.show()

fig1

as you can see the weights are distrubuted from -3 to 3 which is causing the problem, because we want the probability to be around 0, not largely distributed like it is right now.

let’s initialize the weight2 around 0 and see how our loss improves.

W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.01

the distribution becomes fig2

Now most of the values are around 0, and let see our loss.

    0/ 200000: 3.8073
  10000/ 200000: 2.1428
  20000/ 200000: 2.4846
  30000/ 200000: 2.6018
  40000/ 200000: 2.0154
  50000/ 200000: 2.4055
  60000/ 200000: 2.3731
  70000/ 200000: 2.1023
  80000/ 200000: 2.2878
  90000/ 200000: 2.0695
 100000/ 200000: 1.8733
 110000/ 200000: 2.2128
 120000/ 200000: 1.8982
 130000/ 200000: 2.3203
 140000/ 200000: 2.2108
 150000/ 200000: 2.1378
 160000/ 200000: 1.8270
 170000/ 200000: 1.7928
 180000/ 200000: 1.9601
 190000/ 200000: 1.8350

you can see how our initial loss improves, this is because now our weights are normally distributed around 0, and not distributed around extreme values i.e (-3 and 3) which caused our initial loss to explode.

Similarly,

let’s look at the output of our tanh activation. fig3

as you can see most of our values lie in -1 and -1, why is that ???

as you might remember our tanh works like this, if the x values lie near 0, we get some expressive non linear values, but when the x values lie in the extreme values, say abs(x)> 1 or 2, the output values will be squashed and will be between -1 and 1.

fig4

let’s see what our input values are for tanh that is resulting in most values to be -1 and 1. fig5 as you can see the histogram of input values to our tanh function i.e hpreact lie in extreme values (i.e not around 0, but is normally distributed between -15 and 15) which is causing the output of tanh function to be -1 and 1. This behaviour holds true for most of the activation functions i.e if input to the activation function is not around 0 and is more extremely distributed, then it will will squashed( i.e most of them will the at extreme ).

So why having activations -1 and 1 a problem here?

let’s look at how gradient is calculated for tanh function. fig6

as you can see t is the tanh activation, the gradient is dependent on t,

So if, our activations are -1 and 1, you can clearly see self.grad will be 0, and the gradient at this point will stop and not propagate further.

and if most of the activations are -1 and 1, there will be no learning because we will have 0 gradient, so our NN will not learn.

NOTE

  • for a simpler NN like ours, even if we initialize weights that are not every good, it can still still learn, but in much bigger NN the impact can be much worse resulting in no learning at all, if the weights are not properly initialized.

Solution ?

The solution is to initialize our initial weights in such a way that the property of our distribution is maintained. i.e having 0 mean and unit std. fig7

as you can see how the x has 0 mean and unit std, but for y it isn’t the same. y takes on more extreme values which will result in vanishing gradients later on, as shown in the previous steps. so we want to preserve that distribution the same for our y value.

Kaiming Init

A simple multiplication by 0.01 to weights would result is better initialization and would result in good activations. But, how do we get these values (0.001) that we multiply our weights with? So the proper initialization technique can be determined by using Kaiming init

The value with which we can multiply is given by this formula below. fig8

where different activations have different gains, and in place of fan_mode we can add the input dimension of our weight matrix.

For tanh, our gain = 5/3 and fan_in = (n_embd * block_size). so we can multiply our weights in this way.

W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3) / ( (n_embd * block_size)**(0.5))

the precise initialization is not required, we can simply multiply our weight matrices by 1/((n_embd * block_size)xx(0.5)).

This initialization will help in preserving our distribution property (0 mean and unit std)