When we have enough resources we would want to train our neural networks in parallel, the way to do this is to train our NN with different data (different batches of data) in each GPU in parallel. For instance, if we have 8X A100 we run 8 different batches of data on each A100 GPU.

The way to do this in pytorch is to use DDP (take a look into their docs)

The important thing to be careful about is that when we train our NN in different GPUs, each GPU calculates gradient and that gradient is averaged among all the all the gradients calculated from each GPUs and then deposited on each of the gpu and then we do the descent step.

Why do we do this ?

So that our weights become consistent and mathematically equivalent to when we train on 8x the batch size but on same GPU.

At first I was confused how would this be equivalent to the gradients when we trained it on single GPU but with 8x the batch size.

Here’s a simple mathematical formula.

FIGURE1 DDP

But lets look at our manual backprogpagation which will help us understand better.

Forward pass

n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True



batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y


# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
  t.retain_grad()
loss.backward()
loss

Backprop

# Exercise 1: backprop through the whole thing manually,
# backpropagating through exactly all of the variables
# as they are defined in the forward pass above, one by one

# -----------------
# YOUR CODE HERE :)
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0*(1/logprobs.shape[0]) # 1 <=== look here

dprobs = (1/probs)*dlogprobs # 2
dcounts_sum_inv = (dprobs*counts).sum(1, keepdim = True)
dcounts = dprobs * counts_sum_inv
dcounts_sum = -1.0*((counts_sum)**(-2.0))*dcounts_sum_inv
dcounts += torch.ones_like(counts_sum)*dcounts_sum
dnorm_logits = norm_logits.exp()*dcounts
dlogit_maxes = (-1.0*dnorm_logits).sum(1,keepdim=True)
dlogits = (1.0*dnorm_logits)

dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1])*dlogit_maxes
db2 = (dlogits*torch.ones_like(logits)).sum(0)
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
dhpreact = dh*(1-h**(2))
dbnbias = (dhpreact*torch.ones_like(bnraw)).sum(0, keepdim= True)
dbngain = (dhpreact*bnraw*torch.ones_like(bnraw)).sum(0, keepdim=True)
dbnraw = dhpreact*bngain*torch.ones_like(bnraw)
dbnvar_inv = (dbnraw* (torch.ones_like(bndiff) * bndiff)).sum(0, keepdim=True)
dbndiff = (dbnraw* (torch.ones_like(bndiff) * bnvar_inv))
dbnvar = dbnvar_inv* (-0.5)*(((bnvar + 1e-5))**(-1.5))
dbndiff2 = (1.0/(n-1) )*torch.ones_like(bndiff2) * dbnvar
dbndiff += dbndiff2*2*(bndiff)
dhprebn = dbndiff*1.0
dbnmeani = (torch.ones_like(hprebn)*-1.0*dbndiff).sum(0, keepdim = True)
dhprebn += torch.ones_like(hprebn)*(1/n)*dbnmeani
db1 = (torch.ones_like(dhprebn)*dhprebn).sum(0)
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
demb = dembcat.view(emb.shape[0],emb.shape[1],emb.shape[2])
dC = torch.zeros_like(C)
for i in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        dC[Xb[i,j]] += demb[i,j]
#         print(demb[i,j].shape)
# -----------------

cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)

cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

lets look at our last line of code in forward pass

loss = -logprobs[range(n), Yb].mean()

what we do is simply the cross entropy i.e normalize using softmax function and pluck out the log probability from the output token’s index.

and calculate it’s derivative here in second line of code in the backward pass like this.

dlogprobs[range(n), Yb] = -1.0*(1/logprobs.shape[0]) # 1 <=== look here

what we’re doing is simply calculating the average over all the batches and that average is deposited in each element in [range(n), Yb] .

Why do we do the average? dL/dlogprobs = 1 (which comes from dL/dL) x d(dlogprobs[range(n), Yb])/dlogprobs

since the elements in this range [range(n), Yb]) is averaged, each element will get (1/total_number) x itself

so the local derivative of (1/total_number) x iteself w.r.t itself is 1/total_number which will be now deposited into the elements of dlogprobs[range(n), Yb]

let’s stop right here and this matches our first equation in FIGURE1, and if we do this operation for 8 times more batches on single GPU we get the second equation in FIGURE1. If we look at this code, we can see that we are only increasing the n value here.

dlogprobs[range(8*n), Yb] = -1.0*(1/logprobs.shape[0]) # 1 <=== look here

So from this it’s safe to say that training on multiple GPUs and averaging their gradients is same as training on single GPU with 8 times more batches.