Menu
Avatar
The menu of my blog
Quick Stats
Quests
30 Quests
Messages
2 Messages
Playback
5 Playback
Items
6 Items
Skills
2 Skills
Trace
1 Trace
Message

The Sword Art Online Utilities Project

Welcome, traveler. This is a personal blog built in the style of the legendary SAO game interface. Navigate through the menu to explore the journal, skills, and item logs.

© 2020-2026 Nagi-ovo | RSS | Breezing
← Back to Quest Log
用 Rust 实现简单 LLM 推理
用 Rust 实现简单 LLM 推理

在 B 站偶然刷到清华大学主办的大模型与人工智能系统训练营,果断报名参加。计划利用春节返乡时间通过实践巩固 LLM Inference 的理论知识,恰逢学校 VPN 故障无法科研,正好整理学习笔记。

2025年2月7日 40 min read
LLMRustmlsys

Human-Crafted

Written directly by the author with no AI-generated sections.

用 Rust 实现简单 LLM 推理

在 B 站偶然刷到清华大学主办的大模型与人工智能系统训练营,果断报名参加。计划利用春节返乡时间通过实践巩固 LLM Inference 的理论知识,恰逢学校 VPN 故障无法科研,正好整理学习笔记。

关于 Rust 语言,大三时曾两度尝试入门(某圣经教材劝退警告),这次改变策略采用rustlings+官方文档双轨学习,终于突破入门难关(但也仅限于次)。

Llama 架构解析

从核心组件拆解开始,重温经典架构设计:

Llama 架构图

层归一化(RMS Norm)

Llama 采用 Pre-Norm 架构,即在每层输入前执行归一化操作。相较于传统 Layer Norm:

x^i=xi−μσ2+ϵ\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}x^i​=σ2+ϵ​xi​−μ​

RMSNorm 通过移除均值中心化实现计算优化:

aˉi=aiRMS(a)gi,where  RMS(a)=1n∑i=1nai2+ϵ\bar{a}_i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i,\quad \text{where}~~ \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} a_i^2 + \epsilon}aˉi​=RMS(a)ai​​gi​,where  RMS(a)=n1​i=1∑n​ai2​+ϵ​

其中gig_igi​为可学习缩放参数 gamma,ϵ\epsilonϵ 防止除零错误。

注:所有算子初步仅需实现 FP32 支持,后面可以用 generics 和 macro 来支持其它格式推理。

pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon: f32) {
    let shape = x.shape();
    assert!(shape == y.shape());
    let feature_dim = shape[shape.len() - 1];  // d in mathematical notation
    
    // Calculate number of feature vectors to process
    let num_features = shape[0..shape.len()-1].iter().product();
    
    let _y = unsafe { y.data_mut() };
    let _x = x.data();
    let _w = w.data();
 
    // Process each feature vector independently
    for i in 0..num_features {
        let offset = i * feature_dim;
        
        // 1. Calculate Σ(x_i²)
        let sum_squares: f32 = (0..feature_dim)
            .map(|j| {
                let x_ij = _x[offset + j];
                x_ij * x_ij
            })
            .sum();
            
        // 2. Calculate RMS(x) = sqrt(1/d * Σ(x_i²) + ε)
        let rms = f32::sqrt(sum_squares / feature_dim as f32 + epsilon);
        
        // 3. Apply normalization and scaling: y_i = (w_i * x_i) / RMS(x)
        for j in 0..feature_dim {
            let idx = offset + j;
            _y[idx] = (_w[j] * _x[idx]) / rms;
        }
    }
}

旋转位置编码 RoPE

  • 绝对位置编码:如通过影响 max_length 的可学嵌入向量或三角函数构建,直接加到词向量中,每个位置编码基本上是相互独立的;
  • 相对位置编码:学习一对 token 之间的距离,需要修改注意力计算(如 T5 模型用 bias 嵌入矩阵表示相对距离,加到 self-attention 的 Q、K 矩阵中),可以拓展到任意长度的序列,推理速度较慢,难利用 KV-Cache;
  • Rotary Positional Embedding(RoPE)兼具有绝对位置编码和相对位置编码的优点,二维情况的公式如下:
f{q,k}(xm,m)⏟RoPE 输出=(cos⁡mθ−sin⁡mθsin⁡mθcos⁡mθ)⏟旋转矩阵 R(mθ)(W{q,k}(11)W{q,k}(12)W{q,k}(21)W{q,k}(22))⏟线性变换矩阵 W{q,k}(xm(1)xm(2))⏟输入向量 xm\underbrace{f_{\{q,k\}}(\mathbf{x}_m, m)}_{\text{RoPE 输出}} = \underbrace{ \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} }_{\text{旋转矩阵 } R(m\theta)} \underbrace{ \begin{pmatrix} W_{\{q,k\}}^{(11)} & W_{\{q,k\}}^{(12)} \\ W_{\{q,k\}}^{(21)} & W_{\{q,k\}}^{(22)} \end{pmatrix} }_{\text{线性变换矩阵 } W_{\{q,k\}}} \underbrace{ \begin{pmatrix} x_m^{(1)} \\ x_m^{(2)} \end{pmatrix} }_{\text{输入向量 } \mathbf{x}_m}RoPE 输出f{q,k}​(xm​,m)​​=旋转矩阵 R(mθ)(cosmθsinmθ​−sinmθcosmθ​)​​线性变换矩阵 W{q,k}​(W{q,k}(11)​W{q,k}(21)​​W{q,k}(12)​W{q,k}(22)​​)​​输入向量 xm​(xm(1)​xm(2)​​)​​

在旋转矩阵作用前,会先应用线性变换来获得 Query & Key 来保持旋转不变性。然后经过旋转矩阵,其作用是将词向量旋转 mθm\thetamθ 度(mmm为是该 token 在句子中的绝对位置):

RoPE 旋转示意图 1

RoPE 旋转示意图 2

更一般的形式如下,即将向量分割成多组 2D 的块(默认向量维度是偶数),然后分组对每对维度应用不同旋转角度:

f{q,k}(xm,m)=RΘ,mdW{q,k}xmf_{\{q,k\}}(\mathbf{x}_m, m) = \mathbf{R}_{\Theta,m}^d \mathbf{W}_{\{q,k\}} \mathbf{x}_mf{q,k}​(xm​,m)=RΘ,md​W{q,k}​xm​
RΘ,md=(cos⁡mθ1−sin⁡mθ100⋯00sin⁡mθ1cos⁡mθ100⋯0000cos⁡mθ2−sin⁡mθ2⋯0000sin⁡mθ2cos⁡mθ2⋯00⋮⋮⋮⋮⋱⋮⋮0000⋯cos⁡mθd/2−sin⁡mθd/20000⋯sin⁡mθd/2cos⁡mθd/2)\mathbf{R}_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{pmatrix}RΘ,md​=​cosmθ1​sinmθ1​00⋮00​−sinmθ1​cosmθ1​00⋮00​00cosmθ2​sinmθ2​⋮00​00−sinmθ2​cosmθ2​⋮00​⋯⋯⋯⋯⋱⋯⋯​0000⋮cosmθd/2​sinmθd/2​​0000⋮−sinmθd/2​cosmθd/2​​​

优化计算效率,上述计算等价于:

RΘ,mdx=(x1x2x3x4⋮xd−1xd)⊗(cos⁡mθ1cos⁡mθ1cos⁡mθ2cos⁡mθ2⋮cos⁡mθd/2cos⁡mθd/2)+(−x2x1−x4x3⋮−xd−1xd)⊗(sin⁡mθ1sin⁡mθ1sin⁡mθ2sin⁡mθ2⋮sin⁡mθd/2sin⁡mθd/2)\mathbf{R}_{\Theta,m}^d \mathbf{x} = \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_1 \\ \cos m\theta_1 \\ \cos m\theta_2 \\ \cos m\theta_2 \\ \vdots \\ \cos m\theta_{d/2} \\ \cos m\theta_{d/2} \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_1 \\ \sin m\theta_1 \\ \sin m\theta_2 \\ \sin m\theta_2 \\ \vdots \\ \sin m\theta_{d/2} \\ \sin m\theta_{d/2} \end{pmatrix}RΘ,md​x=​x1​x2​x3​x4​⋮xd−1​xd​​​⊗​cosmθ1​cosmθ1​cosmθ2​cosmθ2​⋮cosmθd/2​cosmθd/2​​​+​−x2​x1​−x4​x3​⋮−xd−1​xd​​​⊗​sinmθ1​sinmθ1​sinmθ2​sinmθ2​⋮sinmθd/2​sinmθd/2​​​

RoPE 优化计算

这里的实现是仓库中已经给出的:

// RoPE: Rotary Positional Embedding
pub fn rope(y: &mut Tensor<f32>, start_pos: usize, theta: f32) {
    let shape = y.shape();
    assert!(shape.len() == 3);
    let seq_len = shape[0];
    let n_heads = shape[1];
    let d = shape[2];
    let data = unsafe { y.data_mut() };
    for tok in 0..seq_len {
        let pos = start_pos + tok;
        for head in 0..n_heads {
            // 遍历维度对
            for i in 0..d / 2 {
                let a = data[tok * n_heads * d + head * d + i];
                let b = data[tok * n_heads * d + head * d + i + d / 2];
                let freq = pos as f32 / theta.powf((i * 2) as f32 / d as f32);
                let (sin, cos) = freq.sin_cos();
                // 应用旋转矩阵
                data[tok * n_heads * d + head * d + i] = a * cos - b * sin;
                data[tok * n_heads * d + head * d + i + d / 2] = b * cos + a * sin;
            }
        }
    }
}

FFN(MLP)

MLP 结构图

MLP 结构(src/model.rs):

fn mlp(
    residual: &mut Tensor<f32>,
    hidden_states: &mut Tensor<f32>,
    gate: &mut Tensor<f32>,
    up: &mut Tensor<f32>,
    w_up: &Tensor<f32>,
    w_down: &Tensor<f32>,
    w_gate: &Tensor<f32>,
    rms_w: &Tensor<f32>,
    eps: f32,
) {
    // 1. hidden = rms_norm(residual)
    OP::rms_norm(hidden_states, residual, rms_w, eps);
 
    // 2. gate = hidden @ gate_weight.T
    OP::matmul_transb(gate, 0., hidden_states, w_gate, 1.0);
 
    // 3. up = hidden @ up_weight.T
    OP::matmul_transb(up, 0., hidden_states, w_up, 1.0);
 
    // 4. act = gate * sigmoid(gate) * up (SwiGLU)
    OP::swiglu(up, gate);
 
    // 5. residual = residual + up @ down_weight.T
    OP::matmul_transb(residual, 1.0, up, w_down, 1.0);
}

算子

整个流程中涉及的基本计算操作如图:

算子概览

MatMul (scr/operators.rs):

// C = beta * C + alpha * A @ B^T
// hint: You don't need to do an explicit transpose of B
pub fn matmul_transb(c: &mut Tensor<f32>, beta: f32, a: &Tensor<f32>, b: &Tensor<f32>, alpha: f32) {
    // C_xy = beta * C_xy + alpha * Σ(A_xk * B_yk)  k=1..inner_dim
    let shape_a = a.shape();
    let shape_b = b.shape();
    let (a_rows, b_rows) = (shape_a[0], shape_b[0]);
    let inner = shape_a[1];
 
    let _c = unsafe { c.data_mut() };
    let _a = a.data();
    let _b = b.data();
 
    // 1. using slice to scale C
    _c.iter_mut().for_each(|val| *val *= beta);
 
    // 2. using slice to matmul
    for x in 0..a_rows {
        // get current row of A
        let a_row = &_a[x * inner..(x + 1) * inner];
        
        for y in 0..b_rows {
            // get current row of B (equivalent to column of B^T)
            let b_row = &_b[y * inner..(y + 1) * inner];
            
            let sum: f32 = a_row.iter()
                .zip(b_row.iter())
                .map(|(&a, &b)| a * b)
                .sum();
 
            _c[x * b_rows + y] += alpha * sum;
        }
    }
    
}

SwiGLU (scr/operators.rs)

SiLU(x)∗gateSiLU(x) * gateSiLU(x)∗gate,其中 SiLU(x)=x∗sigmoid(x)SiLU(x) = x * sigmoid(x)SiLU(x)=x∗sigmoid(x)

// hint: this is an element-wise operation
pub fn swiglu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
    let len = y.size();
    assert!(len == x.size());
 
    let _y = unsafe { y.data_mut() };
    let _x = x.data();
 
    for i in 0..len {
        _y[i] = _y[i] * _x[i] / (1. + f32::exp(-_x[i]));
    }
}

从 safetensor 中读取

(src/params.rs)

use crate::config::LlamaConfigJson;
use crate::tensor::Tensor;
use num_traits::Num;
use safetensors::SafeTensors;
 
pub struct LLamaParams<T: Num> {
    // token_id to embedding lookup table
    pub embedding_table: Tensor<T>, // (vocab_size, dim)
    // decoder layer
    pub rms_att_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub wq: Vec<Tensor<T>>,        // (n_heads * head_size, hidden_size) x layers
    pub wk: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wv: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wo: Vec<Tensor<T>>,        // (hidden_size, n_heads * head_size) x layers
    // ffn layer
    pub rms_ffn_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub w_up: Vec<Tensor<T>>,      // (intermediate_size, hidden_size) x layers
    pub w_gate: Vec<Tensor<T>>,    // (intermediate_size, hidden_size) x layers
    pub w_down: Vec<Tensor<T>>,    // (hidden_size, intermediate_size) x layers
    // output
    pub rms_out_w: Tensor<T>, // (hidden_size, )
    pub lm_head: Tensor<T>,   // (vocab_size, dim)
}
 
trait GetTensorFromSafeTensors<P: Num> {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<P>, &'static str>;
}
 
impl GetTensorFromSafeTensors<f32> for f32 {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<f32>, &'static str> {
        let tensor_view = tensors.tensor(name).map_err(|e| {
            assert!(matches!(e, safetensors::SafeTensorError::TensorNotFound(_)));
            "Tensor not found"
        })?;
        
        let data = match tensor_view.dtype() {
            safetensors::Dtype::F32 => tensor_view.data()
                .chunks_exact(4)
                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
                .collect(),
            _ => return Err("Unsupported data type"),
        };
        
        Ok(Tensor::new(data, &tensor_view.shape().to_vec()))
    }
}
 
macro_rules! get_tensor_vec {
    ($tensors:expr, $pattern:literal, $layers:expr) => {{
        (0..$layers)
            .map(|i| {
                let name = format!($pattern, i);
                f32::get_tensor_from($tensors, &name).unwrap()
            })
            .collect::<Vec<_>>()
    }};
}
 
impl LLamaParams<f32> {
    pub fn from_safetensors(safetensor: &SafeTensors, config: &LlamaConfigJson) -> Self {
        let embedding_table = if config.tie_word_embeddings {
            f32::get_tensor_from(safetensor, "lm_head.weight").unwrap()
        } else {
            f32::get_tensor_from(safetensor, "model.embed_tokens.weight").unwrap()
        };
 
        let n_layers = config.num_hidden_layers;
        
        Self {
            embedding_table,
            rms_att_w: get_tensor_vec!(safetensor, "model.layers.{}.input_layernorm.weight", n_layers),
            wq: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.q_proj.weight", n_layers),
            wk: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.k_proj.weight", n_layers),
            wv: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.v_proj.weight", n_layers),
            wo: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.o_proj.weight", n_layers),
            rms_ffn_w: get_tensor_vec!(safetensor, "model.layers.{}.post_attention_layernorm.weight", n_layers),
            w_up: get_tensor_vec!(safetensor, "model.layers.{}.mlp.up_proj.weight", n_layers),
            w_gate: get_tensor_vec!(safetensor, "model.layers.{}.mlp.gate_proj.weight", n_layers),
            w_down: get_tensor_vec!(safetensor, "model.layers.{}.mlp.down_proj.weight", n_layers),
            rms_out_w: f32::get_tensor_from(safetensor, "model.norm.weight").unwrap(),
            lm_head: f32::get_tensor_from(safetensor, "lm_head.weight").unwrap(),
        }
    }
}

KV-Cache 机制

fp16(参数占 2 字节) 推理时 KV 缓存占用计算:

2×2×n_layer×seq_len×n_head×head_dim×batch2\times 2 \times n\_layer\times seq\_{len}\times n\_head\times head\_dim \times batch2×2×n_layer×seq_len×n_head×head_dim×batch

随着序列长度增加,性能呈指数级下降(注意最新 token 生成的增量计算):

KV-Cache 性能

对应参数use_cache: true与max_position_embeddings: 4096,后者决定模型的最大上下文窗口。

推理优化方案

分组查询注意力(GQA)

通过多组 Q 共享单组 KV 实现内存优化(需满足头数整除关系):

Batching 技术、会话管理、对话回滚等暂时没有精力实现,后面仅作科普总结。

混合精度推理

常用的有 FP16、BF16,前者精度和 TF32 一样,后者表示范围和 FP32 一样,精度则低一些,TF32 兼具 FP32 的表示范围和 FP16 的精度,不过需要有 tensor core 硬件支持。

混合精度对比

注意,下面的精度敏感操作需保留 FP32:

  1. RoPE 的三角函数计算
  2. RMS Norm 的指数运算
  3. Softmax 数值稳定性

量化推理

通过数据类型映射实现:

  • 计算加速(需要硬件支持低精度指令)
  • 显存优化(fp16 仅需 fp32 的 50% 空间,甚至是 int8、int4…)

量化推理原理

进阶优化方向

  1. MoE 架构:动态路由的 MLP 优化
  2. 分布式推理:
    • 数据并行:参数全拷贝,提升吞吐
    • 流水线并行:层间切分,通信开销敏感
    • 张量并行:参数分布式存储,需 AllGather/AllReduce 同步

分布式推理架构

总结

本次训练营以硬核实践打通 LLM 推理优化的技术脉络,很值得后面有时间的时候继续推进实现,最近事情超多,这里就先草草收个尾 🥲

Article Info Human-Crafted
Title 用 Rust 实现简单 LLM 推理
Author Nagi-ovo
URL
Last Updated No edits yet
Citation

商业转载请联系站长获得授权;非商业转载请注明出处并附上本文链接。

你可以复制、分发并改编本文,但衍生作品需采用相同许可协议。本文采用 CC BY-NC-SA 4.0 授权。

Session 00:00:00