KV cache and Grouped Query Attention

KV Cache KV cache visual operation In the note blow, I first describe how inferencing is done if we simply do operation without KV cache and then describe how KV cache helps removing redundant operations. We don’t make use of KV cache while training because we already have data filled for each sequence length, we don’t need to calculate loss one by one, instead we do it in batches, whereas while inferencing we do it generally for 1 batch with some sequences and then we keep on appending next-predicted token to that sequence one by one....

January 18, 2025 · 11 min · CohleM

RMSNorm

Recap of LayerNorm let’s first recap by understanding why LayerNorm was used: We needed to balance the distribution of inputs (internal covariance shift) i.e we want inputs to be roughly gaussian (mean 0, std 1), it not maintained it would result in zeroing out the gradients. output of some blocks (transformer block) may produce large values or very small values that would result in either exploding or vanishing gradient problem, in order to have stable training, we needed to have stable range for those outputs....

January 15, 2025 · 1 min · CohleM

RoPE

Recap of Absolute PE We previously used absolute positional embedding in our GPT-2 model. Disadvantages No notion of relative information between tokens doesn’t work for sequences larger than context length the model is trained with, because we run out of token embeddings for tokens that come at sequence larger than the context length. RoPE pre-requisites This is how we rotate a point by an angel theta in a two dimensional space and this is all we need in RoPE....

January 15, 2025 · 7 min · CohleM

GPUs

GPU physcial structure let’s first understand the structure of GPU. Inside a GPU it has a chip named GA102 (depends on architecture, this is for ampere architecture) built from 28.3million transistors (semiconductor device that can switch or amplify electrical signals) and majority area covered by processing cores. processing core is divide into seven Graphics processing clusters (GPCs) among each GPC there are 12 Streaming Multiprocessors. Inside each SM there are 4 warps and 1 Raytracing core inside a warp there are 32 Cudas and 1 Tensor Core....

January 8, 2025 · 6 min · CohleM

DDP and gradient sync

When we have enough resources we would want to train our neural networks in parallel, the way to do this is to train our NN with different data (different batches of data) in each GPU in parallel. For instance, if we have 8X A100 we run 8 different batches of data on each A100 GPU. The way to do this in pytorch is to use DDP (take a look into their docs)...

January 3, 2025 · 6 min · CohleM