banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

Rustを使用してシンプルなLLM推論を実装する

B 站で偶然、清华大学主催の大模型与人工智能系统训练营を見つけ、即座に申し込みました。春節の帰省期間を利用して、実践を通じて LLM Inference の理論知識を強化する予定です。ちょうど学校の VPN が故障して研究ができないため、学習ノートを整理することにしました。

Rust 言語については、大学 3 年生の時に 2 度入門を試みました(ある聖書教材の警告がありました)。今回は戦略を変えて、rustlings公式文書の二本立てで学習し、ようやく入門の壁を突破しました(ただし、限られた範囲内でですが)。

Llama アーキテクチャ解析#

コアコンポーネントの分解から始め、クラシックなアーキテクチャデザインを再確認します:

Screenshot 2025-01-29 at 13.51.47

層归一化(RMS Norm)#

Llama は Pre-Norm アーキテクチャを採用しており、各層の入力前に正規化操作を実行します。従来の Layer Norm と比較して:

x^i=xiμσ2+ϵ\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}

RMSNorm は平均中心化を除去することで計算を最適化しています:

aˉi=aiRMS(a)gi,where  RMS(a)=1ni=1nai2+ϵ\bar{a}_i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i,\quad \text{where}~~ \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} a_i^2 + \epsilon}

ここでgig_iは学習可能なスケーリングパラメータ gamma、ϵ\epsilonはゼロ除算エラーを防ぎます。

注:すべての演算子は初めに FP32 サポートを実装する必要があり、後で generics と macro を使用して他のフォーマットの推論をサポートできます。

pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon: f32) {
    let shape = x.shape();
    assert!(shape == y.shape());
    let feature_dim = shape[shape.len() - 1];  // 数学的表記のd
    
    // 処理する特徴ベクトルの数を計算
    let num_features = shape[0..shape.len()-1].iter().product();
    
    let _y = unsafe { y.data_mut() };
    let _x = x.data();
    let _w = w.data();

    // 各特徴ベクトルを独立して処理
    for i in 0..num_features {
        let offset = i * feature_dim;
        
        // 1. Σ(x_i²)を計算
        let sum_squares: f32 = (0..feature_dim)
            .map(|j| {
                let x_ij = _x[offset + j];
                x_ij * x_ij
            })
            .sum();
            
        // 2. RMS(x) = sqrt(1/d * Σ(x_i²) + ε)を計算
        let rms = f32::sqrt(sum_squares / feature_dim as f32 + epsilon);
        
        // 3. 正規化とスケーリングを適用:y_i = (w_i * x_i) / RMS(x)
        for j in 0..feature_dim {
            let idx = offset + j;
            _y[idx] = (_w[j] * _x[idx]) / rms;
        }
    }
}

旋转位置编码 RoPE#

  • 絶対位置エンコーディング:max_length に影響を与える学習可能な埋め込みベクトルや三角関数を使用して構築し、単語ベクトルに直接加算します。各位置エンコーディングは基本的に相互に独立しています。
  • 相対位置エンコーディング:トークン間の距離を学習し、注意計算を変更する必要があります(T5 モデルではバイアス埋め込み行列を使用して相対距離を表現し、self-attention の Q、K 行列に加算します)。任意の長さのシーケンスに拡張可能ですが、推論速度は遅く、KV-Cache を利用しにくいです。
  • Rotary Positional Embedding(RoPE)は、絶対位置エンコーディングと相対位置エンコーディングの利点を兼ね備えており、2 次元の場合の公式は以下の通りです:
f{q,k}(xm,m)RoPE出力=(cosmθsinmθsinmθcosmθ)回転行列 R(mθ)(W{q,k}(11)W{q,k}(12)W{q,k}(21)W{q,k}(22))線形変換行列 W{q,k}(xm(1)xm(2))入力ベクトル xm\underbrace{f_{\{q,k\}}(\mathbf{x}_m, m)}_{\text{RoPE出力}} = \underbrace{ \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} }_{\text{回転行列 } R(m\theta)} \underbrace{ \begin{pmatrix} W_{\{q,k\}}^{(11)} & W_{\{q,k\}}^{(12)} \\ W_{\{q,k\}}^{(21)} & W_{\{q,k\}}^{(22)} \end{pmatrix} }_{\text{線形変換行列 } W_{\{q,k\}}} \underbrace{ \begin{pmatrix} x_m^{(1)} \\ x_m^{(2)} \end{pmatrix} }_{\text{入力ベクトル } \mathbf{x}_m}

回転行列の作用の前に、線形変換を適用して Query & Key を取得し、回転不変性を保持します。その後、回転行列を通過し、単語ベクトルをmθm\theta度(mmはそのトークンの文中の絶対位置)回転させます:

image

image

より一般的な形式は、ベクトルを複数の 2D ブロックに分割し(デフォルトのベクトル次元は偶数)、各次元ペアに異なる回転角度を適用します:

f{q,k}(xm,m)=RΘ,mdW{q,k}xmf_{\{q,k\}}(\mathbf{x}_m, m) = \mathbf{R}_{\Theta,m}^d \mathbf{W}_{\{q,k\}} \mathbf{x}_m
RΘ,md=(cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2)\mathbf{R}_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{pmatrix}

計算効率を最適化するため、上記の計算は次のように等価です:

RΘ,mdx=(x1x2x3x4xd1xd)(cosmθ1cosmθ1cosmθ2cosmθ2cosmθd/2cosmθd/2)+(x2x1x4x3xd1xd)(sinmθ1sinmθ1sinmθ2sinmθ2sinmθd/2sinmθd/2)\mathbf{R}_{\Theta,m}^d \mathbf{x} = \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_1 \\ \cos m\theta_1 \\ \cos m\theta_2 \\ \cos m\theta_2 \\ \vdots \\ \cos m\theta_{d/2} \\ \cos m\theta_{d/2} \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_1 \\ \sin m\theta_1 \\ \sin m\theta_2 \\ \sin m\theta_2 \\ \vdots \\ \sin m\theta_{d/2} \\ \sin m\theta_{d/2} \end{pmatrix}

Screenshot 2025-01-29 at 15.29.32

ここでの実装はリポジトリに既に提供されています:

// RoPE: Rotary Positional Embedding
pub fn rope(y: &mut Tensor<f32>, start_pos: usize, theta: f32) {
    let shape = y.shape();
    assert!(shape.len() == 3);
    let seq_len = shape[0];
    let n_heads = shape[1];
    let d = shape[2];
    let data = unsafe { y.data_mut() };
    for tok in 0..seq_len {
        let pos = start_pos + tok;
        for head in 0..n_heads {
            // 次元ペアを遍歴
            for i in 0..d / 2 {
                let a = data[tok * n_heads * d + head * d + i];
                let b = data[tok * n_heads * d + head * d + i + d / 2];
                let freq = pos as f32 / theta.powf((i * 2) as f32 / d as f32);
                let (sin, cos) = freq.sin_cos();
                // 回転行列を適用
                data[tok * n_heads * d + head * d + i] = a * cos - b * sin;
                data[tok * n_heads * d + head * d + i + d / 2] = b * cos + a * sin;
            }
        }
    }
}

FFN(MLP)#

image

MLP 構造(src/model.rs):

fn mlp(
    residual: &mut Tensor<f32>,
    hidden_states: &mut Tensor<f32>,
    gate: &mut Tensor<f32>,
    up: &mut Tensor<f32>,
    w_up: &Tensor<f32>,
    w_down: &Tensor<f32>,
    w_gate: &Tensor<f32>,
    rms_w: &Tensor<f32>,
    eps: f32,
) {
    // 1. hidden = rms_norm(residual)
    OP::rms_norm(hidden_states, residual, rms_w, eps);

    // 2. gate = hidden @ gate_weight.T
    OP::matmul_transb(gate, 0., hidden_states, w_gate, 1.0);

    // 3. up = hidden @ up_weight.T
    OP::matmul_transb(up, 0., hidden_states, w_up, 1.0);

    // 4. act = gate * sigmoid(gate) * up (SwiGLU)
    OP::swiglu(up, gate);

    // 5. residual = residual + up @ down_weight.T
    OP::matmul_transb(residual, 1.0, up, w_down, 1.0);
}

算子#

全体のプロセスで関与する基本計算操作は以下の通りです:

image

MatMul(scr/operators.rs):

// C = beta * C + alpha * A @ B^T
// ヒント:Bの明示的な転置を行う必要はありません
pub fn matmul_transb(c: &mut Tensor<f32>, beta: f32, a: &Tensor<f32>, b: &Tensor<f32>, alpha: f32) {
    // C_xy = beta * C_xy + alpha * Σ(A_xk * B_yk)  k=1..inner_dim
    let shape_a = a.shape();
    let shape_b = b.shape();
    let (a_rows, b_rows) = (shape_a[0], shape_b[0]);
    let inner = shape_a[1];

    let _c = unsafe { c.data_mut() };
    let _a = a.data();
    let _b = b.data();

    // 1. スライスを使用してCをスケーリング
    _c.iter_mut().for_each(|val| *val *= beta);

    // 2. スライスを使用して行列積を計算
    for x in 0..a_rows {
        // Aの現在の行を取得
        let a_row = &_a[x * inner..(x + 1) * inner];
        
        for y in 0..b_rows {
            // Bの現在の行を取得(B^Tの列に相当)
            let b_row = &_b[y * inner..(y + 1) * inner];
            
            let sum: f32 = a_row.iter()
                .zip(b_row.iter())
                .map(|(&a, &b)| a * b)
                .sum();

            _c[x * b_rows + y] += alpha * sum;
        }
    }
    
}

SwiGLU(scr/operators.rs)

SiLU(x)gateSiLU(x) * gate, ここで SiLU(x)=xsigmoid(x)SiLU(x) = x * sigmoid(x)

// ヒント:これは要素ごとの操作です
pub fn swiglu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
    let len = y.size();
    assert!(len == x.size());

    let _y = unsafe { y.data_mut() };
    let _x = x.data();

    for i in 0..len {
        _y[i] = _y[i] * _x[i] / (1. + f32::exp(-_x[i]));
    }
}

safetensor からの読み取り#

(src/params.rs)

use crate::config::LlamaConfigJson;
use crate::tensor::Tensor;
use num_traits::Num;
use safetensors::SafeTensors;

pub struct LLamaParams<T: Num> {
    // token_idから埋め込みルックアップテーブル
    pub embedding_table: Tensor<T>, // (vocab_size, dim)
    // デコーダ層
    pub rms_att_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub wq: Vec<Tensor<T>>,        // (n_heads * head_size, hidden_size) x layers
    pub wk: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wv: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wo: Vec<Tensor<T>>,        // (hidden_size, n_heads * head_size) x layers
    // ffn層
    pub rms_ffn_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub w_up: Vec<Tensor<T>>,      // (intermediate_size, hidden_size) x layers
    pub w_gate: Vec<Tensor<T>>,    // (intermediate_size, hidden_size) x layers
    pub w_down: Vec<Tensor<T>>,    // (hidden_size, intermediate_size) x layers
    // 出力
    pub rms_out_w: Tensor<T>, // (hidden_size, )
    pub lm_head: Tensor<T>,   // (vocab_size, dim)
}

trait GetTensorFromSafeTensors<P: Num> {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<P>, &'static str>;
}

impl GetTensorFromSafeTensors<f32> for f32 {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<f32>, &'static str> {
        let tensor_view = tensors.tensor(name).map_err(|e| {
            assert!(matches!(e, safetensors::SafeTensorError::TensorNotFound(_)));
            "Tensor not found"
        })?;
        
        let data = match tensor_view.dtype() {
            safetensors::Dtype::F32 => tensor_view.data()
                .chunks_exact(4)
                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
                .collect(),
            _ => return Err("Unsupported data type"),
        };
        
        Ok(Tensor::new(data, &tensor_view.shape().to_vec()))
    }
}

macro_rules! get_tensor_vec {
    ($tensors:expr, $pattern:literal, $layers:expr) => {{
        (0..$layers)
            .map(|i| {
                let name = format!($pattern, i);
                f32::get_tensor_from($tensors, &name).unwrap()
            })
            .collect::<Vec<_>>()
    }};
}

impl LLamaParams<f32> {
    pub fn from_safetensors(safetensor: &SafeTensors, config: &LlamaConfigJson) -> Self {
        let embedding_table = if config.tie_word_embeddings {
            f32::get_tensor_from(safetensor, "lm_head.weight").unwrap()
        } else {
            f32::get_tensor_from(safetensor, "model.embed_tokens.weight").unwrap()
        };

        let n_layers = config.num_hidden_layers;
        
        Self {
            embedding_table,
            rms_att_w: get_tensor_vec!(safetensor, "model.layers.{}.input_layernorm.weight", n_layers),
            wq: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.q_proj.weight", n_layers),
            wk: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.k_proj.weight", n_layers),
            wv: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.v_proj.weight", n_layers),
            wo: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.o_proj.weight", n_layers),
            rms_ffn_w: get_tensor_vec!(safetensor, "model.layers.{}.post_attention_layernorm.weight", n_layers),
            w_up: get_tensor_vec!(safetensor, "model.layers.{}.mlp.up_proj.weight", n_layers),
            w_gate: get_tensor_vec!(safetensor, "model.layers.{}.mlp.gate_proj.weight", n_layers),
            w_down: get_tensor_vec!(safetensor, "model.layers.{}.mlp.down_proj.weight", n_layers),
            rms_out_w: f32::get_tensor_from(safetensor, "model.norm.weight").unwrap(),
            lm_head: f32::get_tensor_from(safetensor, "lm_head.weight").unwrap(),
        }
    }
}

KV-Cache メカニズム#

fp16(パラメータは 2 バイト)の推論時 KV キャッシュの計算:

2×2×n_layer×seq_len×n_head×head_dim×batch2\times 2 \times n\_layer\times seq\_{len}\times n\_head\times head\_dim \times batch

シーケンスの長さが増加するにつれて、性能は指数的に低下します(最新のトークン生成の増分計算に注意):

Screenshot 2025-01-29 at 17.08.44

対応するパラメータuse_cache: truemax_position_embeddings: 4096、後者はモデルの最大コンテキストウィンドウを決定します。

推論最適化案#

分组查询注意力(GQA)#

複数の Q を共有する単一の KV を通じてメモリを最適化します(ヘッド数の割り切り関係を満たす必要があります):

fn self_attention(
    hidden_states: &mut Tensor<f32>, // (seq, n_kv_h * n_groups * dqkv)
    att_scores: &mut Tensor<f32>,    // (n_kv_h, n_groups, seq, total_seq)
    q: &Tensor<f32>,                 // (seq, n_kv_h * n_groups * dqkv)
    k: &Tensor<f32>,                 // (total_seq, n_kv_h * dqkv)
    v: &Tensor<f32>,                 // (total_seq, n_kv_h * dqkv)
    n_kv_h: usize,
    n_groups: usize,
    seq_len: usize,
    total_seq_len: usize,
    dqkv: usize,
) {
    assert!(k.shape()[0] >= total_seq_len && v.shape()[0] >= total_seq_len);
    assert!(q.shape()[0] == seq_len && q.shape()[1] == n_kv_h * n_groups && q.shape()[2] == dqkv);
    let _a = unsafe { att_scores.data_mut() };
    let _q = q.data();
    let _k = k.data();
    let _v = v.data();
    let sqrt = (dqkv as f32).sqrt();

    // attn_score = Q @ K^T / sqrt(dim)
    for q in 0..n_kv_h * n_groups {
        for seq in 0..seq_len {
            for t_seq in 0..total_seq_len {
                let mut sum = 0.0;
                for d in 0..dqkv {
                    sum += _q[seq * n_kv_h * n_groups * dqkv + q * dqkv + d]
                        * _k[t_seq * n_kv_h * dqkv + q / n_groups * dqkv + d];
                }
                _a[q * seq_len * total_seq_len + seq * total_seq_len + t_seq] = sum / sqrt;
            }
        }
    }

    // attn = softmax(score)
    OP::masked_softmax(att_scores);

    // x = attn @ V
    let _a = att_scores.data();
    let _h = unsafe { hidden_states.data_mut() };
    for q in 0..n_kv_h * n_groups {
        for seq in 0..seq_len {
            for d in 0..dqkv {
                let mut sum = 0.0;
                for t_seq in 0..total_seq_len {
                    sum += _a[q * seq_len * total_seq_len + seq * total_seq_len + t_seq]
                        * _v[d + q / n_groups * dqkv + t_seq * n_kv_h * dqkv];
                }
                _h[seq * n_kv_h * n_groups * dqkv + q * dqkv + d] = sum;
            }
        }
    }
}

バッチ処理技術、セッション管理、対話の巻き戻しなどは、現在は実装する余裕がないため、後で簡単にまとめます。

混合精度推論#

一般的なものには FP16、BF16 があり、前者は精度が TF32 と同じで、後者は表現範囲が FP32 と同じですが、精度は低くなります。TF32 は FP32 の表現範囲と FP16 の精度を兼ね備えていますが、テンソルコアハードウェアのサポートが必要です。

Screenshot 2025-01-30 at 21.26.35

注意:以下の精度に敏感な操作は FP32 を保持する必要があります:

  1. RoPE の三角関数計算
  2. RMS Norm の指数演算
  3. Softmax の数値安定性

量子化推論#

データ型のマッピングを通じて実現します:

  • 計算の加速(低精度命令のハードウェアサポートが必要)
  • メモリ最適化(fp16 は fp32 の 50% のスペースしか必要とせず、さらには int8、int4 なども可能です)

Screenshot 2025-01-30 at 21.52.25

進化的最適化方向#

  1. MoE アーキテクチャ:動的ルーティングの MLP 最適化
  2. 分散推論
    • データ並列:パラメータを全コピーし、スループットを向上
    • パイプライン並列:層間で分割し、通信オーバーヘッドに敏感
    • テンソル並列:パラメータを分散ストレージし、AllGather/AllReduce の同期が必要

Screenshot 2025-01-30 at 22.47.10

まとめ#

今回のトレーニングキャンプでは、ハードコアな実践を通じて LLM 推論最適化の技術的な流れを理解しました。時間があるときにさらに進めて実装を続ける価値があります。最近は多忙なため、ここで一旦締めくくります 🥲

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。