banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

LoRA in PyTorch

This article is a summary of the learning from GitHub - hkproj/pytorch-lora.

I have used the peft library for LoRA fine-tuning many times before, and I understand the general principles but have never implemented it myself, so the content of this course really resonates with me. ADHD classic, it's uncomfortable not to digest knowledge


Fine-Tuning#

Object: Pre-trained model
Purpose: To learn from a dataset specific to a certain field or task, making it better suited for specific application scenarios
Difficulty: Full parameter fine-tuning computational cost is high, model weights and optimizer state memory requirements are high, checkpoints require large disk storage, and switching between multiple fine-tuned models is inconvenient.

LoRA#

LoRA (Low-Rank Adaptation) is a method of PEFT (Parameter-Efficient Fine-Tuning), which is efficient parameter fine-tuning.

One of the core ideas behind LoRA is that many weights in the original weight matrix W may not be directly related to the specific fine-tuning task during the fine-tuning process. Therefore, LoRA assumes that the weight updates can be approximated by low-rank matrices, meaning that only a small number of parameter adjustments are sufficient to adapt to the new task.

What is rank?#

Just as the RGB primary colors can combine to create most colors, the linearly independent vectors in the column (or row) vectors of a matrix can generate the column (or row) space of that matrix. The primary colors can be seen as the "basis vectors" of the color space, and the rank of the matrix represents the number of basis vectors of its column (or row) space. The higher the rank, the richer the "colors" (vectors) that the matrix can express.

Just as we can approximate a color image using grayscale (reducing the color dimension), low-rank approximation can be used to compress matrix information.

Motivation and Principle#

See the original paper: LoRA: Low-Rank Adaptation of Large Language Models

  1. Low-rank structure of pre-trained models: Pre-trained language models have a lower "intrinsic dimension", meaning they can still learn effectively even when subjected to random projections in a smaller subspace. This indicates that during fine-tuning, it is not necessary to update all parameters completely (not considering bias), and many parameters can actually be expressed through combinations of other parameters, showing that the model has a "rank deficient" characteristic.

  2. Low-rank update assumption: Based on this finding, the authors assume that the weight updates also have low-rank characteristics. During training, the pre-trained weight matrix W0 is frozen, and the update matrix ΔW is represented as the product of two low-rank matrices BA, where B and A are trainable matrices, and the rank r is much smaller than d and k.

  3. Formula derivation: The update of the weight matrix is represented as W0+ΔW_0 + \Delta and is used in forward propagation, where the model's output is h=W0x+BAxh = W_0x + BAx. Here, W0 is frozen and not updated, while A and B participate in gradient updates during backpropagation.

Pasted image 20240929021236

Parameter Count Calculation#

  • The original weight matrix W has d×kd \times k parameters. Here, let d=1000,k=5000d = 1000, k = 5000, so the parameter count is 5,000,0005,000,000.

  • After using LoRA, the additional parameters introduced come from matrices A and B. Their parameter count is: p=(d×r)+(r×k)p = (d \times r) + (r \times k)

    Generally, rr is taken to be a very small value; here we take r=1r = 1, so:

p=(1000×1)+(1×5000)=1000+5000=6000p = (1000 \times 1) + (1 \times 5000) = 1000 + 5000 = 6000

This significantly reduces the parameter count by 99.88%, greatly lowering the computational cost of fine-tuning, storage costs, and the difficulty of switching between models (only needing to reload two low-rank matrices).

SVD#

The basic idea of LoRA mentioned above is to represent the large parameter matrix in the original model by introducing two low-rank matrices. SVD (Singular Value Decomposition) is one of the most commonly used matrix decomposition methods, which can split a matrix into three submatrices:

W=UΣVTW = U \Sigma V^T
import torch
import numpy as np
_ = torch.manual_seed(0)

d, k = 10, 10
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)

W_rank = np.linalg.matrix_rank(W) print(f'Rank of W: {W_rank}')
print(f"{W_rank=}")

By multiplying matrices of size 10×210\times2 and 2×102\times10, we obtain a 10×1010 \times 10 matrix W. Since it is the product of two rank 2 matrices, the final matrix W has a rank of at most 2.

# Perform SVD on W (W = UxSxV^T)
U, S, V = torch.svd(W)

# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # Transpose V_r to get the right dimensions

# Compute B = U_r * S_r and A = V_r
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

torch.svd(W): Performs Singular Value Decomposition (SVD) on matrix W, yielding three matrices U, S, and V such that W=USVTW = U \cdot S \cdot V^T.

  • U: An orthogonal matrix whose columns are the left singular vectors of W, with dimensions d×dd \times d.
  • S: A vector (the diagonal matrix with non-zero singular values), containing the singular values of W, with dimensions dd.
  • V: An orthogonal matrix whose columns are the right singular vectors of W, with dimensions k×kk \times k.

Retaining the first r singular values for low-rank approximation:

U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank]) # Obtain the diagonal matrix of singular values
V_r = V[:, :W_rank].t()

Calculating the low-rank approximation:

B = U_r @ S_r
A = V_r
  • y: The result computed using the original matrix W. The computation of matrix WW multiplied by vector xx has a complexity of O(dk)O(d \cdot k), since each row's computation requires kk multiplications, totaling dd rows, thus the computational complexity is O(dk)O(d \cdot k).
  • y': The result computed using the reconstructed matrix BAB \cdot A after low-rank decomposition.
    1. First compute AxA \cdot x, where AA is an r×kr \times k matrix, and xx is a k×1k \times 1 vector.
      • The computational cost is O(rk)O(r \cdot k).
    2. Then compute B(Ax)B \cdot (A \cdot x), where BB is a drd \cdot r matrix, and the size of AxA \cdot x is r×1r \times 1.
      • The computational cost is O(dr)O(d \cdot r).
# Generate random bias and input
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + bias
y = W @ x + bias

# Compute y' = (B*A)x + bias
y_prime = (B @ A) @ x + bias

# Check if the two results are approximately equal
if torch.allclose(y, y_prime, rtol=1e-05, atol=1e-08):
    print("y and y' are approximately equal.")
else:
    print("y and y' are not equal.")
  • Directly using WW: The complexity of computing WxW \cdot x is O(dk)O(d \cdot k).
  • Using BAB \cdot A: The total complexity of computing (BA)x(B \cdot A) \cdot x is O(rk)+O(dr)O(r \cdot k) + O(d \cdot r), which is O(r(k+d))O(r \cdot (k + d)).

Screenshot 2024-09-29 at 20.49.00

10×1010\times10 vs 2×(10+10)2\times(10+10)

However, LoRA is not strictly SVD; it achieves dynamic adaptation of the weight matrix through training learnable low-rank matrices A and B.

LoRA Fine-Tuning for Classification Tasks#

In the classification task of the MNIST handwritten digit dataset, the recognition effect of a certain digit is poor, and we want to fine-tune it.

To highlight the role of LoRA, we will use an overly complex model that far exceeds the task requirements.

# Create an overly expensive neural network to classify MNIST digits
# Daddy got money, so I don't care about efficiency
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = RichBoyNet().to(device)

We can first observe the current model's parameter count.

Screenshot 2024-09-29 at 21.39.43

Train for one epoch and then save the original weights to prove later that LoRA fine-tuning does not alter the original weights.

train(train_loader, net, epochs=1)

Let's test to see which digit is recognized poorly:

Screenshot 2024-09-29 at 22.09.57

We can choose 9 for fine-tuning.

Define LoRA Parameterization#

Here, the forward function receives the original weights original_weights and returns a new weight matrix that includes the LoRA adaptation term. When the model performs forward propagation, the linear layer will use this new weight matrix.

class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # A initialized as a Gaussian distribution, B initialized to zero, ensuring that at the start of training ∆W = BA is zero
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # In section 4.1 of the paper: scaling factor α/r simplifies hyperparameter tuning, α is set to the r value tried first
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Here we initialize the AA matrix as a normal distribution and the BB matrix as zero, which makes the initial ΔW\Delta W zero. The scaling factor αr\frac{\alpha}{r} helps maintain the stability of the learning rate across different ranks rr.

Apply LoRA Parameterization#

PyTorch provides a parameterization mechanism (see the official documentation for PyTorch Parametrizations) that allows for custom transformations of parameters without changing the original structure of the model. When we parameterize a certain parameter (like weight), PyTorch moves the original parameter to a special location and generates new parameters through the parameterization function.

Here we use the parametrize.register_parametrization function to parameterize the weights of the linear layers, applying LoRA to the linear layers of the model:

import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add parameterization to the weight matrix, ignoring bias
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )
  • The original weights are moved to net.linear1.parametrizations.weight.original.
  • Each time net.linear1.weight is called, it is actually computed through the forward function of the LoRA parameterization.
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

Parameter Count Comparison#

Calculate the changes in model parameters after introducing LoRA:

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Screenshot 2024-09-30 at 01.11.02

As we can see, LoRA only introduces a very small number of parameters (approximately increasing by 0.242%), but it can achieve effective fine-tuning of the model.

Freeze Non-LoRA Parameters#

During fine-tuning, we only want to adjust the parameters introduced by LoRA while keeping the original model weights unchanged. Therefore, we need to freeze all non-LoRA parameters.

# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Screenshot 2024-09-30 at 01.14.18

Select Target Dataset#

Since we want to improve the model's recognition of the digit 9, we will only select samples of the digit 9 from the MNIST dataset for fine-tuning.

# Keep only samples of digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_9_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[digit_9_indices]
mnist_trainset.targets = mnist_trainset.targets[digit_9_indices]

# Create data loader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

Fine-Tune the Model#

We fine-tune the model using only the data of digit 9 while keeping the original weights frozen. To save time, we will only train for 100 batches.

# Fine-tune the model, training only 100 batches
train(train_loader, net, epochs=1, total_iterations_limit=100)

Verify Original Weights Are Unchanged#

Ensure again that the original weights have not changed after fine-tuning.

assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Test Model Performance#

After enabling LoRA, we test the model's performance on the test set, comparing it with the original model:

Screenshot 2024-09-30 at 01.20.30

After enabling LoRA, the model's misrecognition count for the digit 9 significantly decreased from 124 errors when LoRA was disabled to 14 errors. Although the overall accuracy (88.7%) decreased compared to when LoRA was disabled, there was a significant improvement in performance for the specific category (digit 9). Through LoRA fine-tuning, the model focused on improving its recognition ability for the digit 9 without significantly altering performance for other categories.

References#

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.