LoRA
Main idea is to approximate the change in weights dW by the use of low-rank matrices
Eg: Usually the weight update is done by adding the change in weights dW to the original weight matrix W. dW is obtained through backpropagation, ex if W is 512 x 512 the parameter size of dW is 262,144.
In LoRA, we approximate that dW but by breaking down into two low rank matrices B @ A where B = matrix of size 512 x r and A = matrix of size r x 512,
previously if the forward pass was like this
out = X @ W
we change the forward pass: out = X @ W + X @ B@A
we freeze all the other parameters (W in this case), and only find gradients for B,A and update only these weights.
First lets implement toy example that approximates sin function, we implement manual backpropagation so that It’ll be easier to understand what gets updated in our LoRA implementation.
# import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Generate synthetic data (sine wave)
X = np.linspace(0, 2*np.pi, 100).reshape(-1, 1)
y = np.sin(X)
# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
# Define a small MLP
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(1, 20, bias = True), # 1 input → 10 hidden
nn.ReLU(),
nn.Linear(20, 1, bias = False) # 10 hidden → 1 output
)
def forward(self, x):
return self.layers(x)
# Initialize model, loss, and optimizer
model = MLP()
criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Training loop
epochs = 5000
for epoch in range(epochs):
optimizer.zero_grad()
#Instead of this forward pass
# outputs = model(X_tensor)
# Implement this forward to manually do the forward pass
o1 = X_tensor @ model.layers[0].weight.T
o2 = o1 + model.layers[0].bias
o3 = model.layers[1](o2) # relu layer
outputs = o3 @ model.layers[2].weight.T
diff = outputs - y_tensor
squared_diff = diff**2
o_seven = squared_diff.sum(0)
loss = o_seven/len(squared_diff)
#clear all the gradients
for p in model.parameters():
p.requires_grad = True
p.grad = None
for p in [loss, o_seven, squared_diff, diff, outputs, o3, o2, o1, X_tensor]:
if not p.requires_grad:
p.requires_grad = True
p.retain_grad()
loss.backward()
### Manual backpropagation
dL = torch.tensor(1.0)
do_seven = dL*1/(len(squared_diff))
dsquared_diff = do_seven*torch.ones_like(squared_diff)
ddiff = dsquared_diff * 2*diff
doutputs = ddiff*1
do3 = doutputs@ model.layers[2].weight
dl2w = o3.T @ doutputs
mo2 = o2 > 0
do2 = do3 * mo2
do1 = do2
dl0bias = do2.sum(0)
dl0w = X_tensor.T @ do1
# # cmp('dL', dL, loss)
# cmp('do_seven', do_seven, o_seven)
# cmp('dsquared_diff', dsquared_diff, squared_diff)
# cmp('ddiff', ddiff, diff)
# cmp('doutputs', doutputs, outputs)
# cmp('do3', do3, o3)
# cmp('dl2w', dl2w.T, model.layers[2].weight)
# cmp('do2', do2, o2)
# cmp('dl0bias', dl0bias, model.layers[0].bias)
# cmp('do1', do1, o1)
# cmp('dl0w', dl0w.T, model.layers[0].weight )
with torch.no_grad():
lr=0.01
model.layers[0].weight.data -= lr*dl0w.T
model.layers[0].bias.data -= lr*dl0bias
model.layers[2].weight.data -= lr*dl2w.T
# optimizer.step()
if epoch % 100 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# Plot results
with torch.no_grad():
predictions = model(X_tensor).numpy()
plt.scatter(X, y, label='True')
plt.scatter(X, predictions, label='Predicted', color='red')
plt.legend()
plt.show()
Now, we implement LoRA. Here we,
- construct parameters A, B when B@A the resulting matrix size matches layer 0’s weight matrix
- modify forward pass by adding
lora_w
- continue other forward passes as they were previously,
- We don’t find the gradients for weights and don’t update those matrices, which is essentially freezing
- through intermediate gradients find the gradients for B and A and only update those weights.
- that’s it!!!!
# Training loop
epochs = 100000
# lets only train only the lora parameters for layer0's weight
d,k = model.layers[0].weight.T.data.shape
r = 8
B = nn.Parameter(torch.zeros((d,r)))
A = nn.Parameter(torch.randn((r,k)))
scale = 2
for epoch in range(epochs):
# Implement this forward to manually do the forward pass
o1 = X_tensor @ model.layers[0].weight.T
## add lora part for the layer0's model weight here
lora_w = scale*B@A
lora_o1 = X_tensor @ lora_w # size of B@A should match model.layers[0].weight.T i.e (1,20)
h = o1 + lora_o1
o2 = h + model.layers[0].bias
o3 = model.layers[1](o2) # relu layer
outputs = o3 @ model.layers[2].weight.T
diff = outputs - y_tensor
squared_diff = diff**2
o_seven = squared_diff.sum(0)
loss = o_seven/len(squared_diff)
#Freeze all the model parameters
for p in model.parameters():
p.requires_grad = False
p.grad = None
A.grad = None
B.grad = None
for p in [loss, o_seven, squared_diff, diff, outputs, o3, o2, o1, X_tensor, h, lora_o1, lora_w]:
if not p.requires_grad:
p.requires_grad = True
p.retain_grad()
loss.backward()
### Manual backpropagation
dL = torch.tensor(1.0)
do_seven = dL*1/(len(squared_diff))
dsquared_diff = do_seven*torch.ones_like(squared_diff)
ddiff = dsquared_diff * 2*diff
doutputs = ddiff*1
do3 = doutputs@ model.layers[2].weight
# We freeze this weight
# dl2w = o3.T @ doutputs
mo2 = o2 > 0
do2 = do3 * mo2
dh = do2
do1 = dh
dlora_o1 = dh
dlora_w = X_tensor.T @ dlora_o1
dB = scale * dlora_w@A.T
dA = scale * B.T@dlora_w
# And we freeze these weights too
# dl0bias = do2.sum(0)
# dl0w = X_tensor.T @ do1
# cmp('dL', dL, loss)
# cmp('do_seven', do_seven, o_seven)
# cmp('dsquared_diff', dsquared_diff, squared_diff)
# cmp('ddiff', ddiff, diff)
# cmp('doutputs', doutputs, outputs)
# cmp('do3', do3, o3)
# # cmp('dl2w', dl2w.T, model.layers[2].weight)
# cmp('do2', do2, o2)
# cmp('dh', dh, h)
# cmp('do1', do1, o1)
# cmp('dlora_o1', dlora_o1, lora_o1)
# cmp('dlora_w', dlora_w, lora_w)
# cmp('dB', dB, B)
# cmp('dA', dA, A)
# cmp('dl0w', dl0w.T, model.layers[0].weight )
with torch.no_grad():
lr=0.001
A.data -= lr * dA
B.data -= lr * dB
# optimizer.step()
if epoch % 100 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# Plot results
with torch.no_grad():
predictions = model(X_tensor).numpy()
plt.scatter(X, y, label='True')
plt.scatter(X, predictions, label='Predicted', color='red')
plt.legend()
plt.show()