Calculation of FLOPs
- multiply accumulate cost: 2FLOPS i.e 1 for multiplication and 1 for accumulation (addition)
- if we multiply two matrices with sizes (a x b) and (b x c), the flops involved is b Multiply-add operation per the output size (a x c) i.e 2 x b x (a x c)
Embedding lookup
we initially have tokens with (seq_len,vocab_size) one-hot representation and embedding lookup matrix is (vocab_size, d_model), it will take
FLOPs = 2 x ( vocab_size x (seq_len x d_model))
Attention
Q,K,V projections X @ (Wq or Wk or Wv) i.e 2 x (seq_len x d_model x key_size x num_heads)
attention matrix Q @ K.T i.e 2* (seq_len x seq_len x key_size x num_heads)
softmax
- 1 for exponential calculation (e^x).
- seq_len - 1 sum for each row. so if we divide it per row, its basically 1 FLOPs per elements.
- 1 for division so it becomes, 2 x num_heads x seq_len x seq_len
Softmax @ query reductions 2 × seq_len × seq_len × (key_size × num_heads)
Final Linear 2 × seq_len × (key_size × num_heads) × d_model
Dense Block (per layer) 2×seq_len×(d_model×ffw_size+d_model×ffw_size) (ignoring FLOPs for actions here,)
Final Logits 2×seq_len×d_model×vocab_size
so total FLOPs: embeddings+num_layers×(total_attention+dense_block) + logits
For backward, it takes 2 times the flops taken in backward.
def calculate_transformer_flops(
seq_len: int,
vocab_size: int,
d_model: int,
key_size: int,
num_heads: int,
ffw_size: int,
num_layers: int,
) -> dict:
"""
Calculate FLOPs for each component of a transformer model including forward and backward passes.
Args:
seq_len: Sequence length
vocab_size: Vocabulary size
d_model: Model dimension
key_size: Key dimension
num_heads: Number of attention heads
ffw_size: Feed-forward layer size
num_layers: Number of transformer layers
Returns:
Dictionary containing FLOPs for each component and total forward/backward passes
"""
# Embeddings
embedding_flops = 2 * seq_len * vocab_size * d_model
# Single Attention Layer
key_query_value_proj = 2 * 3 * seq_len * d_model * (key_size * num_heads)
key_query_logits = 2 * seq_len * seq_len * (key_size * num_heads)
softmax_ops = 3 * num_heads * seq_len * seq_len
softmax_query_reduction = 2 * seq_len * seq_len * (key_size * num_heads)
final_linear = 2 * seq_len * (key_size * num_heads) * d_model
total_attention_flops = (
key_query_value_proj
+ key_query_logits
+ softmax_ops
+ softmax_query_reduction
+ final_linear
)
# Single Dense Block
dense_block_flops = 2 * seq_len * (d_model * ffw_size + d_model * ffw_size)
# Final Logits
final_logits_flops = 2 * seq_len * d_model * vocab_size
# Total forward pass
total_forward_pass = (
embedding_flops
+ num_layers * (total_attention_flops + dense_block_flops)
+ final_logits_flops
)
# Backward pass is approximately 2x forward pass
total_backward_pass = 2 * total_forward_pass
# Total forward + backward
total_flops = total_forward_pass + total_backward_pass
return total_flops
# Example usage
params = {
"seq_len": 512,
"vocab_size": 50000,
"d_model": 640,
"key_size": 64,
"num_heads": 10,
"ffw_size": 2560,
"num_layers": 10,
}
flops = calculate_transformer_flops(**params)
print(flops)
So this is flops required for our model per step with one batch.