banner
Nagi-ovo

Nagi-ovo

Breezing
github

The Evolution of LLMs (Part 4): WaveNet - The Convolutional Revolution of Sequence Models

The source code repository for this section is here.

In the previous sections, we built a character-level language model using a multilayer perceptron, and now it’s time to make its structure more complex. The goal now is to allow the input sequence to take in more characters than the current 3. Additionally, we don’t want to compress all of them into a single hidden layer to avoid losing too much information. This will result in a deeper model similar to WaveNet.

WaveNet#

Published in 2016, it is essentially a type of language model, but instead of predicting character-level or word-level sequences, it predicts audio sequences. Fundamentally, the modeling setup is the same—both are autoregressive models trying to predict the next character in the sequence.

Screenshot 2024-03-08 at 15.00.26

The paper uses this tree-like hierarchical structure for prediction, and this section will implement this model.

nn.Module#

Encapsulating the content from the previous section into a class, mimicking the API of nn.Module in PyTorch. This allows us to think of modules like "Linear", "1D Batch Norm", and "Tanh" as LEGO blocks, which we can stack to build a neural network:

class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out), generator=g) / fan_in**0.5
    self.bias = torch.zeros(fan_out) if bias else None
  
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])

The Linear module serves to perform a matrix multiplication during the forward pass.

class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # Parameters trained using backpropagation
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # Buffers for training using "momentum update"
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
  
  def __call__(self, x):
    # Calculate forward pass
    if self.training:
      xmean = x.mean(0, keepdim=True) # Batch mean
      xvar = x.var(0, keepdim=True) # Batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # Normalize data to unit variance
    self.out = self.gamma * xhat + self.beta
    # Update buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

Batch-Norm:

  1. Maintains running mean & variance trained outside of backpropagation
  2. self.training = True, as batch norm behaves differently during training and evaluation, requiring a training flag to track its state
  3. Coupled computation of elements within the batch to control the statistical properties of activations, reducing internal covariate shift
class Tanh:
  def __call__(self, x):
    self.out = torch.tanh(x)
    return self.out
  def parameters(self):
    return []

Instead of setting a local generator g in the previous setup, we set a global random seed directly:

torch.manual_seed(42);

The following content should look familiar, including the embedding table C, and our layer structure:

n_embd = 10 # Dimension of character embedding vectors
n_hidden = 200 # Number of neurons in the hidden layer of the MLP

C = torch.randn((vocab_size, n_embd))
layers = [
	Linear(n_embd * block_size, n_hidden, bias=False), 
	BatchNorm1d(n_hidden), 
	Tanh(),
	Linear(n_hidden, vocab_size),
]

# Initialize parameters
with torch.no_grad():
	layers[-1].weight *= 0.1 # Scale down the last layer (the output layer) to reduce the model's initial confidence in predictions

parameters = [C] + [p for layer in layers for p in layer.parameters()]
'''
List comprehension equivalent to:
for layer in layers:
	for p in layer.parameters():
		p...
'''

print(sum(p.nelement() for p in parameters)) # Total number of parameters
for p in parameters:
  p.requires_grad = True

The optimization training part will not be modified for now; we continue to see that our loss function curve fluctuates significantly, which is due to the batch size of 32 being too small, leading to highly variable predictions in each batch (high noise).

Screenshot 2024-03-08 at 17.05.44

During the evaluation phase, we need to set the training flag of all layers to False (currently only affecting the batch norm layers):

# Set layers to evaluation mode
for layer in layers:
	layer.training = False

We first address the issue with the loss function graph:

lossi is a list containing all losses, and what we need to do now is simply average the values inside to obtain a more representative value.

Let’s review the use of torch.view():

Screenshot 2024-03-08 at 17.25.53

Equivalent to view(5, -1)

This can conveniently unfold values from a list.

torch.tensor(lossi).view(-1, 1000).mean(1)

Screenshot 2024-03-08 at 20.09.18

Now it looks much better, and we can observe that the learning rate reduction has reached a local minimum.

Next, we will also convert the original Embedding and Flattening operations shown below into modules:

emb = C[Xb]
x = emb.view(emb.shape[0], -1)
class Embedding:
  
  def __init__(self, num_embeddings, embedding_dim):
    self.weight = torch.randn((num_embeddings, embedding_dim))
    # Now C becomes the weight of the embedding
    
  def __call__(self, IX):
    self.out = self.weight[IX]
    return self.out
  
  def parameters(self):
    return [self.weight]


class FlattenConsecutive:
    
  def __call__(self, x):
    self.out = x.view(x.shape[0], -1)
    return self.out
  
  def parameters(self):
    return []

In PyTorch, there is also a concept of containers, which is essentially a way to organize layers into lists or dictionaries. One of them is called Sequential, which primarily serves to pass the given input sequentially through all layers:

class Sequential:
  
  def __init__(self, layers):
    self.layers = layers
  
  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    self.out = x
    return self.out
  
  def parameters(self):
    # Get parameters from all layers and flatten them into a list.
    return [p for layer in self.layers for p in layer.parameters()]

Now we have a concept of a Model:

model = Sequential([
  Embedding(vocab_size, n_embd),
  Flatten(),
  Linear(n_embd * block_size, n_hidden, bias=False),
  BatchNorm1d(n_hidden), Tanh(),
  Linear(n_hidden, vocab_size),
])

parameters = model.parameters()
print(sum(p.nelement() for p in parameters)) # Total number of parameters
for p in parameters:
  p.requires_grad = True

Thus, we have achieved further simplification:

# forward pass
  logits = model(Xb)
  loss = F.cross_entropy(logits, Yb) # loss function

# evaluate the loss
  logits = model(x)
  loss = F.cross_entropy(logits, y)

# sample from the model
  # forward pass the neural net 
  logits = model(torch.tensor([context]))
  probs = F.softmax(logits, dim=1)

Implementing Layered Structure#

We do not want to compress all information into a single layer in one step as the current model does; we want to gradually fuse information into the network, similar to how WaveNet predicts the next character in the sequence by merging two characters into a dual-character representation and then combining them into smaller blocks of four character-level representations.

Screenshot 2024-03-08 at 15.00.26

In the WaveNet example, this image visualizes the "Dilated causal convolution layer"; we don’t need to worry about the specifics, just focus on the core idea of “Progressive fusion”.

Increasing the context input, processing these 8 input characters in a tree structure:

# block_size = 3
# train 2.0677597522735596; val 2.1055991649627686
block_size = 8

Simply expanding the context length has resulted in performance improvement:

Screenshot 2024-03-08 at 20.49.15

To clarify what we are doing, let’s observe the tensor shapes as they pass through each layer:

Screenshot 2024-03-08 at 21.02.13

Inputting 4 random numbers, the shape in the model is 4x8 (block_size=8).

  1. After the first layer (embedding), we get an output of 4x8x10, meaning our embedding table has a 10-dimensional vector to learn for each character;
  2. After the second layer (flatten), as mentioned earlier, it becomes 4x80, where this layer stretches the 10-dimensional embeddings of these 8 characters into a long row, like a concatenation operation.
  3. The third layer (linear) creates 200 channels from this 80 through matrix multiplication.

To summarize the work done by the Embedding layer:

This answer explains it very well:
1. Converts a sparse matrix into a dense matrix through linear transformation (lookup).
2. This dense matrix uses N features to represent all words. The dense matrix essentially represents the relationship coefficients between words and features, which inherently contains a lot of internal relationships between words.
3. The weight parameters between them are represented by the parameters learned from the embedding layer. During the backpropagation optimization process in the neural network, these parameters are continuously updated and optimized.

The linear layer accepts input X during the forward pass, multiplies it by the weights, and optionally adds a bias:

def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # note: kaiming init
    self.bias = torch.zeros(fan_out) if bias else None

Here, the weights are two-dimensional, and the bias is one-dimensional.

Based on the input and output shapes, the internal structure of this linear layer looks like this:

(torch.randn(4, 80) @ torch.randn(80, 200) + torch.randn(200)).shape

The output is 4x200, and the bias added here follows broadcasting semantics.

Additionally, the matrix multiplication operator in PyTorch is very powerful, supporting the input of high-dimensional tensors, where the matrix multiplication only operates on the last dimension, while all other dimensions are treated as batch dimensions.

Screenshot 2024-03-09 at 16.53.25

This is very beneficial for what we want to do next: parallel batch processing. We do not want to input 80 numbers at once; instead, we want two characters fused together in the first layer, meaning we only want to input 20 numbers, as shown below:

# (1 2) (3 4) (5 6) (7 8)

(torch.randn(4, 4, 20) @ torch.randn(20, 200) + torch.randn(200)).shape

This results in four groups of bigrams, where each bigram group consists of 10-dimensional vectors.

To achieve such a structure, Python has a convenient method to extract even and odd parts from a list:

Screenshot 2024-03-09 at 17.04.08

e = torch.randn(4, 8, 10)
torch.cat([e[:, ::2, :], e[:, 1::2, :]], dim=2)
# torch.Size([4, 4, 20])

This explicitly extracts the even and odd parts and concatenates these two 4x4x10 parts together.

Screenshot 2024-03-09 at 17.10.43

The powerful view() can also accomplish equivalent work.

Now let's improve our Flatten layer by creating a constructor that retrieves the number of consecutive elements we want to concatenate in the last dimension, essentially flattening n consecutive elements and placing them in the last dimension.

class FlattenConsecutive:
  
  def __init__(self, n):
    self.n = n
    
  def __call__(self, x):
    B, T, C = x.shape
    x = x.view(B, T//self.n, C*self.n)
    if x.shape[1] == 1:
      x = x.squeeze(1)
    self.out = x
    return self.out
  
  def parameters(self):
    return []
  • B: Batch size, representing the number of samples in the batch.
  • T: Time steps, indicating the number of elements in the sequence, i.e., the length of the sequence.
  • C: Channels or Features, representing the number of features in the data at each time step.
  1. Input tensor: The input x is a three-dimensional tensor with shape (B, T, C).

  2. Flattening operation: By calling x.view(B, T//self.n, C*self.n), this class merges consecutive time steps from the original data. Here, self.n indicates the number of time steps to merge. The result is that every n consecutive time steps are merged into a wider feature vector. Thus, the time dimension T is reduced by a factor of n, while the feature dimension C increases by a factor of n. The new shape becomes (B, T//n, C*n), so each new time step contains information from the original n time steps.

  3. Removing single time step dimension: If the merged time step length is 1, i.e., x.shape[1] == 1, the dimension is removed using x.squeeze(1), which is the situation we faced with two-dimensional vectors.

code

After modifications, we check the shapes of the intermediate layers:

Screenshot 2024-03-09 at 18.09.42

We want to maintain the mean and variance of only 68 channels in batch norm, rather than 32x4 dimensions, so we change the existing implementation of BatchNorm1D:

class BatchNorm1d:
  
  def __call__(self, x):
    # Calculate the forward pass
    if self.training:
      if x.ndim == 2:
        dim = 0
      elif x.ndim == 3:
        dim = (0,1) # torch.mean() can accept a tuple, meaning multiple dimensions for dim
        
      xmean = x.mean(dim, keepdim=True) # Batch mean
      xvar = x.var(dim, keepdim=True) # Batch variance

Now running_mean.shape is [1, 1, 68].

Expanding the Neural Network#

With the completion of the above improvements, we can further enhance performance by increasing the size of the network.

n_embd = 24    # Dimension of embedding vectors
n_hidden = 128 # Number of neurons in the hidden layer of the MLP 

The total number of parameters has now reached 76,579, and performance has also surpassed the threshold of 2.0:

Screenshot 2024-03-09 at 21.45.20

So far, the time required to train the neural network has increased significantly. Although performance has improved, we are still uncertain about the correct settings for hyperparameters like learning rate, merely debugging and modifying while watching the training loss.

Convolution#

In this section, we implemented the main architecture of WaveNet, but we have not yet implemented the specific forward pass involved, which includes a more complex linear layer: the gated linear layer, as well as residual connections and skip connections.

Screenshot 2024-03-09 at 21.52.42

Here, we will briefly understand how our implemented tree structure relates to the convolutional neural network used in the WaveNet paper.

Essentially, we use convolution here to improve efficiency. Convolution allows us to slide the model over the input sequence, enabling the for-loop (referring to the sliding and computation of the convolution kernel) to be executed in the CUDA kernel.

Screenshot 2024-03-08 at 15.00.26

We only implemented the single black structure shown in the diagram and obtained an output, but convolution allows you to place the input sequence over this black structure, computing all the orange outputs simultaneously like a linear filter.

The reasons for the efficiency improvement are as follows:

  1. The for-loop is executed in the CUDA core;
  2. Variables are reused, for example, a white point in the second layer serves as both a left child of a white point in the third layer and another white point's right child; this node and its value are used twice.

Summary#

After this section, the torch.nn module has been unlocked, and we will transition to using it for model implementation in the future.

Reflecting on the work done in this section, much time was spent trying to get the shapes of each layer correct. Therefore, Andrej often performed shape debugging in Jupyter Notebook, and once satisfied, he would copy it to VSCode.

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.