banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

PyTorchにおけるLoRA

本文是对 GitHub - hkproj/pytorch-lora学習のまとめです。

以前に多くの回 peft ライブラリの LoRA 微調整を使用したことがあり、原理は大体理解していましたが、実際に実装したことはありませんでした。そのため、このコースの内容は私にとって非常に響くものでした。ADHD 经典不消化掉知识就难受


ファインチューニング#

対象:事前学習モデル
目的:基礎の上に特定の分野やタスクのデータセットを学習させ、特定のアプリケーションシーンにより適応させる
難点:全パラメータ微調整計算コストが高く、モデルの重み、オプティマイザの状態メモリ要求が高い、チェックポイントハードディスクストレージ量が大きい、複数の微調整モデルの切り替えが不便

LoRA#

LoRA(Low-Rank Adaptation)は PEFT(Parameter-Efficient Fine-Tuning)の一つの方法で、後者は効率的なパラメータ微調整です。

LoRA の背後にある核心的な考え方の一つは、原始的な重み行列 W の多くの重みが微調整の過程で特定の微調整タスクに直接関連しない可能性があるということです。したがって、LoRA は重みの更新が低ランク行列によって近似できると仮定しており、少量のパラメータ調整で新しいタスクに適応できるとしています。

ランクとは?#

RGB 三原色がほとんどの色を組み合わせることができるように、行列の列(または行)ベクトルの線形独立ベクトルは、その行列の列(または行)空間を生成できます。三原色は色空間の「基底ベクトル」と見なすことができ、行列のランクはその列(または行)空間の基底ベクトルの数を示します。ランクが高いほど、行列が表現できる「色」(ベクトル)は豊かになります。

グレースケールを使用してカラー画像を近似できるように(色の次元を減らす)、低ランク近似は行列情報を圧縮するために使用できます。

動機と原理#

詳細は元の論文を参照してください:LoRA: Low-Rank Adaptation of Large Language Models

  1. 事前学習モデルの低ランク構造:事前学習された言語モデルは **「固有次元」が低い **(Intrinsic Dimension)ため、より小さな部分空間でランダム投影を行っても、効果的に学習できます。これは、モデルが微調整時にすべてのパラメータを完全に更新する必要がなく(バイアスも考慮しない)、多くのパラメータは他のパラメータの組み合わせで表現できることを示しています。モデルは「ランク欠損」の特性を持っています。

  2. 低ランク更新仮定:この発見に基づいて、著者は重みの更新も低ランク特性を持つと仮定しています。訓練過程で、事前学習された重み行列 W0 は固定され、更新行列 ΔW は二つの低ランク行列の積 BA として表されます。ここで BA は訓練可能な行列であり、ランク rdk よりもはるかに小さいです。

  3. 公式の導出:重み行列の更新は W0+ΔW_0 + \Delta として表され、前方伝播に使用され、モデルの出力は h=W0x+BAxh = W_0x + BAx となります。ここで、W0 は固定されて更新されず、AB は逆伝播で勾配更新に参加します。

Pasted image 20240929021236

パラメータ量の計算#

  • 原始重み行列 Wd×kd \times k のパラメータを持っています。ここで d=1000k=5000d = 1000,k = 5000 とすると、パラメータ量は 5,000,0005,000,000 です。

  • LoRA を使用した後、追加されたパラメータは行列 AB から来ます。これらのパラメータ量は:p=(d×r)+(r×k)p = (d \times r) + (r \times k)

    一般に rr は非常に小さい値を取ります。ここでは r=1r = 1 とすると:

p=(1000×1)+(1×5000)=1000+5000=6000p = (1000 \times 1) + (1 \times 5000) = 1000 + 5000 = 6000

このようにして、パラメータ量は 99.88% 大幅に減少し、微調整の計算コスト、ストレージコスト、モデル間の切り替えの難易度を大幅に低下させます(二つの低ランク行列を再読み込みするだけで済みます)。

SVD#

上記で述べたように、LoRA の基本的な考え方は、原始モデルの大規模なパラメータ行列を二つの低ランク行列を導入することによって表現することです。一方、SVD(特異値分解)は最も一般的な行列分解方法の一つであり、行列を三つの部分行列に分解できます:

W=UΣVTW = U \Sigma V^T
import torch
import numpy as np
_ = torch.manual_seed(0)

d, k = 10, 10
W_rank = 2
W = torch.randn(d,W_rank) @ torch.randn(W_rank,k)

W_rank = np.linalg.matrix_rank(W) print(f'Rank of W: {W_rank}')
print(f"{W_rank=}")

行列の乗算 10×210\times22×102\times10 行列を掛け合わせて、10×1010 \times 10 の行列 W を得ます。二つのランクが 2 の行列を掛け合わせているため、最終的な行列 W のランクは最大で 2 です。

# W に対して SVD を実行する (W = UxSxV^T)
U, S, V = torch.svd(W)

# ランク-r の因子分解のため、最初の r の特異値(および U と V の対応する列)だけを保持します
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()  # V_r を転置して正しい次元を得ます

# B = U_r * S_r と A = V_r を計算します
B = U_r @ S_r
A = V_r
print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

torch.svd(W):行列 W に対して ** 特異値分解(SVD)** を行い、三つの行列 USV を得ます。これらは W=USVTW = U \cdot S \cdot V^T を満たします。

  • U:正規直交行列で、その列は W の左特異ベクトルであり、次元は d×dd \times d です。
  • S:ベクトル(対角行列の対角線上の非ゼロの特異値)で、W の特異値を含み、次元は dd です。
  • V:正規直交行列で、その列は W の右特異ベクトルであり、次元は k×kk \times k です。

最初の r 個の特異値を保持して低ランク近似を行います:

U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank]) # 特異値の対角行列を得ます
V_r = V[:, :W_rank].t()

低ランク近似を計算します:

B = U_r @ S_r
A = V_r
  • y:原始行列 W を使用して計算された結果。行列 WW とベクトル xx の掛け算の計算量は O(dk)O(d \cdot k) です。なぜなら、各行の計算には kk 回の乗算が必要で、合計 dd 行あるため、計算の複雑度は O(dk)O(d \cdot k) です。
  • y':低ランク分解後に再構成された行列 BAB \cdot A の計算結果。
    1. まず AxA \cdot x を計算します。ここで AAr×kr \times k 行列で、xxk×1k \times 1 ベクトルです。
      • 計算量は O(rk)O(r \cdot k) です。
    2. 次に B(Ax)B \cdot (A \cdot x) を計算します。ここで BBdrd \cdot r 行列で、AxA \cdot x のサイズは r×1r \times 1 です。
      • 計算量は O(dr)O(d \cdot r) です。
# ランダムなバイアスと入力を生成します
bias = torch.randn(d)
x = torch.randn(d)

# y = Wx + bias を計算します
y = W @ x + bias

# y' = (B*A)x + bias を計算します
y_prime = (B @ A) @ x + bias

# 二つの結果がほぼ等しいか確認します
if torch.allclose(y, y_prime, rtol=1e-05, atol=1e-08):
    print("y と y' はほぼ等しいです。")
else:
    print("y と y' は等しくありません。")
  • 直接 WW を使用する場合WxW \cdot x の計算の複雑度は O(dk)O(d \cdot k) です。
  • BAB \cdot A を使用する場合(BA)x(B \cdot A) \cdot x の総複雑度は O(rk)+O(dr)O(r \cdot k) + O(d \cdot r) で、すなわち O(r(k+d))O(r \cdot (k + d)) です。

Screenshot 2024-09-29 at 20.49.00

10×1010\times10 vs 2×(10+10)2\times(10+10)

ただし、LoRA は厳密な SVD ではなく、訓練可能な低ランク行列 A と B を導入することによって重み行列の動的適応を実現します。

LoRA 分類タスク微調整#

MNIST 手書き数字データセットの分類タスクにおいて、ある数字の認識効果が悪いため、微調整を行いたいと考えています。

LoRA の効果を強調するために、ここではタスクの要求を大幅に超える複雑なモデルを定義します。

# MNIST の数字を分類するために過剰に高価なニューラルネットワークを作成します
# お金持ちのパパがいるので、効率は気にしません
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = RichBoyNet().to(device)

現在のモデルのパラメータ量を観察できます。

Screenshot 2024-09-29 at 21.39.43

1 回のトレーニングを行い、元の重みを保存して、後で LoRA 微調整が元の重みに影響を与えないことを証明します。

train(train_loader, net, epochs=1)

どの数字の認識が悪いかをテストしてみましょう:

Screenshot 2024-09-29 at 22.09.57

後で 9 を選んで微調整します。

LoRA パラメータ化の定義#

ここでの forward 関数は元の重み original_weights を受け取り、LoRA 適応項が追加された新しい重み行列を返します。モデルが前方伝播を行うと、線形層はこの新しい重み行列を使用します。

class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Aをガウス分布で初期化し、Bをゼロで初期化して、訓練開始時に∆W = BAがゼロになるようにします
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # 論文4.1に基づき、スケール因子α/rでハイパーパラメータの調整を簡略化し、αを最初に試すr値に設定します
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # W + (B*A)*scale を返します
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

ここでは、AA 行列を正規分布で初期化し、BB 行列をゼロで初期化しています。これにより、初期の ΔW\Delta W はゼロになります。スケール因子 αr\frac{\alpha}{r} は、異なるランク rr の下で学習率の安定性を保つのに役立ちます。

LoRA パラメータ化の適用#

PyTorch はパラメータ化メカニズムを提供しており(詳細は PyTorch Parametrizations 方法の公式文書 を参照)、モデルの元の構造を変更することなくパラメータにカスタム変換を適用できます。特定のパラメータ(例えば weight)にパラメータ化を行うと、PyTorch は元のパラメータを特別な位置に移動し、パラメータ化関数を通じて新しいパラメータを生成します。

ここでは、parametrize.register_parametrization 関数を使用して線形層の重みをパラメータ化し、LoRA をモデルの線形層に適用します:

import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # バイアスを無視して、重み行列にのみパラメータ化を追加します
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )
  • 元の重みは net.linear1.parametrizations.weight.original に移動されます。
  • net.linear1.weight を呼び出すたびに、実際には LoRA パラメータ化の forward 関数によって計算されたものが得られます。
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

パラメータ量の比較#

LoRA 導入後のモデルパラメータの変化を計算します:

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# 非LoRA パラメータのカウントは元のネットワークと一致する必要があります
assert total_parameters_non_lora == total_parameters_original
print(f'元のパラメータの総数: {total_parameters_non_lora:,}')
print(f'元のパラメータ + LoRA の総数: {total_parameters_lora + total_parameters_non_lora:,}')
print(f'LoRA によって導入されたパラメータ: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'パラメータの増加: {parameters_incremment:.3f}%')

Screenshot 2024-09-30 at 01.11.02

LoRA はわずかに少量のパラメータ(約 0.242% 増加)を導入するだけで、モデルの効果的な微調整を実現できることがわかります。

非 LoRA パラメータの凍結#

微調整の過程で、LoRA が導入したパラメータのみを調整し、元のモデルの重みを変更しないようにしたいと考えています。したがって、すべての非 LoRA パラメータを凍結する必要があります。

# 非LoRA パラメータを凍結します
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'非LoRA パラメータ {name} を凍結します')
        param.requires_grad = False

Screenshot 2024-09-30 at 01.14.18

目標データセットの選択#

数字 9 の認識効果を向上させたいので、MNIST データセットから数字 9 のサンプルのみを選択して微調整を行います。

# 数字 9 のサンプルのみを保持します
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
digit_9_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[digit_9_indices]
mnist_trainset.targets = mnist_trainset.targets[digit_9_indices]

# データローダーを作成します
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

モデルの微調整#

元の重みを凍結した状態で、数字 9 のデータのみを使用してモデルを微調整します。時間を節約するために、100 バッチのみをトレーニングします。

# モデルを微調整します。100 バッチのみをトレーニングします
train(train_loader, net, epochs=1, total_iterations_limit=100)

元の重みが変更されていないことの確認#

再度、微調整後に元の重みが変更されていないことを確認します。

assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# 新しい linear1.weight は LoRA パラメータ化の「forward」関数によって得られます
# 元の重みは net.linear1.parametrizations.weight.original に移動されています
# 詳細はこちら: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# LoRA を無効にすると、linear1.weight は元のものになります
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])
# LoRA を有効にしてテストします
enable_disable_lora(enabled=True)
test()

モデル性能のテスト#

LoRA を有効にした後、テストセットでのモデルの性能を元のモデルと比較します:

Screenshot 2024-09-30 at 01.20.30

LoRA を有効にした後、モデルの数字 9 に対する誤認識回数が大幅に減少し、LoRA を無効にしたときの 124 回の誤りから 14 回に減少しました。全体の正確性(88.7%)は LoRA を無効にしたときに比べて若干低下しましたが、特定のカテゴリ(数字 9)での性能は大幅に改善されました。LoRA の微調整を通じて、モデルは数字 9 の認識能力を向上させることに集中し、他のカテゴリの性能を大幅に変更することなく実現しました。

参考資料#

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。