Menu
Avatar
The menu of my blog
Quick Stats
Quests
30 Quests
Messages
2 Messages
Playback
5 Playback
Items
6 Items
Skills
2 Skills
Trace
1 Trace
Message

The Sword Art Online Utilities Project

Welcome, traveler. This is a personal blog built in the style of the legendary SAO game interface. Navigate through the menu to explore the journal, skills, and item logs.

© 2020-2026 Nagi-ovo | RSS | Breezing
Implementing Simple LLM Inference in Rust
Implementing Simple LLM Inference in Rust

I stumbled upon the 'Large Model and AI System Training Camp' hosted by Tsinghua University on Bilibili and signed up immediately. I planned to use the Spring Festival holiday to consolidate my theoretical knowledge of LLM Inference through practice. Coincidentally, the school VPN was down, preventing me from doing research, so it was the perfect time to organize my study notes.

Feb 7, 2025 40 min read
LLMRustmlsys

Human-Crafted

Written directly by the author with no AI-generated sections.

Implementing Simple LLM Inference in Rust

I stumbled upon the Large Model and AI System Training Camp hosted by Tsinghua University on Bilibili and signed up immediately. I planned to use the Spring Festival holiday to consolidate my theoretical knowledge of LLM Inference through practice. Coincidentally, the school VPN was down, preventing me from doing research, so it was the perfect time to organize my study notes.

Regarding the Rust language, I tried to get started twice during my junior year (warned off by some “bible” textbook). This time, I changed my strategy and adopted a dual-track learning approach using rustlings + official documentation, and finally broke through the entry barrier (though only just).

Llama Architecture Analysis

Starting with the core component breakdown, revisiting the classic architecture design:

Llama Architecture

RMS Norm

Llama uses a Pre-Norm architecture, performing normalization before the input of each layer. Compared to traditional Layer Norm:

x^i=xi−μσ2+ϵ\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}x^i​=σ2+ϵ​xi​−μ​

RMSNorm achieves computational optimization by removing mean centering:

aˉi=aiRMS(a)gi,where  RMS(a)=1n∑i=1nai2+ϵ\bar{a}_i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i,\quad \text{where}~~ \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} a_i^2 + \epsilon}aˉi​=RMS(a)ai​​gi​,where  RMS(a)=n1​i=1∑n​ai2​+ϵ​

where gig_igi​ is the learnable scaling parameter gamma, and ϵ\epsilonϵ prevents division by zero errors.

Note: All operators initially only need to implement FP32 support. Later, generics and macros can be used to support inference in other formats.

pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon: f32) {
    let shape = x.shape();
    assert!(shape == y.shape());
    let feature_dim = shape[shape.len() - 1];  // d in mathematical notation
    
    // Calculate number of feature vectors to process
    let num_features = shape[0..shape.len()-1].iter().product();
    
    let _y = unsafe { y.data_mut() };
    let _x = x.data();
    let _w = w.data();
 
    // Process each feature vector independently
    for i in 0..num_features {
        let offset = i * feature_dim;
        
        // 1. Calculate Σ(x_i²)
        let sum_squares: f32 = (0..feature_dim)
            .map(|j| {
                let x_ij = _x[offset + j];
                x_ij * x_ij
            })
            .sum();
            
        // 2. Calculate RMS(x) = sqrt(1/d * Σ(x_i²) + ε)
        let rms = f32::sqrt(sum_squares / feature_dim as f32 + epsilon);
        
        // 3. Apply normalization and scaling: y_i = (w_i * x_i) / RMS(x)
        for j in 0..feature_dim {
            let idx = offset + j;
            _y[idx] = (_w[j] * _x[idx]) / rms;
        }
    }
}

Rotary Positional Embedding (RoPE)

  • Absolute Positional Encoding: Typically constructed via learnable embedding vectors affecting max_length or trigonometric functions, added directly to word vectors. Each position encoding is generally independent.
  • Relative Positional Encoding: Learns the distance between a pair of tokens, requiring modifications to attention calculation (e.g., T5 uses a bias embedding matrix to represent relative distance, added to the Q, K matrices of self-attention). It can extend to arbitrary sequence lengths, but inference speed is slower and difficult to utilize KV-Cache.
  • Rotary Positional Embedding (RoPE) combines the advantages of absolute and relative positional encoding. The 2D case formula is as follows:
f{q,k}(xm,m)⏟RoPE Output=(cos⁡mθ−sin⁡mθsin⁡mθcos⁡mθ)⏟Rotation Matrix R(mθ)(W{q,k}(11)W{q,k}(12)W{q,k}(21)W{q,k}(22))⏟Linear Transformation Matrix W{q,k}(xm(1)xm(2))⏟Input Vector xm\underbrace{f_{\{q,k\}}(\mathbf{x}_m, m)}_{\text{RoPE Output}} = \underbrace{ \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} }_{\text{Rotation Matrix } R(m\theta)} \underbrace{ \begin{pmatrix} W_{\{q,k\}}^{(11)} & W_{\{q,k\}}^{(12)} \\ W_{\{q,k\}}^{(21)} & W_{\{q,k\}}^{(22)} \end{pmatrix} }_{\text{Linear Transformation Matrix } W_{\{q,k\}}} \underbrace{ \begin{pmatrix} x_m^{(1)} \\ x_m^{(2)} \end{pmatrix} }_{\text{Input Vector } \mathbf{x}_m}RoPE Outputf{q,k}​(xm​,m)​​=Rotation Matrix R(mθ)(cosmθsinmθ​−sinmθcosmθ​)​​Linear Transformation Matrix W{q,k}​(W{q,k}(11)​W{q,k}(21)​​W{q,k}(12)​W{q,k}(22)​​)​​Input Vector xm​(xm(1)​xm(2)​​)​​

Before the rotation matrix is applied, a linear transformation is applied to obtain Query & Key to maintain rotation invariance. Then, passing through the rotation matrix rotates the word vector by mθm\thetamθ degrees (mmm is the absolute position of the token in the sentence):

RoPE Rotation Diagram 1

RoPE Rotation Diagram 2

A more general form is as follows, splitting the vector into multiple groups of 2D chunks (assuming the vector dimension is even by default), and then applying different rotation angles to each pair of dimensions in groups:

f{q,k}(xm,m)=RΘ,mdW{q,k}xmf_{\{q,k\}}(\mathbf{x}_m, m) = \mathbf{R}_{\Theta,m}^d \mathbf{W}_{\{q,k\}} \mathbf{x}_mf{q,k}​(xm​,m)=RΘ,md​W{q,k}​xm​
RΘ,md=(cos⁡mθ1−sin⁡mθ100⋯00sin⁡mθ1cos⁡mθ100⋯0000cos⁡mθ2−sin⁡mθ2⋯0000sin⁡mθ2cos⁡mθ2⋯00⋮⋮⋮⋮⋱⋮⋮0000⋯cos⁡mθd/2−sin⁡mθd/20000⋯sin⁡mθd/2cos⁡mθd/2)\mathbf{R}_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{pmatrix}RΘ,md​=​cosmθ1​sinmθ1​00⋮00​−sinmθ1​cosmθ1​00⋮00​00cosmθ2​sinmθ2​⋮00​00−sinmθ2​cosmθ2​⋮00​⋯⋯⋯⋯⋱⋯⋯​0000⋮cosmθd/2​sinmθd/2​​0000⋮−sinmθd/2​cosmθd/2​​​

To optimize computational efficiency, the above calculation is equivalent to:

RΘ,mdx=(x1x2x3x4⋮xd−1xd)⊗(cos⁡mθ1cos⁡mθ1cos⁡mθ2cos⁡mθ2⋮cos⁡mθd/2cos⁡mθd/2)+(−x2x1−x4x3⋮−xd−1xd)⊗(sin⁡mθ1sin⁡mθ1sin⁡mθ2sin⁡mθ2⋮sin⁡mθd/2sin⁡mθd/2)\mathbf{R}_{\Theta,m}^d \mathbf{x} = \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_1 \\ \cos m\theta_1 \\ \cos m\theta_2 \\ \cos m\theta_2 \\ \vdots \\ \cos m\theta_{d/2} \\ \cos m\theta_{d/2} \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_1 \\ \sin m\theta_1 \\ \sin m\theta_2 \\ \sin m\theta_2 \\ \vdots \\ \sin m\theta_{d/2} \\ \sin m\theta_{d/2} \end{pmatrix}RΘ,md​x=​x1​x2​x3​x4​⋮xd−1​xd​​​⊗​cosmθ1​cosmθ1​cosmθ2​cosmθ2​⋮cosmθd/2​cosmθd/2​​​+​−x2​x1​−x4​x3​⋮−xd−1​xd​​​⊗​sinmθ1​sinmθ1​sinmθ2​sinmθ2​⋮sinmθd/2​sinmθd/2​​​

RoPE Optimization

The implementation here is given in the repository:

// RoPE: Rotary Positional Embedding
pub fn rope(y: &mut Tensor<f32>, start_pos: usize, theta: f32) {
    let shape = y.shape();
    assert!(shape.len() == 3);
    let seq_len = shape[0];
    let n_heads = shape[1];
    let d = shape[2];
    let data = unsafe { y.data_mut() };
    for tok in 0..seq_len {
        let pos = start_pos + tok;
        for head in 0..n_heads {
            // Traverse dimension pairs
            for i in 0..d / 2 {
                let a = data[tok * n_heads * d + head * d + i];
                let b = data[tok * n_heads * d + head * d + i + d / 2];
                let freq = pos as f32 / theta.powf((i * 2) as f32 / d as f32);
                let (sin, cos) = freq.sin_cos();
                // Apply rotation matrix
                data[tok * n_heads * d + head * d + i] = a * cos - b * sin;
                data[tok * n_heads * d + head * d + i + d / 2] = b * cos + a * sin;
            }
        }
    }
}

FFN (MLP)

MLP Structure Diagram

MLP Structure (src/model.rs):

fn mlp(
    residual: &mut Tensor<f32>,
    hidden_states: &mut Tensor<f32>,
    gate: &mut Tensor<f32>,
    up: &mut Tensor<f32>,
    w_up: &Tensor<f32>,
    w_down: &Tensor<f32>,
    w_gate: &Tensor<f32>,
    rms_w: &Tensor<f32>,
    eps: f32,
) {
    // 1. hidden = rms_norm(residual)
    OP::rms_norm(hidden_states, residual, rms_w, eps);
 
    // 2. gate = hidden @ gate_weight.T
    OP::matmul_transb(gate, 0., hidden_states, w_gate, 1.0);
 
    // 3. up = hidden @ up_weight.T
    OP::matmul_transb(up, 0., hidden_states, w_up, 1.0);
 
    // 4. act = gate * sigmoid(gate) * up (SwiGLU)
    OP::swiglu(up, gate);
 
    // 5. residual = residual + up @ down_weight.T
    OP::matmul_transb(residual, 1.0, up, w_down, 1.0);
}

Operators

The basic computational operations involved in the entire process are shown in the figure:

Operators Overview

MatMul (scr/operators.rs):

// C = beta * C + alpha * A @ B^T
// hint: You don't need to do an explicit transpose of B
pub fn matmul_transb(c: &mut Tensor<f32>, beta: f32, a: &Tensor<f32>, b: &Tensor<f32>, alpha: f32) {
    // C_xy = beta * C_xy + alpha * Σ(A_xk * B_yk)  k=1..inner_dim
    let shape_a = a.shape();
    let shape_b = b.shape();
    let (a_rows, b_rows) = (shape_a[0], shape_b[0]);
    let inner = shape_a[1];
 
    let _c = unsafe { c.data_mut() };
    let _a = a.data();
    let _b = b.data();
 
    // 1. using slice to scale C
    _c.iter_mut().for_each(|val| *val *= beta);
 
    // 2. using slice to matmul
    for x in 0..a_rows {
        // get current row of A
        let a_row = &_a[x * inner..(x + 1) * inner];
        
        for y in 0..b_rows {
            // get current row of B (equivalent to column of B^T)
            let b_row = &_b[y * inner..(y + 1) * inner];
            
            let sum: f32 = a_row.iter()
                .zip(b_row.iter())
                .map(|(&a, &b)| a * b)
                .sum();
 
            _c[x * b_rows + y] += alpha * sum;
        }
    }
    
}

SwiGLU (scr/operators.rs)

SiLU(x)∗gateSiLU(x) * gateSiLU(x)∗gate, where SiLU(x)=x∗sigmoid(x)SiLU(x) = x * sigmoid(x)SiLU(x)=x∗sigmoid(x)

// hint: this is an element-wise operation
pub fn swiglu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
    let len = y.size();
    assert!(len == x.size());
 
    let _y = unsafe { y.data_mut() };
    let _x = x.data();
 
    for i in 0..len {
        _y[i] = _y[i] * _x[i] / (1. + f32::exp(-_x[i]));
    }
}

Reading from safetensor

(src/params.rs)

use crate::config::LlamaConfigJson;
use crate::tensor::Tensor;
use num_traits::Num;
use safetensors::SafeTensors;
 
pub struct LLamaParams<T: Num> {
    // token_id to embedding lookup table
    pub embedding_table: Tensor<T>, // (vocab_size, dim)
    // decoder layer
    pub rms_att_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub wq: Vec<Tensor<T>>,        // (n_heads * head_size, hidden_size) x layers
    pub wk: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wv: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wo: Vec<Tensor<T>>,        // (hidden_size, n_heads * head_size) x layers
    // ffn layer
    pub rms_ffn_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub w_up: Vec<Tensor<T>>,      // (intermediate_size, hidden_size) x layers
    pub w_gate: Vec<Tensor<T>>,    // (intermediate_size, hidden_size) x layers
    pub w_down: Vec<Tensor<T>>,    // (hidden_size, intermediate_size) x layers
    // output
    pub rms_out_w: Tensor<T>, // (hidden_size, )
    pub lm_head: Tensor<T>,   // (vocab_size, dim)
}
 
trait GetTensorFromSafeTensors<P: Num> {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<P>, &'static str>;
}
 
impl GetTensorFromSafeTensors<f32> for f32 {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<f32>, &'static str> {
        let tensor_view = tensors.tensor(name).map_err(|e| {
            assert!(matches!(e, safetensors::SafeTensorError::TensorNotFound(_)));
            "Tensor not found"
        })?;
        
        let data = match tensor_view.dtype() {
            safetensors::Dtype::F32 => tensor_view.data()
                .chunks_exact(4)
                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
                .collect(),
            _ => return Err("Unsupported data type"),
        };
        
        Ok(Tensor::new(data, &tensor_view.shape().to_vec()))
    }
}
 
macro_rules! get_tensor_vec {
    ($tensors:expr, $pattern:literal, $layers:expr) => {{
        (0..$layers)
            .map(|i| {
                let name = format!($pattern, i);
                f32::get_tensor_from($tensors, &name).unwrap()
            })
            .collect::<Vec<_>>()
    }};
}
 
impl LLamaParams<f32> {
    pub fn from_safetensors(safetensor: &SafeTensors, config: &LlamaConfigJson) -> Self {
        let embedding_table = if config.tie_word_embeddings {
            f32::get_tensor_from(safetensor, "lm_head.weight").unwrap()
        } else {
            f32::get_tensor_from(safetensor, "model.embed_tokens.weight").unwrap()
        };
 
        let n_layers = config.num_hidden_layers;
        
        Self {
            embedding_table,
            rms_att_w: get_tensor_vec!(safetensor, "model.layers.{}.input_layernorm.weight", n_layers),
            wq: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.q_proj.weight", n_layers),
            wk: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.k_proj.weight", n_layers),
            wv: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.v_proj.weight", n_layers),
            wo: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.o_proj.weight", n_layers),
            rms_ffn_w: get_tensor_vec!(safetensor, "model.layers.{}.post_attention_layernorm.weight", n_layers),
            w_up: get_tensor_vec!(safetensor, "model.layers.{}.mlp.up_proj.weight", n_layers),
            w_gate: get_tensor_vec!(safetensor, "model.layers.{}.mlp.gate_proj.weight", n_layers),
            w_down: get_tensor_vec!(safetensor, "model.layers.{}.mlp.down_proj.weight", n_layers),
            rms_out_w: f32::get_tensor_from(safetensor, "model.norm.weight").unwrap(),
            lm_head: f32::get_tensor_from(safetensor, "lm_head.weight").unwrap(),
        }
    }
}

KV-Cache Mechanism

KV cache usage calculation during fp16 (parameter takes 2 bytes) inference:

2×2×n_layer×seq_len×n_head×head_dim×batch2\times 2 \times n\_layer\times seq\_{len}\times n\_head\times head\_dim \times batch2×2×n_layer×seq_len×n_head×head_dim×batch

As the sequence length increases, performance declines exponentially (note the incremental calculation for the latest token generation):

KV-Cache Performance

Corresponding parameters use_cache: true and max_position_embeddings: 4096, the latter determining the model’s maximum context window.

Inference Optimization Schemes

Grouped Query Attention (GQA)

Memory optimization achieved by sharing a single group of KV across multiple groups of Q (requires head count divisibility relation):

Batching techniques, session management, dialog rollback, etc., are not implemented due to lack of energy for now, and will only be summarized for educational purposes later.

Mixed Precision Inference

Commonly used are FP16 and BF16. The former has the same precision as TF32, while the latter has the same range as FP32 but lower precision. TF32 combines the range of FP32 and the precision of FP16, but requires tensor core hardware support.

Mixed Precision Comparison

Note that the following precision-sensitive operations typically need to retain FP32:

  1. Trigonometric calculations in RoPE
  2. Exponential operations in RMS Norm
  3. Softmax numerical stability

Quantized Inference

Achieved through data type mapping:

  • Computation acceleration (requires hardware support for low-precision instructions)
  • VRAM optimization (fp16 requires only 50% of fp32 space, or even int8, int4…)

Quantized Inference Principle

Advanced Optimization Directions

  1. MoE Architecture: Dynamic routing MLP optimization
  2. Distributed Inference:
    • Data Parallelism: Full parameter copy, increasing throughput
    • Pipeline Parallelism: Inter-layer splitting, sensitive to communication overhead
    • Tensor Parallelism: Distributed parameter storage, requires AllGather/AllReduce synchronization

Distributed Inference Architecture

Summary

This training camp used hard-core practice to connect the technical context of LLM inference optimization. It is very worthwhile to continue advancing the implementation when I have time later. Recently, there have been too many things to do, so I will just wrap it up hastily here 🥲

Article Info Human-Crafted
Title Implementing Simple LLM Inference in Rust
Author Nagi-ovo
URL
Last Updated No edits yet
Citation

For commercial reuse, contact the site owner for authorization. For non-commercial use, please credit the source and link to this article.

You may copy, distribute, and adapt this work as long as derivatives share the same license. Licensed under CC BY-NC-SA 4.0.

Session 00:00:00