Prerequisite Knowledge: Familiarity with basic concepts of Python, calculus, and statistics, as well as the previous micrograd and makemore series courses (optional).
Objective: To understand and appreciate how GPT works.
Resources you might need:
Colab Notebook link
A detailed note seen on Twitter, better than what I wrote
ChatGPT#
Launched at the end of 2022, ChatGPT has evolved into GPT-4 and Claude 3, and these LLMs (large language models) have become part of many people's daily lives. They are all probabilistic systems, and for the same prompt, their answers can vary. Compared to the language models we implemented earlier, models like GPT can simulate sequences of words, characters, or more generally, symbols, and understand how certain words in English follow one another. From the perspective of these models, our prompt is the beginning of a sequence, and the model's task is to complete this sequence.
So, what is the neural network that models these sequences of words?
Transformer#
In 2017, the landmark paper “Attention is All You Need” proposed the Transformer architecture. The GPT we know stands for Generative Pre-trained Transformer. Although the original paper pertains to machine translation, it has profoundly influenced the entire AI field, and slight modifications to this architecture can be applied to a multitude of AI applications, forming the core of ChatGPT.
Of course, the goal of this section is not to train a ChatGPT; after all, that is a super industrial-grade project involving extensive data training, pre-training, and fine-tuning processes. What we aim to do is train a language model based on the Transformer, which, like before, will also be a character-level language model.
Building the Model#
Dataset#
We will use a toy-level small dataset called “Tiny Shakespeare,” which is favored by Andrej. This dataset is essentially a mishmash of all of Shakespeare's works, with a file size of about 1MB. One difference from ChatGPT is that the output unit of ChatGPT is tokens, similar to the concept of "word chunks," which we will mention later.
# We always start training from a dataset, downloading the Tiny Shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
# Read to check
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
Tokenize#
chars = sorted(list(set(text))) # set gets the unordered sequence of unique characters in the sequence, converted to list for sorting functionality
vocab_size = len(chars)
print(''.join(chars)) # Merge into a string
print(vocab_size)
# Output (sorted by ASCII code):
# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
# 65
This is similar to the character table functionality we did in previous sections, and now we need to develop a function that can tokenize the input sequence. The name means converting the raw text as a string into some integer sequences. For our character-level model, this simply means mapping individual characters to numbers.
If you've seen the content from previous sections, this part of the code should feel quite familiar, as it resembles the "Creating Lookup Table and Character Mapping" in Bigram.
# Create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: takes a string and outputs a list of integers (encoder)
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: takes a list of integers and outputs a string (decoder)
print(encode("hii there"))
print(decode(encode("hii there")))
Here we simultaneously built the encoder and decoder, which serve to convert strings and integers at the character level. This is just a very simple tokenize algorithm, and many methods have been proposed, such as Google's SentencePiece, which can split text into subwords, a common practice; OpenAI also has TikToken that tokenizes using byte pairs.
Using tiktoken to encode: the vocabulary of gpt2 contains 50,257 tokens, and for the same string, it only took 3 integers to complete the encoding compared to our simple algorithm.
# Now encode the entire text dataset and store it in a torch.tensor.
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])
The 1000 characters seen earlier will appear in GPT as follows:
You can see that 0 is a space, and 1 is a newline character.
So far, the entire dataset has been re-represented as a large integer sequence.
Train/Validation Split#
# Split the data into training and validation sets to check the model's overfitting degree
n = int(0.9*len(data)) # The first 90% will be used for training, the rest for validation
train_data = data[:n]
val_data = data[n:]
We do not want the model to perfectly memorize Shakespeare's works, but rather to create text that mimics Shakespeare's style.
Chunks & Batches#
It is important to note that we will not input the entire text into the Transformer at once, but rather use chunks of the dataset, which means randomly sampling small blocks from the training set.
Chunk Processing#
Block Size is used to specify the fixed length of each input data block (such as text segments) during model training.
block_size = 8
train_data[:block_size+1]
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
context = x[:t+1]
target = y[t]
print(f"when input is {context} the target: {target}")
This is actually a strategy to gradually reveal context information to the model.
This method forces the model to learn to predict the next character (or token) in the sequence based on previous characters (or tokens), enhancing its reasoning ability.
Batch Processing#
To improve the efficiency of parallel computation that GPUs excel at, we also need to consider batch training, stacking multiple batches of text blocks into a tensor and processing multiple independent data blocks simultaneously.
The meaning of batch size is how many independent sequences our Transformer needs to process in each forward & backward pass.
torch.manual_seed(1337) # Provide sampling and reproducibility
batch_size = 4 # Number of independent sequences processed in parallel
block_size = 8 # Maximum context length for prediction
# The role here is similar to torch's dataloader
def get_batch(split):
# Generate a small batch of input data x and target data y
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
return x, y
xb, yb = get_batch('train')
torch.stack
is used to stack a series of tensors along a new dimension, and all tensors need to have the same shape.
You can see that the shape of inputs is 4x8, with each column being a part of the training set; while the targets serve to compute the loss function at the end of the model.
for b in range(batch_size): # Batch dimension
for t in range(block_size): # Time dimension
context = xb[b, :t+1]
target = yb[b,t]
This makes it clearer to understand the relationship between inputs & outputs of the two arrays.
Bigram#
In the Makemore series, we delved into and implemented the bigram language model, and now we will quickly re-implement it using the PyTorch Module.
Model Building#
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
# Each token directly reads the logit of the next token from the lookup table
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
This embedding layer should also be familiar; for example, inputting 24 retrieves the 24th row from the embedding table.
def forward(self, idx, targets=None):
# idx and targets are both integer (B, T) tensors
logits = self.token_embedding_table(idx) # (Batch=4, Time=8, Channel=65)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
In the Makemore series, we learned a good way to measure loss: negative log-likelihood loss, which corresponds to "cross-entropy" in PyTorch. Intuitively, this means that the model should have a high probability (high confidence) for the correct classification corresponding to the logits, while all other dimensions should have very low probabilities (very low confidence). The loss at this point can be estimated to be about -log(1/65), approximately 4.17, but due to some entropy, the actual result will be somewhat larger.
# Generate from the model
def generate(self, idx, max_new_tokens):
# idx is the (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Get the prediction results
logits, loss = self(idx)
# Focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# Apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append the sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
print(loss)
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist())) # The first character of the sequence is a newline (0)
The task of generate is to expand the size of the BxT index representing context information to $B\times T + 1, + 2 ,+\ldots$, meaning to continue generating across all batch dimensions in the time dimension.
The generation result when the model is untrained is completely random.
Model Training#
Now we will start training this model. Compared to the Makemore series using stochastic gradient descent (SGD), here we use the more advanced and popular AdamW optimizer.
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
The optimizer's basic function is to obtain gradients and update parameters based on those gradients.
batch_size = 32 # Choose a larger batch size
for steps in range(100): # Increase steps for better results
# Sample a batch of data
xb, yb = get_batch('train')
# Evaluate loss
logits, loss = m(xb, yb)
optimizer.zero_grad(set_to_none=True) # Zero out the gradients from the previous step
loss.backward() # Backpropagation
optimizer.step() # Equivalent to "new parameters = old parameters - learning rate * gradient," similar to the gradient descent loop we previously implemented manually
print(loss.item())
We can see that our optimization is working, and the loss is decreasing.
By increasing the number of training epochs, we eventually reached around 2.48. We copied the previous sampling code snippet to generate again, and we should get improved results.
It has a human-like shape, but not so much.
No matter how much we train, it is difficult to achieve an ideal structure because this model's assumption is very simple (it only predicts the next token based on the previous one), and there is no relationship between tokens, which is why we use transformers.
Transformer#
If you have an Nvidia graphics card, you can speed up training:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
After this setting, some changes need to be made to the code, mainly to ensure that data loading, computation, and sampling generation all occur on the device (GPU). For specific details, please refer to Andrej's lecture repository, where bigram.py is our starting point.
In addition, our model is divided into training and evaluation phases, but currently, the model only has one nn.Embedding
layer, and both phases perform the same without introducing dropout layer
, batch norm layer
, etc. This is considered best practice in training models, as some layers behave differently during training and inference.
Self-Attention#
Self-attention mechanism
Before diving into the Transformer, the first thing we need to do is get accustomed to a mathematical trick that implements the core self-attention of the Transformer through a simple example.
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
# x.shape = torch.Size([4, 8, 2])
We want to combine these originally independent tokens in a specific way. For example, the fifth token should not be able to interact with the sixth, seventh, and eighth tokens, as these three belong to future future tokens in this sequence. Therefore, the fifth token can only interact with the fourth, third, second, and first tokens, meaning that information can only flow from previous contexts to the current time step, using this information to predict future information.
So, what is the simplest way for tokens to communicate with each other?
The answer is surprisingly simple: average the previous tokens to form a historical feature vector in the current context. Of course, it is easy to guess that this interaction method is too weak and loses a lot of information about the spatial arrangement of these tokens.
v1. Loop#
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
for t in range(T):
xprev = x[b,:t+1] # (t,c)
xbow[b,t] = torch.mean(xprev, 0)
Printing it out helps understand the function of this code.
This method is computationally inefficient; we can use matrix multiplication to accomplish this task more efficiently.
v2. Matrix Multiplication#
# A simplified example to illustrate how matrix multiplication can be used for "weighted aggregation."
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
Here is the basic matrix multiplication operation, where each number in c is the dot product of the corresponding rows and columns in a and b. For example, the (1, 1) element of c is .
To achieve the same effect, we can replace a with a lower triangular matrix:
This can achieve the effect of "summing the first and second rows separately," implemented using torch.tril
:
a = torch.tril(torch.ones(3, 3))
Now the function is summation, as all elements in a are 1. To achieve weighted aggregation, we can normalize each row of A so that the sum of the elements in each row equals 1:
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True) # keepdim ensures broadcasting semantics are feasible
Now each row of a sums to 1, and c is the average of the corresponding previous rows in b.
Returning to apply this more efficient method:
# Version 2: Using matrix multiplication for weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
The weight matrix corresponds to the a matrix above.
# Here torch will create a batch dimension for weight
xbow2 = wei @ x # (T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2) # Compare if two tensors are equal within a certain numerical tolerance
# Output True, both methods yield the same effect.
To summarize this trick: we can use batch matrix multiplication to achieve weighted aggregation, with weights specified in this T✖️T matrix, and the weighted sum is distributed according to the dimensions and weights in a triangular distribution, allowing tokens in the t dimension to only obtain information from previous tokens.
v3. Softmax#
Additionally, we can use Softmax to implement the third version.
One important API is torch.masked_fill()
, which fills the input tensor based on a specified mask tensor, as shown in the following diagram:
What happens if we take Softmax for each row? As mentioned in previous chapters, Softmax is a normalization operation, and here it serves to perform weighted aggregation on the "past" elements using a lower triangular matrix multiplication:
# Version 3: Using Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
In addition to encoding the identity of tokens, we also need to encode the position of tokens:
class BigramLanguageModel(nn.Module):
def __init__(self):
super().__init__()
# Each token directly reads the logits of the next token from the lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # token encoding
self.position_embedding_table = nn.Embedding(block_size, n_embd) # position encoding
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both integer (B,T) tensors
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
Here, x stores the sum of token embeddings and position embeddings, but since this is currently just a simple bigram model, different positions have translational invariance. However, when we use the attention mechanism, it will be different.
v4. Self-Attention#
What we are currently doing is just simple averaging, but in reality, the meaning of each token is not the same; it is data-dependent. For example, for a vowel, it wants to know what consonants are passing information before it. This is the problem that the self-attention mechanism solves.
Each token will emit two vectors, a query (what I want to find, what I am interested in) and a key (what information I contain, who is similar to me).
The way we obtain the affinity (weights) between these tokens in the sequence is essentially by taking the dot product between keys and queries. If a key matches or aligns very well with a query, the weight corresponding to that key will be high. Therefore, the model's attention will focus on the value corresponding to that key—meaning the model will pay more attention to that specific information (or token) rather than any other information in the sequence.
In other words, the attention mechanism allows the model to focus on the most relevant information by calculating the match between queries and keys and allocating weights accordingly, thus enhancing its ability to process and understand sequential data.
We also need a Value (the information I will contribute to you if you are interested in me), and the final aggregation is not directly from x but from x propagated through a linear layer.
Now let's implement this single-head attention mechanism:
# Version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)
# Let's see how single-head self-attention works
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# Now key and query perform forward propagation on x
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
# Dot product to get affinities (weights)
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x)
out = wei @ v
# out = wei @ x
You can see that the weight matrix wei is not completely averaged like before, but is data-dependent: tokens with high affinity will provide more information for the current token in the weighted aggregation.
Attention Summary#
-
Attention is a communication mechanism. It can be seen as nodes in a directed graph observing each other and aggregating information through a weighted sum of all nodes pointing to them, with weights being data-dependent.
-
It does not have a spatial concept in itself. Attention only acts on a set of vectors. This is why we need to encode positions for tokens.
-
Examples on the batch dimension are processed completely independently and never communicate with each other.
-
The attention mechanism does not care whether you are only interested in past information. In our implementation here, the current token is masked from future information, but simply removing the code that uses
masked_fill
for masking in the "encoder" attention module allows all tokens to communicate with each other. This is referred to as the "decoder" module because it has a triangular matrix mask. The attention mechanism supports connections between any nodes. -
"Self-attention" means that keys and values are generated from the same source as queries (x). In "cross-attention," queries are still generated by x, but keys and values come from some other external source (like an encoder module).
-
"Scaled" Attention: additionally dividing
wei
by $\sqrt{head_size}$.
The reason is that when we currently have a standard Gaussian distribution (mean 0, variance 1) as input, we find that the simple weighting results in the variance of wei actually being on the order of the head size (which is 16 in our implementation).
With this normalization added, the variance of the weights becomes 1:
Why is this step important?
# Our wei will go through softmax
wei = F.softmax(wei, dim=-1)
One property of softmax is that after applying softmax, elements with absolute values larger in the distribution will tend to approach 1:
You can see that the same distribution, when multiplied by 8, becomes more like a one-hot distribution (one-hot: one element is 1 and all others are 0), which means that in the initial stage, the distribution becomes too sharp, essentially only obtaining information from a single node.
Therefore, "scaled" attention essentially controls the variance during initialization by additionally adjusting wei
by dividing it by 1/√(head size). This way, when the variance of inputs Q (queries) and K (keys) is 1, wei
will also maintain unit variance, meaning that the softmax will remain dispersed and not overly saturated. This prevents the weight distribution from causing gradient vanishing or explosion due to large values before applying the softmax function, thus maintaining the stability and effectiveness of the model.
Code Implementation#
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size):
super().__init__() # People generally do not use bias here
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
def forward(self, x):
B,T,C = x.shape
k = self.key(x) # (B,T,C)
q = self.query(x) # (B,T,C)
# Calculate attention scores ("affinities"), applying the scaled method mentioned above
wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
# Perform weighted aggregation
v = self.value(x) # (B,T,C)
out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
return out
In the constructor, tril
is not a parameter of nn.Module
, which is referred to as a buffer in PyTorch naming conventions, and it must be assigned to nn.Module
using register_buffer
.
Multi-Head Attention#
This part of the paper has not been reproduced yet; multi-head attention applies multiple attentions in parallel and concatenates their results:
Code Implementation#
In PyTorch, we can easily achieve this by creating multiple heads.
class MultiHeadAttention(nn.Module):
""" Multi-head parallelism in self-attention. """
def __init__(self, num_heads, head_size):
super().__init__()
# Run in parallel in a list, then concatenate the outputs
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
def forward(self, x):
return torch.cat([h(x) for h in self.heads], dim=-1) # Concatenate along the channel dimension
Now we have 4 parallel communication channels instead of one, and each individual channel will correspondingly become smaller. The embedding dimension is 32, corresponding to an 8-dimensional self-attention, and concatenating yields 32 again, which is the original embedding. This is somewhat similar to group convolution, where instead of performing a large convolution, we perform grouped convolutions.
class BigramLanguageModel(nn.Module):
def __init__(self):
super().__init__()
# Each token directly reads the logit of the next token from the lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.sa_head = nn.MultiHeadAttention(4, n_embd//4)
self.lm_head = nn.Linear(n_embd, vocab_size)
Blocks#
For the diagram below, which shows the network structure presented in the paper, we will not implement the cross-attention of the encoder. However, there is also a feedforward part that is grouped into a block, a block that is repeated (N times).
Feedforward Network#
This feedforward part is just a simple MLP:
Note that the paper states that the input and output dimensions are 512, and the inner dimension of the feedforward is 2048, so the inner channel size of the feedforward should be multiplied by 4.
class FeedFoward(nn.Module):
""" A simple linear layer followed by a non-linear function """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: disperses communication and computation
Communication: multi-head attention
Computation: feedforward network independently across all tokens
"""
def __init__(self, n_embd, n_head):
# n_embd: embedding dimension, n_head: the number of heads we want, here is 8
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size) # Communication
#
self.ffwd = FeedFoward(n_embd) # Computation
def forward(self, x):
x = self.sa(x)
x = self.ffwd(x)
return x
class BigramLanguageModel(nn.Module):
def __init__(self):
super().__init__()
# Each token directly reads the logit of the next token from the lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(
Block(n_embd, n_head=4),
Block(n_embd, n_head=4),
Block(n_embd, n_head=4),
)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) sized integer tensors
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
...
When trying to decode, we find that the results have not improved much. The reason is that we now have a fairly deep neural network that is affected by optimization issues, and we need to borrow a method from the Transformer paper to solve this problem.
Now there are two methods that can greatly increase the depth of the network while ensuring that the network remains optimizable:
Residual Connections#
The part circled in red in the block (the arrows and Add) is the residual connection, a concept proposed in the paper Deep Residual Learning for Image Recognition.
Andrej's original words are, "You transform data, then add it back to the previous features." Let's explain this in detail (which can be paired with the image):
-
You transform data (You transform data): In each layer of a deep neural network, the input data undergoes operations such as multiplication by weight matrices and non-linear transformations through activation functions to learn an abstract representation of the data.
-
But then you have a skip connection (But then you have a skip connection): In traditional deep networks, this transformation of data is continuous and linear. The residual connection breaks this pattern by introducing "skip connections." A skip connection directly connects the input of a certain layer to the output of a later layer (usually one or several layers apart), aiming to pass the features from earlier layers directly to later layers.
-
With addition from the previous features (With addition from the previous features): The implementation of a skip connection is typically done by performing an element-wise addition operation between the skipped input and the output of the target layer. This addition ensures that the original feature information can be directly passed through the network without being "diluted" by transformations in subsequent layers.
The introduction of residual connections allows the network to more easily learn identity mappings, which is very useful for training very deep networks. In fact, residual connections allow gradients to flow directly through the network, helping to alleviate the problems of vanishing or exploding gradients, thus making training deep networks more feasible and efficient.
We mentioned in micrograd that the role of addition nodes in neural networks is to evenly distribute gradients to all inputs (because the addition operation is linear for each input, and each input's contribution to the output is independent). Seeing this makes one marvel at how vibrant knowledge points are interconnected.
class MultiHeadAttention(nn.Module):
""" Multi-head parallelism in self-attention. """
def __init__(self, num_heads, head_size):
super().__init__()
# Run in parallel in a list, then concatenate the outputs
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd) # Introduce projection
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1) # Concatenate along the channel dimension
out = self.proj(out) # The projection of the residual path is a linear transformation of out
return out
class FeedFoward(nn.Module):
""" A simple linear layer followed by a non-linear function """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd), # Project back to the residual path
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: disperses communication and computation
Communication: multi-head attention
Computation: feedforward network independently across all tokens
"""
def __init__(self, n_embd, n_head):
# n_embd: embedding dimension, n_head: the number of heads we want, here is 8
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size) # Communication
self.ffwd = FeedFoward(n_embd) # Computation
def forward(self, x):
x = x + self.sa(x)
x = x + self.ffwd(x)
return x
Layer Norm#
The second optimization method is the Norm here, referring to a technique called Layer Normalization:
Layer Norm is very similar to Batch Norm (which ensures that any neuron is a standard normal distribution across the batch dimension); the only difference is that Layer Norm normalizes across the feature dimension instead of the batch dimension. This means that for each sample in the network, Layer Norm calculates the mean and standard deviation of all features of that sample and uses these statistics to normalize all features of that sample.
It is important to note that the internal details of the Transformer paper have not changed much since its release, but our implementation here differs slightly from the original paper. In the paper, Add & Norm is added after the Transform, but the more common practice now is to apply Layer Norm before the Transform, referred to as the Pre-norm formulation.
Now we have a fairly complete Transformer (only the decoder).
class Block(nn.Module):
""" Transformer block: disperses communication and computation
Communication: multi-head attention
Computation: feedforward network independently across all tokens
"""
def __init__(self, n_embd, n_head):
# n_embd: embedding dimension, n_head: the number of heads we want, here is 8
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size) # Communication
self.ffwd = FeedFoward(n_embd) # Computation
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = self.sa(self.ln1(x))
x = self.ffwd(self.ln2(x))
return x
class BigramLanguageModel(nn.Module):
def __init__(self):
super().__init__()
# Each token directly reads the logit of the next token from the lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
''' This part is equivalent to the two lines below
self.blocks = nn.Sequential(
Block(n_embd, n_head=4),
Block(n_embd, n_head=4),
Block(n_embd, n_head=4),
nn.LayerNorm(n_embd),
)
'''
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)]) # Expand the model, specifying the number of block layers
self.ln_f = nn.LayerNorm(n_embd) # Final Layer Norm
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) sized integer tensors
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
Dropout#
Dropout can be added before the residual connection returns to the residual path, as proposed in the paper Dropout: A Simple Way to Prevent Neural Networks from Overfitting. It essentially randomly turns off some neurons in your neural network during each forward-backward pass (i.e., setting them to 0, so they do not participate in subsequent training).
Here, you only need to know that it is a regularization technique.
# Hyperparameters (can run on Colab V100; too slow on CPU, or lower hyperparameters)
batch_size = 64
block_size = 256 # Increase context length, predicting the 257th token
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3 # Lower learning rate as the network grows
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6 # 384 / 6
n_layer = 4 # 4 layers of Block
dropout = 0.2 # 20% chance of dropping neurons during each forward-backward pass
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size):
super().__init__() # People generally do not use bias here
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout) # Dropout
def forward(self, x):
B,T,C = x.shape
k = self.key(x) # (B,T,C)
q = self.query(x) # (B,T,C)
# Calculate attention scores ("affinities"), applying the scaled method mentioned above
wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei) # Dropout
# Perform weighted aggregation
v = self.value(x) # (B,T,C)
out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout) # Dropout
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedFoward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout), # Dropout
)
def forward(self, x):
return self.net(x)
Now the results resemble Shakespearean gibberish, but they are clearly quite impressive.
The encoder shown on the left side of the diagram and the cross-attention in the red circle on the right have not been implemented here.
The reason we only use the decoder is that we are generating text without any conditional constraints, just like the final results, which are merely gibberish based on the given Shakespeare dataset. We achieve this autoregressive property for language modeling through the triangular mask in the attention mechanism.
The original paper adopted an encoder-decoder structure because it was in the field of machine translation, where the model expects input in the form of encoded tokens from a foreign language (like French) and then decodes it into English, as shown in the diagram below:
Here, the encoder takes the French sentences of interest to create tokens, using the Transformer structure above but without triangular masking, allowing all tokens to communicate as much as possible. The decoder responsible for language modeling connects to the output after encoding is complete (the topmost part of the left side of the diagram), which is accomplished through cross-attention. What is actually done is to constrain the decoding, not just decoding past information but also performing it on the fully encoded French tokens. What we implemented is a decoder-only version.
The knowledge corresponding to this section is found in the project: karpathy/nanoGPT, which also focuses solely on the implementation of the pre-training part.
Returning to ChatGPT#
Training ChatGPT roughly consists of two phases: pre-training and fine-tuning.
Pre-training#
Training on a large corpus of internet text, attempting to obtain an encoder-only Transformer. We have now completed a small pre-training step. One difference is that OpenAI's training uses a tokenizer, meaning that the vocabulary is not single characters but character blocks. The Shakespeare dataset we used corresponds to about 300,000 tokens, and we trained about 10 million parameters on it, while GPT-3's Transformer has up to 175 billion parameters and was trained on 300 billion tokens.
After completing this step, you cannot ask the model questions because it currently only generates information like news from the internet, meaning it only serves to complete sequences.
Fine-tuning#
This phase is about training it to be a language model assistant. The first step is to collect thousands of documents formatted as "question: answer" to fine-tune the model for alignment, where the sample efficiency in fine-tuning large models is very high.
The second step involves raters ranking the model's responses, which is used to train a reward model.
The third step runs PPO (Proximal Policy Optimization, a policy gradient reinforcement learning optimizer) to fine-tune the sampling strategy, transforming the model from a document completer into a question-answerer.
Of course, these parts are basically impossible for individuals to replicate; only large companies can do this.
For a detailed discussion about GPT, Andrej provided a comprehensive overview in his Microsoft Build talk in March 2023, see The State of GPT.