banner
Nagi-ovo

Nagi-ovo

Breezing homepage: nagi.fun
github

Implementing Simple LLM Inference with Rust

I stumbled upon the Large Model and Artificial Intelligence System Training Camp organized by Tsinghua University on Bilibili and decisively signed up. I plan to use the time returning home during the Spring Festival to consolidate my theoretical knowledge of LLM Inference through practice, coinciding with a VPN failure at school that prevents research, making it a perfect opportunity to organize my study notes.

Regarding the Rust language, I attempted to get started twice during my junior year (a certain biblical textbook warned against it), but this time I changed my strategy to a dual-track learning approach using rustlings + official documentation, finally breaking through the entry barrier (though still limited to that).

Llama Architecture Analysis#

Starting from the core component breakdown, revisiting classic architectural design:

Screenshot 2025-01-29 at 13.51.47

Layer Normalization (RMS Norm)#

Llama adopts a Pre-Norm architecture, performing normalization before the input of each layer. Compared to traditional Layer Norm:

x^i=xiμσ2+ϵ\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}

RMSNorm achieves computational optimization by removing mean centering:

aˉi=aiRMS(a)gi,where  RMS(a)=1ni=1nai2+ϵ\bar{a}_i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i,\quad \text{where}~~ \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} a_i^2 + \epsilon}

where gig_i is the learnable scaling parameter gamma, and ϵ\epsilon prevents division by zero errors.

Note: All operators initially only need to implement FP32 support, and later can use generics and macros to support other formats for inference.

Rotary Position Encoding (RoPE)#

  • Absolute position encoding: constructed by learnable embedding vectors or trigonometric functions affecting max_length, directly added to word vectors, with each position encoding being essentially independent;
  • Relative position encoding: learns the distance between a pair of tokens, requiring modifications to attention calculations (e.g., the T5 model uses a bias embedding matrix to represent relative distances, added to the self-attention Q and K matrices), which can be extended to sequences of arbitrary lengths, with slower inference speed, difficult to utilize KV-Cache;
  • Rotary Positional Embedding (RoPE) combines the advantages of both absolute and relative position encodings, with the formula for the two-dimensional case as follows:
f{q,k}(xm,m)RoPE output=(cosmθsinmθsinmθcosmθ)Rotation matrix R(mθ)(W{q,k}(11)W{q,k}(12)W{q,k}(21)W{q,k}(22))Linear transformation matrix W{q,k}(xm(1)xm(2))Input vector xm\underbrace{f_{\{q,k\}}(\mathbf{x}_m, m)}_{\text{RoPE output}} = \underbrace{ \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} }_{\text{Rotation matrix } R(m\theta)} \underbrace{ \begin{pmatrix} W_{\{q,k\}}^{(11)} & W_{\{q,k\}}^{(12)} \\ W_{\{q,k\}}^{(21)} & W_{\{q,k\}}^{(22)} \end{pmatrix} }_{\text{Linear transformation matrix } W_{\{q,k\}}} \underbrace{ \begin{pmatrix} x_m^{(1)} \\ x_m^{(2)} \end{pmatrix} }_{\text{Input vector } \mathbf{x}_m}

Before the action of the rotation matrix, a linear transformation is applied to obtain Query & Key to maintain rotational invariance. Then, after passing through the rotation matrix, its effect is to rotate the word vector by mθm\theta degrees (where mm is the absolute position of the token in the sentence):

image

image

In a more general form, the vector is split into multiple groups of 2D blocks (with the default vector dimension being even), and then a different rotation angle is applied to each pair of dimensions:

f{q,k}(xm,m)=RΘ,mdW{q,k}xmf_{\{q,k\}}(\mathbf{x}_m, m) = \mathbf{R}_{\Theta,m}^d \mathbf{W}_{\{q,k\}} \mathbf{x}_m
RΘ,md=(cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2)\mathbf{R}_{\Theta,m}^d = \begin{pmatrix} \cos m\theta_1 & -\sin m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_1 & \cos m\theta_1 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_2 & -\sin m\theta_2 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_2 & \cos m\theta_2 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2} \end{pmatrix}

To optimize computational efficiency, the above computation is equivalent to:

RΘ,mdx=(x1x2x3x4xd1xd)(cosmθ1cosmθ1cosmθ2cosmθ2cosmθd/2cosmθd/2)+(x2x1x4x3xd1xd)(sinmθ1sinmθ1sinmθ2sinmθ2sinmθd/2sinmθd/2)\mathbf{R}_{\Theta,m}^d \mathbf{x} = \begin{pmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \cos m\theta_1 \\ \cos m\theta_1 \\ \cos m\theta_2 \\ \cos m\theta_2 \\ \vdots \\ \cos m\theta_{d/2} \\ \cos m\theta_{d/2} \end{pmatrix} + \begin{pmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_{d-1} \\ x_d \end{pmatrix} \otimes \begin{pmatrix} \sin m\theta_1 \\ \sin m\theta_1 \\ \sin m\theta_2 \\ \sin m\theta_2 \\ \vdots \\ \sin m\theta_{d/2} \\ \sin m\theta_{d/2} \end{pmatrix}

Screenshot 2025-01-29 at 15.29.32

The implementation here is already provided in the repository:

FFN (MLP)#

image

MLP structure (src/model.rs):

Operators#

The basic computational operations involved in the entire process are shown in the figure:

image

MatMul (src/operators.rs):

SwiGLU (src/operators.rs)

SiLU(x)gateSiLU(x) * gate, where SiLU(x)=xsigmoid(x)SiLU(x) = x * sigmoid(x)

Reading from safetensor#

(src/params.rs)

KV-Cache Mechanism#

During fp16 (parameters occupy 2 bytes) inference, the KV cache occupies computation:

2×2×n_layer×seq_len×n_head×head_dim×batch2\times 2 \times n\_layer\times seq\_{len}\times n\_head\times head\_dim \times batch

As the sequence length increases, performance declines exponentially (note the incremental computation for the latest token generation):

Screenshot 2025-01-29 at 17.08.44

The corresponding parameter use_cache: true and max_position_embeddings: 4096, the latter determines the model's maximum context window.

Inference Optimization Plan#

Group Query Attention (GQA)#

Memory optimization is achieved by sharing a single group KV across multiple groups Q (must satisfy head count divisibility):

Batching techniques, session management, dialogue rollback, etc., are temporarily not implemented due to lack of energy, and will only be summarized for popular science later.

Mixed Precision Inference#

Commonly used are FP16, BF16; the former has the same precision as TF32, while the latter has the same range as FP32 but lower precision. TF32 combines the range of FP32 with the precision of FP16 but requires tensor core hardware support.

Screenshot 2025-01-30 at 21.26.35

Note that the following precision-sensitive operations must retain FP32:

  1. Trigonometric calculations of RoPE
  2. Exponential operations of RMS Norm
  3. Numerical stability of Softmax

Quantized Inference#

Achieved through data type mapping:

  • Computational acceleration (requires hardware support for low-precision instructions)
  • Memory optimization (fp16 only requires 50% of fp32 space, even int8, int4...)

Screenshot 2025-01-30 at 21.52.25

Advanced Optimization Directions#

  1. MoE Architecture: Dynamic routing MLP optimization
  2. Distributed Inference:
    • Data parallelism: full parameter copy, improving throughput
    • Pipeline parallelism: inter-layer splitting, communication overhead sensitive
    • Tensor parallelism: distributed parameter storage, requires AllGather/AllReduce synchronization

Screenshot 2025-01-30 at 22.47.10

Summary#

This training camp has rigorously connected the technical threads of LLM inference optimization through practical experience, which is worth continuing to advance when time permits. Recently, there have been many things to attend to, so I will wrap it up here 🥲

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.