I stumbled upon the Large Model and Artificial Intelligence System Training Camp organized by Tsinghua University on Bilibili and decisively signed up. I plan to use the time returning home during the Spring Festival to consolidate my theoretical knowledge of LLM Inference through practice, coinciding with a VPN failure at school that prevents research, making it a perfect opportunity to organize my study notes.
Regarding the Rust language, I attempted to get started twice during my junior year (a certain biblical textbook warned against it), but this time I changed my strategy to a dual-track learning approach using rustlings + official documentation, finally breaking through the entry barrier (though still limited to that).
Llama Architecture Analysis#
Starting from the core component breakdown, revisiting classic architectural design:
Layer Normalization (RMS Norm)#
Llama adopts a Pre-Norm architecture, performing normalization before the input of each layer. Compared to traditional Layer Norm:
RMSNorm achieves computational optimization by removing mean centering:
where is the learnable scaling parameter gamma, and prevents division by zero errors.
Note: All operators initially only need to implement FP32 support, and later can use generics and macros to support other formats for inference.
pub fn rms_norm(y: &mut Tensor<f32>, x: &Tensor<f32>, w: &Tensor<f32>, epsilon: f32) {
let shape = x.shape();
assert!(shape == y.shape());
let feature_dim = shape[shape.len() - 1]; // d in mathematical notation
// Calculate number of feature vectors to process
let num_features = shape[0..shape.len()-1].iter().product();
let _y = unsafe { y.data_mut() };
let _x = x.data();
let _w = w.data();
// Process each feature vector independently
for i in 0..num_features {
let offset = i * feature_dim;
// 1. Calculate Σ(x_i²)
let sum_squares: f32 = (0..feature_dim)
.map(|j| {
let x_ij = _x[offset + j];
x_ij * x_ij
})
.sum();
// 2. Calculate RMS(x) = sqrt(1/d * Σ(x_i²) + ε)
let rms = f32::sqrt(sum_squares / feature_dim as f32 + epsilon);
// 3. Apply normalization and scaling: y_i = (w_i * x_i) / RMS(x)
for j in 0..feature_dim {
let idx = offset + j;
_y[idx] = (_w[j] * _x[idx]) / rms;
}
}
}
Rotary Position Encoding (RoPE)#
- Absolute position encoding: constructed by learnable embedding vectors or trigonometric functions affecting max_length, directly added to word vectors, with each position encoding being essentially independent;
- Relative position encoding: learns the distance between a pair of tokens, requiring modifications to attention calculations (e.g., the T5 model uses a bias embedding matrix to represent relative distances, added to the self-attention Q and K matrices), which can be extended to sequences of arbitrary lengths, with slower inference speed, difficult to utilize KV-Cache;
- Rotary Positional Embedding (RoPE) combines the advantages of both absolute and relative position encodings, with the formula for the two-dimensional case as follows:
Before the action of the rotation matrix, a linear transformation is applied to obtain Query & Key to maintain rotational invariance. Then, after passing through the rotation matrix, its effect is to rotate the word vector by degrees (where is the absolute position of the token in the sentence):
In a more general form, the vector is split into multiple groups of 2D blocks (with the default vector dimension being even), and then a different rotation angle is applied to each pair of dimensions:
To optimize computational efficiency, the above computation is equivalent to:
The implementation here is already provided in the repository:
// RoPE: Rotary Positional Embedding
pub fn rope(y: &mut Tensor<f32>, start_pos: usize, theta: f32) {
let shape = y.shape();
assert!(shape.len() == 3);
let seq_len = shape[0];
let n_heads = shape[1];
let d = shape[2];
let data = unsafe { y.data_mut() };
for tok in 0..seq_len {
let pos = start_pos + tok;
for head in 0..n_heads {
// Iterate over dimension pairs
for i in 0..d / 2 {
let a = data[tok * n_heads * d + head * d + i];
let b = data[tok * n_heads * d + head * d + i + d / 2];
let freq = pos as f32 / theta.powf((i * 2) as f32 / d as f32);
let (sin, cos) = freq.sin_cos();
// Apply rotation matrix
data[tok * n_heads * d + head * d + i] = a * cos - b * sin;
data[tok * n_heads * d + head * d + i + d / 2] = b * cos + a * sin;
}
}
}
}
FFN (MLP)#
MLP structure (src/model.rs):
fn mlp(
residual: &mut Tensor<f32>,
hidden_states: &mut Tensor<f32>,
gate: &mut Tensor<f32>,
up: &mut Tensor<f32>,
w_up: &Tensor<f32>,
w_down: &Tensor<f32>,
w_gate: &Tensor<f32>,
rms_w: &Tensor<f32>,
eps: f32,
) {
// 1. hidden = rms_norm(residual)
OP::rms_norm(hidden_states, residual, rms_w, eps);
// 2. gate = hidden @ gate_weight.T
OP::matmul_transb(gate, 0., hidden_states, w_gate, 1.0);
// 3. up = hidden @ up_weight.T
OP::matmul_transb(up, 0., hidden_states, w_up, 1.0);
// 4. act = gate * sigmoid(gate) * up (SwiGLU)
OP::swiglu(up, gate);
// 5. residual = residual + up @ down_weight.T
OP::matmul_transb(residual, 1.0, up, w_down, 1.0);
}
Operators#
The basic computational operations involved in the entire process are shown in the figure:
MatMul (src/operators.rs):
// C = beta * C + alpha * A @ B^T
// hint: You don't need to do an explicit transpose of B
pub fn matmul_transb(c: &mut Tensor<f32>, beta: f32, a: &Tensor<f32>, b: &Tensor<f32>, alpha: f32) {
// C_xy = beta * C_xy + alpha * Σ(A_xk * B_yk) k=1..inner_dim
let shape_a = a.shape();
let shape_b = b.shape();
let (a_rows, b_rows) = (shape_a[0], shape_b[0]);
let inner = shape_a[1];
let _c = unsafe { c.data_mut() };
let _a = a.data();
let _b = b.data();
// 1. using slice to scale C
_c.iter_mut().for_each(|val| *val *= beta);
// 2. using slice to matmul
for x in 0..a_rows {
// get current row of A
let a_row = &_a[x * inner..(x + 1) * inner];
for y in 0..b_rows {
// get current row of B (equivalent to column of B^T)
let b_row = &_b[y * inner..(y + 1) * inner];
let sum: f32 = a_row.iter()
.zip(b_row.iter())
.map(|(&a, &b)| a * b)
.sum();
_c[x * b_rows + y] += alpha * sum;
}
}
}
SwiGLU (src/operators.rs)
, where
// hint: this is an element-wise operation
pub fn swiglu(y: &mut Tensor<f32>, x: &Tensor<f32>) {
let len = y.size();
assert!(len == x.size());
let _y = unsafe { y.data_mut() };
let _x = x.data();
for i in 0..len {
_y[i] = _y[i] * _x[i] / (1. + f32::exp(-_x[i]));
}
}
Reading from safetensor#
(src/params.rs)
use crate::config::LlamaConfigJson;
use crate::tensor::Tensor;
use num_traits::Num;
use safetensors::SafeTensors;
pub struct LLamaParams<T: Num> {
// token_id to embedding lookup table
pub embedding_table: Tensor<T>, // (vocab_size, dim)
// decoder layer
pub rms_att_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
pub wq: Vec<Tensor<T>>, // (n_heads * head_size, hidden_size) x layers
pub wk: Vec<Tensor<T>>, // (n_kv_heads * head_size, hidden_size) x layers
pub wv: Vec<Tensor<T>>, // (n_kv_heads * head_size, hidden_size) x layers
pub wo: Vec<Tensor<T>>, // (hidden_size, n_heads * head_size) x layers
// ffn layer
pub rms_ffn_w: Vec<Tensor<T>>, // (hidden_size, ) x layers
pub w_up: Vec<Tensor<T>>, // (intermediate_size, hidden_size) x layers
pub w_gate: Vec<Tensor<T>>, // (intermediate_size, hidden_size) x layers
pub w_down: Vec<Tensor<T>>, // (hidden_size, intermediate_size) x layers
// output
pub rms_out_w: Tensor<T>, // (hidden_size, )
pub lm_head: Tensor<T>, // (vocab_size, dim)
}
trait GetTensorFromSafeTensors<P: Num> {
fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<P>, &'static str>;
}
impl GetTensorFromSafeTensors<f32> for f32 {
fn get_tensor_from(tensors: &SafeTensors, name: &str) -> Result<Tensor<f32>, &'static str> {
let tensor_view = tensors.tensor(name).map_err(|e| {
assert!(matches!(e, safetensors::SafeTensorError::TensorNotFound(_)));
"Tensor not found"
})?;
let data = match tensor_view.dtype() {
safetensors::Dtype::F32 => tensor_view.data()
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
_ => return Err("Unsupported data type"),
};
Ok(Tensor::new(data, &tensor_view.shape().to_vec()))
}
}
macro_rules! get_tensor_vec {
($tensors:expr, $pattern:literal, $layers:expr) => {{
(0..$layers)
.map(|i| {
let name = format!($pattern, i);
f32::get_tensor_from($tensors, &name).unwrap()
})
.collect::<Vec<_>>()
}};
}
impl LLamaParams<f32> {
pub fn from_safetensors(safetensor: &SafeTensors, config: &LlamaConfigJson) -> Self {
let embedding_table = if config.tie_word_embeddings {
f32::get_tensor_from(safetensor, "lm_head.weight").unwrap()
} else {
f32::get_tensor_from(safetensor, "model.embed_tokens.weight").unwrap()
};
let n_layers = config.num_hidden_layers;
Self {
embedding_table,
rms_att_w: get_tensor_vec!(safetensor, "model.layers.{}.input_layernorm.weight", n_layers),
wq: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.q_proj.weight", n_layers),
wk: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.k_proj.weight", n_layers),
wv: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.v_proj.weight", n_layers),
wo: get_tensor_vec!(safetensor, "model.layers.{}.self_attn.o_proj.weight", n_layers),
rms_ffn_w: get_tensor_vec!(safetensor, "model.layers.{}.post_attention_layernorm.weight", n_layers),
w_up: get_tensor_vec!(safetensor, "model.layers.{}.mlp.up_proj.weight", n_layers),
w_gate: get_tensor_vec!(safetensor, "model.layers.{}.mlp.gate_proj.weight", n_layers),
w_down: get_tensor_vec!(safetensor, "model.layers.{}.mlp.down_proj.weight", n_layers),
rms_out_w: f32::get_tensor_from(safetensor, "model.norm.weight").unwrap(),
lm_head: f32::get_tensor_from(safetensor, "lm_head.weight").unwrap(),
}
}
}
KV-Cache Mechanism#
During fp16 (parameters occupy 2 bytes) inference, the KV cache occupies computation:
As the sequence length increases, performance declines exponentially (note the incremental computation for the latest token generation):
The corresponding parameter use_cache: true
and max_position_embeddings: 4096
, the latter determines the model's maximum context window.
Inference Optimization Plan#
Group Query Attention (GQA)#
Memory optimization is achieved by sharing a single group KV across multiple groups Q (must satisfy head count divisibility):
fn self_attention(
hidden_states: &mut Tensor<f32>, // (seq, n_kv_h * n_groups * dqkv)
att_scores: &mut Tensor<f32>, // (n_kv_h, n_groups, seq, total_seq)
q: &Tensor<f32>, // (seq, n_kv_h * n_groups * dqkv)
k: &Tensor<f32>, // (total_seq, n_kv_h * dqkv)
v: &Tensor<f32>, // (total_seq, n_kv_h * dqkv)
n_kv_h: usize,
n_groups: usize,
seq_len: usize,
total_seq_len: usize,
dqkv: usize,
) {
assert!(k.shape()[0] >= total_seq_len && v.shape()[0] >= total_seq_len);
assert!(q.shape()[0] == seq_len && q.shape()[1] == n_kv_h * n_groups && q.shape()[2] == dqkv);
let _a = unsafe { att_scores.data_mut() };
let _q = q.data();
let _k = k.data();
let _v = v.data();
let sqrt = (dqkv as f32).sqrt();
// attn_score = Q @ K^T / sqrt(dim)
for q in 0..n_kv_h * n_groups {
for seq in 0..seq_len {
for t_seq in 0..total_seq_len {
let mut sum = 0.0;
for d in 0..dqkv {
sum += _q[seq * n_kv_h * n_groups * dqkv + q * dqkv + d]
* _k[t_seq * n_kv_h * dqkv + q / n_groups * dqkv + d];
}
_a[q * seq_len * total_seq_len + seq * total_seq_len + t_seq] = sum / sqrt;
}
}
}
// attn = softmax(score)
OP::masked_softmax(att_scores);
// x = attn @ V
let _a = att_scores.data();
let _h = unsafe { hidden_states.data_mut() };
for q in 0..n_kv_h * n_groups {
for seq in 0..seq_len {
for d in 0..dqkv {
let mut sum = 0.0;
for t_seq in 0..total_seq_len {
sum += _a[q * seq_len * total_seq_len + seq * total_seq_len + t_seq]
* _v[d + q / n_groups * dqkv + t_seq * n_kv_h * dqkv];
}
_h[seq * n_kv_h * n_groups * dqkv + q * dqkv + d] = sum;
}
}
}
}
Batching techniques, session management, dialogue rollback, etc., are temporarily not implemented due to lack of energy, and will only be summarized for popular science later.
Mixed Precision Inference#
Commonly used are FP16, BF16; the former has the same precision as TF32, while the latter has the same range as FP32 but lower precision. TF32 combines the range of FP32 with the precision of FP16 but requires tensor core hardware support.
Note that the following precision-sensitive operations must retain FP32:
- Trigonometric calculations of RoPE
- Exponential operations of RMS Norm
- Numerical stability of Softmax
Quantized Inference#
Achieved through data type mapping:
- Computational acceleration (requires hardware support for low-precision instructions)
- Memory optimization (fp16 only requires 50% of fp32 space, even int8, int4...)
Advanced Optimization Directions#
- MoE Architecture: Dynamic routing MLP optimization
- Distributed Inference:
- Data parallelism: full parameter copy, improving throughput
- Pipeline parallelism: inter-layer splitting, communication overhead sensitive
- Tensor parallelism: distributed parameter storage, requires AllGather/AllReduce synchronization
Summary#
This training camp has rigorously connected the technical threads of LLM inference optimization through practical experience, which is worth continuing to advance when time permits. Recently, there have been many things to attend to, so I will wrap it up here 🥲