banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

Implementing Simple LLM Inference with Rust

I stumbled upon the Large Model and Artificial Intelligence System Training Camp organized by Tsinghua University on Bilibili and decisively signed up. I plan to use the time returning home during the Spring Festival to consolidate my theoretical knowledge of LLM Inference through practice, coinciding with a VPN failure at school that prevents research, making it a perfect opportunity to organize my study notes.

Regarding the Rust language, I attempted to get started twice during my junior year (a certain biblical textbook warned against it), but this time I changed my strategy to a dual-track learning approach using rustlings + official documentation, finally breaking through the entry barrier (though still limited to that).

Llama Architecture Analysis#

Starting from the core component breakdown, revisiting classic architectural design:

Screenshot 2025-01-29 at 13.51.47

Layer Normalization (RMS Norm)#

Llama adopts a Pre-Norm architecture, performing normalization before the input of each layer. Compared to traditional Layer Norm:

x^i=xiμσ2+ϵ\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}

RMSNorm achieves computational optimization by removing mean centering:

aˉi=aiRMS(a)gi,where  RMS(a)=1ni=1nai2+ϵ\bar{a}_i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i,\quad \text{where}~~ \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} a_i^2 + \epsilon}

where gig_i is the learnable scaling parameter gamma, and ϵ\epsilon prevents division by zero errors.

Note: All operators initially only need to implement FP32 support, and later can use generics and macros to support other formats for inference.

pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon: f32) {
    let shape = x.shape();
    assert!(shape == y.shape());
    let feature_dim = shape[shape.len() - 1];  // d in mathematical notation
    
    // Calculate number of feature vectors to process
    let num_features = shape[0..shape.len()-1].iter().product();
    
    let _y = unsafe { y.data_mut() };
    let _x = x.data();
    let _w = w.data();

    // Process each feature vector independently
    for i in 0..num_features {
        let offset = i * feature_dim;
        
        // 1. Calculate Σ(x_i²)
        let sum_squares: f32 = (0..feature_dim)
            .map(|j| {
                let x_ij = _x[offset + j];
                x_ij * x_ij
            })
            .sum();
            
        // 2. Calculate RMS(x) = sqrt(1/d * Σ(x_i²) + ε)
        let rms = f32::sqrt(sum_squares / feature_dim as f32 + epsilon);
        
        // 3. Apply normalization and scaling: y_i = (w_i * x_i) / RMS(x)
        for j in 0..feature_dim {
            let idx = offset + j;
            _y[idx] = (_w[j] * _x[idx]) / rms;
        }
    }
}

Rotary Position Encoding (RoPE)#

  • Absolute position encoding: constructed by learnable embedding vectors or trigonometric functions affecting max_length, directly added to word vectors, with each position encoding being essentially independent;
  • Relative position encoding: learns the distance between a pair of tokens, requiring modifications to attention calculations (e.g., the T5 model uses a bias embedding matrix to represent relative distances, added to the self-attention Q and K matrices), which can be extended to sequences of arbitrary lengths, with slower inference speed, difficult to utilize KV-Cache;
  • Rotary Positional Embedding (RoPE) combines the advantages of both absolute and relative position encodings, with the formula for the two-dimensional case as follows:
f{q,k}(xm,m)RoPE output=(cosmθsinmθsinmθcosmθ)Rotation matrix R(mθ)(W{q,k}(11)W{q,k}(12)W{q,k}(21)W{q,k}(22))Linear transformation matrix W{q,k}(xm(1)xm(2))Input vector xm\underbrace{f_{\{q,k\}}(\mathbf{x}_m, m)}_{\text{RoPE output}} = \underbrace{ \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} }_{\text{Rotation matrix } R(m\theta)} \underbrace{ \begin{pmatrix} W_{\{q,k\}}^{(11)} & W_{\{q,k\}}^{(12)} \\ W_{\{q,k\}}^{(21)} & W_{\{q,k\}}^{(22)} \end{pmatrix} }_{\text{Linear transformation matrix } W_{\{q,k\}}} \underbrace{ \begin{pmatrix} x_m^{(1)} \\ x_m^{(2)} \end{pmatrix} }_{\text{Input vector } \mathbf{x}_m}

Before the action of the rotation matrix, a linear transformation is applied to obtain Query & Key to maintain rotational invariance. Then, after passing through the rotation matrix, its effect is to rotate the word vector by mθm\theta degrees (where mm is the absolute position of the token in the sentence):

image

image

In a more general form, the vector is split into multiple groups of 2D blocks (with the default vector dimension being even), and then a different rotation angle is applied to each pair of dimensions:

f{q,k}(xm,m)=RΘ,mdW{q,k}xmf_{\{q,k\}}(\mathbf{x}_m, m) = \mathbf{R}_{\Theta,m}^d \mathbf{W}_{\{q,k\}} \mathbf{x}_m
RΘ,md=(cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2)\mathbf{R}_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{pmatrix}

To optimize computational efficiency, the above computation is equivalent to:

RΘ,mdx=(x1x2x3x4xd1xd)(cosmθ1cosmθ1cosmθ2cosmθ2cosmθd/2cosmθd/2)+(x2x1x4x3xd1xd)(sinmθ1sinmθ1sinmθ2sinmθ2sinmθd/2sinmθd/2)\mathbf{R}_{\Theta,m}^d \mathbf{x} = \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_1 \\ \cos m\theta_1 \\ \cos m\theta_2 \\ \cos m\theta_2 \\ \vdots \\ \cos m\theta_{d/2} \\ \cos m\theta_{d/2} \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_1 \\ \sin m\theta_1 \\ \sin m\theta_2 \\ \sin m\theta_2 \\ \vdots \\ \sin m\theta_{d/2} \\ \sin m\theta_{d/2} \end{pmatrix}

Screenshot 2025-01-29 at 15.29.32

The implementation here is already provided in the repository:

// RoPE: Rotary Positional Embedding
pub fn rope(y: &mut Tensor<f32>, start_pos: usize, theta: f32) {
    let shape = y.shape();
    assert!(shape.len() == 3);
    let seq_len = shape[0];
    let n_heads = shape[1];
    let d = shape[2];
    let data = unsafe { y.data_mut() };
    for tok in 0..seq_len {
        let pos = start_pos + tok;
        for head in 0..n_heads {
            // Iterate over dimension pairs
            for i in 0..d / 2 {
                let a = data[tok * n_heads * d + head * d + i];
                let b = data[tok * n_heads * d + head * d + i + d / 2];
                let freq = pos as f32 / theta.powf((i * 2) as f32 / d as f32);
                let (sin, cos) = freq.sin_cos();
                // Apply rotation matrix
                data[tok * n_heads * d + head * d + i] = a * cos - b * sin;
                data[tok * n_heads * d + head * d + i + d / 2] = b * cos + a * sin;
            }
        }
    }
}

FFN (MLP)#

image

MLP structure (src/model.rs):

fn mlp(
    residual: &mut Tensor<f32>,
    hidden_states: &mut Tensor<f32>,
    gate: &mut Tensor<f32>,
    up: &mut Tensor<f32>,
    w_up: &Tensor<f32>,
    w_down: &Tensor<f32>,
    w_gate: &Tensor<f32>,
    rms_w: &Tensor<f32>,
    eps: f32,
) {
    // 1. hidden = rms_norm(residual)
    OP::rms_norm(hidden_states, residual, rms_w, eps);

    // 2. gate = hidden @ gate_weight.T
    OP::matmul_transb(gate, 0., hidden_states, w_gate, 1.0);

    // 3. up = hidden @ up_weight.T
    OP::matmul_transb(up, 0., hidden_states, w_up, 1.0);

    // 4. act = gate * sigmoid(gate) * up (SwiGLU)
    OP::swiglu(up, gate);

    // 5. residual = residual + up @ down_weight.T
    OP::matmul_transb(residual, 1.0, up, w_down, 1.0);
}

Operators#

The basic computational operations involved in the entire process are shown in the figure:

image

MatMul (src/operators.rs):

// C = beta * C + alpha * A @ B^T
// hint: You don't need to do an explicit transpose of B
pub fn matmul_transb(c: &mut Tensor<f32>, beta: f32, a: &Tensor<f32>, b: &Tensor<f32>, alpha: f32) {
    // C_xy = beta * C_xy + alpha * Σ(A_xk * B_yk)  k=1..inner_dim
    let shape_a = a.shape();
    let shape_b = b.shape();
    let (a_rows, b_rows) = (shape_a[0], shape_b[0]);
    let inner = shape_a[1];

    let _c = unsafe { c.data_mut() };
    let _a = a.data();
    let _b = b.data();

    // 1. using slice to scale C
    _c.iter_mut().for_each(|val| *val *= beta);

    // 2. using slice to matmul
    for x in 0..a_rows {
        // get current row of A
        let a_row = &_a[x * inner..(x + 1) * inner];
        
        for y in 0..b_rows {
            // get current row of B (equivalent to column of B^T)
            let b_row = &_b[y * inner..(y + 1) * inner];
            
            let sum: f32 = a_row.iter()
                .zip(b_row.iter())
                .map(|(&a, &b)| a * b)
                .sum();

            _c[x * b_rows + y] += alpha * sum;
        }
    }
    
}

SwiGLU (src/operators.rs)

SiLU(x)gateSiLU(x) * gate, where SiLU(x)=xsigmoid(x)SiLU(x) = x * sigmoid(x)

// hint: this is an element-wise operation
pub fn swiglu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
    let len = y.size();
    assert!(len == x.size());

    let _y = unsafe { y.data_mut() };
    let _x = x.data();

    for i in 0..len {
        _y[i] = _y[i] * _x[i] / (1. + f32::exp(-_x[i]));
    }
}

Reading from safetensor#

(src/params.rs)

use crate::config::LlamaConfigJson;
use crate::tensor::Tensor;
use num_traits::Num;
use safetensors::SafeTensors;

pub struct LLamaParams<T: Num> {
    // token_id to embedding lookup table
    pub embedding_table: Tensor<T>, // (vocab_size, dim)
    // decoder layer
    pub rms_att_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub wq: Vec<Tensor<T>>,        // (n_heads * head_size, hidden_size) x layers
    pub wk: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wv: Vec<Tensor<T>>,        // (n_kv_heads * head_size, hidden_size) x layers
    pub wo: Vec<Tensor<T>>,        // (hidden_size, n_heads * head_size) x layers
    // ffn layer
    pub rms_ffn_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
    pub w_up: Vec<Tensor<T>>,      // (intermediate_size, hidden_size) x layers
    pub w_gate: Vec<Tensor<T>>,    // (intermediate_size, hidden_size) x layers
    pub w_down: Vec<Tensor<T>>,    // (hidden_size, intermediate_size) x layers
    // output
    pub rms_out_w: Tensor<T>, // (hidden_size, )
    pub lm_head: Tensor<T>,   // (vocab_size, dim)
}

trait GetTensorFromSafeTensors<P: Num> {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<P>, &'static str>;
}

impl GetTensorFromSafeTensors<f32> for f32 {
    fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<f32>, &'static str> {
        let tensor_view = tensors.tensor(name).map_err(|e| {
            assert!(matches!(e, safetensors::SafeTensorError::TensorNotFound(_)));
            "Tensor not found"
        })?;
        
        let data = match tensor_view.dtype() {
            safetensors::Dtype::F32 => tensor_view.data()
                .chunks_exact(4)
                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
                .collect(),
            _ => return Err("Unsupported data type"),
        };
        
        Ok(Tensor::new(data, &tensor_view.shape().to_vec()))
    }
}

macro_rules! get_tensor_vec {
    ($tensors:expr, $pattern:literal, $layers:expr) => {{
        (0..$layers)
            .map(|i| {
                let name = format!($pattern, i);
                f32::get_tensor_from($tensors, &name).unwrap()
            })
            .collect::<Vec<_>>()
    }};
}

impl LLamaParams<f32> {
    pub fn from_safetensors(safetensor: &SafeTensors, config: &LlamaConfigJson) -> Self {
        let embedding_table = if config.tie_word_embeddings {
            f32::get_tensor_from(safetensor, "lm_head.weight").unwrap()
        } else {
            f32::get_tensor_from(safetensor, "model.embed_tokens.weight").unwrap()
        };

        let n_layers = config.num_hidden_layers;
        
        Self {
            embedding_table,
            rms_att_w: get_tensor_vec!(safetensor, "model.layers.{}.input_layernorm.weight", n_layers),
            wq: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.q_proj.weight", n_layers),
            wk: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.k_proj.weight", n_layers),
            wv: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.v_proj.weight", n_layers),
            wo: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.o_proj.weight", n_layers),
            rms_ffn_w: get_tensor_vec!(safetensor, "model.layers.{}.post_attention_layernorm.weight", n_layers),
            w_up: get_tensor_vec!(safetensor, "model.layers.{}.mlp.up_proj.weight", n_layers),
            w_gate: get_tensor_vec!(safetensor, "model.layers.{}.mlp.gate_proj.weight", n_layers),
            w_down: get_tensor_vec!(safetensor, "model.layers.{}.mlp.down_proj.weight", n_layers),
            rms_out_w: f32::get_tensor_from(safetensor, "model.norm.weight").unwrap(),
            lm_head: f32::get_tensor_from(safetensor, "lm_head.weight").unwrap(),
        }
    }
}

KV-Cache Mechanism#

During fp16 (parameters occupy 2 bytes) inference, the KV cache occupies computation:

2×2×n_layer×seq_len×n_head×head_dim×batch2\times 2 \times n\_layer\times seq\_{len}\times n\_head\times head\_dim \times batch

As the sequence length increases, performance declines exponentially (note the incremental computation for the latest token generation):

Screenshot 2025-01-29 at 17.08.44

The corresponding parameter use_cache: true and max_position_embeddings: 4096, the latter determines the model's maximum context window.

Inference Optimization Plan#

Group Query Attention (GQA)#

Memory optimization is achieved by sharing a single group KV across multiple groups Q (must satisfy head count divisibility):

fn self_attention(
    hidden_states: &mut Tensor<f32>, // (seq, n_kv_h * n_groups * dqkv)
    att_scores: &mut Tensor<f32>,    // (n_kv_h, n_groups, seq, total_seq)
    q: &Tensor<f32>,                 // (seq, n_kv_h * n_groups * dqkv)
    k: &Tensor<f32>,                 // (total_seq, n_kv_h * dqkv)
    v: &Tensor<f32>,                 // (total_seq, n_kv_h * dqkv)
    n_kv_h: usize,
    n_groups: usize,
    seq_len: usize,
    total_seq_len: usize,
    dqkv: usize,
) {
    assert!(k.shape()[0] >= total_seq_len && v.shape()[0] >= total_seq_len);
    assert!(q.shape()[0] == seq_len && q.shape()[1] == n_kv_h * n_groups && q.shape()[2] == dqkv);
    let _a = unsafe { att_scores.data_mut() };
    let _q = q.data();
    let _k = k.data();
    let _v = v.data();
    let sqrt = (dqkv as f32).sqrt();

    // attn_score = Q @ K^T / sqrt(dim)
    for q in 0..n_kv_h * n_groups {
        for seq in 0..seq_len {
            for t_seq in 0..total_seq_len {
                let mut sum = 0.0;
                for d in 0..dqkv {
                    sum += _q[seq * n_kv_h * n_groups * dqkv + q * dqkv + d]
                        * _k[t_seq * n_kv_h * dqkv + q / n_groups * dqkv + d];
                }
                _a[q * seq_len * total_seq_len + seq * total_seq_len + t_seq] = sum / sqrt;
            }
        }
    }

    // attn = softmax(score)
    OP::masked_softmax(att_scores);

    // x = attn @ V
    let _a = att_scores.data();
    let _h = unsafe { hidden_states.data_mut() };
    for q in 0..n_kv_h * n_groups {
        for seq in 0..seq_len {
            for d in 0..dqkv {
                let mut sum = 0.0;
                for t_seq in 0..total_seq_len {
                    sum += _a[q * seq_len * total_seq_len + seq * total_seq_len + t_seq]
                        * _v[d + q / n_groups * dqkv + t_seq * n_kv_h * dqkv];
                }
                _h[seq * n_kv_h * n_groups * dqkv + q * dqkv + d] = sum;
            }
        }
    }
}

Batching techniques, session management, dialogue rollback, etc., are temporarily not implemented due to lack of energy, and will only be summarized for popular science later.

Mixed Precision Inference#

Commonly used are FP16, BF16; the former has the same precision as TF32, while the latter has the same range as FP32 but lower precision. TF32 combines the range of FP32 with the precision of FP16 but requires tensor core hardware support.

Screenshot 2025-01-30 at 21.26.35

Note that the following precision-sensitive operations must retain FP32:

  1. Trigonometric calculations of RoPE
  2. Exponential operations of RMS Norm
  3. Numerical stability of Softmax

Quantized Inference#

Achieved through data type mapping:

  • Computational acceleration (requires hardware support for low-precision instructions)
  • Memory optimization (fp16 only requires 50% of fp32 space, even int8, int4...)

Screenshot 2025-01-30 at 21.52.25

Advanced Optimization Directions#

  1. MoE Architecture: Dynamic routing MLP optimization
  2. Distributed Inference:
    • Data parallelism: full parameter copy, improving throughput
    • Pipeline parallelism: inter-layer splitting, communication overhead sensitive
    • Tensor parallelism: distributed parameter storage, requires AllGather/AllReduce synchronization

Screenshot 2025-01-30 at 22.47.10

Summary#

This training camp has rigorously connected the technical threads of LLM inference optimization through practical experience, which is worth continuing to advance when time permits. Recently, there have been many things to attend to, so I will wrap it up here 🥲

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.